diff --git a/config/external/moz.build b/config/external/moz.build index 4e9888f36503..5096a1ca00e5 100644 --- a/config/external/moz.build +++ b/config/external/moz.build @@ -55,6 +55,9 @@ if CONFIG["CPU_ARCH"] == "arm": if CONFIG["MOZ_FFVPX"]: external_dirs += ["media/ffvpx"] +if CONFIG["MOZ_JXL"]: + external_dirs += ["media/libjxl", "media/highway"] + external_dirs += [ "media/kiss_fft", "media/libcubeb", diff --git a/media/highway/moz.build b/media/highway/moz.build new file mode 100644 index 000000000000..72b86ca34484 --- /dev/null +++ b/media/highway/moz.build @@ -0,0 +1,41 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +LOCAL_INCLUDES += [ + "/third_party/highway/", +] + +SOURCES += [ + "/third_party/highway/contrib/image/image.cc", + "/third_party/highway/hwy/aligned_allocator.cc", + "/third_party/highway/hwy/targets.cc", +] + +EXPORTS.hwy += [ + "/third_party/highway/hwy/aligned_allocator.h", + "/third_party/highway/hwy/base.h", + "/third_party/highway/hwy/cache_control.h", + "/third_party/highway/hwy/foreach_target.h", + "/third_party/highway/hwy/highway.h", + "/third_party/highway/hwy/targets.h", +] + +EXPORTS.hwy.ops += [ + "/third_party/highway/hwy/ops/arm_neon-inl.h", + "/third_party/highway/hwy/ops/rvv-inl.h", + "/third_party/highway/hwy/ops/scalar-inl.h", + "/third_party/highway/hwy/ops/set_macros-inl.h", + "/third_party/highway/hwy/ops/shared-inl.h", + "/third_party/highway/hwy/ops/wasm_128-inl.h", + "/third_party/highway/hwy/ops/x86_128-inl.h", + "/third_party/highway/hwy/ops/x86_256-inl.h", + "/third_party/highway/hwy/ops/x86_512-inl.h", +] + +FINAL_LIBRARY = "gkmedias" + +# We allow warnings for third-party code that can be updated from upstream. +AllowCompilerWarnings() diff --git a/media/highway/moz.yaml b/media/highway/moz.yaml new file mode 100644 index 000000000000..dce13ab1c74d --- /dev/null +++ b/media/highway/moz.yaml @@ -0,0 +1,43 @@ +# Version of this schema +schema: 1 + +bugzilla: + # Bugzilla product and component for this directory and subdirectories + product: Core + component: "ImageLib" + +# Document the source of externally hosted code +origin: + + # Short name of the package/library + name: highway + + description: Performance-portable, length-agnostic SIMD with runtime dispatch + + # Full URL for the package's homepage/etc + # Usually different from repository url + url: https://github.com/google/highway + + # Human-readable identifier for this version/release + # Generally "version NNN", "tag SSS", "bookmark SSS" + release: commit ca1a57c342cd815053abfcffa29b44eaead4f20b (2021-04-15T17:54:35Z). + + # Revision to pull in + # Must be a long or short commit SHA (long preferred) + revision: ca1a57c342cd815053abfcffa29b44eaead4f20b + + # The package's license, where possible using the mnemonic from + # https://spdx.org/licenses/ + # Multiple licenses can be specified (as a YAML list) + # A "LICENSE" file must exist containing the full license text + license: Apache-2.0 + + license-file: LICENSE + +vendoring: + url: https://github.com/google/highway.git + source-hosting: github + vendor-directory: third_party/jpeg-xl/third_party/highway + + exclude: + - g3doc/ diff --git a/media/libjxl/include/jxl/jxl_export.h b/media/libjxl/include/jxl/jxl_export.h new file mode 100644 index 000000000000..295e45974c9d --- /dev/null +++ b/media/libjxl/include/jxl/jxl_export.h @@ -0,0 +1,12 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef JXL_EXPORT_H +#define JXL_EXPORT_H + +#define JXL_EXPORT + +#endif /* JXL_EXPORT_H */ diff --git a/media/libjxl/include/jxl/jxl_threads_export.h b/media/libjxl/include/jxl/jxl_threads_export.h new file mode 100644 index 000000000000..b08aabe76084 --- /dev/null +++ b/media/libjxl/include/jxl/jxl_threads_export.h @@ -0,0 +1,12 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef JXL_THREADS_EXPORT_H +#define JXL_THREADS_EXPORT_H + +#define JXL_THREADS_EXPORT + +#endif /* JXL_THREADS_EXPORT_H */ diff --git a/media/libjxl/moz.build b/media/libjxl/moz.build new file mode 100644 index 000000000000..a3ed5b9cc94a --- /dev/null +++ b/media/libjxl/moz.build @@ -0,0 +1,114 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +LOCAL_INCLUDES += [ + "./include/", + "/third_party/jpeg-xl/", + "/third_party/jpeg-xl/lib/include/", +] + +SOURCES += [ + "/third_party/jpeg-xl/lib/jxl/ac_strategy.cc", + "/third_party/jpeg-xl/lib/jxl/alpha.cc", + "/third_party/jpeg-xl/lib/jxl/ans_common.cc", + "/third_party/jpeg-xl/lib/jxl/aux_out.cc", + "/third_party/jpeg-xl/lib/jxl/base/cache_aligned.cc", + "/third_party/jpeg-xl/lib/jxl/base/data_parallel.cc", + "/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.cc", + "/third_party/jpeg-xl/lib/jxl/base/padded_bytes.cc", + "/third_party/jpeg-xl/lib/jxl/base/status.cc", + "/third_party/jpeg-xl/lib/jxl/base/time.cc", + "/third_party/jpeg-xl/lib/jxl/blending.cc", + "/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc", + "/third_party/jpeg-xl/lib/jxl/coeff_order.cc", + "/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc", + "/third_party/jpeg-xl/lib/jxl/color_management.cc", + "/third_party/jpeg-xl/lib/jxl/compressed_dc.cc", + "/third_party/jpeg-xl/lib/jxl/convolve.cc", + "/third_party/jpeg-xl/lib/jxl/dct_scales.cc", + "/third_party/jpeg-xl/lib/jxl/dec_ans.cc", + "/third_party/jpeg-xl/lib/jxl/dec_context_map.cc", + "/third_party/jpeg-xl/lib/jxl/dec_external_image.cc", + "/third_party/jpeg-xl/lib/jxl/dec_frame.cc", + "/third_party/jpeg-xl/lib/jxl/dec_group.cc", + "/third_party/jpeg-xl/lib/jxl/dec_group_border.cc", + "/third_party/jpeg-xl/lib/jxl/dec_huffman.cc", + "/third_party/jpeg-xl/lib/jxl/dec_modular.cc", + "/third_party/jpeg-xl/lib/jxl/dec_noise.cc", + "/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc", + "/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc", + "/third_party/jpeg-xl/lib/jxl/dec_upsample.cc", + "/third_party/jpeg-xl/lib/jxl/dec_xyb.cc", + "/third_party/jpeg-xl/lib/jxl/decode.cc", + "/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc", + "/third_party/jpeg-xl/lib/jxl/entropy_coder.cc", + "/third_party/jpeg-xl/lib/jxl/epf.cc", + "/third_party/jpeg-xl/lib/jxl/fields.cc", + "/third_party/jpeg-xl/lib/jxl/filters.cc", + "/third_party/jpeg-xl/lib/jxl/frame_header.cc", + "/third_party/jpeg-xl/lib/jxl/gauss_blur.cc", + "/third_party/jpeg-xl/lib/jxl/headers.cc", + "/third_party/jpeg-xl/lib/jxl/huffman_table.cc", + "/third_party/jpeg-xl/lib/jxl/icc_codec.cc", + "/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc", + "/third_party/jpeg-xl/lib/jxl/image.cc", + "/third_party/jpeg-xl/lib/jxl/image_bundle.cc", + "/third_party/jpeg-xl/lib/jxl/image_metadata.cc", + "/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc", + "/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc", + "/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc", + "/third_party/jpeg-xl/lib/jxl/loop_filter.cc", + "/third_party/jpeg-xl/lib/jxl/luminance.cc", + "/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc", + "/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc", + "/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc", + "/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc", + "/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc", + "/third_party/jpeg-xl/lib/jxl/opsin_params.cc", + "/third_party/jpeg-xl/lib/jxl/passes_state.cc", + "/third_party/jpeg-xl/lib/jxl/quant_weights.cc", + "/third_party/jpeg-xl/lib/jxl/quantizer.cc", + "/third_party/jpeg-xl/lib/jxl/splines.cc", + "/third_party/jpeg-xl/lib/jxl/toc.cc", +] + +SOURCES += [ + "/third_party/jpeg-xl/lib/threads/thread_parallel_runner.cc", + "/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.cc", +] + +DEFINES["JPEGXL_MAJOR_VERSION"] = "0" +DEFINES["JPEGXL_MINOR_VERSION"] = "0" +DEFINES["JPEGXL_PATCH_VERSION"] = "0" + +EXPORTS.jxl += [ + "./include/jxl/jxl_export.h", + "./include/jxl/jxl_threads_export.h", + "/third_party/jpeg-xl/lib/include/jxl/butteraugli.h", + "/third_party/jpeg-xl/lib/include/jxl/butteraugli_cxx.h", + "/third_party/jpeg-xl/lib/include/jxl/codestream_header.h", + "/third_party/jpeg-xl/lib/include/jxl/color_encoding.h", + "/third_party/jpeg-xl/lib/include/jxl/decode.h", + "/third_party/jpeg-xl/lib/include/jxl/decode_cxx.h", + "/third_party/jpeg-xl/lib/include/jxl/encode.h", + "/third_party/jpeg-xl/lib/include/jxl/encode_cxx.h", + "/third_party/jpeg-xl/lib/include/jxl/memory_manager.h", + "/third_party/jpeg-xl/lib/include/jxl/parallel_runner.h", + "/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner.h", + "/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner_cxx.h", + "/third_party/jpeg-xl/lib/include/jxl/types.h", +] + +FINAL_LIBRARY = "gkmedias" + +# We allow warnings for third-party code that can be updated from upstream. +AllowCompilerWarnings() + +# Clang 5.0 has a compiler bug that prevents build in c++17 +# See https://gitlab.com/wg1/jpeg-xl/-/issues/227 +# This should be okay since we are using the C API. +if CONFIG["CC_TYPE"] == "clang": + CXXFLAGS += ["-std=c++11"] diff --git a/media/libjxl/moz.yaml b/media/libjxl/moz.yaml new file mode 100644 index 000000000000..289b2cbf9048 --- /dev/null +++ b/media/libjxl/moz.yaml @@ -0,0 +1,53 @@ +# Version of this schema +schema: 1 + +bugzilla: + # Bugzilla product and component for this directory and subdirectories + product: Core + component: "ImageLib" + +# Document the source of externally hosted code +origin: + + # Short name of the package/library + name: jpeg-xl + + description: JPEG XL image format reference implementation + + # Full URL for the package's homepage/etc + # Usually different from repository url + url: https://gitlab.com/wg1/jpeg-xl + + # Human-readable identifier for this version/release + # Generally "version NNN", "tag SSS", "bookmark SSS" + release: commit 9a8f5195e4d1c45112fd65f184ebe115f4163ba2 (2021-05-04T13:15:00.000+02:00). + + # Revision to pull in + # Must be a long or short commit SHA (long preferred) + # NOTE(krosylight): Update highway together when updating this! + revision: 9a8f5195e4d1c45112fd65f184ebe115f4163ba2 + + # The package's license, where possible using the mnemonic from + # https://spdx.org/licenses/ + # Multiple licenses can be specified (as a YAML list) + # A "LICENSE" file must exist containing the full license text + license: Apache-2.0 + + license-file: LICENSE + +updatebot: + maintainer-phab: saschanaz + maintainer-bz: krosylight@mozilla.com + tasks: + - type: vendoring + enabled: True + +vendoring: + url: https://gitlab.com/wg1/jpeg-xl.git + source-hosting: gitlab + vendor-directory: third_party/jpeg-xl + + exclude: + - doc/ + - third_party/testdata/ + - tools/ diff --git a/third_party/highway/CMakeLists.txt b/third_party/highway/CMakeLists.txt new file mode 100644 index 000000000000..a8bfb0a13354 --- /dev/null +++ b/third_party/highway/CMakeLists.txt @@ -0,0 +1,300 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.10) + +# Set PIE flags for POSITION_INDEPENDENT_CODE targets, added in 3.14. +if(POLICY CMP0083) + cmake_policy(SET CMP0083 NEW) +endif() + +project(hwy VERSION 0.1) + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +# Enabled PIE binaries by default if supported. +include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) +if(CHECK_PIE_SUPPORTED) + check_pie_supported(LANGUAGES CXX) + if(CMAKE_CXX_LINK_PIE_SUPPORTED) + set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) + endif() +endif() + +include(GNUInstallDirs) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RelWithDebInfo) +endif() + +include(CheckCXXSourceCompiles) +check_cxx_source_compiles( + "int main() { + #if !defined(__EMSCRIPTEN__) + static_assert(false, \"__EMSCRIPTEN__ is not defined\"); + #endif + return 0; + }" + HWY_EMSCRIPTEN +) + +set(HWY_SOURCES + contrib/image/image.cc + contrib/image/image.h + contrib/math/math-inl.h + hwy/aligned_allocator.cc + hwy/aligned_allocator.h + hwy/base.h + hwy/cache_control.h + hwy/foreach_target.h + hwy/highway.h + hwy/nanobenchmark.cc + hwy/nanobenchmark.h + hwy/ops/arm_neon-inl.h + hwy/ops/scalar-inl.h + hwy/ops/set_macros-inl.h + hwy/ops/shared-inl.h + hwy/ops/wasm_128-inl.h + hwy/ops/x86_128-inl.h + hwy/ops/x86_256-inl.h + hwy/ops/x86_512-inl.h + hwy/targets.cc + hwy/targets.h + hwy/tests/test_util-inl.h +) + +if (MSVC) + # TODO(janwas): add flags +else() + set(HWY_FLAGS + # Avoid changing binaries based on the current time and date. + -Wno-builtin-macro-redefined + -D__DATE__="redacted" + -D__TIMESTAMP__="redacted" + -D__TIME__="redacted" + + # Optimizations + -fmerge-all-constants + + # Warnings + -Wall + -Wextra + -Wformat-security + -Wno-unused-function + -Wnon-virtual-dtor + -Woverloaded-virtual + -Wvla + ) + + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + -Wc++2a-extensions + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + + -fno-cxx-exceptions + -fno-slp-vectorize + -fno-vectorize + + # Use color in messages + -fdiagnostics-show-option -fcolor-diagnostics + ) + endif() + + if (WIN32) + list(APPEND HWY_FLAGS + -Wno-c++98-compat-pedantic + -Wno-cast-align + -Wno-double-promotion + -Wno-float-equal + -Wno-format-nonliteral + -Wno-global-constructors + -Wno-language-extension-token + -Wno-missing-prototypes + -Wno-shadow + -Wno-shadow-field-in-constructor + -Wno-sign-conversion + -Wno-unused-member-function + -Wno-unused-template + -Wno-used-but-marked-unused + -Wno-zero-as-null-pointer-constant + ) + else() + list(APPEND HWY_FLAGS + -fmath-errno + -fno-exceptions + ) + endif() +endif() + +add_library(hwy STATIC ${HWY_SOURCES}) +target_compile_options(hwy PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(hwy PUBLIC ${CMAKE_CURRENT_LIST_DIR}) + +# -------------------------------------------------------- install library +install(TARGETS hwy + DESTINATION "${CMAKE_INSTALL_LIBDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() + +# Add a pkg-config file for libhwy and the test library. +set(HWY_LIBRARY_VERSION "${CMAKE_PROJECT_VERSION}") +foreach (pc libhwy.pc libhwy-test.pc) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/${pc}.in" "${pc}" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${pc}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") +endforeach() + +# -------------------------------------------------------- hwy_list_targets +# Generate a tool to print the compiled-in targets as defined by the current +# flags. This tool will print to stderr at build time, after building hwy. +add_executable(hwy_list_targets hwy/tests/list_targets.cc) +target_compile_options(hwy_list_targets PRIVATE ${HWY_FLAGS}) +target_include_directories(hwy_list_targets PRIVATE + $) +# TARGET_FILE always returns the path to executable +# Naked target also not always could be run (due to the lack of '.\' prefix) +# Thus effective command to run should contain the full path +# and emulator prefix (if any). +add_custom_command(TARGET hwy POST_BUILD + COMMAND ${CMAKE_CROSSCOMPILING_EMULATOR} $ || (exit 0)) + +# -------------------------------------------------------- Examples + +# Avoids mismatch between GTest's static CRT and our dynamic. +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Programming exercise with integrated benchmark +add_executable(hwy_benchmark hwy/examples/benchmark.cc) +target_sources(hwy_benchmark PRIVATE + hwy/nanobenchmark.cc + hwy/nanobenchmark.h) +# Try adding either -DHWY_COMPILE_ONLY_SCALAR or -DHWY_COMPILE_ONLY_STATIC to +# observe the difference in targets printed. +target_compile_options(hwy_benchmark PRIVATE ${HWY_FLAGS}) +target_link_libraries(hwy_benchmark hwy) +set_target_properties(hwy_benchmark + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") + +# -------------------------------------------------------- Tests + +include(CTest) + +if(BUILD_TESTING) +enable_testing() +include(GoogleTest) + +set(HWY_SYSTEM_GTEST OFF CACHE BOOL "Use pre-installed googletest?") +if(HWY_SYSTEM_GTEST) +find_package(GTest REQUIRED) +else() +# Download and unpack googletest at configure time +configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src + ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + EXCLUDE_FROM_ALL) + +# The gtest/gtest_main targets carry header search path +# dependencies automatically when using CMake 2.8.11 or +# later. Otherwise we have to add them here ourselves. +if (CMAKE_VERSION VERSION_LESS 2.8.11) + include_directories("${gtest_SOURCE_DIR}/include") +endif() +endif() # HWY_SYSTEM_GTEST + +set(HWY_TEST_FILES + contrib/image/image_test.cc + # contrib/math/math_test.cc + hwy/aligned_allocator_test.cc + hwy/base_test.cc + hwy/highway_test.cc + hwy/targets_test.cc + hwy/examples/skeleton_test.cc + hwy/tests/arithmetic_test.cc + hwy/tests/combine_test.cc + hwy/tests/compare_test.cc + hwy/tests/convert_test.cc + hwy/tests/logical_test.cc + hwy/tests/memory_test.cc + hwy/tests/swizzle_test.cc + hwy/tests/test_util_test.cc +) + +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +foreach (TESTFILE IN LISTS HWY_TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + add_executable(${TESTNAME} ${TESTFILE}) + target_compile_options(${TESTNAME} PRIVATE ${HWY_FLAGS}) + + if(HWY_SYSTEM_GTEST) + target_link_libraries(${TESTNAME} hwy GTest::GTest GTest::Main) + else() + target_link_libraries(${TESTNAME} hwy gtest gtest_main) + endif() + # Output test targets in the test directory. + set_target_properties(${TESTNAME} PROPERTIES PREFIX "tests/") + + if (HWY_EMSCRIPTEN) + set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "-s SINGLE_FILE=1") + endif() + + if(${CMAKE_VERSION} VERSION_LESS "3.10.3") + gtest_discover_tests(${TESTNAME} TIMEOUT 60) + else () + gtest_discover_tests(${TESTNAME} DISCOVERY_TIMEOUT 60) + endif () +endforeach () + +# The skeleton test uses the skeleton library code. +target_sources(skeleton_test PRIVATE hwy/examples/skeleton.cc) + +endif() # BUILD_TESTING diff --git a/third_party/highway/CMakeLists.txt.in b/third_party/highway/CMakeLists.txt.in new file mode 100644 index 000000000000..f98ccb4ac978 --- /dev/null +++ b/third_party/highway/CMakeLists.txt.in @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 2.8.2) + +project(googletest-download NONE) + +include(ExternalProject) +ExternalProject_Add(googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG master + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) \ No newline at end of file diff --git a/third_party/highway/CONTRIBUTING b/third_party/highway/CONTRIBUTING new file mode 100644 index 000000000000..8b7d4d2537e5 --- /dev/null +++ b/third_party/highway/CONTRIBUTING @@ -0,0 +1,33 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Testing + +This repository is used by JPEG XL, so major API changes will require +coordination. Please get in touch with us beforehand, e.g. by raising an issue. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/third_party/highway/LICENSE b/third_party/highway/LICENSE new file mode 100644 index 000000000000..f49a4e16e68b --- /dev/null +++ b/third_party/highway/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/third_party/highway/Makefile b/third_party/highway/Makefile new file mode 100644 index 000000000000..34aafd095e99 --- /dev/null +++ b/third_party/highway/Makefile @@ -0,0 +1,44 @@ +.DELETE_ON_ERROR: + +OBJS := aligned_allocator.o nanobenchmark.o targets.o +IMAGE_OBJS := image.o +TEST_NAMES := arithmetic_test combine_test compare_test convert_test hwy_test logical_test memory_test swizzle_test +TESTS := $(foreach i, $(TEST_NAMES), bin/$(i)) +ROOT_TEST_NAMES := aligned_allocator_test nanobenchmark_test +ROOT_TESTS := $(foreach i, $(ROOT_TEST_NAMES), bin/$(i)) + +CXXFLAGS += -I. -fmerge-all-constants -std=c++17 -O2 \ + -Wno-builtin-macro-redefined -D__DATE__="redacted" \ + -D__TIMESTAMP__="redacted" -D__TIME__="redacted" \ + -Wall -Wextra -Wformat-security -Wno-unused-function \ + -Wnon-virtual-dtor -Woverloaded-virtual -Wvla + +.PHONY: all +all: $(TESTS) $(ROOT_TESTS) benchmark test + +.PHONY: clean +clean: ; @rm -rf $(OBJS) bin/ benchmark.o + +$(OBJS): %.o: hwy/%.cc + $(CXX) -c $(CXXFLAGS) $< -o $@ + +$(IMAGE_OBJS): %.o: contrib/image/%.cc + $(CXX) -c $(CXXFLAGS) $< -o $@ + +benchmark: $(OBJS) hwy/examples/benchmark.cc + mkdir -p bin && $(CXX) $(CXXFLAGS) $^ -o bin/benchmark + +bin/%: hwy/%.cc $(OBJS) + mkdir -p bin && $(CXX) $(CXXFLAGS) $< $(OBJS) -o $@ -lgtest -lgtest_main -lpthread + +bin/%: hwy/tests/%.cc $(OBJS) + mkdir -p bin && $(CXX) $(CXXFLAGS) $< $(OBJS) -o $@ -lgtest -lgtest_main -lpthread + +bin/%: contrib/image/%.cc $(OBJS) + mkdir -p bin && $(CXX) $(CXXFLAGS) $< $(OBJS) -o $@ -lgtest -lgtest_main -lpthread +bin/%: contrib/math/%.cc $(OBJS) + mkdir -p bin && $(CXX) $(CXXFLAGS) $< $(OBJS) -o $@ -lgtest -lgtest_main -lpthread + +.PHONY: test +test: $(TESTS) $(ROOT_TESTS) + for name in $^; do echo ---------------------$$name && $$name; done diff --git a/third_party/highway/README.md b/third_party/highway/README.md new file mode 100644 index 000000000000..87fac045ff3a --- /dev/null +++ b/third_party/highway/README.md @@ -0,0 +1,361 @@ +# Efficient and performance-portable SIMD + +Highway is a C++ library for SIMD (Single Instruction, Multiple Data), i.e. +applying the same operation to 'lanes'. + +## Why Highway? + +- more portable (same source code) than platform-specific intrinsics, +- works on a wider range of compilers than compiler-specific vector extensions, +- more dependable than autovectorization, +- easier to write/maintain than assembly language, +- supports **runtime dispatch**, +- supports **variable-length vector** architectures. + +## Current status + +Supported targets: scalar, SSE4, AVX2, AVX-512, NEON (ARMv7 and v8), WASM SIMD. +A port to RVV is in progress. + +Version 0.11 is considered stable enough to use in other projects, and is +expected to remain backwards compatible unless serious issues are discovered +while implementing SVE/RVV targets. After these targets are added, Highway will +reach version 1.0. + +Continuous integration tests build with a recent version of Clang (running on +x86 and QEMU for ARM) and MSVC from VS2015 (running on x86). Also periodically +tested on x86 with Clang 7-11 and GCC 8, 9 and 10.2.1. + +The `contrib` directory contains SIMD-related utilities: an image class with +aligned rows, and a math library (16 functions already implemented, mostly +trigonometry). + +## Installation + +This project uses cmake to generate and build. In a Debian-based system you can +install it via: + +```bash +sudo apt install cmake +``` + +Highway's unit tests use [googletest](https://github.com/google/googletest). +By default, Highway's CMake downloads this dependency at configuration time. +You can disable this by setting the `HWY_SYSTEM_GTEST` CMake variable to ON and +installing gtest separately: + +```bash +sudo apt install libgtest-dev +``` + +To build and test the library the standard cmake workflow can be used: + +```bash +mkdir -p build && cd build +cmake .. +make -j && make test +``` + +Or you can run `run_tests.sh` (`run_tests.bat` on Windows). + +To test on all the attainable targets for your platform, use +`cmake .. -DCMAKE_CXX_FLAGS="-DHWY_COMPILE_ALL_ATTAINABLE"`. Otherwise, the +default configuration skips baseline targets (e.g. scalar) that are superseded +by another baseline target. + +## Quick start + +You can use the `benchmark` inside examples/ as a starting point. + +A [quick-reference page](g3doc/quick_reference.md) briefly lists all operations +and their parameters, and the [instruction_matrix][instmtx] indicates the +number of instructions per operation. + +We recommend using full SIMD vectors whenever possible for maximum performance +portability. To obtain them, pass a `HWY_FULL(float)` tag to functions such as +`Zero/Set/Load`. There is also the option of a vector of up to `N` (a power of +two) lanes: `HWY_CAPPED(T, N)`. 128-bit vectors are guaranteed to be available +for lanes of type `T` if `HWY_TARGET != HWY_SCALAR` and `N == 16 / sizeof(T)`. + +Functions using Highway must be inside a namespace `namespace HWY_NAMESPACE {` +(possibly nested in one or more other namespaces defined by the project), and +additionally either prefixed with `HWY_ATTR`, or residing between +`HWY_BEFORE_NAMESPACE()` and `HWY_AFTER_NAMESPACE()`. + +* For static dispatch, `HWY_TARGET` will be the best available target among + `HWY_BASELINE_TARGETS`, i.e. those allowed for use by the compiler (see + [quick-reference](g3doc/quick_reference.md)). Functions inside `HWY_NAMESPACE` + can be called using `HWY_STATIC_DISPATCH(func)(args)` within the same module + they are defined in. You can call the function from other modules by + wrapping it in a regular function and declaring the regular function in a + header. + +* For dynamic dispatch, a table of function pointers is generated via the + `HWY_EXPORT` macro that is used by `HWY_DYNAMIC_DISPATCH(func)(args)` to + call the best function pointer for the current CPU supported targets. A + module is automatically compiled for each target in `HWY_TARGETS` (see + [quick-reference](g3doc/quick_reference.md)) if `HWY_TARGET_INCLUDE` is + defined and foreach_target.h is included. + +## Strip-mining loops + +To vectorize a loop, "strip-mining" transforms it into an outer loop and inner +loop with number of iterations matching the preferred vector width. + +In this section, let `T` denote the element type, `d = HWY_FULL(T)`, `count` the +number of elements to process, and `N = Lanes(d)` the number of lanes in a full +vector. Assume the loop body is given as a function `template void LoopBody(D d, size_t max_n)`. + +Highway offers several ways to express loops where `N` need not divide `count`: + +* Ensure all inputs/outputs are padded. Then the loop is simply + + ``` + for (size_t i = 0; i < count; i += N) LoopBody(d, 0); + ``` + Here, the template parameter and second function argument are not needed. + + This is the preferred option, unless `N` is in the thousands and vector + operations are pipelined with long latencies. This was the case for + supercomputers in the 90s, but nowadays ALUs are cheap and we see most + implementations split vectors into 1, 2 or 4 parts, so there is little cost + to processing entire vectors even if we do not need all their lanes. Indeed + this avoids the (potentially large) cost of predication or partial + loads/stores on older targets, and does not duplicate code. + +* Process whole vectors as above, followed by a scalar loop: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) LoopBody(d, 0); + for (; i < count; ++i) LoopBody(HWY_CAPPED(T, 1)(), 0); + ``` + The template parameter and second function arguments are again not needed. + + This avoids duplicating code, and is reasonable if `count` is large. + Otherwise, multiple iterations may be slower than one `LoopBody` variant + with masking, especially because the `HWY_SCALAR` target selected by + `HWY_CAPPED(T, 1)` is slower for some operations due to workarounds for + undefined behavior in C++. + +* Process whole vectors as above, followed by a single call to a modified + `LoopBody` with masking: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) { + LoopBody(d, 0); + } + if (i < count) { + LoopBody(d, count - i); + } + ``` + Now the template parameter and second function argument can be used inside + `LoopBody` to replace `Load/Store` of full aligned vectors with + `LoadN/StoreN(n)` that affect no more than `1 <= n <= N` aligned elements + (pending implementation). + + This is a good default when it is infeasible to ensure vectors are padded. + In contrast to the scalar loop, only a single final iteration is needed. + +## Design philosophy + +* Performance is important but not the sole consideration. Anyone who goes to + the trouble of using SIMD clearly cares about speed. However, portability, + maintainability and readability also matter, otherwise we would write in + assembly. We aim for performance within 10-20% of a hand-written assembly + implementation on the development platform. + +* The guiding principles of C++ are "pay only for what you use" and "leave no + room for a lower-level language below C++". We apply these by defining a + SIMD API that ensures operation costs are visible, predictable and minimal. + +* Performance portability is important, i.e. the API should be efficient on + all target platforms. Unfortunately, common idioms for one platform can be + inefficient on others. For example: summing lanes horizontally versus + shuffling. Documenting which operations are expensive does not prevent their + use, as evidenced by widespread use of `HADDPS`. Performance acceptance + tests may detect large regressions, but do not help choose the approach + during initial development. Analysis tools can warn about some potential + inefficiencies, but likely not all. We instead provide [a carefully chosen + set of vector types and operations that are efficient on all target + platforms][instmtx] (PPC8, SSE4/AVX2+, ARMv8). + +* Future SIMD hardware features are difficult to predict. For example, AVX2 + came with surprising semantics (almost no interaction between 128-bit + blocks) and AVX-512 added two kinds of predicates (writemask and zeromask). + To ensure the API reflects hardware realities, we suggest a flexible + approach that adds new operations as they become commonly available, with + scalar fallbacks where not supported. + +* Masking is not yet widely supported on current CPUs. It is difficult to + define an interface that provides access to all platform features while + retaining performance portability. The P0214R5 proposal lacks support for + AVX-512/ARM SVE zeromasks. We suggest limiting usage of masks to the + `IfThen[Zero]Else[Zero]` functions until the community has gained more + experience with them. + +* "Width-agnostic" SIMD is more future-proof than user-specified fixed sizes. + For example, valarray-like code can iterate over a 1D array with a + library-specified vector width. This will result in better code when vector + sizes increase, and matches the direction taken by + [ARM SVE](https://alastairreid.github.io/papers/sve-ieee-micro-2017.pdf) and + RiscV V as well as Agner Fog's + [ForwardCom instruction set proposal](https://goo.gl/CFizWu). However, some + applications may require fixed sizes, so we also guarantee support for + 128-bit vectors in each instruction set. + +* The API and its implementation should be usable and efficient with commonly + used compilers, including MSVC. For example, we write `ShiftLeft<3>(v)` + instead of `v << 3` because MSVC 2017 (ARM64) does not propagate the literal + (https://godbolt.org/g/rKx5Ga). Highway requires function-specific + target attributes, supported by GCC 4.9 / Clang 3.9 / MSVC 2015. + +* Efficient and safe runtime dispatch is important. Modules such as image or + video codecs are typically embedded into larger applications such as + browsers, so they cannot require separate binaries for each CPU. Libraries + also cannot predict whether the application already uses AVX2 (and pays the + frequency throttling cost), so this decision must be left to the + application. Using only the lowest-common denominator instructions + sacrifices too much performance. + Therefore, we provide code paths for multiple instruction sets and choose + the most suitable at runtime. To reduce overhead, dispatch should be hoisted + to higher layers instead of checking inside every low-level function. + Highway supports inlining functions in the same file or in *-inl.h headers. + We generate all code paths from the same source to reduce implementation- + and debugging cost. + +* Not every CPU need be supported. For example, pre-SSE4.1 CPUs are + increasingly rare and the AVX instruction set is limited to floating-point + operations. To reduce code size and compile time, we provide specializations + for SSE4, AVX2 and AVX-512 instruction sets on x86, plus a scalar fallback. + +* Access to platform-specific intrinsics is necessary for acceptance in + performance-critical projects. We provide conversions to and from intrinsics + to allow utilizing specialized platform-specific functionality, and simplify + incremental porting of existing code. + +* The core API should be compact and easy to learn. We provide only the few + dozen operations which are necessary and sufficient for most of the 150+ + SIMD applications we examined. + +## Prior API designs + +The author has been writing SIMD code since 2002: first via assembly language, +then intrinsics, later Intel's `F32vec4` wrapper, followed by three generations +of custom vector classes. The first used macros to generate the classes, which +reduces duplication but also readability. The second used templates instead. +The third (used in highwayhash and PIK) added support for AVX2 and runtime +dispatch. The current design (used in JPEG XL) enables code generation for +multiple platforms and/or instruction sets from the same source, and improves +runtime dispatch. + +## Differences versus [P0214R5 proposal](https://goo.gl/zKW4SA) + +1. Adding widely used and portable operations such as `AndNot`, `AverageRound`, + bit-shift by immediates and `IfThenElse`. + +1. Designing the API to avoid or minimize overhead on AVX2/AVX-512 caused by + crossing 128-bit 'block' boundaries. + +1. Avoiding the need for non-native vectors. By contrast, P0214R5's `simd_cast` + returns `fixed_size<>` vectors which are more expensive to access because + they reside on the stack. We can avoid this plus additional overhead on + ARM/AVX2 by defining width-expanding operations as functions of a vector + part, e.g. promoting half a vector of `uint8_t` lanes to one full vector of + `uint16_t`, or demoting full vectors to half vectors with half-width lanes. + +1. Guaranteeing access to the underlying intrinsic vector type. This ensures + all platform-specific capabilities can be used. P0214R5 instead only + 'encourages' implementations to provide access. + +1. Enabling safe runtime dispatch and inlining in the same binary. P0214R5 is + based on the Vc library, which does not provide assistance for linking + multiple instruction sets into the same binary. The Vc documentation + suggests compiling separate executables for each instruction set or using + GCC's ifunc (indirect functions). The latter is compiler-specific and risks + crashes due to ODR violations when compiling the same function with + different compiler flags. We solve this problem via target-specific + namespaces and attributes (see HOWTO section below). We also permit a mix of + static target selection and runtime dispatch for hotspots that may benefit + from newer instruction sets if available. + +1. Using built-in PPC vector types without a wrapper class. This leads to much + better code generation with GCC 6.3: https://godbolt.org/z/pd2PNP. + By contrast, P0214R5 requires a wrapper. We avoid this by using only the + member operators provided by the PPC vectors; all other functions and + typedefs are non-members. 2019-04 update: Clang power64le does not have + this issue, so we simplified get_part(d, v) to GetLane(v). + +1. Omitting inefficient or non-performance-portable operations such as `hmax`, + `operator[]`, and unsupported integer comparisons. Applications can often + replace these operations at lower cost than emulating that exact behavior. + +1. Omitting `long double` types: these are not commonly available in hardware. + +1. Ensuring signed integer overflow has well-defined semantics (wraparound). + +1. Simple header-only implementation and less than a tenth of the size of the + Vc library from which P0214 was derived (98,000 lines in + https://github.com/VcDevel/Vc according to the gloc Chrome extension). + +1. Avoiding hidden performance costs. P0214R5 allows implicit conversions from + integer to float, which costs 3-4 cycles on x86. We make these conversions + explicit to ensure their cost is visible. + +## Other related work + +* [Neat SIMD](http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7568423) + adopts a similar approach with interchangeable vector/scalar types and + a compact interface. It allows access to the underlying intrinsics, but + does not appear to be designed for other platforms than x86. + +* UME::SIMD ([code](https://goo.gl/yPeVZx), [paper](https://goo.gl/2xpZrk)) + also adopts an explicit vectorization model with vector classes. + However, it exposes the union of all platform capabilities, which makes the + API harder to learn (209-page spec) and implement (the estimated LOC count + is [500K](https://goo.gl/1THFRi)). The API is less performance-portable + because it allows applications to use operations that are inefficient on + other platforms. + +* Inastemp ([code](https://goo.gl/hg3USM), [paper](https://goo.gl/YcTU7S)) + is a vector library for scientific computing with some innovative features: + automatic FLOPS counting, and "if/else branches" using lambda functions. + It supports IBM Power8, but only provides float and double types. + +## Overloaded function API + +Most C++ vector APIs rely on class templates. However, the ARM SVE vector +type is sizeless and cannot be wrapped in a class. We instead rely on overloaded +functions. Overloading based on vector types is also undesirable because SVE +vectors cannot be default-constructed. We instead use a dedicated 'descriptor' +type `Simd` for overloading, abbreviated to `D` for template arguments and +`d` in lvalues. + +Note that generic function templates are possible (see highway.h). + +## Masks + +AVX-512 introduced a major change to the SIMD interface: special mask registers +(one bit per lane) that serve as predicates. It would be expensive to force +AVX-512 implementations to conform to the prior model of full vectors with lanes +set to all one or all zero bits. We instead provide a Mask type that emulates +a subset of this functionality on other platforms at zero cost. + +Masks are returned by comparisons and `TestBit`; they serve as the input to +`IfThen*`. We provide conversions between masks and vector lanes. For clarity +and safety, we use FF..FF as the definition of true. To also benefit from +x86 instructions that only require the sign bit of floating-point inputs to be +set, we provide a special `ZeroIfNegative` function. + +## Additional resources + +* [Highway introduction (slides)][intro] +* [Overview of instructions per operation on different architectures][instmtx] + +[intro]: g3doc/highway_intro.pdf +[instmtx]: g3doc/instruction_matrix.pdf + +This is not an officially supported Google product. +Contact: janwas@google.com diff --git a/third_party/highway/contrib/image/image.cc b/third_party/highway/contrib/image/image.cc new file mode 100644 index 000000000000..f77ad8c3a6bb --- /dev/null +++ b/third_party/highway/contrib/image/image.cc @@ -0,0 +1,145 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "contrib/image/image.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "contrib/image/image.cc" + +#include // swap + +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return Lanes(HWY_FULL(uint8_t)()); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetVectorSize); // Local function. +} // namespace + +size_t ImageBase::VectorSize() { + // Do not cache result - must return the current value, which may be greater + // than the first call if it was subject to DisableTargets! + return HWY_DYNAMIC_DISPATCH(GetVectorSize)(); +} + +size_t ImageBase::BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 1) { + HWY_DASSERT(vec_size >= sizeof_t); + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = std::max(vec_size, HWY_ALIGNMENT); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % HWY_ALIGNMENT == 0) { + bytes_per_row += align; + } + + HWY_DASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + HWY_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateAligned(bytes_per_row_ * ysize); + HWY_ASSERT(bytes_.get() != nullptr); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t bytes_per_row, void* const aligned) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_per_row_(bytes_per_row), + bytes_(static_cast(aligned), + AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + const size_t vec_size = VectorSize(); + HWY_ASSERT(bytes_per_row % vec_size == 0); + HWY_ASSERT(reinterpret_cast(aligned) % vec_size == 0); +} + +void ImageBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if defined(MEMORY_SANITIZER) || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); // Bytes, independent of sizeof_t! + if (vec_size == 1) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* HWY_RESTRICT row = static_cast(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + memset(row, 0, initialize_size); +#else + memset(row + valid_size, 0, initialize_size - valid_size); +#endif // clang6 + } +#else + (void)sizeof_t; + (void)padding; +#endif // MEMORY_SANITIZER +} + +void ImageBase::Swap(ImageBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/contrib/image/image.h b/third_party/highway/contrib/image/image.h new file mode 100644 index 000000000000..e939778a0e79 --- /dev/null +++ b/third_party/highway/contrib/image/image.h @@ -0,0 +1,468 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_CONTRIB_IMAGE_IMAGE_H_ +#define HIGHWAY_CONTRIB_IMAGE_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include +#include +#include + +#include +#include // std::move + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace hwy { + +// Type-independent parts of Image<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct ImageBase { + // Returns required alignment in bytes for externally allocated memory. + static size_t VectorSize(); + + // Returns distance [bytes] between the start of two consecutive rows, a + // multiple of VectorSize but NOT kAlias (see implementation). + static size_t BytesPerRow(const size_t xsize, const size_t sizeof_t); + + // No allocation (for output params or unused images) + ImageBase() + : xsize_(0), + ysize_(0), + bytes_per_row_(0), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {} + + // Allocates memory (this is the common case) + ImageBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // References but does not take ownership of external memory. Useful for + // interoperability with other libraries. `aligned` must be aligned to a + // multiple of VectorSize() and `bytes_per_row` must also be a multiple of + // VectorSize() or preferably equal to BytesPerRow(). + ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + ImageBase(const ImageBase& other) = delete; + ImageBase& operator=(const ImageBase& other) = delete; + + // Move constructor (required for returning Image from function) + ImageBase(ImageBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + ImageBase& operator=(ImageBase&& other) noexcept = default; + + void Swap(ImageBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. Caller is responsible + // for ensuring xsize/ysize are <= the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + HWY_INLINE size_t xsize() const { return xsize_; } + HWY_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + HWY_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + HWY_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidRow(const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (y >= ysize_) { + HWY_ABORT("Row(%zu) >= %u\n", y, ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return HWY_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + size_t bytes_per_row_; // Includes padding. + AlignedFreeUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Image : public ImageBase { + public: + using T = ComponentType; + + Image() = default; + Image(const size_t xsize, const size_t ysize) + : ImageBase(xsize, ysize, sizeof(T)) {} + Image(const size_t xsize, const size_t ysize, size_t bytes_per_row, + void* aligned) + : ImageBase(xsize, ysize, bytes_per_row, aligned) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + HWY_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE const T* ConstRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns pointer to non-const. This allows passing const Image* parameters + // when the callee is only supposed to fill the pixels, as opposed to + // allocating or resizing the image. + HWY_INLINE T* MutableRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE T* MutableRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageF = Image; + +// A bundle of 3 same-sized images. To fill an existing Image3 using +// single-channel producers, we also need access to each const Image*. Const +// prevents breaking the same-size invariant, while still allowing pixels to be +// changed via MutableRow. +template +class Image3 { + public: + using T = ComponentType; + using ImageT = Image; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{ImageT(), ImageT(), ImageT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{ImageT(xsize, ysize), ImageT(xsize, ysize), + ImageT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) { + if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) { + HWY_ABORT("Not same size: %zu x %zu, %zu x %zu, %zu x %zu\n", + plane0.xsize(), plane0.ysize(), plane1.xsize(), plane1.ysize(), + plane2.xsize(), plane2.ysize()); + } + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (ImageT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + HWY_INLINE size_t xsize() const { return planes_[0].xsize(); } + HWY_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (c >= kNumPlanes || y >= ysize()) { + HWY_ABORT("PlaneRow(%zu, %zu) >= %zu\n", c, y, ysize()); + } +#endif + // Use the first plane's stride because the compiler might not realize they + // are all equal. Thus we only need a single multiplication for all planes. + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast( + HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT)); + } + + private: + ImageT planes_[kNumPlanes]; +}; + +using Image3F = Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions. Can compare size via SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image. + template + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max) { + return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_, + y0_ + ysize_); + } + + template + const T* ConstRow(const Image* image, size_t y) const { + return image->ConstRow(y + y0_) + x0_; + } + + template + T* MutableRow(const Image* image, size_t y) const { + return image->MutableRow(y + y0_) + x0_; + } + + template + const T* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + template + T* MutablePlaneRow(Image3* image, const size_t c, size_t y) const { + return image->MutablePlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Image or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Works for any image-like input type(s). +template +HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x, + const int64_t xsize) { + HWY_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return static_cast(x); +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const { + return Mirror(coord, static_cast(size)); + } +}; + +// Returns the same coordinate, for when we know "coord" is already valid (e.g. +// interior of an image). +struct WrapUnchanged { + HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const { + return static_cast(coord); + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const View& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const HWY_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const HWY_RESTRICT first_row_; + const float* const HWY_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +} // namespace hwy + +#endif // HIGHWAY_CONTRIB_IMAGE_IMAGE_H_ diff --git a/third_party/highway/contrib/image/image_test.cc b/third_party/highway/contrib/image/image_test.cc new file mode 100644 index 000000000000..0b87260a4d64 --- /dev/null +++ b/third_party/highway/contrib/image/image_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "contrib/image/image.h" + +#include + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "contrib/image/image_test.cc" +#include "hwy/foreach_target.h" + +#include +#include +#include + +#include +#include + +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Ensure we can always write full aligned vectors. +struct TestAlignedT { + template + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution dist(0, 16); + const HWY_FULL(T) d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 64; ++xsize) { + Image img(xsize, ysize); + + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto values = Iota(d, dist(rng)); + Store(values, d, row + x); + } + } + + // Sanity check to prevent optimizing out the writes + const auto x = std::uniform_int_distribution(0, xsize - 1)(rng); + const auto y = std::uniform_int_distribution(0, ysize - 1)(rng); + HWY_ASSERT(img.ConstRow(y)[x] < 16 + Lanes(d)); + } + } + } +}; + +void TestAligned() { ForUnsignedTypes(TestAlignedT()); } + +// Ensure we can write an unaligned vector starting at the last valid value. +struct TestUnalignedT { + template + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution dist(0, 3); + const HWY_FULL(T) d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 128; ++xsize) { + Image img(xsize, ysize); + img.InitializePaddingForUnalignedAccesses(); + +// This test reads padding, which only works if it was initialized, +// which only happens in MSAN builds. +#if defined(MEMORY_SANITIZER) || HWY_IDE + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = 1 << dist(rng); + } + } + + // Read padding bits + auto accum = Zero(d); + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + accum |= LoadU(d, row + x); + } + } + + // Ensure padding was zero + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(accum, d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(lanes[i] < 16); + } +#else // Check that writing padding does not overwrite valid samples + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = static_cast(x); + } + } + + // Zero padding and rightmost sample + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + StoreU(Zero(d), d, row + xsize - 1); + } + + // Ensure no samples except the rightmost were overwritten + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize - 1; ++x) { + HWY_ASSERT_EQ(static_cast(x), row[x]); + } + } +#endif + } + } + } +}; + +void TestUnaligned() { ForUnsignedTypes(TestUnalignedT()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(ImageTest); +HWY_EXPORT_AND_TEST_P(ImageTest, TestAligned); +HWY_EXPORT_AND_TEST_P(ImageTest, TestUnaligned); +} // namespace hwy +#endif diff --git a/third_party/highway/contrib/math/math-inl.h b/third_party/highway/contrib/math/math-inl.h new file mode 100644 index 000000000000..a3b269e486ff --- /dev/null +++ b/third_party/highway/contrib/math/math-inl.h @@ -0,0 +1,1192 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_CONTRIB_MATH_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_CONTRIB_MATH_MATH_INL_H_ +#undef HIGHWAY_CONTRIB_MATH_MATH_INL_H_ +#else +#define HIGHWAY_CONTRIB_MATH_MATH_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +/** + * Highway SIMD version of std::acos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc cosine of 'x' + */ +template +HWY_INLINE V Acos(const D d, V x); +template +HWY_NOINLINE V CallAcos(const D d, V x) { + return Acos(d, x); +} + +/** + * Highway SIMD version of std::acosh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[1, +FLT_MAX], float64[1, +DBL_MAX] + * @return hyperbolic arc cosine of 'x' + */ +template +HWY_INLINE V Acosh(const D d, V x); +template +HWY_NOINLINE V CallAcosh(const D d, V x) { + return Acosh(d, x); +} + +/** + * Highway SIMD version of std::asin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc sine of 'x' + */ +template +HWY_INLINE V Asin(const D d, V x); +template +HWY_NOINLINE V CallAsin(const D d, V x) { + return Asin(d, x); +} + +/** + * Highway SIMD version of std::asinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic arc sine of 'x' + */ +template +HWY_INLINE V Asinh(const D d, V x); +template +HWY_NOINLINE V CallAsinh(const D d, V x) { + return Asinh(d, x); +} + +/** + * Highway SIMD version of std::atan(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return arc tangent of 'x' + */ +template +HWY_INLINE V Atan(const D d, V x); +template +HWY_NOINLINE V CallAtan(const D d, V x) { + return Atan(d, x); +} + +/** + * Highway SIMD version of std::atanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: (-1, +1) + * @return hyperbolic arc tangent of 'x' + */ +template +HWY_INLINE V Atanh(const D d, V x); +template +HWY_NOINLINE V CallAtanh(const D d, V x) { + return Atanh(d, x); +} + +/** + * Highway SIMD version of std::cos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return cosine of 'x' + */ +template +HWY_INLINE V Cos(const D d, V x); +template +HWY_NOINLINE V CallCos(const D d, V x) { + return Cos(d, x); +} + +/** + * Highway SIMD version of std::exp(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x + */ +template +HWY_INLINE V Exp(const D d, V x); +template +HWY_NOINLINE V CallExp(const D d, V x) { + return Exp(d, x); +} + +/** + * Highway SIMD version of std::expm1(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x - 1 + */ +template +HWY_INLINE V Expm1(const D d, V x); +template +HWY_NOINLINE V CallExpm1(const D d, V x) { + return Expm1(d, x); +} + +/** + * Highway SIMD version of std::log(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return natural logarithm of 'x' + */ +template +HWY_INLINE V Log(const D d, V x); +template +HWY_NOINLINE V CallLog(const D d, V x) { + return Log(d, x); +} + +/** + * Highway SIMD version of std::log10(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 10 logarithm of 'x' + */ +template +HWY_INLINE V Log10(const D d, V x); +template +HWY_NOINLINE V CallLog10(const D d, V x) { + return Log10(d, x); +} + +/** + * Highway SIMD version of std::log1p(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[0, +FLT_MAX], float64[0, +DBL_MAX] + * @return log(1 + x) + */ +template +HWY_INLINE V Log1p(const D d, V x); +template +HWY_NOINLINE V CallLog1p(const D d, V x) { + return Log1p(d, x); +} + +/** + * Highway SIMD version of std::log2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 2 logarithm of 'x' + */ +template +HWY_INLINE V Log2(const D d, V x); +template +HWY_NOINLINE V CallLog2(const D d, V x) { + return Log2(d, x); +} + +/** + * Highway SIMD version of std::sin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return sine of 'x' + */ +template +HWY_INLINE V Sin(const D d, V x); +template +HWY_NOINLINE V CallSin(const D d, V x) { + return Sin(d, x); +} + +/** + * Highway SIMD version of std::sinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-88.7228, +88.7228], float64[-709, +709] + * @return hyperbolic sine of 'x' + */ +template +HWY_INLINE V Sinh(const D d, V x); +template +HWY_NOINLINE V CallSinh(const D d, V x) { + return Sinh(d, x); +} + +/** + * Highway SIMD version of std::tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic tangent of 'x' + */ +template +HWY_INLINE V Tanh(const D d, V x); +template +HWY_NOINLINE V CallTanh(const D d, V x) { + return Tanh(d, x); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation +//////////////////////////////////////////////////////////////////////////////// +namespace impl { + +// Estrin's Scheme is a faster method for evaluating large polynomials on +// super scalar architectures. It works by factoring the Horner's Method +// polynomial into power of two sub-trees that can be evaluated in parallel. +// Wikipedia Link: https://en.wikipedia.org/wiki/Estrin%27s_scheme +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1) { + return MulAdd(c1, x, c0); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2) { + T x2(x * x); + return MulAdd(x2, c2, MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3) { + T x2(x * x); + return MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4) { + T x2(x * x), x4(x2 * x2); + return MulAdd(x4, c4, MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5) { + T x2(x * x), x4(x2 * x2); + return MulAdd(x4, MulAdd(c5, x, c4), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6) { + T x2(x * x), x4(x2 * x2); + return MulAdd(x4, MulAdd(x2, c6, MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7) { + T x2(x * x), x4(x2 * x2); + return MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, c8, + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, MulAdd(c9, x, c8), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, MulAdd(x2, c10, MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd( + x8, MulAdd(x4, c12, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(c13, x, c12), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, c14, MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4), x16(x8 * x8); + return MulAdd( + x16, c16, + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4), x16(x8 * x8); + return MulAdd( + x16, MulAdd(c17, x, c16), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17, + T c18) { + T x2(x * x), x4(x2 * x2), x8(x4 * x4), x16(x8 * x8); + return MulAdd( + x16, MulAdd(x2, c18, MulAdd(c17, x, c16)), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} + +template +struct AsinImpl {}; +template +struct AtanImpl {}; +template +struct CosSinImpl {}; +template +struct ExpImpl {}; +template +struct LogImpl {}; + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666677296f); + const auto k1 = Set(d, +0.07495029271f); + const auto k2 = Set(d, +0.04547423869f); + const auto k3 = Set(d, +0.02424046025f); + const auto k4 = Set(d, +0.04197454825f); + + return Estrin(x2, k0, k1, k2, k3, k4); + } +}; + +#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666666666666497543); + const auto k1 = Set(d, +0.07500000000378581611); + const auto k2 = Set(d, +0.04464285681377102438); + const auto k3 = Set(d, +0.03038195928038132237); + const auto k4 = Set(d, +0.02237176181932048341); + const auto k5 = Set(d, +0.01735956991223614604); + const auto k6 = Set(d, +0.01388715184501609218); + const auto k7 = Set(d, +0.01215360525577377331); + const auto k8 = Set(d, +0.006606077476277170610); + const auto k9 = Set(d, +0.01929045477267910674); + const auto k10 = Set(d, -0.01581918243329996643); + const auto k11 = Set(d, +0.03161587650653934628); + + return Estrin(x2, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11); + } +}; + +#endif + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333331018686294555664062f); + const auto k1 = Set(d, +0.199926957488059997558594f); + const auto k2 = Set(d, -0.142027363181114196777344f); + const auto k3 = Set(d, +0.106347933411598205566406f); + const auto k4 = Set(d, -0.0748900920152664184570312f); + const auto k5 = Set(d, +0.0425049886107444763183594f); + const auto k6 = Set(d, -0.0159569028764963150024414f); + const auto k7 = Set(d, +0.00282363896258175373077393f); + + const auto y = (x * x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7), (y * x), x); + } +}; + +#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333333333333311110369124); + const auto k1 = Set(d, +0.199999999996591265594148); + const auto k2 = Set(d, -0.14285714266771329383765); + const auto k3 = Set(d, +0.111111105648261418443745); + const auto k4 = Set(d, -0.090908995008245008229153); + const auto k5 = Set(d, +0.0769219538311769618355029); + const auto k6 = Set(d, -0.0666573579361080525984562); + const auto k7 = Set(d, +0.0587666392926673580854313); + const auto k8 = Set(d, -0.0523674852303482457616113); + const auto k9 = Set(d, +0.0466667150077840625632675); + const auto k10 = Set(d, -0.0407629191276836500001934); + const auto k11 = Set(d, +0.0337852580001353069993897); + const auto k12 = Set(d, -0.0254517624932312641616861); + const auto k13 = Set(d, +0.016599329773529201970117); + const auto k14 = Set(d, -0.00889896195887655491740809); + const auto k15 = Set(d, +0.00370026744188713119232403); + const auto k16 = Set(d, -0.00110611831486672482563471); + const auto k17 = Set(d, +0.000209850076645816976906797); + const auto k18 = Set(d, -1.88796008463073496563746e-5); + + const auto y = (x * x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, + k12, k13, k14, k15, k16, k17, k18), + (y * x), x); + } +}; + +#endif + +template <> +struct CosSinImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -1.66666597127914428710938e-1f); + const auto k1 = Set(d, +8.33307858556509017944336e-3f); + const auto k2 = Set(d, -1.981069071916863322258e-4f); + const auto k3 = Set(d, +2.6083159809786593541503e-6f); + + const auto y(x * x); + return MulAdd(Estrin(y, k0, k1, k2, k3), (y * x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2 + const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f); + const V kHalfPiPart1f = Set(d, -0.5f * 0.0009670257568359375f); + const V kHalfPiPart2f = Set(d, -0.5f * 6.2771141529083251953e-7f); + const V kHalfPiPart3f = Set(d, -0.5f * 1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kHalfPiPart0f, x); + x = MulAdd(qf, kHalfPiPart1f, x); + x = MulAdd(qf, kHalfPiPart2f, x); + x = MulAdd(qf, kHalfPiPart3f, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi + const V kPiPart0f = Set(d, -3.140625f); + const V kPiPart1f = Set(d, -0.0009670257568359375f); + const V kPiPart2f = Set(d, -6.2771141529083251953e-7f); + const V kPiPart3f = Set(d, -1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kPiPart0f, x); + x = MulAdd(qf, kPiPart1f, x); + x = MulAdd(qf, kPiPart2f, x); + x = MulAdd(qf, kPiPart3f, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast(d, ShiftLeft<31>(And(q, kOne))); + } +}; + +#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 + +template <> +struct CosSinImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -0.166666666666666657414808); + const auto k1 = Set(d, +0.00833333333333332974823815); + const auto k2 = Set(d, -0.000198412698412696162806809); + const auto k3 = Set(d, +2.75573192239198747630416e-6); + const auto k4 = Set(d, -2.50521083763502045810755e-8); + const auto k5 = Set(d, +1.60590430605664501629054e-10); + const auto k6 = Set(d, -7.64712219118158833288484e-13); + const auto k7 = Set(d, +2.81009972710863200091251e-15); + const auto k8 = Set(d, -7.97255955009037868891952e-18); + + const auto y(x * x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), (y * x), x); + } + + template + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2 + const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914); + const V kHalfPiPart1d = Set(d, -0.5 * 3.1786509424591713469e-8); + const V kHalfPiPart2d = Set(d, -0.5 * 1.2246467864107188502e-16); + const V kHalfPiPart3d = Set(d, -0.5 * 1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kHalfPiPart0d, x); + x = MulAdd(qf, kHalfPiPart1d, x); + x = MulAdd(qf, kHalfPiPart2d, x); + x = MulAdd(qf, kHalfPiPart3d, x); + return x; + } + + template + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi + const V kPiPart0d = Set(d, -3.1415926218032836914); + const V kPiPart1d = Set(d, -3.1786509424591713469e-8); + const V kPiPart2d = Set(d, -1.2246467864107188502e-16); + const V kPiPart3d = Set(d, -1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kPiPart0d, x); + x = MulAdd(qf, kPiPart1d, x); + x = MulAdd(qf, kPiPart2d, x); + x = MulAdd(qf, kPiPart3d, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast( + d, ShiftLeft<62>(PromoteTo(Rebind(), AndNot(q, kTwo)))); + } + + // ((q & 1) ? -0.0 : +0.0) + template + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast( + d, ShiftLeft<63>(PromoteTo(Rebind(), And(q, kOne)))); + } +}; + +#endif + +template <> +struct ExpImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5f); + const auto k1 = Set(d, +0.166666671633720397949219f); + const auto k2 = Set(d, +0.0416664853692054748535156f); + const auto k3 = Set(d, +0.00833336077630519866943359f); + const auto k4 = Set(d, +0.00139304355252534151077271f); + const auto k5 = Set(d, +0.000198527617612853646278381f); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5), (x * x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(x + kOffset)); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return x * Pow2I(d, y) * Pow2I(d, e - y); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0f + kLn2Part1f ~= -ln(2) + const V kLn2Part0f = Set(d, -0.693145751953125f); + const V kLn2Part1f = Set(d, -1.428606765330187045e-6f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kLn2Part0f, x); + x = MulAdd(qf, kLn2Part1f, x); + return x; + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di32; + const Rebind du32; + return BitCast(di32, ShiftRight<23>(BitCast(du32, x))) - Set(di32, 0x7F); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.66666662693f); + const V k1 = Set(d, 0.40000972152f); + const V k2 = Set(d, 0.28498786688f); + const V k3 = Set(d, 0.24279078841f); + + const V x2 = (x * x); + const V x4 = (x2 * x2); + return MulAdd(MulAdd(k2, x4, k0), x2, (MulAdd(k3, x4, k1) * x4)); + } +}; + +#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 +template <> +struct ExpImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5); + const auto k1 = Set(d, +0.166666666666666851703837); + const auto k2 = Set(d, +0.0416666666666665047591422); + const auto k3 = Set(d, +0.00833333333331652721664984); + const auto k4 = Set(d, +0.00138888888889774492207962); + const auto k5 = Set(d, +0.000198412698960509205564975); + const auto k6 = Set(d, +2.4801587159235472998791e-5); + const auto k7 = Set(d, +2.75572362911928827629423e-6); + const auto k8 = Set(d, +2.75573911234900471893338e-7); + const auto k9 = Set(d, +2.51112930892876518610661e-8); + const auto k10 = Set(d, +2.08860621107283687536341e-9); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10), + (x * x), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, x + kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return (x * Pow2I(d, y) * Pow2I(d, e - y)); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0d + kLn2Part1d ~= -ln(2) + const V kLn2Part0d = Set(d, -0.6931471805596629565116018); + const V kLn2Part1d = Set(d, -0.28235290563031577122588448175e-12); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kLn2Part0d, x); + x = MulAdd(qf, kLn2Part1d, x); + return x; + } +}; + +template <> +struct LogImpl { + template + HWY_INLINE Vec> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind di64; + const Rebind du64; + return BitCast(di64, ShiftRight<52>(BitCast(du64, x))) - Set(di64, 0x3FF); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.6666666666666735130); + const V k1 = Set(d, 0.3999999999940941908); + const V k2 = Set(d, 0.2857142874366239149); + const V k3 = Set(d, 0.2222219843214978396); + const V k4 = Set(d, 0.1818357216161805012); + const V k5 = Set(d, 0.1531383769920937332); + const V k6 = Set(d, 0.1479819860511658591); + + const V x2 = (x * x); + const V x4 = (x2 * x2); + return MulAdd(MulAdd(MulAdd(MulAdd(k6, x4, k4), x4, k2), x4, k0), x2, + (MulAdd(MulAdd(k5, x4, k3), x4, k1) * x4)); + } +}; + +#endif + +template +HWY_INLINE V Log(const D d, V x) { + // http://git.musl-libc.org/cgit/musl/tree/src/math/log.c for more info. + using LaneType = LaneType; + impl::LogImpl impl; + + // clang-format off + constexpr bool kIsF32 = (sizeof(LaneType) == 4); + + // Float Constants + const V kLn2Hi = Set(d, (kIsF32 ? 0.69313812256f : + 0.693147180369123816490 )); + const V kLn2Lo = Set(d, (kIsF32 ? 9.0580006145e-6f : + 1.90821492927058770002e-10)); + const V kOne = Set(d, +1.0); + const V kMinNormal = Set(d, (kIsF32 ? 1.175494351e-38f : + 2.2250738585072014e-308 )); + const V kScale = Set(d, (kIsF32 ? 3.355443200e+7f : + 1.8014398509481984e+16 )); + + // Integer Constants + const Rebind, D> di; + using VI = decltype(Zero(di)); + const VI kLowerBits = Set(di, (kIsF32 ? 0x00000000L : 0xFFFFFFFFLL)); + const VI kMagic = Set(di, (kIsF32 ? 0x3F3504F3L : 0x3FE6A09E00000000LL)); + const VI kExpMask = Set(di, (kIsF32 ? 0x3F800000L : 0x3FF0000000000000LL)); + const VI kExpScale = Set(di, (kIsF32 ? -25 : -54)); + const VI kManMask = Set(di, (kIsF32 ? 0x7FFFFFL : 0xFFFFF00000000LL)); + // clang-format on + + // Scale up 'x' so that it is no longer denormalized. + VI exp_bits; + V exp; + if (kAllowSubnormals == true) { + const auto is_denormal = (x < kMinNormal); + x = IfThenElse(is_denormal, (x * kScale), x); + + // Compute the new exponent. + exp_bits = (BitCast(di, x) + (kExpMask - kMagic)); + const VI exp_scale = + BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale))); + exp = ConvertTo( + d, exp_scale + impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))); + } else { + // Compute the new exponent. + exp_bits = (BitCast(di, x) + (kExpMask - kMagic)); + exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))); + } + + // Renormalize. + const V y = Or(And(x, BitCast(d, kLowerBits)), + BitCast(d, ((exp_bits & kManMask) + kMagic))); + + // Approximate and reconstruct. + const V ym1 = (y - kOne); + const V z = (ym1 / (y + kOne)); + + return MulSub(exp, kLn2Hi, + (MulSub(z, (ym1 - impl.LogPoly(d, z)), (exp * kLn2Lo)) - ym1)); +} + +} // namespace impl + +template +HWY_NOINLINE V Acos(const D d, V x) { + using LaneType = LaneType; + + const V kZero = Zero(d); + const V kHalf = Set(d, +0.5); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + const V kPi = Set(d, +3.14159265358979323846264); + const V kPiOverTwo = Set(d, +1.57079632679489661923132169); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = (abs_x < kHalf); + const V yy = IfThenElse(mask, (abs_x * abs_x), ((kOne - abs_x) * kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V t = (impl.AsinPoly(d, yy, y) * (y * yy)); + const V z = IfThenElse(mask, (kPiOverTwo - (Xor(y, sign_x) + Xor(t, sign_x))), + ((t + y) * kTwo)); + return IfThenElse(Or(mask, (x >= kZero)), z, (kPi - z)); +} + +template +HWY_NOINLINE V Acosh(const D d, V x) { + const V kLarge = Set(d, 268435456.0); + const V kLog2 = Set(d, 0.693147180559945286227); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + + const auto is_x_large = (x > kLarge); + const auto is_x_gt_2 = (x > kTwo); + + const V x_minus_1 = (x - kOne); + const V y0 = MulSub(kTwo, x, (kOne / (Sqrt(MulSub(x, x, kOne)) + x))); + const V y1 = + (Sqrt(MulAdd(x_minus_1, kTwo, (x_minus_1 * x_minus_1))) + x_minus_1); + const V y2 = + IfThenElse(is_x_gt_2, IfThenElse(is_x_large, x, y0), (y1 + kOne)); + const V z = impl::Log(d, y2); + + const auto is_pole = y2 == kOne; + const auto divisor = IfThenZeroElse(is_pole, y2) - kOne; + return IfThenElse(is_x_gt_2, z, IfThenElse(is_pole, y1, z * y1 / divisor)) + + IfThenElseZero(is_x_large, kLog2); +} + +template +HWY_NOINLINE V Asin(const D d, V x) { + using LaneType = LaneType; + + const V kHalf = Set(d, +0.5); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + const V kPiOverTwo = Set(d, +1.57079632679489661923132169); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = (abs_x < kHalf); + const V yy = IfThenElse(mask, (abs_x * abs_x), (kOne - abs_x) * kHalf); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V z0 = MulAdd(impl.AsinPoly(d, yy, y), (yy * y), y); + const V z1 = (kPiOverTwo - (z0 * kTwo)); + return Or(IfThenElse(mask, z0, z1), sign_x); +} + +template +HWY_NOINLINE V Asinh(const D d, V x) { + const V kSmall = Set(d, 1.0 / 268435456.0); + const V kLarge = Set(d, 268435456.0); + const V kLog2 = Set(d, 0.693147180559945286227); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + + const V sign_x = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign_x); + + const auto is_x_large = (abs_x > kLarge); + const auto is_x_lt_2 = (abs_x < kTwo); + + const V x2 = (x * x); + const V sqrt_x2_plus_1 = Sqrt(x2 + kOne); + + const V y0 = MulAdd(abs_x, kTwo, (kOne / (sqrt_x2_plus_1 + abs_x))); + const V y1 = ((x2 / (sqrt_x2_plus_1 + kOne)) + abs_x); + const V y2 = + IfThenElse(is_x_lt_2, (y1 + kOne), IfThenElse(is_x_large, abs_x, y0)); + const V z = impl::Log(d, y2); + + const auto is_pole = y2 == kOne; + const auto divisor = IfThenZeroElse(is_pole, y2) - kOne; + const auto large = IfThenElse(is_pole, y1, z * y1 / divisor); + const V y = IfThenElse(abs_x < kSmall, x, large); + return Or((IfThenElse(is_x_lt_2, y, z) + IfThenElseZero(is_x_large, kLog2)), + sign_x); +} + +template +HWY_NOINLINE V Atan(const D d, V x) { + using LaneType = LaneType; + + const V kOne = Set(d, +1.0); + const V kPiOverTwo = Set(d, +1.57079632679489661923132169); + + const V sign = And(SignBit(d), x); + const V abs_x = Xor(x, sign); + const auto mask = (abs_x > kOne); + + impl::AtanImpl impl; + const auto divisor = IfThenElse(mask, abs_x, kOne); + const V y = impl.AtanPoly(d, IfThenElse(mask, kOne / divisor, abs_x)); + return Or(IfThenElse(mask, (kPiOverTwo - y), y), sign); +} + +template +HWY_NOINLINE V Atanh(const D d, V x) { + const V kHalf = Set(d, +0.5); + const V kOne = Set(d, +1.0); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + return Log1p(d, ((abs_x + abs_x) / (kOne - abs_x))) * Xor(kHalf, sign); +} + +template +HWY_NOINLINE V Cos(const D d, V x) { + using LaneType = LaneType; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, 0.31830988618379067153); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + const VI32 kOne = Set(di32, 1); + + const V y = Abs(x); // cos(x) == cos(|x|) + + // Compute the quadrant, q = int(|x| / pi) * 2 + 1 + const VI32 q = (ShiftLeft<1>(impl.ToInt32(d, y * kOneOverPi)) + kOne); + + // Reduce range, apply sign, and approximate. + return impl.Poly( + d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); +} + +template +HWY_NOINLINE V Exp(const D d, V x) { + using LaneType = LaneType; + + // clang-format off + const V kHalf = Set(d, +0.5); + const V kLowerBound = Set(d, (sizeof(LaneType) == 4 ? -104.0 : -1000.0)); + const V kNegZero = Set(d, -0.0); + const V kOne = Set(d, +1.0); + const V kOneOverLog2 = Set(d, +1.442695040888963407359924681); + // clang-format on + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, (impl.ExpPoly(d, impl.ExpReduce(d, x, q)) + kOne), q); + return IfThenElseZero(x >= kLowerBound, y); +} + +template +HWY_NOINLINE V Expm1(const D d, V x) { + using LaneType = LaneType; + + // clang-format off + const V kHalf = Set(d, +0.5); + const V kLowerBound = Set(d, (sizeof(LaneType) == 4 ? -104.0 : -1000.0)); + const V kLn2Over2 = Set(d, +0.346573590279972654708616); + const V kNegOne = Set(d, -1.0); + const V kNegZero = Set(d, -0.0); + const V kOne = Set(d, +1.0); + const V kOneOverLog2 = Set(d, +1.442695040888963407359924681); + // clang-format on + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.ExpPoly(d, impl.ExpReduce(d, x, q)); + const V z = IfThenElse(Abs(x) < kLn2Over2, y, + impl.LoadExpShortRange(d, (y + kOne), q) - kOne); + return IfThenElse(x < kLowerBound, kNegOne, z); +} + +template +HWY_NOINLINE V Log(const D d, V x) { + return impl::Log(d, x); +} + +template +HWY_NOINLINE V Log10(const D d, V x) { + return Log(d, x) * Set(d, 0.4342944819032518276511); +} + +template +HWY_NOINLINE V Log1p(const D d, V x) { + const V kOne = Set(d, +1.0); + + const V y = x + kOne; + const auto is_pole = y == kOne; + const auto divisor = IfThenZeroElse(is_pole, y) - kOne; + const auto non_pole = + impl::Log(d, y) * (x / divisor); + return IfThenElse(is_pole, x, non_pole); +} + +template +HWY_NOINLINE V Log2(const D d, V x) { + return Log(d, x) * Set(d, 1.44269504088896340735992); +} + +template +HWY_NOINLINE V Sin(const D d, V x) { + using LaneType = LaneType; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, 0.31830988618379067153); + const V kHalf = Set(d, 0.5); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + + const V abs_x = Abs(x); + const V sign_x = Xor(abs_x, x); + + // Compute the quadrant, q = int((|x| / pi) + 0.5) + const VI32 q = impl.ToInt32(d, MulAdd(abs_x, kOneOverPi, kHalf)); + + // Reduce range, apply sign, and approximate. + return impl.Poly(d, Xor(impl.SinReduce(d, abs_x, q), + Xor(impl.SinSignFromQuadrant(d, q), sign_x))); +} + +template +HWY_NOINLINE V Sinh(const D d, V x) { + const V kHalf = Set(d, +0.5); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x); + const V z = ((y + kTwo) / (y + kOne) * (y * kHalf)); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_NOINLINE V Tanh(const D d, V x) { + const V kLimit = Set(d, 18.714973875); + const V kOne = Set(d, +1.0); + const V kTwo = Set(d, +2.0); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x * kTwo); + const V z = IfThenElse((abs_x > kLimit), kOne, (y / (y + kTwo))); + return Xor(z, sign); // Reapply the sign bit +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_CONTRIB_MATH_MATH_INL_H_ diff --git a/third_party/highway/contrib/math/math_test.cc b/third_party/highway/contrib/math/math_test.cc new file mode 100644 index 000000000000..5cede6828542 --- /dev/null +++ b/third_party/highway/contrib/math/math_test.cc @@ -0,0 +1,188 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // FLT_MAX +#include +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "contrib/math/math_test.cc" +#include "hwy/foreach_target.h" + +#include "contrib/math/math-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +void TestMath(const std::string name, T (*fx1)(T), Vec (*fxN)(D, Vec), + D d, T min, T max, uint64_t max_error_ulp) { + constexpr bool kIsF32 = (sizeof(T) == 4); + using UintT = MakeUnsigned; + + const UintT min_bits = BitCast(min); + const UintT max_bits = BitCast(max); + + // If min is negative and max is positive, the range needs to be broken into + // two pieces, [+0, max] and [-0, min], otherwise [min, max]. + int range_count = 1; + UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}}; + if ((min < 0.0) && (max > 0.0)) { + ranges[0][0] = BitCast(static_cast(+0.0)); + ranges[0][1] = max_bits; + ranges[1][0] = BitCast(static_cast(-0.0)); + ranges[1][1] = min_bits; + range_count = 2; + } + + uint64_t max_ulp = 0; +#if HWY_ARCH_ARM + // Emulation is slower, so cannot afford as many. + constexpr UintT kSamplesPerRange = 25000; +#else + constexpr UintT kSamplesPerRange = 100000; +#endif + for (int range_index = 0; range_index < range_count; ++range_index) { + const UintT start = ranges[range_index][0]; + const UintT stop = ranges[range_index][1]; + const UintT step = std::max(1, ((stop - start) / kSamplesPerRange)); + for (UintT value_bits = start; value_bits <= stop; value_bits += step) { + const T value = BitCast(std::min(value_bits, stop)); + const T actual = GetLane(fxN(d, Set(d, value))); + const T expected = fx1(value); + + // Skip small inputs and outputs on armv7, it flushes subnormals to zero. +#if HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) { + continue; + } +#endif + + const auto ulp = ComputeUlpDelta(actual, expected); + max_ulp = std::max(max_ulp, ulp); + if (ulp > max_error_ulp) { + std::cout << name << "<" << (kIsF32 ? "F32x" : "F64x") << Lanes(d) + << ">(" << value << ") expected: " << expected + << " actual: " << actual << std::endl; + } + HWY_ASSERT(ulp <= max_error_ulp); + } + } + std::cout << (kIsF32 ? "F32x" : "F64x") << Lanes(d) + << ", Max ULP: " << max_ulp << std::endl; +} + +#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \ + F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \ + struct Test##NAME { \ + template \ + HWY_NOINLINE void operator()(T, D d) { \ + if (sizeof(T) == 4) { \ + TestMath(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \ + F32_ERROR); \ + } else { \ + TestMath(HWY_STR(NAME), F64x1, F64xN, d, F64_MIN, F64_MAX, \ + F64_ERROR); \ + } \ + } \ + }; \ + HWY_NOINLINE void TestAll##NAME() { \ + ForFloatTypes(ForPartialVectors()); \ + } + +// Floating point values closest to but less than 1.0 +const float kNearOneF = BitCast(0x3F7FFFFF); +const double kNearOneD = BitCast(0x3FEFFFFFFFFFFFFFULL); + +// clang-format off +DEFINE_MATH_TEST(Acos, + std::acos, CallAcos, -1.0, +1.0, 3, // NEON is 3 instead of 2 + std::acos, CallAcos, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Acosh, + std::acosh, CallAcosh, +1.0, +FLT_MAX, 3, + std::acosh, CallAcosh, +1.0, +DBL_MAX, 3) +DEFINE_MATH_TEST(Asin, + std::asin, CallAsin, -1.0, +1.0, 3, // NEON is 3 instead of 2 + std::asin, CallAsin, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Asinh, + std::asinh, CallAsinh, -FLT_MAX, +FLT_MAX, 3, + std::asinh, CallAsinh, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atan, + std::atan, CallAtan, -FLT_MAX, +FLT_MAX, 3, + std::atan, CallAtan, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atanh, + std::atanh, CallAtanh, -kNearOneF, +kNearOneF, 4, // NEON is 4 instead of 3 + std::atanh, CallAtanh, -kNearOneD, +kNearOneD, 3) +DEFINE_MATH_TEST(Cos, + std::cos, CallCos, -39000.0, +39000.0, 3, + std::cos, CallCos, -39000.0, +39000.0, 3) +DEFINE_MATH_TEST(Exp, + std::exp, CallExp, -FLT_MAX, +104.0, 1, + std::exp, CallExp, -DBL_MAX, +104.0, 1) +DEFINE_MATH_TEST(Expm1, + std::expm1, CallExpm1, -FLT_MAX, +104.0, 4, + std::expm1, CallExpm1, -DBL_MAX, +104.0, 4) +DEFINE_MATH_TEST(Log, + std::log, CallLog, +FLT_MIN, +FLT_MAX, 1, + std::log, CallLog, +DBL_MIN, +DBL_MAX, 1) +DEFINE_MATH_TEST(Log10, + std::log10, CallLog10, +FLT_MIN, +FLT_MAX, 2, + std::log10, CallLog10, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log1p, + std::log1p, CallLog1p, +0.0f, +1e37, 3, // NEON is 3 instead of 2 + std::log1p, CallLog1p, +0.0, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log2, + std::log2, CallLog2, +FLT_MIN, +FLT_MAX, 2, + std::log2, CallLog2, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Sin, + std::sin, CallSin, -39000.0, +39000.0, 3, + std::sin, CallSin, -39000.0, +39000.0, 3) +DEFINE_MATH_TEST(Sinh, + std::sinh, CallSinh, -80.0f, +80.0f, 4, + std::sinh, CallSinh, -709.0, +709.0, 4) +DEFINE_MATH_TEST(Tanh, + std::tanh, CallTanh, -FLT_MAX, +FLT_MAX, 4, + std::tanh, CallTanh, -DBL_MAX, +DBL_MAX, 4) +// clang-format on + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyMathTest); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcosh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtan); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtanh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllCos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllTanh); +} // namespace hwy +#endif diff --git a/third_party/highway/debian/changelog b/third_party/highway/debian/changelog new file mode 100644 index 000000000000..ea71560b3f4e --- /dev/null +++ b/third_party/highway/debian/changelog @@ -0,0 +1,31 @@ +highway (0.12.0-1) UNRELEASED; urgency=medium + + * Add Shift*8, Compress16, emulated Scatter/Gather, StoreInterleaved3/4 + * Remove deprecated HWY_*_LANES, deprecate HWY_GATHER_LANES + * Proper IEEE rounding, reduce libstdc++ usage, inlined math + + -- Jan Wassenberg Thu, 15 Apr 2021 20:00:00 +0200 + +highway (0.11.1-1) UNRELEASED; urgency=medium + + * Fix clang7 asan error, finish f16 conversions and add test + + -- Jan Wassenberg Thu, 25 Feb 2021 16:00:00 +0200 + +highway (0.11.0-1) UNRELEASED; urgency=medium + + * Add RVV+mask logical ops, allow Shl/ShiftLeftSame on all targets, more math + + -- Jan Wassenberg Thu, 18 Feb 2021 20:00:00 +0200 + +highway (0.7.0-1) UNRELEASED; urgency=medium + + * Added API stability notice, Compress[Store], contrib/, SignBit, CopySign + + -- Jan Wassenberg Tue, 5 Jan 2021 17:00:00 +0200 + +highway (0.1-1) UNRELEASED; urgency=medium + + * Initial debian package. + + -- Alex Deymo Mon, 19 Oct 2020 16:48:07 +0200 diff --git a/third_party/highway/debian/compat b/third_party/highway/debian/compat new file mode 100644 index 000000000000..f599e28b8ab0 --- /dev/null +++ b/third_party/highway/debian/compat @@ -0,0 +1 @@ +10 diff --git a/third_party/highway/debian/control b/third_party/highway/debian/control new file mode 100644 index 000000000000..7c60ebc7f413 --- /dev/null +++ b/third_party/highway/debian/control @@ -0,0 +1,23 @@ +Source: highway +Maintainer: JPEG XL Maintainers +Section: misc +Priority: optional +Standards-Version: 3.9.8 +Build-Depends: cmake, + debhelper (>= 9), + libgtest-dev +Homepage: https://github.com/google/highway + +Package: libhwy-dev +Architecture: any +Section: libdevel +Depends: ${misc:Depends} +Description: Efficient and performance-portable SIMD wrapper (developer files) + This library provides type-safe and source-code portable wrappers over + existing platform-specific intrinsics. Its design aims for simplicity, + reliable efficiency across platforms, and immediate usability with current + compilers. + . + This package installs the development files. There's no runtime library + since most of Highway is implemented in headers and only a very small + static library is needed. diff --git a/third_party/highway/debian/copyright b/third_party/highway/debian/copyright new file mode 100644 index 000000000000..53ea57aa97da --- /dev/null +++ b/third_party/highway/debian/copyright @@ -0,0 +1,20 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ +Upstream-Name: highway + +Files: * +Copyright: 2020 Google LLC +License: Apache-2.0 + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + http://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + . + On Debian systems, the complete text of the Apache License, Version 2 + can be found in "/usr/share/common-licenses/Apache-2.0". diff --git a/third_party/highway/debian/rules b/third_party/highway/debian/rules new file mode 100644 index 000000000000..969fc120e87d --- /dev/null +++ b/third_party/highway/debian/rules @@ -0,0 +1,6 @@ +#!/usr/bin/make -f +%: + dh $@ --buildsystem=cmake + +override_dh_auto_configure: + dh_auto_configure -- -DHWY_SYSTEM_GTEST=ON diff --git a/third_party/highway/debian/source/format b/third_party/highway/debian/source/format new file mode 100644 index 000000000000..163aaf8d82b6 --- /dev/null +++ b/third_party/highway/debian/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/third_party/highway/hwy/aligned_allocator.cc b/third_party/highway/hwy/aligned_allocator.cc new file mode 100644 index 000000000000..bec7c3bb1b70 --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator.cc @@ -0,0 +1,138 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include +#include +#include +#include // malloc + +#include +#include + +#include "hwy/base.h" + +namespace hwy { +namespace { + +constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, kMaxVectorSize); +// On x86, aliasing can only occur at multiples of 2K, but that's too wasteful +// if this is used for single-vector allocations. 256 is more reasonable. +constexpr size_t kAlias = kAlignment * 4; + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t payload_size; +}; +#pragma pack(pop) + +// Returns a 'random' (cyclical) offset for AllocateAlignedBytes. +size_t NextAlignedOffset() { + static std::atomic next{0}; + constexpr uint32_t kGroups = kAlias / kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + const size_t offset = kAlignment * group; + HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias); + return offset; +} + +} // namespace + +void* AllocateAlignedBytes(const size_t payload_size, AllocPtr alloc_ptr, + void* opaque_ptr) { + if (payload_size >= std::numeric_limits::max() / 2) { + HWY_DASSERT(false && "payload_size too large"); + return nullptr; + } + + size_t offset = NextAlignedOffset(); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + offset = kAlignment; // = RoundUpTo(sizeof(AllocationHeader), kAlignment) + static_assert(sizeof(AllocationHeader) <= kAlignment, "Else: round up"); + } + + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated; + if (alloc_ptr == nullptr) { + allocated = malloc(allocated_size); + } else { + allocated = (*alloc_ptr)(opaque_ptr, allocated_size); + } + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for FreeAlignedBytes(). + // The allocated_size can be reconstructed from the payload_size. + AllocationHeader* header = reinterpret_cast(payload) - 1; + header->allocated = allocated; + header->payload_size = payload_size; + + return HWY_ASSUME_ALIGNED(reinterpret_cast(payload), kMaxVectorSize); +} + +void FreeAlignedBytes(const void* aligned_pointer, FreePtr free_ptr, + void* opaque_ptr) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +// static +void AlignedDeleter::DeleteAlignedArray(void* aligned_pointer, FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (deleter) { + (*deleter)(aligned_pointer, header->payload_size); + } + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +} // namespace hwy diff --git a/third_party/highway/hwy/aligned_allocator.h b/third_party/highway/hwy/aligned_allocator.h new file mode 100644 index 000000000000..2ead7b523387 --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator.h @@ -0,0 +1,179 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ +#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ + +// Memory allocator with support for alignment and offsets. + +#include +#include + +namespace hwy { + +// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which +// requires a literal. This matches typical L1 cache line sizes, which prevents +// false sharing. +#define HWY_ALIGNMENT 64 + +// Pointers to functions equivalent to malloc/free with an opaque void* passed +// to them. +using AllocPtr = void* (*)(void* opaque, size_t bytes); +using FreePtr = void (*)(void* opaque, void* memory); + +// Returns null or a pointer to at least `payload_size` (which can be zero) +// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and +// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain +// memory or malloc() if it is null. +void* AllocateAlignedBytes(size_t payload_size, AllocPtr alloc_ptr, + void* opaque_ptr); + +// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it +// must have been returned from a previous call to `AllocateAlignedBytes`. +// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if +// `free_ptr` function is null, uses the default free(). +void FreeAlignedBytes(const void* aligned_pointer, FreePtr free_ptr, + void* opaque_ptr); + +// Class that deletes the aligned pointer passed to operator() calling the +// destructor before freeing the pointer. This is equivalent to the +// std::default_delete but for aligned objects. For a similar deleter equivalent +// to free() for aligned memory see AlignedFreer(). +class AlignedDeleter { + public: + AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedDeleter(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_, + TypedArrayDeleter); + } + + private: + template + static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) { + size_t elems = size_in_bytes / sizeof(T); + for (size_t i = 0; i < elems; i++) { + // Explicitly call the destructor on each element. + (static_cast(ptr) + i)->~T(); + } + } + + // Function prototype that calls the destructor for each element in a typed + // array. TypeArrayDeleter would match this prototype. + using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); + + static void DeleteAlignedArray(void* aligned_pointer, FreePtr free_ptr, + void* opaque_ptr, ArrayDeleter deleter); + + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to T with custom aligned deleter. This can be a single +// element U or an array of element if T is a U[]. The custom aligned deleter +// will call the destructor on U or each element of a U[] in the array case. +template +using AlignedUniquePtr = std::unique_ptr; + +// Aligned memory equivalent of make_unique using the custom allocators +// alloc/free with the passed `opaque` pointer. This function calls the +// constructor with the passed Args... and calls the destructor of the object +// when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, + void* opaque, Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T), alloc, opaque)); + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter(free, opaque)); +} + +// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free +// functions. +template +AlignedUniquePtr MakeUniqueAligned(Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes( + sizeof(T), /*alloc_ptr=*/nullptr, /*opaque_ptr=*/nullptr)); + return AlignedUniquePtr( + new (ptr) T(std::forward(args)...), AlignedDeleter()); +} + +// Aligned memory equivalent of make_unique for array types using the +// custom allocators alloc/free. This function calls the constructor with the +// passed Args... on every created item. The destructor of each element will be +// called when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedArrayWithAlloc( + size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) { + T* ptr = + static_cast(AllocateAlignedBytes(items * sizeof(T), alloc, opaque)); + for (size_t i = 0; i < items; i++) { + new (ptr + i) T(std::forward(args)...); + } + return AlignedUniquePtr(ptr, AlignedDeleter(free, opaque)); +} + +template +AlignedUniquePtr MakeUniqueAlignedArray(size_t items, Args&&... args) { + return MakeUniqueAlignedArrayWithAlloc( + items, nullptr, nullptr, nullptr, std::forward(args)...); +} + +// Custom deleter for std::unique_ptr equivalent to using free() as a deleter +// but for aligned memory. +class AlignedFreer { + public: + // Pass address of this to ctor to skip deleting externally-owned memory. + static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {} + + AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedFreer(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + // TODO(deymo): assert that we are using a POD type T. + FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_); + } + + private: + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD +// data use AlignedUniquePtr. +template +using AlignedFreeUniquePtr = std::unique_ptr; + +// Allocate an aligned and uninitialized array of POD values as a unique_ptr. +// Upon destruction of the unique_ptr the aligned array will be freed. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items, AllocPtr alloc, + FreePtr free, void* opaque) { + return AlignedFreeUniquePtr( + static_cast(AllocateAlignedBytes(items * sizeof(T), alloc, opaque)), + AlignedFreer(free, opaque)); +} + +// Same as previous AllocateAligned(), using default allocate/free functions. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items) { + return AllocateAligned(items, nullptr, nullptr, nullptr); +} + +} // namespace hwy +#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/third_party/highway/hwy/aligned_allocator_test.cc b/third_party/highway/hwy/aligned_allocator_test.cc new file mode 100644 index 000000000000..63e0e993be55 --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator_test.cc @@ -0,0 +1,250 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "hwy/base.h" + +namespace { + +// Sample object that keeps track on an external counter of how many times was +// the explicit constructor and destructor called. +template +class SampleObject { + public: + SampleObject() { data_[0] = 'a'; } + explicit SampleObject(int* counter) : counter_(counter) { + if (counter) (*counter)++; + data_[0] = 'b'; + } + + ~SampleObject() { + if (counter_) (*counter_)--; + } + + static_assert(N > sizeof(int*), "SampleObject size too small."); + int* counter_ = nullptr; + char data_[N - sizeof(int*)]; +}; + +class FakeAllocator { + public: + // static AllocPtr and FreePtr member to be used with the alligned + // allocator. These functions calls the private non-static members. + static void* StaticAlloc(void* opaque, size_t bytes) { + return reinterpret_cast(opaque)->Alloc(bytes); + } + static void StaticFree(void* opaque, void* memory) { + return reinterpret_cast(opaque)->Free(memory); + } + + // Returns the number of pending allocations to be freed. + size_t PendingAllocs() { return allocs_.size(); } + + private: + void* Alloc(size_t bytes) { + void* ret = malloc(bytes); + allocs_.insert(ret); + return ret; + } + void Free(void* memory) { + if (!memory) return; + EXPECT_NE(allocs_.end(), allocs_.find(memory)); + free(memory); + allocs_.erase(memory); + } + + std::set allocs_; +}; + +} // namespace + +namespace hwy { + +class AlignedAllocatorTest : public testing::Test {}; + +TEST(AlignedAllocatorTest, FreeNullptr) { + // Calling free with a nullptr is always ok. + FreeAlignedBytes(/*aligned_pointer=*/nullptr, /*free_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, AllocDefaultPointers) { + const size_t kSize = 7777; + void* ptr = AllocateAlignedBytes(kSize, /*alloc_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); + ASSERT_NE(nullptr, ptr); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast(ptr) % kMaxVectorSize); + char* p = static_cast(ptr); + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using p[] to prevent it being optimized away. + p[i] = static_cast(i & 0x7F); + if (i) ret += p[i] * p[i - 1]; + } + EXPECT_NE(0U, ret); + FreeAlignedBytes(ptr, /*free_ptr=*/nullptr, /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, EmptyAlignedUniquePtr) { + AlignedUniquePtr> ptr(nullptr, AlignedDeleter()); + AlignedUniquePtr[]> arr(nullptr, AlignedDeleter()); +} + +TEST(AlignedAllocatorTest, EmptyAlignedFreeUniquePtr) { + AlignedFreeUniquePtr> ptr(nullptr, AlignedFreer()); + AlignedFreeUniquePtr[]> arr(nullptr, AlignedFreer()); +} + +TEST(AlignedAllocatorTest, CustomAlloc) { + FakeAllocator fake_alloc; + + const size_t kSize = 7777; + void* ptr = + AllocateAlignedBytes(kSize, &FakeAllocator::StaticAlloc, &fake_alloc); + ASSERT_NE(nullptr, ptr); + // We should have only requested one alloc from the allocator. + EXPECT_EQ(1U, fake_alloc.PendingAllocs()); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast(ptr) % kMaxVectorSize); + FreeAlignedBytes(ptr, &FakeAllocator::StaticFree, &fake_alloc); + EXPECT_EQ(0U, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedDefaultConstructor) { + { + auto ptr = MakeUniqueAligned>(); + // Default constructor sets the data_[0] to 'a'. + EXPECT_EQ('a', ptr->data_[0]); + EXPECT_EQ(nullptr, ptr->counter_); + } +} + +TEST(AlignedAllocatorTest, MakeUniqueAligned) { + int counter = 0; + { + // Creates the object, initializes it with the explicit constructor and + // returns an unique_ptr to it. + auto ptr = MakeUniqueAligned>(&counter); + EXPECT_EQ(1, counter); + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', ptr->data_[0]); + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArray) { + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArray>(7, &counter); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, AllocSingleInt) { + auto ptr = AllocateAligned(1); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast(ptr.get()) % kMaxVectorSize); + // Force delete of the unique_ptr now to check that it doesn't crash. + ptr.reset(nullptr); + EXPECT_EQ(nullptr, ptr.get()); +} + +TEST(AlignedAllocatorTest, AllocMultipleInt) { + const size_t kSize = 7777; + auto ptr = AllocateAligned(kSize); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast(ptr.get()) % kMaxVectorSize); + // ptr[i] is actually (*ptr.get())[i] which will use the operator[] of the + // underlying type chosen by AllocateAligned() for the std::unique_ptr. + EXPECT_EQ(&(ptr[0]) + 1, &(ptr[1])); + + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using ptr[] to prevent it being optimized away. + ptr[i] = static_cast(i); + if (i) ret += ptr[i] * ptr[i - 1]; + } + EXPECT_NE(0U, ret); +} + +TEST(AlignedAllocatorTest, AllocateAlignedObjectWithoutDestructor) { + int counter = 0; + { + // This doesn't call the constructor. + auto obj = AllocateAligned>(1); + obj[0].counter_ = &counter; + } + // Destroying the unique_ptr shouldn't have called the destructor of the + // SampleObject<24>. + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArrayWithCustomAlloc) { + FakeAllocator fake_alloc; + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArrayWithAlloc>( + 7, FakeAllocator::StaticAlloc, FakeAllocator::StaticFree, &fake_alloc, + &counter); + // An array shold still only call a single allocation. + EXPECT_EQ(1u, fake_alloc.PendingAllocs()); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); + EXPECT_EQ(0u, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, DefaultInit) { + // The test is whether this compiles. Default-init is useful for output params + // and per-thread storage. + std::vector> ptrs; + std::vector> free_ptrs; + ptrs.resize(128); + free_ptrs.resize(128); + // The following is to prevent elision of the pointers. + std::mt19937 rng(129); // Emscripten lacks random_device. + std::uniform_int_distribution dist(0, 127); + ptrs[dist(rng)] = MakeUniqueAlignedArray(123); + free_ptrs[dist(rng)] = AllocateAligned(456); + // "Use" pointer without resorting to printf. 0 == 0. Can't shift by 64. + const auto addr1 = reinterpret_cast(ptrs[dist(rng)].get()); + const auto addr2 = reinterpret_cast(free_ptrs[dist(rng)].get()); + constexpr size_t kBits = sizeof(uintptr_t) * 8; + EXPECT_EQ((addr1 >> (kBits - 1)) >> (kBits - 1), + (addr2 >> (kBits - 1)) >> (kBits - 1)); +} + +} // namespace hwy diff --git a/third_party/highway/hwy/base.h b/third_party/highway/hwy/base.h new file mode 100644 index 000000000000..e7a859fcef52 --- /dev/null +++ b/third_party/highway/hwy/base.h @@ -0,0 +1,635 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BASE_H_ +#define HIGHWAY_HWY_BASE_H_ + +// For SIMD module implementations and their callers, target-independent. + +#include +#include + +#include +#include + +// Add to #if conditions to prevent IDE from graying out code. +#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \ + (defined Q_CREATOR_RUN) || (defined(__CLANGD__)) +#define HWY_IDE 1 +#else +#define HWY_IDE 0 +#endif + +//------------------------------------------------------------------------------ +// Detect compiler using predefined macros + +// clang-cl defines _MSC_VER but doesn't behave like MSVC in other aspects like +// used in HWY_DIAGNOSTICS(). We include a check that we are not clang for that +// purpose. +#if defined(_MSC_VER) && !defined(__clang__) +#define HWY_COMPILER_MSVC _MSC_VER +#else +#define HWY_COMPILER_MSVC 0 +#endif + +#ifdef __INTEL_COMPILER +#define HWY_COMPILER_ICC __INTEL_COMPILER +#else +#define HWY_COMPILER_ICC 0 +#endif + +#ifdef __GNUC__ +#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define HWY_COMPILER_GCC 0 +#endif + +// Clang can masquerade as MSVC/GCC, in which case both are set. +#ifdef __clang__ +#ifdef __APPLE__ +// Apple LLVM version is unrelated to the actual Clang version, which we need +// for enabling workarounds. Use the presence of warning flags to deduce it. +// Adapted from https://github.com/simd-everywhere/simde/ simde-detect-clang.h. +#if __has_warning("-Wformat-insufficient-args") +#define HWY_COMPILER_CLANG 1200 +#elif __has_warning("-Wimplicit-const-int-float-conversion") +#define HWY_COMPILER_CLANG 1100 +#elif __has_warning("-Wmisleading-indentation") +#define HWY_COMPILER_CLANG 1000 +#elif defined(__FILE_NAME__) +#define HWY_COMPILER_CLANG 900 +#elif __has_warning("-Wextra-semi-stmt") || \ + __has_builtin(__builtin_rotateleft32) +#define HWY_COMPILER_CLANG 800 +#elif __has_warning("-Wc++98-compat-extra-semi") +#define HWY_COMPILER_CLANG 700 +#else // Anything older than 7.0 is not recommended for Highway. +#define HWY_COMPILER_CLANG 600 +#endif // __has_warning chain +#else // Non-Apple: normal version +#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +#endif +#else // Not clang +#define HWY_COMPILER_CLANG 0 +#endif + +// More than one may be nonzero, but we want at least one. +#if !HWY_COMPILER_MSVC && !HWY_COMPILER_ICC && !HWY_COMPILER_GCC && \ + !HWY_COMPILER_CLANG +#error "Unsupported compiler" +#endif + +//------------------------------------------------------------------------------ +// Compiler-specific definitions + +#define HWY_STR_IMPL(macro) #macro +#define HWY_STR(macro) HWY_STR_IMPL(macro) + +#if HWY_COMPILER_MSVC + +#include + +#define HWY_RESTRICT __restrict +#define HWY_INLINE __forceinline +#define HWY_NOINLINE __declspec(noinline) +#define HWY_FLATTEN +#define HWY_NORETURN __declspec(noreturn) +#define HWY_LIKELY(expr) (expr) +#define HWY_UNLIKELY(expr) (expr) +#define HWY_PRAGMA(tokens) __pragma(tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) +#define HWY_MAYBE_UNUSED +#define HWY_HAS_ASSUME_ALIGNED 0 +#if (_MSC_VER >= 1700) +#define HWY_MUST_USE_RESULT _Check_return_ +#else +#define HWY_MUST_USE_RESULT +#endif + +#else + +#define HWY_RESTRICT __restrict__ +#define HWY_INLINE inline __attribute__((always_inline)) +#define HWY_NOINLINE __attribute__((noinline)) +#define HWY_FLATTEN __attribute__((flatten)) +#define HWY_NORETURN __attribute__((noreturn)) +#define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#define HWY_PRAGMA(tokens) _Pragma(#tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define HWY_MAYBE_UNUSED __attribute__((unused)) +#define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) + +#endif // !HWY_COMPILER_MSVC + +//------------------------------------------------------------------------------ +// Builtin/attributes + +#ifdef __has_builtin +#define HWY_HAS_BUILTIN(name) __has_builtin(name) +#else +#define HWY_HAS_BUILTIN(name) 0 +#endif + +#ifdef __has_attribute +#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name) +#else +#define HWY_HAS_ATTRIBUTE(name) 0 +#endif + +// Enables error-checking of format strings. +#if HWY_HAS_ATTRIBUTE(__format__) +#define HWY_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define HWY_FORMAT(idx_fmt, idx_arg) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if HWY_HAS_BUILTIN(__builtin_assume_aligned) +#define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +// Clang and GCC require attributes on each function into which SIMD intrinsics +// are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and +// automatic annotation via pragmas. +#if HWY_COMPILER_CLANG +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ + apply_to = function)) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) +#elif HWY_COMPILER_GCC +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) +#else +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#endif + +//------------------------------------------------------------------------------ +// Detect architecture using predefined macros + +#if defined(__i386__) || defined(_M_IX86) +#define HWY_ARCH_X86_32 1 +#else +#define HWY_ARCH_X86_32 0 +#endif + +#if defined(__x86_64__) || defined(_M_X64) +#define HWY_ARCH_X86_64 1 +#else +#define HWY_ARCH_X86_64 0 +#endif + +#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64 +#define HWY_ARCH_X86 1 +#else +#define HWY_ARCH_X86 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define HWY_ARCH_PPC 1 +#else +#define HWY_ARCH_PPC 0 +#endif + +#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) +#define HWY_ARCH_ARM_A64 1 +#else +#define HWY_ARCH_ARM_A64 0 +#endif + +#if defined(__arm__) || defined(_M_ARM) +#define HWY_ARCH_ARM_V7 1 +#else +#define HWY_ARCH_ARM_V7 0 +#endif + +#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7 +#error "Cannot have both A64 and V7" +#endif + +#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7 +#define HWY_ARCH_ARM 1 +#else +#define HWY_ARCH_ARM 0 +#endif + +#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__) +#define HWY_ARCH_WASM 1 +#else +#define HWY_ARCH_WASM 0 +#endif + +#ifdef __riscv +#define HWY_ARCH_RVV 1 +#else +#define HWY_ARCH_RVV 0 +#endif + +#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_WASM + \ + HWY_ARCH_RVV) != 1 +#error "Must detect exactly one platform" +#endif + +//------------------------------------------------------------------------------ +// Macros + +#define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED + +#define HWY_CONCAT_IMPL(a, b) a##b +#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) + +#define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) +#define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) + +// Compile-time fence to prevent undesirable code reordering. On Clang x86, the +// typical asm volatile("" : : : "memory") has no effect, whereas atomic fence +// does, without generating code. +#if HWY_ARCH_X86 +#define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel) +#else +// TODO(janwas): investigate alternatives. On ARM, the above generates barriers. +#define HWY_FENCE +#endif + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define HWY_REP4(literal) literal, literal, literal, literal + +#define HWY_ABORT(format, ...) \ + ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) + +// Always enabled. +#define HWY_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s", #condition); \ + } \ + } while (0) + +// Only for "debug" builds +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || \ + defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) +#define HWY_DASSERT(condition) HWY_ASSERT(condition) +#else +#define HWY_DASSERT(condition) \ + do { \ + } while (0) +#endif + + +namespace hwy { + +//------------------------------------------------------------------------------ +// Alignment + +// Not guaranteed to be an upper bound, but the alignment established by +// aligned_allocator is HWY_MAX(HWY_ALIGNMENT, kMaxVectorSize). +#if HWY_ARCH_X86 +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 +#define HWY_ALIGN_MAX alignas(64) +#elif HWY_ARCH_RVV +// Not actually an upper bound on the size, but this value prevents crossing a +// 4K boundary (relevant on Andes). +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; +#define HWY_ALIGN_MAX alignas(8) // only elements need be aligned +#else +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; +#define HWY_ALIGN_MAX alignas(16) +#endif + +//------------------------------------------------------------------------------ +// Lane types + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. + +// RVV already has a builtin type and the GCC intrinsics require it. +#if HWY_ARCH_RVV && HWY_COMPILER_GCC +using float16_t = __fp16; +// Clang does not allow __fp16 arguments, but scalar.h requires LaneType +// arguments, so use a wrapper. +// TODO(janwas): replace with _Float16 when that is supported? +#else +#pragma pack(push, 1) +struct float16_t { + uint16_t bits; +}; +#pragma pack(pop) +#endif + +using float32_t = float; +using float64_t = double; + +//------------------------------------------------------------------------------ +// Controlling overload resolution (SFINAE) + +template +struct EnableIfT {}; +template +struct EnableIfT { + using type = T; +}; + +template +using EnableIf = typename EnableIfT::type; + +// Insert into template/function arguments to enable this overload only for +// vectors of AT MOST this many bits. +// +// Note that enabling for exactly 128 bits is unnecessary because a function can +// simply be overloaded with Vec128 and Full128 descriptor. Enabling for +// other sizes (e.g. 64 bit) can be achieved with Simd. +#define HWY_IF_LE128(T, N) hwy::EnableIf* = nullptr +#define HWY_IF_LE64(T, N) hwy::EnableIf* = nullptr +#define HWY_IF_LE32(T, N) hwy::EnableIf* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf() && !IsFloat()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr + +#define HWY_IF_LANE_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_NOT_LANE_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +//------------------------------------------------------------------------------ +// Type traits + +template +constexpr bool IsFloat() { + return T(1.25) != T(1); +} + +template +constexpr bool IsSigned() { + return T(0) > T(-1); +} + +// Largest/smallest representable integer values. +template +constexpr T LimitsMax() { + static_assert(!IsFloat(), "Only for integer types"); + return IsSigned() ? T((1ULL << (sizeof(T) * 8 - 1)) - 1) + : static_cast(~0ull); +} +template +constexpr T LimitsMin() { + static_assert(!IsFloat(), "Only for integer types"); + return IsSigned() ? T(-1) - LimitsMax() : T(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits::min() (the smallest positive value). +template +constexpr T LowestValue() { + return LimitsMin(); +} +template <> +constexpr float LowestValue() { + return -FLT_MAX; +} +template <> +constexpr double LowestValue() { + return -DBL_MAX; +} + +template +constexpr T HighestValue() { + return LimitsMax(); +} +template <> +constexpr float HighestValue() { + return FLT_MAX; +} +template <> +constexpr double HighestValue() { + return DBL_MAX; +} + +// Returns bitmask of the exponent field in IEEE binary32/64. +template +constexpr T ExponentMask() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr uint32_t ExponentMask() { + return 0x7F800000; +} +template <> +constexpr uint64_t ExponentMask() { + return 0x7FF0000000000000ULL; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template +constexpr T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr float MantissaEnd() { + return 8388608.0f; // 1 << 23 +} +template <> +constexpr double MantissaEnd() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + +//------------------------------------------------------------------------------ +// Type relations + +namespace detail { + +template +struct Relations; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = uint16_t; +}; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = int16_t; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = uint32_t; + using Narrow = uint8_t; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = int32_t; + using Narrow = int8_t; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = uint64_t; + using Narrow = uint16_t; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = int64_t; + using Narrow = int16_t; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = uint32_t; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = int32_t; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = float; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = double; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = float; +}; + +} // namespace detail + +// Aliases for types of a different category, but the same size. +template +using MakeUnsigned = typename detail::Relations::Unsigned; +template +using MakeSigned = typename detail::Relations::Signed; +template +using MakeFloat = typename detail::Relations::Float; + +// Aliases for types of the same category, but different size. +template +using MakeWide = typename detail::Relations::Wide; +template +using MakeNarrow = typename detail::Relations::Narrow; + +//------------------------------------------------------------------------------ +// Helper functions + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanForward(&index, x); + return index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t PopCount(uint64_t x) { +#if HWY_COMPILER_CLANG || HWY_COMPILER_GCC + return static_cast(__builtin_popcountll(x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _mm_popcnt_u64(x); +#elif HWY_COMPILER_MSVC + return _mm_popcnt_u32(uint32_t(x)) + _mm_popcnt_u32(uint32_t(x >> 32)); +#else + x -= ((x >> 1) & 0x55555555U); + x = (((x >> 2) & 0x33333333U) + (x & 0x33333333U)); + x = (((x >> 4) + x) & 0x0F0F0F0FU); + x += (x >> 8); + x += (x >> 16); + x += (x >> 32); + x = x & 0x0000007FU; + return (unsigned int)x; +#endif +} + +// The source/destination must not overlap/alias. +template +HWY_API void CopyBytes(const From* from, To* to) { +#if HWY_COMPILER_MSVC + const uint8_t* HWY_RESTRICT from_bytes = + reinterpret_cast(from); + uint8_t* HWY_RESTRICT to_bytes = reinterpret_cast(to); + for (size_t i = 0; i < kBytes; ++i) { + to_bytes[i] = from_bytes[i]; + } +#else + // Avoids horrible codegen on Clang (series of PINSRB) + __builtin_memcpy(to, from, kBytes); +#endif +} + +HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...); + +} // namespace hwy + +#endif // HIGHWAY_HWY_BASE_H_ diff --git a/third_party/highway/hwy/base_test.cc b/third_party/highway/hwy/base_test.cc new file mode 100644 index 000000000000..19e0b6f5445c --- /dev/null +++ b/third_party/highway/hwy/base_test.cc @@ -0,0 +1,123 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "base_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +HWY_NOINLINE void TestAllLimits() { + HWY_ASSERT_EQ(uint8_t(0), LimitsMin()); + HWY_ASSERT_EQ(uint16_t(0), LimitsMin()); + HWY_ASSERT_EQ(uint32_t(0), LimitsMin()); + HWY_ASSERT_EQ(uint64_t(0), LimitsMin()); + + HWY_ASSERT_EQ(int8_t(-128), LimitsMin()); + HWY_ASSERT_EQ(int16_t(-32768), LimitsMin()); + HWY_ASSERT_EQ(int32_t(0x80000000u), LimitsMin()); + HWY_ASSERT_EQ(int64_t(0x8000000000000000ull), LimitsMin()); + + HWY_ASSERT_EQ(uint8_t(0xFF), LimitsMax()); + HWY_ASSERT_EQ(uint16_t(0xFFFF), LimitsMax()); + HWY_ASSERT_EQ(uint32_t(0xFFFFFFFFu), LimitsMax()); + HWY_ASSERT_EQ(uint64_t(0xFFFFFFFFFFFFFFFFull), LimitsMax()); + + HWY_ASSERT_EQ(int8_t(0x7F), LimitsMax()); + HWY_ASSERT_EQ(int16_t(0x7FFF), LimitsMax()); + HWY_ASSERT_EQ(int32_t(0x7FFFFFFFu), LimitsMax()); + HWY_ASSERT_EQ(int64_t(0x7FFFFFFFFFFFFFFFull), LimitsMax()); +} + +struct TestLowestHighest { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + HWY_ASSERT_EQ(std::numeric_limits::lowest(), LowestValue()); + HWY_ASSERT_EQ(std::numeric_limits::max(), HighestValue()); + } +}; + +HWY_NOINLINE void TestAllLowestHighest() { ForAllTypes(TestLowestHighest()); } +struct TestIsUnsigned { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat(), "Expected !IsFloat"); + static_assert(!IsSigned(), "Expected !IsSigned"); + } +}; + +struct TestIsSigned { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat(), "Expected !IsFloat"); + static_assert(IsSigned(), "Expected IsSigned"); + } +}; + +struct TestIsFloat { + template + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(IsFloat(), "Expected IsFloat"); + static_assert(IsSigned(), "Floats are also considered signed"); + } +}; + +HWY_NOINLINE void TestAllType() { + ForUnsignedTypes(TestIsUnsigned()); + ForSignedTypes(TestIsSigned()); + ForFloatTypes(TestIsFloat()); +} + +HWY_NOINLINE void TestAllPopCount() { + HWY_ASSERT_EQ(size_t(0), PopCount(0u)); + HWY_ASSERT_EQ(size_t(1), PopCount(1u)); + HWY_ASSERT_EQ(size_t(1), PopCount(2u)); + HWY_ASSERT_EQ(size_t(2), PopCount(3u)); + HWY_ASSERT_EQ(size_t(1), PopCount(0x80000000u)); + HWY_ASSERT_EQ(size_t(31), PopCount(0x7FFFFFFFu)); + HWY_ASSERT_EQ(size_t(32), PopCount(0xFFFFFFFFu)); + + HWY_ASSERT_EQ(size_t(1), PopCount(0x80000000ull)); + HWY_ASSERT_EQ(size_t(31), PopCount(0x7FFFFFFFull)); + HWY_ASSERT_EQ(size_t(32), PopCount(0xFFFFFFFFull)); + HWY_ASSERT_EQ(size_t(33), PopCount(0x10FFFFFFFFull)); + HWY_ASSERT_EQ(size_t(63), PopCount(0xFFFEFFFFFFFFFFFFull)); + HWY_ASSERT_EQ(size_t(64), PopCount(0xFFFFFFFFFFFFFFFFull)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(BaseTest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLimits); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLowestHighest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllType); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllPopCount); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/cache_control.h b/third_party/highway/hwy/cache_control.h new file mode 100644 index 000000000000..6581e640e92b --- /dev/null +++ b/third_party/highway/hwy/cache_control.h @@ -0,0 +1,88 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_ +#define HIGHWAY_HWY_CACHE_CONTROL_H_ + +#include +#include + +#include "hwy/base.h" + +// Requires SSE2; fails to compile on 32-bit Clang 7 (see +// https://github.com/gperftools/gperftools/issues/946). +#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32) +#undef HWY_DISABLE_CACHE_CONTROL +#define HWY_DISABLE_CACHE_CONTROL +#endif + +// intrin.h is sufficient on MSVC and already included by base.h. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#include // SSE2 +#endif + +namespace hwy { + +// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size. +#define HWY_STREAM_MULTIPLE 16 + +// The following functions may also require an attribute. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#define HWY_ATTR_CACHE __attribute__((target("sse2"))) +#else +#define HWY_ATTR_CACHE +#endif + +// Delays subsequent loads until prior loads are visible. On Intel CPUs, also +// serves as a full fence (waits for all prior instructions to complete). +// No effect on non-x86. +HWY_INLINE HWY_ATTR_CACHE void LoadFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_lfence(); +#endif +} + +// Ensures previous weakly-ordered stores are visible. No effect on non-x86. +HWY_INLINE HWY_ATTR_CACHE void StoreFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_sfence(); +#endif +} + +// Begins loading the cache line containing "p". +template +HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_prefetch(reinterpret_cast(p), _MM_HINT_T0); +#elif HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not + // desirable, so use the default 3 (keep in caches). + __builtin_prefetch(p, /*write=*/0, /*hint=*/3); +#else + (void)p; +#endif +} + +// Invalidates and flushes the cache line containing "p". No effect on non-x86. +HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_clflush(p); +#else + (void)p; +#endif +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/third_party/highway/hwy/examples/benchmark.cc b/third_party/highway/hwy/examples/benchmark.cc new file mode 100644 index 000000000000..13f408386c30 --- /dev/null +++ b/third_party/highway/hwy/examples/benchmark.cc @@ -0,0 +1,243 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/examples/benchmark.cc" +#include "hwy/foreach_target.h" + +#include +#include + +#include +#include // iota + +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightBytes; +#endif + +class TwoArray { + public: + // Passed to ctor as a value NOT known to the compiler. Must be a multiple of + // the vector lane count * 8. + static size_t NumItems() { return 3456; } + + explicit TwoArray(const size_t num_items) + : a_(AllocateAligned(num_items * 2)), b_(a_.get() + num_items) { + const float init = num_items / NumItems(); // 1, but compiler doesn't know + std::iota(a_.get(), a_.get() + num_items, init); + std::iota(b_, b_ + num_items, init); + } + + protected: + AlignedFreeUniquePtr a_; + float* b_; +}; + +// Measures durations, verifies results, prints timings. +template +void RunBenchmark(const char* caption) { + printf("%10s: ", caption); + const size_t kNumInputs = 1; + const size_t num_items = Benchmark::NumItems() * Unpredictable1(); + const FuncInput inputs[kNumInputs] = {num_items}; + Result results[kNumInputs]; + + Benchmark benchmark(num_items); + + Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&benchmark](const FuncInput input) { return benchmark(input); }, inputs, + kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + + benchmark.Verify(num_items); + + for (size_t i = 0; i < num_results; ++i) { + const double cycles_per_item = results[i].ticks / results[i].input; + const double mad = results[i].variability * cycles_per_item; + printf("%6zu: %6.3f (+/- %5.3f)\n", results[i].input, cycles_per_item, mad); + } +} + +void Intro() { + HWY_ALIGN const float in[16] = {1, 2, 3, 4, 5, 6}; + HWY_ALIGN float out[16]; + HWY_FULL(float) d; // largest possible vector + for (size_t i = 0; i < 16; i += Lanes(d)) { + const auto vec = Load(d, in + i); // aligned! + auto result = vec * vec; + result += result; // can update if not const + Store(result, d, out + i); + } + printf("\nF(x)->2*x^2, F(%.0f) = %.1f\n", in[2], out[2]); +} + +// BEGINNER: dot product +// 0.4 cyc/float = bronze, 0.25 = silver, 0.15 = gold! +class BenchmarkDot : public TwoArray { + public: + explicit BenchmarkDot(size_t num_items) : TwoArray(num_items), dot_{-1.0f} {} + + FuncOutput operator()(const size_t num_items) { + HWY_FULL(float) d; + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + constexpr int unroll = 8; + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // Some older compilers might not be able to fit the 8 arrays in registers, + // so manual unrolling can be helpfull if you run into this issue. + // 2 FMA ports * 4 cycle latency = 8x unrolled. + V sum[unroll]; + for (int i = 0; i < unroll; ++i) { + sum[i] = Zero(d); + } + const float* const HWY_RESTRICT pa = &a_[0]; + const float* const HWY_RESTRICT pb = b_; + for (size_t i = 0; i < num_items; i += unroll * N) { + for (int j = 0; j < unroll; ++j) { + const auto a = Load(d, pa + i + j * N); + const auto b = Load(d, pb + i + j * N); + sum[j] = MulAdd(a, b, sum[j]); + } + } + // Reduction tree: sum of all accumulators by pairs into sum[0], then the + // lanes. + for (int power = 1; power < unroll; power *= 2) { + for (int i = 0; i < unroll; i += 2 * power) { + sum[i] += sum[i + power]; + } + } + return dot_ = GetLane(SumOfLanes(sum[0])); + } + void Verify(size_t num_items) { + if (dot_ == -1.0f) { + fprintf(stderr, "Dot: must call Verify after benchmark"); + abort(); + } + + const float expected = + std::inner_product(a_.get(), a_.get() + num_items, b_, 0.0f); + const float rel_err = std::abs(expected - dot_) / expected; + if (rel_err > 1.1E-6f) { + fprintf(stderr, "Dot: expected %e actual %e (%e)\n", expected, dot_, + rel_err); + abort(); + } + } + + private: + float dot_; // for Verify +}; + +// INTERMEDIATE: delta coding +// 1.0 cycles/float = bronze, 0.7 = silver, 0.4 = gold! +struct BenchmarkDelta : public TwoArray { + explicit BenchmarkDelta(size_t num_items) : TwoArray(num_items) {} + + FuncOutput operator()(const size_t num_items) const { +#if HWY_TARGET == HWY_SCALAR + b_[0] = a_[0]; + for (size_t i = 1; i < num_items; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } +#elif HWY_CAP_GE256 + // Larger vectors are split into 128-bit blocks, easiest to use the + // unaligned load support to shift between them. + const HWY_FULL(float) df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + for (; i < num_items; i += N) { + const auto a = Load(df, &a_[i]); + const auto shifted = LoadU(df, &a_[i - 1]); + Store(a - shifted, df, &b_[i]); + } +#else // 128-bit + // Slightly better than unaligned loads + const HWY_CAPPED(float, 4) df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + auto prev = Load(df, &a_[0]); + for (; i < num_items; i += Lanes(df)) { + const auto a = Load(df, &a_[i]); + const auto shifted = CombineShiftRightLanes<3>(a, prev); + prev = a; + Store(a - shifted, df, &b_[i]); + } +#endif + return b_[num_items - 1]; + } + + void Verify(size_t num_items) { + for (size_t i = 0; i < num_items; ++i) { + const float expected = (i == 0) ? a_[0] : a_[i] - a_[i - 1]; + const float err = std::abs(expected - b_[i]); + if (err > 1E-6f) { + fprintf(stderr, "Delta: expected %e, actual %e\n", expected, b_[i]); + } + } + } +}; + +void RunBenchmarks() { + Intro(); + printf("------------------------ %s\n", TargetName(HWY_TARGET)); + RunBenchmark("dot"); + RunBenchmark("delta"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_EXPORT(RunBenchmarks); + +void Run() { + for (uint32_t target : SupportedAndGeneratedTargets()) { + SetSupportedTargetsForTest(target); + HWY_DYNAMIC_DISPATCH(RunBenchmarks)(); + } + SetSupportedTargetsForTest(0); // Reset the mask afterwards. +} + +} // namespace hwy + +int main(int /*argc*/, char** /*argv*/) { + hwy::Run(); + return 0; +} +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/examples/skeleton-inl.h b/third_party/highway/hwy/examples/skeleton-inl.h new file mode 100644 index 000000000000..d8136be4f503 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton-inl.h @@ -0,0 +1,62 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo of functions that might be called from multiple SIMD modules (either +// other -inl.h files, or a .cc file between begin/end_target-inl). This is +// optional - all SIMD code can reside in .cc files. However, this allows +// splitting code into different files while still inlining instead of requiring +// calling through function pointers. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +// It is fine to #include normal or *-inl headers. +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +using namespace hwy::HWY_NAMESPACE; + +// Example of a type-agnostic (caller-specified lane type) and width-agnostic +// (uses best available instruction set) function in a header. +// +// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size. +template +HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array, + const T* HWY_RESTRICT add_array, + const size_t size, T* HWY_RESTRICT x_array) { + for (size_t i = 0; i < size; i += Lanes(d)) { + const auto mul = Load(d, mul_array + i); + const auto add = Load(d, add_array + i); + auto x = Load(d, x_array + i); + x = MulAdd(mul, x, add); + Store(x, d, x_array + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#endif // include guard diff --git a/third_party/highway/hwy/examples/skeleton.cc b/third_party/highway/hwy/examples/skeleton.cc new file mode 100644 index 000000000000..df021f970cd7 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton.cc @@ -0,0 +1,103 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/examples/skeleton.h" + +#include +#include + +// First undef to prevent error when re-included. +#undef HWY_TARGET_INCLUDE +// For runtime dispatch, specify the name of the current file (unfortunately +// __FILE__ is not reliable) so that foreach_target.h can re-include it. +#define HWY_TARGET_INCLUDE "hwy/examples/skeleton.cc" +// Generates code for each enabled target by re-including this source file. +#include "hwy/foreach_target.h" + +#include "hwy/highway.h" + +// Optional, can instead add HWY_ATTR to all functions. +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +using namespace hwy::HWY_NAMESPACE; + +// Computes log2 by converting to a vector of floats. Compiled once per target. +template +HWY_NOINLINE void OneFloorLog2(const DF df, const uint8_t* HWY_RESTRICT values, + uint8_t* HWY_RESTRICT log2) { + // Type tags for converting to other element types (Rebind = same count). + const Rebind d32; + const Rebind d8; + + const auto u8 = Load(d8, values); + const auto bits = BitCast(d32, ConvertTo(df, PromoteTo(d32, u8))); + const auto exponent = ShiftRight<23>(bits) - Set(d32, 127); + Store(DemoteTo(d8, exponent), d8, log2); +} + +HWY_NOINLINE void CodepathDemo() { + // Highway defaults to portability, but per-target codepaths may be selected + // via #if HWY_TARGET == HWY_SSE4 or by testing capability macros: +#if HWY_CAP_INTEGER64 + const char* gather = "Has int64"; +#else + const char* gather = "No int64"; +#endif + printf("Target %s: %s\n", hwy::TargetName(HWY_TARGET), gather); +} + +HWY_NOINLINE void FloorLog2(const uint8_t* HWY_RESTRICT values, size_t count, + uint8_t* HWY_RESTRICT log2) { + CodepathDemo(); + + // Second argument is necessary on RVV until it supports fractional lengths. + HWY_FULL(float, 4) df; + + // Caller padded memory to a multiple of Lanes(). If that is not possible, + // we could use blended or masked stores, see README.md. + const size_t N = Lanes(df); + for (size_t i = 0; i < count; i += N) { + OneFloorLog2(df, values + i, log2 + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace skeleton { + +// This macro declares a static array used for dynamic dispatch; it resides in +// the same outer namespace that contains FloorLog2. +HWY_EXPORT(FloorLog2); + +// This function is optional and only needed in the case of exposing it in the +// header file. Otherwise using HWY_DYNAMIC_DISPATCH(FloorLog2) in this module +// is equivalent to inlining this function. +void CallFloorLog2(const uint8_t* HWY_RESTRICT in, const size_t count, + uint8_t* HWY_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(FloorLog2)(in, count, out); +} + +// Optional: anything to compile only once, e.g. non-SIMD implementations of +// public functions provided by this module, can go inside #if HWY_ONCE. + +} // namespace skeleton +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/examples/skeleton.h b/third_party/highway/hwy/examples/skeleton.h new file mode 100644 index 000000000000..4935b881ebba --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton.h @@ -0,0 +1,35 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo interface to target-specific code in skeleton.cc + +// Normal header with include guard and namespace. +#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_ +#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_ + +#include + +// Platform-specific definitions used for declaring an interface, independent of +// the SIMD instruction set. +#include "hwy/base.h" // HWY_RESTRICT + +namespace skeleton { + +// Computes base-2 logarithm by converting to float. Supports dynamic dispatch. +void CallFloorLog2(const uint8_t* HWY_RESTRICT in, const size_t count, + uint8_t* HWY_RESTRICT out); + +} // namespace skeleton + +#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_ diff --git a/third_party/highway/hwy/examples/skeleton_test.cc b/third_party/highway/hwy/examples/skeleton_test.cc new file mode 100644 index 000000000000..1173eabea264 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton_test.cc @@ -0,0 +1,104 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Example of unit test for the "skeleton" library. + +#include "hwy/examples/skeleton.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "examples/skeleton_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Optional: factor out parts of the implementation into *-inl.h +#include "hwy/examples/skeleton-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +using namespace hwy::HWY_NAMESPACE; + +// Calls function defined in skeleton.cc. +struct TestFloorLog2 { + template + HWY_NOINLINE void operator()(T /*unused*/, DF df) { + const size_t count = 5 * Lanes(df); + auto in = hwy::AllocateAligned(count); + auto expected = hwy::AllocateAligned(count); + + hwy::RandomState rng; + for (size_t i = 0; i < count; ++i) { + expected[i] = Random32(&rng) & 7; + in[i] = static_cast(1u << expected[i]); + } + auto out = hwy::AllocateAligned(count); + CallFloorLog2(in.get(), count, out.get()); + int sum = 0; + for (size_t i = 0; i < count; ++i) { + HWY_ASSERT_EQ(expected[i], out[i]); + sum += out[i]; + } + hwy::PreventElision(sum); + } +}; + +HWY_NOINLINE void TestAllFloorLog2() { + ForFullVectors()(float()); +} + +// Calls function defined in skeleton-inl.h. +struct TestSumMulAdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + hwy::RandomState rng; + const size_t count = 4096; + EXPECT_TRUE(count % Lanes(d) == 0); + auto mul = hwy::AllocateAligned(count); + auto x = hwy::AllocateAligned(count); + auto add = hwy::AllocateAligned(count); + for (size_t i = 0; i < count; ++i) { + mul[i] = static_cast(Random32(&rng) & 0xF); + x[i] = static_cast(Random32(&rng) & 0xFF); + add[i] = static_cast(Random32(&rng) & 0xFF); + } + double expected_sum = 0.0; + for (size_t i = 0; i < count; ++i) { + expected_sum += mul[i] * x[i] + add[i]; + } + + MulAddLoop(d, mul.get(), add.get(), count, x.get()); + HWY_ASSERT_EQ(4344240.0, expected_sum); + } +}; + +HWY_NOINLINE void TestAllSumMulAdd() { + ForFloatTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace skeleton { +HWY_BEFORE_TEST(SkeletonTest); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllFloorLog2); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllSumMulAdd); +} // namespace skeleton +#endif diff --git a/third_party/highway/hwy/foreach_target.h b/third_party/highway/hwy/foreach_target.h new file mode 100644 index 000000000000..a0c4198b17c3 --- /dev/null +++ b/third_party/highway/hwy/foreach_target.h @@ -0,0 +1,161 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_ +#define HIGHWAY_HWY_FOREACH_TARGET_H_ + +// Re-includes the translation unit zero or more times to compile for any +// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that +// highway.h defines the corresponding macro/namespace. + +#include "hwy/targets.h" + +// *_inl.h may include other headers, which requires include guards to prevent +// repeated inclusion. The guards must be reset after compiling each target, so +// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE, +// defining it if undefined and vice versa. This macro is initially undefined +// so that IDEs don't gray out the contents of each header. +#ifdef HWY_TARGET_TOGGLE +#error "This macro must not be defined outside foreach_target.h" +#endif + +#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard +// Trigger fixup at the bottom of this header. +#define HWY_ALREADY_INCLUDED + +// The next highway.h must re-include set_macros-inl.h because the first +// highway.h chose the static target instead of what we will set below. +#undef HWY_SET_MACROS_PER_TARGET +#endif + +// Disable HWY_EXPORT in user code until we have generated all targets. Note +// that a subsequent highway.h will not override this definition. +#undef HWY_ONCE +#define HWY_ONCE (0 || HWY_IDE) + +// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE; +// also skip if only 1 target defined (no re-inclusion will be necessary). +#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +#if !defined(HWY_TARGET_INCLUDE) +#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h" +#endif + +#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR) +#undef HWY_TARGET +#define HWY_TARGET HWY_SCALAR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC8 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +// Now that all but the static target have been generated, re-enable HWY_EXPORT. +#undef HWY_ONCE +#define HWY_ONCE 1 + +// If we re-include once per enabled target, the translation unit's +// implementation would have to be skipped via #if to avoid redefining symbols. +// We instead skip the re-include for HWY_STATIC_TARGET, and generate its +// implementation when resuming compilation of the translation unit. +#undef HWY_TARGET +#define HWY_TARGET HWY_STATIC_TARGET + +#ifdef HWY_ALREADY_INCLUDED +// Revert the previous toggle to prevent redefinitions for the static target. +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif + +// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored. +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif +#endif + +#endif // HIGHWAY_HWY_FOREACH_TARGET_H_ diff --git a/third_party/highway/hwy/highway.h b/third_party/highway/hwy/highway.h new file mode 100644 index 000000000000..fdfd3b11825f --- /dev/null +++ b/third_party/highway/hwy/highway.h @@ -0,0 +1,337 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This include guard is checked by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included +// after/outside this include guard. +#ifndef HWY_HIGHWAY_INCLUDED +#define HWY_HIGHWAY_INCLUDED + +// Main header required before using vector types. + +#include "hwy/base.h" +#include "hwy/targets.h" + +namespace hwy { + +// API version (https://semver.org/) +#define HWY_MAJOR 0 +#define HWY_MINOR 12 +#define HWY_PATCH 0 + +//------------------------------------------------------------------------------ +// Shorthand for descriptors (defined in shared-inl.h) used to select overloads. + +// Because Highway functions take descriptor and/or vector arguments, ADL finds +// these functions without requiring users in project::HWY_NAMESPACE to +// qualify Highway functions with hwy::HWY_NAMESPACE. However, ADL rules for +// templates require `using hwy::HWY_NAMESPACE::ShiftLeft;` etc. declarations. + +// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of +// registers in the group, and is ignored on targets that do not support groups. +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::Simd +#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 +// Workaround for MSVC grouping __VA_ARGS__ into a single argument +#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren +// Trailing comma avoids -pedantic false alarm +#define HWY_CHOOSE_FULL(...) \ + HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) +#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) + +// Vector of up to MAX_N lanes. +#define HWY_CAPPED(T, MAX_N) \ + hwy::HWY_NAMESPACE::Simd + +//------------------------------------------------------------------------------ +// Export user functions for static/dynamic dispatch + +// Evaluates to 0 inside a translation unit if it is generating anything but the +// static target (the last one if multiple targets are enabled). Used to prevent +// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only +// compile once anyway, so this is 1 unless it is or has been included. +#ifndef HWY_ONCE +#define HWY_ONCE 1 +#endif + +// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for +// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is +// defined), and can be used to deduce the return type of Choose*. +#if HWY_STATIC_TARGET == HWY_SCALAR +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_RVV +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC8 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSE4 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME +#endif + +// Dynamic dispatch declarations. + +template +struct FunctionCache { + public: + typedef RetType(FunctionType)(Args...); + + // A template function that when instantiated has the same signature as the + // function being called. This function initializes the global cache of the + // current supported targets mask used for dynamic dispatch and calls the + // appropriate function. Since this mask used for dynamic dispatch is a + // global cache, all the highway exported functions, even those exposed by + // different modules, will be initialized after this function runs for any one + // of those exported functions. + template + static RetType ChooseAndCall(Args... args) { + // If we are running here it means we need to update the chosen target. + chosen_target.Update(); + return (table[chosen_target.GetIndex()])(args...); + } +}; + +// Factory function only used to infer the template parameters RetType and Args +// from a function passed to the factory. +template +FunctionCache FunctionCacheFactory(RetType (*)(Args...)) { + return FunctionCache(); +} + +// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or +// nullptr is that target was not compiled. +#if HWY_TARGETS & HWY_SCALAR +#define HWY_CHOOSE_SCALAR(FUNC_NAME) &N_SCALAR::FUNC_NAME +#else +// When scalar is not present and we try to use scalar because other targets +// were disabled at runtime we fall back to the baseline with +// HWY_STATIC_DISPATCH() +#define HWY_CHOOSE_SCALAR(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#endif + +#if HWY_TARGETS & HWY_WASM +#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME +#else +#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_RVV +#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME +#else +#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON +#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME +#else +#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC8 +#define HWY_CHOOSE_PCC8(FUNC_NAME) &N_PPC8::FUNC_NAME +#else +#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSE4 +#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME +#else +#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX2 +#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME +#else +#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3 +#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr +#endif + +#define HWY_DISPATCH_TABLE(FUNC_NAME) \ + HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) + +// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. This +// static array must be defined at the same namespace level as the function +// it is exporting. +// After being exported, it can be called from other parts of the same source +// file using HWY_DYNAMIC_DISTPATCH(), in particular from a function wrapper +// like in the following example: +// +// #include "hwy/highway.h" +// HWY_BEFORE_NAMESPACE(); +// namespace skeleton { +// namespace HWY_NAMESPACE { +// +// void MyFunction(int a, char b, const char* c) { ... } +// +// // NOLINTNEXTLINE(google-readability-namespace-comments) +// } // namespace HWY_NAMESPACE +// } // namespace skeleton +// HWY_AFTER_NAMESPACE(); +// +// namespace skeleton { +// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope. +// +// void MyFunction(int a, char b, const char* c) { +// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c); +// } +// } // namespace skeleton +// + +#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// Simplified version for IDE or the dynamic dispatch case with only one target. +// This case still uses a table, although of a single element, to provide the +// same compile error conditions as with the dynamic dispatch case when multiple +// targets are being compiled. +#define HWY_EXPORT(FUNC_NAME) \ + HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) \ + const HWY_DISPATCH_TABLE(FUNC_NAME)[1] = { \ + &HWY_STATIC_DISPATCH(FUNC_NAME)} +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + +#else + +// Dynamic dispatch case with one entry per dynamic target plus the scalar +// mode and the initialization wrapper. +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) \ + const HWY_DISPATCH_TABLE(FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::FunctionCacheFactory(&HWY_STATIC_DISPATCH( \ + FUNC_NAME)))::ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_SCALAR(FUNC_NAME), \ + } +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ + (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::chosen_target.GetIndex()])) + +#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +} // namespace hwy + +#endif // HWY_HIGHWAY_INCLUDED + +//------------------------------------------------------------------------------ + +// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want +// to include them once per target, which is ensured by the toggle check. +// Because ops/*.h are included under it, they do not need their own guard. +#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_HIGHWAY_PER_TARGET +#undef HWY_HIGHWAY_PER_TARGET +#else +#define HWY_HIGHWAY_PER_TARGET +#endif + +#undef HWY_FULL2 +#if HWY_TARGET == HWY_RVV +#define HWY_FULL2(T, LMUL) hwy::HWY_NAMESPACE::Simd +#else +#define HWY_FULL2(T, LMUL) hwy::HWY_NAMESPACE::Simd +#endif + +// These define ops inside namespace hwy::HWY_NAMESPACE. +#if HWY_TARGET == HWY_SSE4 +#include "hwy/ops/x86_128-inl.h" +#elif HWY_TARGET == HWY_AVX2 +#include "hwy/ops/x86_256-inl.h" +#elif HWY_TARGET == HWY_AVX3 +#include "hwy/ops/x86_512-inl.h" +#elif HWY_TARGET == HWY_PPC8 +#elif HWY_TARGET == HWY_NEON +#include "hwy/ops/arm_neon-inl.h" +#elif HWY_TARGET == HWY_WASM +#include "hwy/ops/wasm_128-inl.h" +#elif HWY_TARGET == HWY_RVV +#include "hwy/ops/rvv-inl.h" +#elif HWY_TARGET == HWY_SCALAR +#include "hwy/ops/scalar-inl.h" +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +// Commonly used functions/types that must come after ops are defined. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec>. +template +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128 for Simd. Useful as the return type +// of functions that do not take a vector argument, or as an argument type if +// the function only has a template argument for D, or for explicit type names +// instead of auto. This may be a built-in type. +template +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and ..Lanes) are not available for the scalar target. +// TODO(janwas): implement for RVV +#if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV + +template +HWY_API V CombineShiftRightLanes(const V hi, const V lo) { + return CombineShiftRightBytes)>(hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template +HWY_API Vec SignBit(D d) { + using Unsigned = MakeUnsigned>; + const Unsigned bit = Unsigned(1) << (sizeof(Unsigned) * 8 - 1); + return BitCast(d, Set(Rebind(), bit)); +} + +// Returns quiet NaN. +template +HWY_API Vec NaN(D d) { + const RebindToSigned di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax>())); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HWY_HIGHWAY_PER_TARGET diff --git a/third_party/highway/hwy/highway_test.cc b/third_party/highway/hwy/highway_test.cc new file mode 100644 index 000000000000..c80789411401 --- /dev/null +++ b/third_party/highway/hwy/highway_test.cc @@ -0,0 +1,313 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "highway_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" // Unpredictable1 +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestSet { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Zero + const auto v0 = Zero(d); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + std::fill(expected.get(), expected.get() + N, T(0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), v0); + + // Set + const auto v2 = Set(d, T(2)); + for (size_t i = 0; i < N; ++i) { + expected[i] = 2; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), v2); + + // Iota + const auto vi = Iota(d, T(5)); + for (size_t i = 0; i < N; ++i) { + expected[i] = T(5 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), vi); + + // Undefined + const auto vu = Undefined(d); + Store(vu, d, expected.get()); + } +}; + +HWY_NOINLINE void TestAllSet() { ForAllTypes(ForPartialVectors()); } + +// Ensures wraparound (mod 2^bits) +struct TestOverflow { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(1)); + const auto vmax = Set(d, LimitsMax()); + const auto vmin = Set(d, LimitsMin()); + // Unsigned underflow / negative -> positive + HWY_ASSERT_VEC_EQ(d, vmax, vmin - v1); + // Unsigned overflow / positive -> negative + HWY_ASSERT_VEC_EQ(d, vmin, vmax + v1); + } +}; + +HWY_NOINLINE void TestAllOverflow() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestSignBitInteger { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto all = VecFromMask(d, Eq(v0, v0)); + const auto vs = SignBit(d); + const auto other = Sub(vs, Set(d, 1)); + + // Shifting left by one => overflow, equal zero + HWY_ASSERT_VEC_EQ(d, v0, Add(vs, vs)); + // Verify the lower bits are zero (only +/- and logical ops are available + // for all types) + HWY_ASSERT_VEC_EQ(d, all, Add(vs, other)); + } +}; + +struct TestSignBitFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vs = SignBit(d); + const auto vp = Set(d, 2.25); + const auto vn = Set(d, -2.25); + HWY_ASSERT_VEC_EQ(d, Or(vp, vs), vn); + HWY_ASSERT_VEC_EQ(d, AndNot(vs, vn), vp); + HWY_ASSERT_VEC_EQ(d, v0, vs); + } +}; + +HWY_NOINLINE void TestAllSignBit() { + ForIntegerTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +// std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. +template +bool IsNaN(TF f) { + MakeUnsigned bits; + memcpy(&bits, &f, sizeof(TF)); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + return bits > ExponentMask(); +} + +template +HWY_NOINLINE void AssertNaN(const D d, const V v, const char* file, int line) { + using T = TFromD; + const T lane = GetLane(v); + if (!IsNaN(lane)) { + const std::string type_name = TypeName(T(), Lanes(d)); + MakeUnsigned bits; + memcpy(&bits, &lane, sizeof(T)); + // RVV lacks PRIu64, so use size_t; double will be truncated on 32-bit. + Abort(file, line, "Expected %s NaN, got %E (%zu)", type_name.c_str(), lane, + size_t(bits)); + } +} + +#define HWY_ASSERT_NAN(d, v) AssertNaN(d, v, __FILE__, __LINE__) + +struct TestNaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, nan); + + // Arithmetic + HWY_ASSERT_NAN(d, Add(nan, v1)); + HWY_ASSERT_NAN(d, Add(v1, nan)); + HWY_ASSERT_NAN(d, Sub(nan, v1)); + HWY_ASSERT_NAN(d, Sub(v1, nan)); + HWY_ASSERT_NAN(d, Mul(nan, v1)); + HWY_ASSERT_NAN(d, Mul(v1, nan)); + HWY_ASSERT_NAN(d, Div(nan, v1)); + HWY_ASSERT_NAN(d, Div(v1, nan)); + + // FMA + HWY_ASSERT_NAN(d, MulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, MulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, v1, nan)); + + // Rcp/Sqrt + HWY_ASSERT_NAN(d, Sqrt(nan)); + + // Sign manipulation + HWY_ASSERT_NAN(d, Abs(nan)); + HWY_ASSERT_NAN(d, Neg(nan)); + HWY_ASSERT_NAN(d, CopySign(nan, v1)); + HWY_ASSERT_NAN(d, CopySignToAbs(nan, v1)); + + // Rounding + HWY_ASSERT_NAN(d, Ceil(nan)); + HWY_ASSERT_NAN(d, Floor(nan)); + HWY_ASSERT_NAN(d, Round(nan)); + HWY_ASSERT_NAN(d, Trunc(nan)); + + // Logical (And/AndNot/Xor will clear NaN!) + HWY_ASSERT_NAN(d, Or(nan, v1)); + + // Comparison + HWY_ASSERT(AllFalse(Eq(nan, v1))); + HWY_ASSERT(AllFalse(Gt(nan, v1))); + HWY_ASSERT(AllFalse(Lt(nan, v1))); + HWY_ASSERT(AllFalse(Ge(nan, v1))); + HWY_ASSERT(AllFalse(Le(nan, v1))); + } +}; + +// For functions only available for float32 +struct TestF32NaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, ApproximateReciprocal(nan)); + HWY_ASSERT_NAN(d, ApproximateReciprocalSqrt(nan)); + HWY_ASSERT_NAN(d, AbsDiff(nan, v1)); + HWY_ASSERT_NAN(d, AbsDiff(v1, nan)); + } +}; + +// TODO(janwas): move to TestNaN once supported for partial vectors +struct TestFullNaN { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + + HWY_ASSERT_NAN(d, SumOfLanes(nan)); +// Reduction (pending clarification on RVV) +#if HWY_TARGET != HWY_RVV + HWY_ASSERT_NAN(d, MinOfLanes(nan)); + HWY_ASSERT_NAN(d, MaxOfLanes(nan)); +#endif + +#if HWY_ARCH_X86 && HWY_TARGET != HWY_SCALAR + // x86 SIMD returns the second operand if any input is NaN. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); +#elif HWY_ARCH_WASM + // Should return NaN if any input is NaN, but does not for scalar. + // TODO(janwas): remove once this is fixed. +#elif HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + // ARMv7 NEON returns NaN if any input is NaN. + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); + HWY_ASSERT_NAN(d, Min(nan, v1)); + HWY_ASSERT_NAN(d, Max(nan, v1)); +#else + // IEEE 754-2019 minimumNumber is defined as the other argument if exactly + // one is NaN, and qNaN if both are. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, nan)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, nan)); +#endif + HWY_ASSERT_NAN(d, Min(nan, nan)); + HWY_ASSERT_NAN(d, Max(nan, nan)); + + // Comparison + HWY_ASSERT(AllFalse(Eq(nan, v1))); + HWY_ASSERT(AllFalse(Gt(nan, v1))); + HWY_ASSERT(AllFalse(Lt(nan, v1))); + HWY_ASSERT(AllFalse(Ge(nan, v1))); + HWY_ASSERT(AllFalse(Le(nan, v1))); + } +}; + +HWY_NOINLINE void TestAllNaN() { + ForFloatTypes(ForPartialVectors()); + ForPartialVectors()(float()); + ForFloatTypes(ForFullVectors()); +} + +struct TestCopyAndAssign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // copy V + const auto v3 = Iota(d, 3); + auto v3b(v3); + HWY_ASSERT_VEC_EQ(d, v3, v3b); + + // assign V + auto v3c = Undefined(d); + v3c = v3; + HWY_ASSERT_VEC_EQ(d, v3, v3c); + } +}; + +HWY_NOINLINE void TestAllCopyAndAssign() { + ForAllTypes(ForPartialVectors()); +} + +struct TestGetLane { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + HWY_ASSERT_EQ(T(0), GetLane(Zero(d))); + HWY_ASSERT_EQ(T(1), GetLane(Set(d, 1))); + } +}; + +HWY_NOINLINE void TestAllGetLane() { + ForAllTypes(ForPartialVectors()); +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HighwayTest); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSet); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllOverflow); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSignBit); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllNaN); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllCopyAndAssign); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllGetLane); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/nanobenchmark.cc b/third_party/highway/hwy/nanobenchmark.cc new file mode 100644 index 000000000000..49e4e47ca116 --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark.cc @@ -0,0 +1,722 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#include +#include +#include // abort +#include // memcpy +#include // clock_gettime + +#include // sort +#include +#include +#include +#include // iota +#include +#include +#include + +#include "hwy/base.h" +#if HWY_ARCH_PPC +#include // NOLINT __ppc_get_timebase_freq +#elif HWY_ARCH_X86 + +#if HWY_COMPILER_MSVC +#include +#else +#include // NOLINT +#endif // HWY_COMPILER_MSVC + +#endif // HWY_ARCH_X86 + +namespace hwy { +namespace platform { +namespace { + +#if HWY_ARCH_X86 + +void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif +} + +std::string BrandString() { + char brand_string[49]; + std::array abcd; + + // Check if brand string is supported (it is on all reasonable Intel/AMD) + Cpuid(0x80000000U, 0, abcd.data()); + if (abcd[0] < 0x80000004U) { + return std::string(); + } + + for (size_t i = 0; i < 3; ++i) { + Cpuid(static_cast(0x80000002U + i), 0, abcd.data()); + memcpy(brand_string + i * 16, abcd.data(), sizeof(abcd)); + } + brand_string[48] = 0; + return brand_string; +} + +// Returns the frequency quoted inside the brand string. This does not +// account for throttling nor Turbo Boost. +double NominalClockRate() { + const std::string& brand_string = BrandString(); + // Brand strings include the maximum configured frequency. These prefixes are + // defined by Intel CPUID documentation. + const char* prefixes[3] = {"MHz", "GHz", "THz"}; + const double multipliers[3] = {1E6, 1E9, 1E12}; + for (size_t i = 0; i < 3; ++i) { + const size_t pos_prefix = brand_string.find(prefixes[i]); + if (pos_prefix != std::string::npos) { + const size_t pos_space = brand_string.rfind(' ', pos_prefix - 1); + if (pos_space != std::string::npos) { + const std::string digits = + brand_string.substr(pos_space + 1, pos_prefix - pos_space - 1); + return std::stod(digits) * multipliers[i]; + } + } + } + + return 0.0; +} + +#endif // HWY_ARCH_X86 + +} // namespace + +// Returns tick rate. Invariant means the tick counter frequency is independent +// of CPU throttling or sleep. May be expensive, caller should cache the result. +double InvariantTicksPerSecond() { +#if HWY_ARCH_PPC + return __ppc_get_timebase_freq(); +#elif HWY_ARCH_X86 + // We assume the TSC is invariant; it is on all recent Intel/AMD CPUs. + return NominalClockRate(); +#else + // Fall back to clock_gettime nanoseconds. + return 1E9; +#endif +} + +} // namespace platform +namespace { + +// Prevents the compiler from eliding the computations that led to "output". +template +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC == 0 + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#else + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic dummy(T{}); + dummy.store(output, std::memory_order_relaxed); +#endif +} + +namespace timer { + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +inline uint64_t Start64() { + uint64_t t; +#if HWY_ARCH_PPC + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RVV + asm volatile("rdcycle %0" : "=r"(t)); +#else + // Fall back to OS - unsure how to reliably query cntvct_el0 frequency. + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = ts.tv_sec * 1000000000LL + ts.tv_nsec; +#endif + return t; +} + +inline uint64_t Stop64() { + uint64_t t; +#if HWY_ARCH_PPC + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = Start64(); +#endif + return t; +} + +// Returns a 32-bit timestamp with about 4 cycles less overhead than +// Start64. Only suitable for measuring very short regions because the +// timestamp overflows about once a second. +inline uint32_t Start32() { + uint32_t t; +#if HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = static_cast(__rdtsc()); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + : "rdx", "memory"); +#elif HWY_ARCH_RVV + asm volatile("rdcycle %0" : "=r"(t)); +#else + t = static_cast(Start64()); +#endif + return t; +} + +inline uint32_t Stop32() { + uint32_t t; +#if HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = static_cast(__rdtscp(&aux)); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + : "rcx", "rdx", "memory"); +#else + t = static_cast(Stop64()); +#endif + return t; +} + +} // namespace timer + +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair; + std::vector unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* HWY_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + NANOBENCHMARK_CHECK(p == values + num_values); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template +size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, + const size_t half_count) { + T min_range = std::numeric_limits::max(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + NANOBENCHMARK_CHECK(sorted[idx] <= sorted[idx + half_count]); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template +T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + NANOBENCHMARK_CHECK(half_count == 1); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template +T Median(T* values, const size_t num_values) { + NANOBENCHMARK_CHECK(!values->empty()); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // Even count: return average of middle two. + return (values[half] + values[half - 1] + 1) / 2; +} + +// Returns a robust measure of variability. +template +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + NANOBENCHMARK_CHECK(num_values != 0); + std::vector abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = std::abs(int64_t(values[i]) - int64_t(median)); + abs_deviations.push_back(static_cast(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. 32 bit timers are faster to +// read than 64 bit. +using Ticks = uint32_t; + +// Returns timer overhead / minimum measurable difference. +Ticks TimerResolution() { + // Nested loop avoids exceeding stack/L1 capacity. + Ticks repetitions[Params::kTimerSamples]; + for (size_t rep = 0; rep < Params::kTimerSamples; ++rep) { + Ticks samples[Params::kTimerSamples]; + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const Ticks t0 = timer::Start32(); + const Ticks t1 = timer::Stop32(); + samples[i] = t1 - t0; + } + repetitions[rep] = robust_statistics::Mode(samples); + } + return robust_statistics::Mode(repetitions); +} + +static const Ticks timer_resolution = TimerResolution(); + +// Estimates the expected value of "lambda" values with a variable number of +// samples until the variability "rel_mad" is less than "max_rel_mad". +template +Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, + const Params& p, const Lambda& lambda) { + // Choose initial samples_per_eval based on a single estimated duration. + Ticks t0 = timer::Start32(); + lambda(); + Ticks t1 = timer::Stop32(); + Ticks est = t1 - t0; + static const double ticks_per_second = platform::InvariantTicksPerSecond(); + const size_t ticks_per_eval = + static_cast(ticks_per_second * p.seconds_per_eval); + size_t samples_per_eval = + est == 0 ? p.min_samples_per_eval : ticks_per_eval / est; + samples_per_eval = std::max(samples_per_eval, p.min_samples_per_eval); + + std::vector samples; + samples.reserve(1 + samples_per_eval); + samples.push_back(est); + + // Percentage is too strict for tiny differences, so also allow a small + // absolute "median absolute deviation". + const Ticks max_abs_mad = (timer_resolution + 99) / 100; + *rel_mad = 0.0; // ensure initialized + + for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) { + samples.reserve(samples.size() + samples_per_eval); + for (size_t i = 0; i < samples_per_eval; ++i) { + t0 = timer::Start32(); + lambda(); + t1 = timer::Stop32(); + samples.push_back(t1 - t0); + } + + if (samples.size() >= p.min_mode_samples) { + est = robust_statistics::Mode(samples.data(), samples.size()); + } else { + // For "few" (depends also on the variance) samples, Median is safer. + est = robust_statistics::Median(samples.data(), samples.size()); + } + NANOBENCHMARK_CHECK(est != 0); + + // Median absolute deviation (mad) is a robust measure of 'variability'. + const Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( + samples.data(), samples.size(), est); + *rel_mad = static_cast(int(abs_mad)) / est; + + if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) { + if (p.verbose) { + printf("%6zu samples => %5u (abs_mad=%4u, rel_mad=%4.2f%%)\n", + samples.size(), est, abs_mad, *rel_mad * 100.0); + } + return est; + } + } + + if (p.verbose) { + printf( + "WARNING: rel_mad=%4.2f%% still exceeds %4.2f%% after %6zu samples.\n", + *rel_mad * 100.0, max_rel_mad * 100.0, samples.size()); + } + return est; +} + +using InputVec = std::vector; + +// Returns vector of unique input values. +InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) { + InputVec unique(inputs, inputs + num_inputs); + std::sort(unique.begin(), unique.end()); + unique.erase(std::unique(unique.begin(), unique.end()), unique.end()); + return unique; +} + +// Returns how often we need to call func for sufficient precision, or zero +// on failure (e.g. the elapsed time is too long for a 32-bit tick count). +size_t NumSkip(const Func func, const uint8_t* arg, const InputVec& unique, + const Params& p) { + // Min elapsed ticks for any input. + Ticks min_duration = ~0u; + + for (const FuncInput input : unique) { + // Make sure a 32-bit timer is sufficient. + const uint64_t t0 = timer::Start64(); + PreventElision(func(arg, input)); + const uint64_t t1 = timer::Stop64(); + const uint64_t elapsed = t1 - t0; + if (elapsed >= (1ULL << 30)) { + fprintf(stderr, "Measurement failed: need 64-bit timer for input=%zu\n", + input); + return 0; + } + + double rel_mad; + const Ticks total = SampleUntilStable( + p.target_rel_mad, &rel_mad, p, + [func, arg, input]() { PreventElision(func(arg, input)); }); + min_duration = std::min(min_duration, total - timer_resolution); + } + + // Number of repetitions required to reach the target resolution. + const size_t max_skip = p.precision_divisor; + // Number of repetitions given the estimated duration. + const size_t num_skip = + min_duration == 0 ? 0 : (max_skip + min_duration - 1) / min_duration; + if (p.verbose) { + printf("res=%u max_skip=%zu min_dur=%u num_skip=%zu\n", timer_resolution, + max_skip, min_duration, num_skip); + } + return num_skip; +} + +// Replicates inputs until we can omit "num_skip" occurrences of an input. +InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs, + const size_t num_unique, const size_t num_skip, + const Params& p) { + InputVec full; + if (num_unique == 1) { + full.assign(p.subset_ratio * num_skip, inputs[0]); + return full; + } + + full.reserve(p.subset_ratio * num_skip * num_inputs); + for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) { + full.insert(full.end(), inputs, inputs + num_inputs); + } + std::mt19937 rng; + std::shuffle(full.begin(), full.end(), rng); + return full; +} + +// Copies the "full" to "subset" in the same order, but with "num_skip" +// randomly selected occurrences of "input_to_skip" removed. +void FillSubset(const InputVec& full, const FuncInput input_to_skip, + const size_t num_skip, InputVec* subset) { + const size_t count = + static_cast(std::count(full.begin(), full.end(), input_to_skip)); + // Generate num_skip random indices: which occurrence to skip. + std::vector omit(count); + std::iota(omit.begin(), omit.end(), 0); + // omit[] is the same on every call, but that's OK because they identify the + // Nth instance of input_to_skip, so the position within full[] differs. + std::mt19937 rng; + std::shuffle(omit.begin(), omit.end(), rng); + omit.resize(num_skip); + std::sort(omit.begin(), omit.end()); + + uint32_t occurrence = ~0u; // 0 after preincrement + size_t idx_omit = 0; // cursor within omit[] + size_t idx_subset = 0; // cursor within *subset + for (const FuncInput next : full) { + if (next == input_to_skip) { + ++occurrence; + // Haven't removed enough already + if (idx_omit < num_skip) { + // This one is up for removal + if (occurrence == omit[idx_omit]) { + ++idx_omit; + continue; + } + } + } + if (idx_subset < subset->size()) { + (*subset)[idx_subset++] = next; + } + } + NANOBENCHMARK_CHECK(idx_subset == subset->size()); + NANOBENCHMARK_CHECK(idx_omit == omit.size()); + NANOBENCHMARK_CHECK(occurrence == count - 1); +} + +// Returns total ticks elapsed for all inputs. +Ticks TotalDuration(const Func func, const uint8_t* arg, const InputVec* inputs, + const Params& p, double* max_rel_mad) { + double rel_mad; + const Ticks duration = + SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(func(arg, input)); + } + }); + *max_rel_mad = std::max(*max_rel_mad, rel_mad); + return duration; +} + +// (Nearly) empty Func for measuring timer overhead/resolution. +HWY_NOINLINE FuncOutput EmptyFunc(const void* /*arg*/, const FuncInput input) { + return input; +} + +// Returns overhead of accessing inputs[] and calling a function; this will +// be deducted from future TotalDuration return values. +Ticks Overhead(const uint8_t* arg, const InputVec* inputs, const Params& p) { + double rel_mad; + // Zero tolerance because repeatability is crucial and EmptyFunc is fast. + return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(EmptyFunc(arg, input)); + } + }); +} + +} // namespace + +int Unpredictable1() { return timer::Start64() != ~0ULL; } + +size_t Measure(const Func func, const uint8_t* arg, const FuncInput* inputs, + const size_t num_inputs, Result* results, const Params& p) { + NANOBENCHMARK_CHECK(num_inputs != 0); + const InputVec& unique = UniqueInputs(inputs, num_inputs); + + const size_t num_skip = NumSkip(func, arg, unique, p); // never 0 + if (num_skip == 0) return 0; // NumSkip already printed error message + // (slightly less work on x86 to cast from signed integer) + const float mul = 1.0f / static_cast(static_cast(num_skip)); + + const InputVec& full = + ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p); + InputVec subset(full.size() - num_skip); + + const Ticks overhead = Overhead(arg, &full, p); + const Ticks overhead_skip = Overhead(arg, &subset, p); + if (overhead < overhead_skip) { + fprintf(stderr, "Measurement failed: overhead %u < %u\n", overhead, + overhead_skip); + return 0; + } + + if (p.verbose) { + printf("#inputs=%5zu,%5zu overhead=%5u,%5u\n", full.size(), subset.size(), + overhead, overhead_skip); + } + + double max_rel_mad = 0.0; + const Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad); + + for (size_t i = 0; i < unique.size(); ++i) { + FillSubset(full, unique[i], num_skip, &subset); + const Ticks total_skip = TotalDuration(func, arg, &subset, p, &max_rel_mad); + + if (total < total_skip) { + fprintf(stderr, "Measurement failed: total %u < %u\n", total, total_skip); + return 0; + } + + const Ticks duration = (total - overhead) - (total_skip - overhead_skip); + results[i].input = unique[i]; + results[i].ticks = static_cast(duration) * mul; + results[i].variability = static_cast(max_rel_mad); + } + + return unique.size(); +} + +} // namespace hwy diff --git a/third_party/highway/hwy/nanobenchmark.h b/third_party/highway/hwy/nanobenchmark.h new file mode 100644 index 000000000000..7c2db54bf569 --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark.h @@ -0,0 +1,186 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_ +#define HIGHWAY_HWY_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +// WARNING if included from multiple translation units compiled with distinct +// flags: this header requires textual inclusion and a predefined NB_NAMESPACE +// macro that is unique to the current compile flags. We must also avoid +// standard library headers such as vector and functional that define functions. + +#include +#include + +// Enables sanity checks that verify correct operation at the cost of +// longer benchmark runs. +#ifndef NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_ENABLE_CHECKS 0 +#endif + +#define NANOBENCHMARK_CHECK_ALWAYS(condition) \ + while (!(condition)) { \ + fprintf(stderr, "Nanobenchmark check failed at line %d\n", __LINE__); \ + abort(); \ + } + +#if NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_CHECK(condition) NANOBENCHMARK_CHECK_ALWAYS(condition) +#else +#define NANOBENCHMARK_CHECK(condition) +#endif + +namespace hwy { + +namespace platform { + +// Returns tick rate, useful for converting measurements to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +double InvariantTicksPerSecond(); + +} // namespace platform + +// Returns 1, but without the compiler knowing what the value is. This prevents +// optimizing out code. +int Unpredictable1(); + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // For measuring timer overhead/resolution. Used in a nested loop => + // quadratic time, acceptable because we know timer overhead is "low". + // constexpr because this is used to define array bounds. + static constexpr size_t kTimerSamples = 256; + + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +size_t Measure(const Func func, const uint8_t* arg, const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template +static FuncOutput CallClosure(const Closure* f, const FuncInput input) { + return (*f)(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(reinterpret_cast(&CallClosure), + reinterpret_cast(&closure), inputs, num_inputs, + results, p); +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_NANOBENCHMARK_H_ diff --git a/third_party/highway/hwy/nanobenchmark_test.cc b/third_party/highway/hwy/nanobenchmark_test.cc new file mode 100644 index 000000000000..f73d773dcb8e --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark_test.cc @@ -0,0 +1,104 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#include +#include // strtol +#include // sleep + +#include + +namespace hwy { +namespace { + +FuncOutput Div(const void*, FuncInput in) { + // Here we're measuring the throughput because benchmark invocations are + // independent. Any dividend will do; the divisor is nonzero. + return 0xFFFFF / in; +} + +template +void MeasureDiv(const FuncInput (&inputs)[N]) { + Result results[N]; + Params params; + params.max_evals = 4; // avoid test timeout + const size_t num_results = Measure(&Div, nullptr, inputs, N, results, params); + for (size_t i = 0; i < num_results; ++i) { + printf("%5zu: %6.2f ticks; MAD=%4.2f%%\n", results[i].input, + results[i].ticks, results[i].variability * 100.0); + } +} + +std::mt19937 rng; + +// A function whose runtime depends on rng. +FuncOutput Random(const void* /*arg*/, FuncInput in) { + const size_t r = rng() & 0xF; + uint32_t ret = in; + for (size_t i = 0; i < r; ++i) { + ret /= ((rng() & 1) + 2); + } + return ret; +} + +// Ensure the measured variability is high. +template +void MeasureRandom(const FuncInput (&inputs)[N]) { + Result results[N]; + Params p; + p.max_evals = 4; // avoid test timeout + p.verbose = false; + const size_t num_results = Measure(&Random, nullptr, inputs, N, results, p); + for (size_t i = 0; i < num_results; ++i) { + NANOBENCHMARK_CHECK(results[i].variability > 1E-3); + } +} + +template +void EnsureLongMeasurementFails(const FuncInput (&inputs)[N]) { + printf("Expect a 'measurement failed' below:\n"); + Result results[N]; + + const size_t num_results = Measure( + [](const void*, const FuncInput input) -> FuncOutput { + // Loop until the sleep succeeds (not interrupted by signal). We assume + // >= 512 MHz, so 2 seconds will exceed the 1 << 30 tick safety limit. + while (sleep(2) != 0) { + } + return input; + }, + nullptr, inputs, N, results); + NANOBENCHMARK_CHECK(num_results == 0); + (void)num_results; +} + +void RunAll(const int argc, char** /*argv*/) { + // unpredictable == 1 but the compiler doesn't know that. + const int unpredictable = argc != 999; + static const FuncInput inputs[] = {static_cast(unpredictable) + 2, + static_cast(unpredictable + 9)}; + + MeasureDiv(inputs); + MeasureRandom(inputs); + EnsureLongMeasurementFails(inputs); +} + +} // namespace +} // namespace hwy + +int main(int argc, char* argv[]) { + hwy::RunAll(argc, argv); + return 0; +} diff --git a/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/highway/hwy/ops/arm_neon-inl.h new file mode 100644 index 000000000000..b8f2350098aa --- /dev/null +++ b/third_party/highway/hwy/ops/arm_neon-inl.h @@ -0,0 +1,4272 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit ARM64 NEON vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_INLINE HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8_t, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8_t, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8_t, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8_t, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8_t, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8_t, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8_t, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8_t, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8_t, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8_t, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16_t, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16_t, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16_t, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16_t, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16_t, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16_t, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16_t, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16_t, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32_t, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32_t, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32_t, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32_t, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32_t, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32_t, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64_t, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64_t, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64_t, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64_t, 1, name, prefix, infix, s64, args) + +// float and double +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float, 1, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(double, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(double, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float, 1, name, prefix, infix, f32, args) +#endif + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +template +struct Raw128; + +// 128 +template <> +struct Raw128 { + using type = uint8x16_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = uint32x4_t; +}; + +template <> +struct Raw128 { + using type = uint64x2_t; +}; + +template <> +struct Raw128 { + using type = int8x16_t; +}; + +template <> +struct Raw128 { + using type = int16x8_t; +}; + +template <> +struct Raw128 { + using type = int32x4_t; +}; + +template <> +struct Raw128 { + using type = int64x2_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; + +template <> +struct Raw128 { + using type = float32x4_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x2_t; +}; +#endif + +// 64 +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = uint64x1_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = int64x1_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = float32x2_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128 { + using type = float64x1_t; +}; +#endif + +// 32 (same as 64) +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = float32x2_t; +}; + +// 16 (same as 64) +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = uint16x4_t; +}; + +// 8 (same as 64) +template <> +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = int8x8_t; +}; + +template +using Full128 = Simd; + +template +class Vec128 { + using Raw = typename Raw128::type; + + public: + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// FF..FF or 0, also for floating-point - see README. +template +class Mask128 { + using Raw = typename Raw128::type; + + public: + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128 to Vec128 using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +// Special case for float16_t, which has the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return v; +} + +// 64-bit or less: + +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s8_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_f32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_u64_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_s64_u8(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(vreinterpret_f64_u8(v.raw)); +} +#endif + +// 128-bit full: + +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s8_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s16_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f32_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_u64_u8(v.raw)); +} +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_s64_u8(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 BitCastFromByte(Full128 /* tag */, + Vec128 v) { + return Vec128(vreinterpretq_f64_u8(v.raw)); +} +#endif + +// Special case for float16_t, which has the same Raw as uint16_t. +template +HWY_INLINE Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128(BitCastFromByte(Simd(), v).raw); +} + +} // namespace detail + +template +HWY_INLINE Vec128 BitCast( + Simd d, Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns a vector with all lanes set to "t". +#define HWY_NEON_BUILD_TPL_HWY_SET1 +#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ + Simd /* tag */, const type t +#define HWY_NEON_BUILD_ARG_HWY_SET1 t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) + +#undef HWY_NEON_BUILD_TPL_HWY_SET1 +#undef HWY_NEON_BUILD_RET_HWY_SET1 +#undef HWY_NEON_BUILD_PARAM_HWY_SET1 +#undef HWY_NEON_BUILD_ARG_HWY_SET1 + +// Returns an all-zero vector. +template +HWY_INLINE Vec128 Zero(Simd d) { + return Set(d, 0); +} + +// Returns a vector with uninitialized elements. +template +HWY_INLINE Vec128 Undefined(Simd /*d*/) { + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") + typename Raw128::type a; + return Vec128(a); + HWY_DIAGNOSTICS(pop) +} + +// ------------------------------ Extract lane + +HWY_INLINE uint8_t GetLane(const Vec128 v) { + return vget_lane_u8(vget_low_u8(v.raw), 0); +} +template +HWY_INLINE uint8_t GetLane(const Vec128 v) { + return vget_lane_u8(v.raw, 0); +} + +HWY_INLINE int8_t GetLane(const Vec128 v) { + return vget_lane_s8(vget_low_s8(v.raw), 0); +} +template +HWY_INLINE int8_t GetLane(const Vec128 v) { + return vget_lane_s8(v.raw, 0); +} + +HWY_INLINE uint16_t GetLane(const Vec128 v) { + return vget_lane_u16(vget_low_u16(v.raw), 0); +} +template +HWY_INLINE uint16_t GetLane(const Vec128 v) { + return vget_lane_u16(v.raw, 0); +} + +HWY_INLINE int16_t GetLane(const Vec128 v) { + return vget_lane_s16(vget_low_s16(v.raw), 0); +} +template +HWY_INLINE int16_t GetLane(const Vec128 v) { + return vget_lane_s16(v.raw, 0); +} + +HWY_INLINE uint32_t GetLane(const Vec128 v) { + return vget_lane_u32(vget_low_u32(v.raw), 0); +} +template +HWY_INLINE uint32_t GetLane(const Vec128 v) { + return vget_lane_u32(v.raw, 0); +} + +HWY_INLINE int32_t GetLane(const Vec128 v) { + return vget_lane_s32(vget_low_s32(v.raw), 0); +} +template +HWY_INLINE int32_t GetLane(const Vec128 v) { + return vget_lane_s32(v.raw, 0); +} + +HWY_INLINE uint64_t GetLane(const Vec128 v) { + return vget_lane_u64(vget_low_u64(v.raw), 0); +} +HWY_INLINE uint64_t GetLane(const Vec128 v) { + return vget_lane_u64(v.raw, 0); +} +HWY_INLINE int64_t GetLane(const Vec128 v) { + return vget_lane_s64(vget_low_s64(v.raw), 0); +} +HWY_INLINE int64_t GetLane(const Vec128 v) { + return vget_lane_s64(v.raw, 0); +} + +HWY_INLINE float GetLane(const Vec128 v) { + return vget_lane_f32(vget_low_f32(v.raw), 0); +} +HWY_INLINE float GetLane(const Vec128 v) { + return vget_lane_f32(v.raw, 0); +} +HWY_INLINE float GetLane(const Vec128 v) { + return vget_lane_f32(v.raw, 0); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE double GetLane(const Vec128 v) { + return vget_lane_f64(vget_low_f64(v.raw), 0); +} +HWY_INLINE double GetLane(const Vec128 v) { + return vget_lane_f64(v.raw, 0); +} +#endif + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) + +// ------------------------------ Saturating addition and subtraction +// Only defined for uint8_t, uint16_t and their signed versions, as in other +// architectures. + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedSub, vqsub, _, 2) + +// Not part of API, used in implementation. +namespace detail { +HWY_NEON_DEF_FUNCTION_UINT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_64(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_64(SaturatedSub, vqsub, _, 2) +} // namespace detail + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s8(v.raw)); +} +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s16(v.raw)); +} +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_s32(v.raw)); +} +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f32(v.raw)); +} + +template +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s8(v.raw)); +} +template +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s16(v.raw)); +} +template +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabs_s32(v.raw)); +} +template +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabs_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabsq_f64(v.raw)); +} + +HWY_INLINE Vec128 Abs(const Vec128 v) { + return Vec128(vabs_f64(v.raw)); +} +#endif + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +HWY_INLINE Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vneg_s64(v.raw)); +#else + return Zero(Simd()) - v; +#endif +} + +HWY_INLINE Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vnegq_s64(v.raw)); +#else + return Zero(Full128()) - v; +#endif +} + +// ------------------------------ ShiftLeft + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template \ + HWY_INLINE Vec128 name(const Vec128 v) { \ + return kBits == 0 ? v \ + : Vec128(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, HWY_SHIFT) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, HWY_SHIFT) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, HWY_SHIFT) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Shl + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, bits.raw)); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, bits.raw)); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, bits.raw)); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, bits.raw)); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, bits.raw)); +} +template +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, bits.raw)); +} + +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, bits.raw)); +} +HWY_INLINE Vec128 operator<<(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x16_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u8(v.raw, neg_bits)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int8x8_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u8(v.raw, neg_bits)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x8_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u16(v.raw, neg_bits)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int16x4_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u16(v.raw, neg_bits)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x4_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u32(v.raw, neg_bits)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int32x2_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u32(v.raw, neg_bits)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int64x2_t neg_bits = Neg(BitCast(Full128(), bits)).raw; + return Vec128(vshlq_u64(v.raw, neg_bits)); +} +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + const int64x1_t neg_bits = Neg(BitCast(Simd(), bits)).raw; + return Vec128(vshl_u64(v.raw, neg_bits)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_INLINE Vec128 operator>>(const Vec128 v, + const Vec128 bits) { + return Vec128(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template +HWY_INLINE Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(Simd(), bits); +} +template +HWY_INLINE Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(Simd(), bits); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u16(a.raw, b.raw)); +} +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_u32(a.raw, b.raw)); +} + +template +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u16(a.raw, b.raw)); +} +template +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_u32(a.raw, b.raw)); +} + +// Signed +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s16(a.raw, b.raw)); +} +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmulq_s32(a.raw, b.raw)); +} + +template +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s16(a.raw, b.raw)); +} +template +HWY_INLINE Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128(vmul_s32(a.raw, b.raw)); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_INLINE Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_INLINE Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template +HWY_INLINE Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template +HWY_INLINE Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + int32x4_t a_packed = vuzp1q_s32(a.raw, a.raw); + int32x4_t b_packed = vuzp1q_s32(b.raw, b.raw); + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + uint32x4_t a_packed = vuzp1q_u32(a.raw, a.raw); + uint32x4_t b_packed = vuzp1q_u32(b.raw, b.raw); + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + int32x2_t a_packed = vuzp1_s32(a.raw, a.raw); + int32x2_t b_packed = vuzp1_s32(b.raw, b.raw); + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_INLINE Vec128 MulEven(const Vec128 a, + const Vec128 b) { + uint32x2_t a_packed = vuzp1_u32(a.raw, a.raw); + uint32x2_t b_packed = vuzp1_u32(b.raw, b.raw); + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +// ------------------------------ Floating-point mul / div + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +// Approximate reciprocal +HWY_INLINE Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpeq_f32(v.raw)); +} +template +HWY_INLINE Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128(vrecpe_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else +// Emulated with approx reciprocal + Newton-Raphson + mul +template +HWY_INLINE Vec128 operator/(const Vec128 a, + const Vec128 b) { + auto x = ApproximateReciprocal(b); + // Newton-Raphson on 1/x - b + const auto two = Set(Simd(), 2); + x = x * (two - b * x); + x = x * (two - b * x); + x = x * (two - b * x); + return a * x; +} +#endif + +// Absolute value of difference. +HWY_INLINE Vec128 AbsDiff(const Vec128 a, const Vec128 b) { + return Vec128(vabdq_f32(a.raw, b.raw)); +} +template +HWY_INLINE Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Vec128(vabd_f32(a.raw, b.raw)); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns add + mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_INLINE Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfma_f32(add.raw, mul.raw, x.raw)); +} +HWY_INLINE Vec128 MulAdd(const Vec128 mul, const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_INLINE Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return mul * x + add; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfma_f64(add.raw, mul.raw, x.raw)); +} +HWY_INLINE Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfmaq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns add - mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template +HWY_INLINE Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfms_f32(add.raw, mul.raw, x.raw)); +} +HWY_INLINE Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template +HWY_INLINE Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return add - mul * x; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfms_f64(add.raw, mul.raw, x.raw)); +} +HWY_INLINE Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + return Vec128(vfmsq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns mul * x - sub +template +HWY_INLINE Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} + +// Returns -mul * x - sub +template +HWY_INLINE Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} + +#if HWY_ARCH_ARM_A64 +template +HWY_INLINE Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} +template +HWY_INLINE Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} +#endif + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_INLINE Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128(vrsqrteq_f32(v.raw)); +} +template +HWY_INLINE Vec128 ApproximateReciprocalSqrt( + const Vec128 v) { + return Vec128(vrsqrte_f32(v.raw)); +} + +// Full precision square root +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else +// Not defined on armv7: emulate with approx reciprocal sqrt + Goldschmidt. +template +HWY_INLINE Vec128 Sqrt(const Vec128 v) { + auto b = v; + auto Y = ApproximateReciprocalSqrt(v); + auto x = v * Y; + const auto half = Set(Simd(), 0.5); + const auto oneandhalf = Set(Simd(), 1.5); + for (size_t i = 0; i < 3; i++) { + b = b * Y * Y; + Y = oneandhalf - half * b; + x = x * Y; + } + return IfThenZeroElse(v == Zero(Simd()), x); +} +#endif + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality + +// Signed/float < (no unsigned) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// Signed/float > (no unsigned) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS(operator>, vcgt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator>, vcgt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator>, vcgt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality + +// Float <= >= +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator>=, vcge, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template +HWY_INLINE Vec128 Not(const Vec128 v) { + const Full128 d8; + return Vec128(vmvnq_u8(BitCast(d8, v).raw)); +} +template +HWY_INLINE Vec128 Not(const Vec128 v) { + const Repartition> d8; + return Vec128(vmvn_u8(BitCast(d8, v).raw)); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template +HWY_INLINE Vec128 And(const Vec128 a, const Vec128 b) { + const Simd, N> d; + return BitCast(Simd(), BitCast(d, a) & BitCast(d, b)); +} + +// ------------------------------ AndNot + +namespace internal { +// reversed_andnot returns a & ~b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) +} // namespace internal + +// Returns ~not_mask & mask. +template +HWY_INLINE Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return internal::reversed_andnot(mask, not_mask); +} + +// Uses the u32/64 defined above. +template +HWY_INLINE Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + const Simd, N> du; + Vec128, N> ret = + internal::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(Simd(), ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template +HWY_INLINE Vec128 Or(const Vec128 a, const Vec128 b) { + const Simd, N> d; + return BitCast(Simd(), BitCast(d, a) | BitCast(d, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template +HWY_INLINE Vec128 Xor(const Vec128 a, const Vec128 b) { + const Simd, N> d; + return BitCast(Simd(), BitCast(d, a) ^ BitCast(d, b)); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_INLINE Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_INLINE Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_INLINE Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ------------------------------ Make mask + +template +HWY_INLINE Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// Mask and Vec are the same (true = FF..FF). +template +HWY_INLINE Mask128 MaskFromVec(const Vec128 v) { + return Mask128(v.raw); +} + +template +HWY_INLINE Vec128 VecFromMask(const Mask128 v) { + return Vec128(v.raw); +} + +template +HWY_INLINE Vec128 VecFromMask(Simd /* tag */, + const Mask128 v) { + return Vec128(v.raw); +} + +// IfThenElse(mask, yes, no) +// Returns mask ? b : a. +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template +HWY_INLINE Vec128 IfThenElseZero(const Mask128 mask, + const Vec128 yes) { + return yes & VecFromMask(Simd(), mask); +} + +// mask ? 0 : no +template +HWY_INLINE Vec128 IfThenZeroElse(const Mask128 mask, + const Vec128 no) { + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_INLINE Vec128 ZeroIfNegative(Vec128 v) { + const auto zero = Zero(Simd()); + return Max(zero, v); +} + + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +namespace detail { + +#if HWY_ARCH_ARM_A64 + +HWY_INLINE Vec128 Gt(Vec128 a, Vec128 b) { + return Vec128(vcgtq_u64(a.raw, b.raw)); +} +HWY_INLINE Vec128 Gt(Vec128 a, + Vec128 b) { + return Vec128(vcgt_u64(a.raw, b.raw)); +} + +HWY_INLINE Vec128 Gt(Vec128 a, Vec128 b) { + return Vec128(vcgtq_s64(a.raw, b.raw)); +} +HWY_INLINE Vec128 Gt(Vec128 a, Vec128 b) { + return Vec128(vcgt_s64(a.raw, b.raw)); +} + +#endif + +} // namespace detail + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template +HWY_INLINE Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(MaskFromVec(detail::Gt(a, b)), b, a); +#else + const Simd du; + const Simd di; + return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template +HWY_INLINE Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(MaskFromVec(detail::Gt(a, b)), b, a); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vminnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template +HWY_INLINE Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(MaskFromVec(detail::Gt(a, b)), a, b); +#else + const Simd du; + const Simd di; + return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template +HWY_INLINE Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(MaskFromVec(detail::Gt(a, b)), a, b); +#else + const Vec128 sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE maximumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmaxnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const uint8_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_u8(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const uint16_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_u16(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const uint32_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_u32(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const uint64_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_u64(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const int8_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_s8(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const int16_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_s16(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const int32_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_s32(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const int64_t* HWY_RESTRICT aligned) { + return Vec128(vld1q_s64(aligned)); +} +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128(vld1q_f32(aligned)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128(vld1q_f64(aligned)); +} +#endif + +// ------------------------------ Load 64 + +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const uint8_t* HWY_RESTRICT p) { + return Vec128(vld1_u8(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const uint16_t* HWY_RESTRICT p) { + return Vec128(vld1_u16(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const uint32_t* HWY_RESTRICT p) { + return Vec128(vld1_u32(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const uint64_t* HWY_RESTRICT p) { + return Vec128(vld1_u64(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const int8_t* HWY_RESTRICT p) { + return Vec128(vld1_s8(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const int16_t* HWY_RESTRICT p) { + return Vec128(vld1_s16(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const int32_t* HWY_RESTRICT p) { + return Vec128(vld1_s32(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const int64_t* HWY_RESTRICT p) { + return Vec128(vld1_s64(p)); +} +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const float* HWY_RESTRICT p) { + return Vec128(vld1_f32(p)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 LoadU(Simd /* tag */, + const double* HWY_RESTRICT p) { + return Vec128(vld1_f64(p)); +} +#endif + +// ------------------------------ Load 32 + +// In the following load functions, |a| is purposely undefined. +// It is a required parameter to the intrinsic, however +// we don't actually care what is in it, and we don't want +// to introduce extra overhead by initializing it to something. + +HWY_INLINE Vec128 LoadU(Simd d, + const uint8_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(d).raw; + uint32x2_t b = vld1_lane_u32(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_u8_u32(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const uint16_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(d).raw; + uint32x2_t b = vld1_lane_u32(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_u16_u32(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const uint32_t* HWY_RESTRICT p) { + uint32x2_t a = Undefined(d).raw; + uint32x2_t b = vld1_lane_u32(p, a, 0); + return Vec128(b); +} +HWY_INLINE Vec128 LoadU(Simd d, + const int8_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(d).raw; + int32x2_t b = vld1_lane_s32(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_s8_s32(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const int16_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(d).raw; + int32x2_t b = vld1_lane_s32(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_s16_s32(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const int32_t* HWY_RESTRICT p) { + int32x2_t a = Undefined(d).raw; + int32x2_t b = vld1_lane_s32(p, a, 0); + return Vec128(b); +} +HWY_INLINE Vec128 LoadU(Simd d, + const float* HWY_RESTRICT p) { + float32x2_t a = Undefined(d).raw; + float32x2_t b = vld1_lane_f32(p, a, 0); + return Vec128(b); +} + +// ------------------------------ Load 16 + +HWY_INLINE Vec128 LoadU(Simd d, + const uint8_t* HWY_RESTRICT p) { + uint16x4_t a = Undefined(d).raw; + uint16x4_t b = vld1_lane_u16(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_u8_u16(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const uint16_t* HWY_RESTRICT p) { + uint16x4_t a = Undefined(d).raw; + uint16x4_t b = vld1_lane_u16(p, a, 0); + return Vec128(b); +} + +HWY_INLINE Vec128 LoadU(Simd d, + const int8_t* HWY_RESTRICT p) { + int16x4_t a = Undefined(d).raw; + int16x4_t b = vld1_lane_s16(reinterpret_cast(p), a, 0); + return Vec128(vreinterpret_s8_s16(b)); +} +HWY_INLINE Vec128 LoadU(Simd d, + const int16_t* HWY_RESTRICT p) { + int16x4_t a = Undefined(d).raw; + int16x4_t b = vld1_lane_s16(p, a, 0); + return Vec128(b); +} + +// ------------------------------ Load 8 + +HWY_INLINE Vec128 LoadU(Simd d, + const uint8_t* HWY_RESTRICT p) { + uint8x8_t a = Undefined(d).raw; + uint8x8_t b = vld1_lane_u8(p, a, 0); + return Vec128(b); +} + +HWY_INLINE Vec128 LoadU(Simd d, + const int8_t* HWY_RESTRICT p) { + int8x8_t a = Undefined(d).raw; + int8x8_t b = vld1_lane_s8(p, a, 0); + return Vec128(b); +} + +// float16_t uses the same Raw as uint16_t, so forward to that. +template +HWY_INLINE Vec128 LoadU(Simd /*d*/, + const float16_t* HWY_RESTRICT p) { + const Simd du16; + const auto pu16 = reinterpret_cast(p); + return Vec128(LoadU(du16, pu16).raw); +} + +// On ARM, Load is the same as LoadU. +template +HWY_INLINE Vec128 Load(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_INLINE Vec128 LoadDup128(Simd d, + const T* const HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + uint8_t* HWY_RESTRICT aligned) { + vst1q_u8(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + uint16_t* HWY_RESTRICT aligned) { + vst1q_u16(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + uint32_t* HWY_RESTRICT aligned) { + vst1q_u32(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + uint64_t* HWY_RESTRICT aligned) { + vst1q_u64(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + int8_t* HWY_RESTRICT aligned) { + vst1q_s8(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + int16_t* HWY_RESTRICT aligned) { + vst1q_s16(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + int32_t* HWY_RESTRICT aligned) { + vst1q_s32(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + int64_t* HWY_RESTRICT aligned) { + vst1q_s64(aligned, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT aligned) { + vst1q_f32(aligned, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT aligned) { + vst1q_f64(aligned, v.raw); +} +#endif + +// ------------------------------ Store 64 + +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE void StoreU(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif + +// ------------------------------ Store 32 + +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint8_t* HWY_RESTRICT p) { + uint32x2_t a = vreinterpret_u32_u8(v.raw); + vst1_lane_u32(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint16_t* HWY_RESTRICT p) { + uint32x2_t a = vreinterpret_u32_u16(v.raw); + vst1_lane_u32(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int8_t* HWY_RESTRICT p) { + int32x2_t a = vreinterpret_s32_s8(v.raw); + vst1_lane_s32(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int16_t* HWY_RESTRICT p) { + int32x2_t a = vreinterpret_s32_s16(v.raw); + vst1_lane_s32(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +// ------------------------------ Store 16 + +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint8_t* HWY_RESTRICT p) { + uint16x4_t a = vreinterpret_u16_u8(v.raw); + vst1_lane_u16(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int8_t* HWY_RESTRICT p) { + int16x4_t a = vreinterpret_s16_s8(v.raw); + vst1_lane_s16(p, a, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} + +// ------------------------------ Store 8 + +HWY_INLINE void StoreU(const Vec128 v, Simd, + uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +HWY_INLINE void StoreU(const Vec128 v, Simd, + int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// float16_t uses the same Raw as uint16_t, so forward to that. +template +HWY_API void StoreU(Vec128 v, Simd /* tag */, + float16_t* HWY_RESTRICT p) { + const Simd du16; + const auto pu16 = reinterpret_cast(p); + return StoreU(Vec128(v.raw), du16, pu16); +} + +// On ARM, Store is the same as StoreU. +template +HWY_INLINE void Store(Vec128 v, Simd d, T* HWY_RESTRICT p) { + StoreU(v, d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_INLINE void Stream(const Vec128 v, Simd d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend to full vector. +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_u8(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vmovl_u16(vget_low_u16(a))); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_u16(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_u32(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_u8(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vreinterpretq_s32_u16(vmovl_u16(vget_low_u16(a)))); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_u16(v.raw)); +} + +// Unsigned: zero-extend to half vector. +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u16(vmovl_u8(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vget_low_u32(vmovl_u16(vget_low_u16(a)))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u32(vmovl_u16(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_u64(vmovl_u32(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s16(vmovl_u8(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint16x8_t a = vmovl_u8(v.raw); + uint32x4_t b = vmovl_u16(vget_low_u16(a)); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(b))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + uint32x4_t a = vmovl_u16(v.raw); + return Vec128(vget_low_s32(vreinterpretq_s32_u32(a))); +} + +// Signed: replicate sign bit to full vector. +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_s8(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128(vmovl_s16(vget_low_s16(a))); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_s16(v.raw)); +} +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s16(vmovl_s8(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + int16x8_t a = vmovl_s8(v.raw); + int32x4_t b = vmovl_s16(vget_low_s16(a)); + return Vec128(vget_low_s32(b)); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s32(vmovl_s16(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_s64(vmovl_s32(v.raw))); +} + +#if __ARM_FP & 2 + +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_f16(vreinterpret_f16_u16(v.raw))); +} +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_f32(vcvt_f32_f16(v.raw))); +} + +#else + +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + const Simd di32; + const Simd du32; + const Simd df32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvt_f64_f32(v.raw)); +} + +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +HWY_INLINE Vec128 PromoteTo(Full128 /* tag */, + const Vec128 v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); + return Vec128(vcvt_f64_s64(i64)); +} + +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +// From full vector to half or quarter +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s32(v.raw)); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s32(v.raw)); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec128(vqmovn_u16(vcombine_u16(a, a))); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s16(v.raw)); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec128(vqmovn_s16(vcombine_s16(a, a))); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s16(v.raw)); +} + +// From half vector to partial half +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return Vec128(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} + +#if __ARM_FP & 2 + +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; +} +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{vcvt_f16_f32(vcombine_f32(v.raw, v.raw))}; +} + +#else + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const Simd di; + const Simd du; + const Simd du16; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128(DemoteTo(du16, bits16).raw); +} + +#endif +#if HWY_ARCH_ARM_A64 + +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_f64(v.raw)); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int64x2_t i64 = vcvtq_s64_f64(v.raw); + return Vec128(vqmovn_s64(i64)); +} +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const int64x1_t i64 = vcvt_s64_f64(v.raw); + // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. + const int64x2_t i64x2 = vcombine_s64(i64, i64); + return Vec128(vqmovn_s64(i64x2)); +} + +#endif + +HWY_API Vec128 U8FromU32(const Vec128 v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec128(vget_low_u8(vuzp1q_u8(w, w))); +} +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128(vuzp1_u8(w, w)); +} + +// In the following DemoteTo functions, |b| is purposely undefined. +// The value a needs to be extended to 128 bits so that vqmovn can be +// used and |b| is undefined so that no extra overhead is introduced. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + uint16x8_t c = vcombine_u16(a.raw, b.raw); + return Vec128(vqmovn_u16(c)); +} + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + Vec128 a = DemoteTo(Simd(), v); + Vec128 b; + int16x8_t c = vcombine_s16(a.raw, b.raw); + return Vec128(vqmovn_s16(c)); +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Convert integer <=> floating-point + +HWY_INLINE Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f32_s32(v.raw)); +} +template +HWY_INLINE Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f32_s32(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_INLINE Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s32_f32(v.raw)); +} +template +HWY_INLINE Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_s32_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_INLINE Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_f64_s64(v.raw)); +} +HWY_INLINE Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_f64_s64(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_INLINE Vec128 ConvertTo(Full128 /* tag */, + const Vec128 v) { + return Vec128(vcvtq_s64_f64(v.raw)); +} +HWY_INLINE Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128(vcvt_s64_f64(v.raw)); +} + +#endif + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// ARMv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template +HWY_API Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +template +HWY_INLINE Vec128 Trunc(const Vec128 v) { + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template +HWY_INLINE Vec128 Round(const Vec128 v) { + const Simd df; + + // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template +HWY_INLINE Vec128 Ceil(const Vec128 v) { + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template +HWY_INLINE Vec128 Floor(const Vec128 v) { + const Simd df; + const Simd di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ NearestInt (Round) + +#if HWY_ARCH_ARM_A64 + +HWY_INLINE Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s32_f32(v.raw)); +} +template +HWY_INLINE Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s32_f32(v.raw)); +} + +#else + +template +HWY_INLINE Vec128 NearestInt(const Vec128 v) { + const Simd di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ================================================== SWIZZLE + +// ------------------------------ Extract half + +// <= 64 bit: just return different type +template +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(v.raw); +} + +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_u8(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_u16(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_u32(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_u64(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_s8(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_s16(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_s32(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_s64(v.raw)); +} +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 LowerHalf(const Vec128 v) { + return Vec128(vget_low_f64(v.raw)); +} +#endif + +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_u8(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_u16(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_u32(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_u64(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_s8(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_s16(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_s32(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_s64(v.raw)); +} +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 UpperHalf(const Vec128 v) { + return Vec128(vget_high_f64(v.raw)); +} +#endif + +// ------------------------------ Extract from 2x 128-bit at constant offset + +// Extracts 128 bits from by skipping the least-significant kBytes. +template +HWY_INLINE Vec128 CombineShiftRightBytes(const Vec128 hi, + const Vec128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Full128 d8; + return BitCast(Full128(), + Vec128(vextq_u8(BitCast(d8, lo).raw, + BitCast(d8, hi).raw, kBytes))); +} + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Need to partially specialize because CombineShiftRightBytes<16> and <0> are +// compile errors. +template +struct ShiftLeftBytesT { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return CombineShiftRightBytes<16 - kBytes>(v, Zero(Full128())); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; + +template +struct ShiftRightBytesT { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return CombineShiftRightBytes(Zero(Full128()), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_INLINE Vec128 ShiftLeftBytes(const Vec128 v) { + return detail::ShiftLeftBytesT()(v); +} + +template +HWY_INLINE Vec128 ShiftLeftLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_INLINE Vec128 ShiftRightBytes(const Vec128 v) { + return detail::ShiftRightBytesT()(v); +} + +template +HWY_INLINE Vec128 ShiftRightLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ Broadcast/splat any lane + +#if HWY_ARCH_ARM_A64 +// Unsigned +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_u16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_u32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_u64(v.raw, kLane)); +} +// Vec128 is defined below. + +// Signed +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_s16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_s32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_s64(v.raw, kLane)); +} +// Vec128 is defined below. + +// Float +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_f32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_f64(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +#else +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} +// Vec128 is defined below. + +// Signed +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} +// Vec128 is defined below. + +// Float +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} + +#endif + +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} +template +HWY_INLINE Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const Full128 d; + const Repartition d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128(vcombine_u8(low, hi))); +#endif +} + +template +HWY_INLINE Vec128 TableLookupBytes( + const Vec128 bytes, + const Vec128 from) { + const Simd d; + const Repartition d8; + return BitCast(d, decltype(Zero(d8))(vtbl1_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bits +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64_u32(v.raw)); +} +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64_s32(v.raw)); +} +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64_f32(v.raw)); +} +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_u32(v.raw)); +} +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_s32(v.raw)); +} +HWY_INLINE Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_f32(v.raw)); +} + +// Swap 64-bit halves +template +HWY_INLINE Vec128 Shuffle1032(const Vec128 v) { + return CombineShiftRightBytes<8>(v, v); +} +template +HWY_INLINE Vec128 Shuffle01(const Vec128 v) { + return CombineShiftRightBytes<8>(v, v); +} + +// Rotate right 32 bits +template +HWY_INLINE Vec128 Shuffle0321(const Vec128 v) { + return CombineShiftRightBytes<4>(v, v); +} + +// Rotate left 32 bits +template +HWY_INLINE Vec128 Shuffle2103(const Vec128 v) { + return CombineShiftRightBytes<12>(v, v); +} + +// Reverse +template +HWY_INLINE Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, + "Shuffle0123 should only be applied to 32-bit types"); + // TODO(janwas): more efficient implementation?, + // It is possible to use two instructions (vrev64q_u32 and vcombine_u32 of the + // high/low parts) instead of the extra memory and load. + static constexpr uint8_t bytes[16] = {12, 13, 14, 15, 8, 9, 10, 11, + 4, 5, 6, 7, 0, 1, 2, 3}; + const Full128 d8; + const Full128 d; + return TableLookupBytes(v, BitCast(d, Load(d8, bytes))); +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + typename Raw128::type raw; +}; + +template +HWY_INLINE Indices128 SetTableIndices(const Full128, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = 16 / sizeof(T); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast(N)); + } +#endif + + const Full128 d8; + alignas(16) uint8_t control[16]; + for (size_t idx_byte = 0; idx_byte < 16; ++idx_byte) { + const size_t idx_lane = idx_byte / sizeof(T); + const size_t mod = idx_byte % sizeof(T); + control[idx_byte] = idx[idx_lane] * sizeof(T) + mod; + } + return Indices128{BitCast(Full128(), Load(d8, control)).raw}; +} + +HWY_INLINE Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128(idx.raw)); +} +HWY_INLINE Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128(idx.raw)); +} +HWY_INLINE Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + const Full128 di; + const Full128 df; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128(idx.raw))); +} + +// ------------------------------ Interleave lanes + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveLower, vzip1, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveLower, vzip1, _, 2) + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveUpper, vzip2, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// For 64 bit types, we only have the "q" version of the function defined as +// interleaving 64-wide registers with 64-wide types in them makes no sense. +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_u64(a.raw, b.raw)); +} +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_s64(a.raw, b.raw)); +} + +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_u64(a.raw, b.raw)); +} +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_s64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + auto flip = CombineShiftRightBytes<8>(a, a); + return CombineShiftRightBytes<8>(b, flip); +} +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + auto flip = CombineShiftRightBytes<8>(a, a); + return CombineShiftRightBytes<8>(b, flip); +} + +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + auto flip = CombineShiftRightBytes<8>(b, b); + return CombineShiftRightBytes<8>(flip, a); +} +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + auto flip = CombineShiftRightBytes<8>(b, b); + return CombineShiftRightBytes<8>(flip, a); +} +#endif + +// Floats +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f32(a.raw, b.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128(vzip1q_f64(a.raw, b.raw)); +} +#endif + +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_f32(a.raw, b.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vzip2q_f64(a.raw, b.raw)); +} +#endif + +// ------------------------------ Zip lanes + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. + +// Full vectors +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u16_u8(vzip1q_u8(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u32_u16(vzip1q_u16(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u64_u32(vzip1q_u32(a.raw, b.raw))); +} + +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s16_s8(vzip1q_s8(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s32_s16(vzip1q_s16(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s64_s32(vzip1q_s32(a.raw, b.raw))); +} + +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u16_u8(vzip2q_u8(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u32_u16(vzip2q_u16(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_u64_u32(vzip2q_u32(a.raw, b.raw))); +} + +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s16_s8(vzip2q_s8(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s32_s16(vzip2q_s16(a.raw, b.raw))); +} +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpretq_s64_s32(vzip2q_s32(a.raw, b.raw))); +} + +// Half vectors or less +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_u16_u8(vzip1_u8(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_u32_u16(vzip1_u16(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_u64_u32(vzip1_u32(a.raw, b.raw))); +} + +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_s16_s8(vzip1_s8(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_s32_s16(vzip1_s16(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128( + vreinterpret_s64_s32(vzip1_s32(a.raw, b.raw))); +} + +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_u16_u8(vzip2_u8(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_u32_u16(vzip2_u16(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_u64_u32(vzip2_u32(a.raw, b.raw))); +} + +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_s16_s8(vzip2_s8(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_s32_s16(vzip2_s16(a.raw, b.raw))); +} +template +HWY_INLINE Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128(vreinterpret_s64_s32(vzip2_s32(a.raw, b.raw))); +} + +// ------------------------------ Blocks + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_INLINE Vec128 ConcatLowerLower(const Vec128 hi, const Vec128 lo) { + const Full128 d64; + return BitCast(Full128(), + InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_INLINE Vec128 ConcatUpperUpper(const Vec128 hi, const Vec128 lo) { + const Full128 d64; + return BitCast(Full128(), + InterleaveUpper(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_INLINE Vec128 ConcatLowerUpper(const Vec128 hi, const Vec128 lo) { + return CombineShiftRightBytes<8>(hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_INLINE Vec128 ConcatUpperLower(const Vec128 hi, const Vec128 lo) { + // TODO(janwas): more efficient implementation? + alignas(16) const uint8_t kBytes[16] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto vec = BitCast(Full128(), Load(Full128(), kBytes)); + return IfThenElse(MaskFromVec(vec), lo, hi); +} + +// ------------------------------ Odd/even lanes + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + alignas(16) constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(Full128(), Load(Full128(), kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ------------------------------ ARMv7 int64 comparisons (requires Shuffle2301) + +#if HWY_ARCH_ARM_V7 + +template +HWY_INLINE Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, BitCast(d32, a) == BitCast(d32, b)); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template +HWY_INLINE Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, BitCast(d32, a) == BitCast(d32, b)); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_INLINE Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} +HWY_INLINE Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} + +template +HWY_INLINE Mask128 operator>(const Vec128 a, + const Vec128 b) { + return b < a; +} +#endif + +// ------------------------------ Reductions + +#if HWY_ARCH_ARM_A64 +// Supported for 32b and 64b vector types. Returns the sum in each lane. +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_u32(vaddvq_u32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_s32(vaddvq_s32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_f32(vaddvq_f32(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_u64(vaddvq_u64(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_s64(vaddvq_s64(v.raw))); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return Vec128(vdupq_n_f64(vaddvq_f64(v.raw))); +} +#else +// ARMv7 version for everything except doubles. +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + uint32x4x2_t v0 = vuzpq_u32(v.raw, v.raw); + uint32x4_t c0 = vaddq_u32(v0.val[0], v0.val[1]); + uint32x4x2_t v1 = vuzpq_u32(c0, c0); + return Vec128(vaddq_u32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + int32x4x2_t v0 = vuzpq_s32(v.raw, v.raw); + int32x4_t c0 = vaddq_s32(v0.val[0], v0.val[1]); + int32x4x2_t v1 = vuzpq_s32(c0, c0); + return Vec128(vaddq_s32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + float32x4x2_t v0 = vuzpq_f32(v.raw, v.raw); + float32x4_t c0 = vaddq_f32(v0.val[0], v0.val[1]); + float32x4x2_t v1 = vuzpq_f32(c0, c0); + return Vec128(vaddq_f32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return v + CombineShiftRightBytes<8>(v, v); +} +HWY_INLINE Vec128 SumOfLanes(const Vec128 v) { + return v + CombineShiftRightBytes<8>(v, v); +} +#endif + +namespace detail { + +// For u32/i32/f32. +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// For u64/i64[/f64]. +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +} // namespace detail + +template +HWY_API Vec128 MinOfLanes(const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(Full128(), mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vreinterpret_u16_u8(x8)[0]; +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (uint64_t(x8[1]) << 8) | x8[0]; +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Simd d; + const Simd du; + const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return vaddv_u8(values.raw); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return vget_lane_u64(x8, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint16_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u16(values.raw); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const Simd d; + const Simd du; + const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u16(values.raw); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return vget_lane_u64(x4, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + alignas(16) constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u32(values.raw); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; + const Simd d; + const Simd du; + const Vec128 slice(Load(Simd(), kSliceLanes).raw); + const Vec128 values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u32(values.raw); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return vget_lane_u64(x2, 0); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + alignas(16) constexpr uint64_t kSliceLanes[2] = {1, 2}; + const Full128 d; + const Full128 du; + const Vec128 values = + BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u64(values.raw); +#else + return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1); +#endif +} + +template +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 m) { + const Simd d; + const Simd du; + const Vec128 values = + BitCast(du, VecFromMask(d, m)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +// Returns the lowest N for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1)); +} + +template +HWY_INLINE uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { + const Full128 di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return vaddvq_s8(ones); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return x8[0] + x8[1]; +#endif +} +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, const Mask128 mask) { + const Full128 di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return vaddvq_s16(ones); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return x4[0] + x4[1]; +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 mask) { + const Full128 di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return vaddvq_s32(ones); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return x2[0] + x2[1]; +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128 mask) { +#if HWY_ARCH_ARM_A64 + const Full128 di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); + return vaddvq_s64(ones); +#else + const Full128 di; + const int64x2_t ones = + vshrq_n_u64(BitCast(di, VecFromMask(Full128(), mask)).raw, 63); + return ones[0] + ones[1]; +#endif +} + +} // namespace detail + +// Full +template +HWY_INLINE size_t CountTrue(const Mask128 mask) { + return detail::CountTrue(hwy::SizeTag(), mask); +} + +// Partial +template +HWY_INLINE size_t CountTrue(const Mask128 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template +HWY_INLINE size_t StoreMaskBits(const Mask128 mask, uint8_t* p) { + const uint64_t bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&bits, p); + return kNumBytes; +} + +// Full +template +HWY_INLINE bool AllFalse(const Mask128 m) { +#if HWY_ARCH_ARM_A64 + return (vmaxvq_u32(m.raw) == 0); +#else + const auto v64 = BitCast(Full128(), VecFromMask(Full128(), m)); + uint32x2_t a = vqmovn_u64(v64.raw); + return vreinterpret_u64_u32(a)[0] == 0; +#endif +} + +// Partial +template +HWY_INLINE bool AllFalse(const Mask128 m) { + return detail::BitsFromMask(m) == 0; +} + +template +HWY_INLINE bool AllTrue(const Mask128 m) { + const Simd d; + return AllFalse(VecFromMask(d, m) == Zero(d)); +} + +// ------------------------------ Compress + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +HWY_INLINE Vec128 Load8Bytes(Full128 /*d*/, + const uint8_t* bytes) { + return Vec128(vreinterpretq_u8_u64( + vld1q_dup_u64(reinterpret_cast(bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template +HWY_INLINE Vec128 Load8Bytes(Simd d, + const uint8_t* bytes) { + return Load(d, bytes); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, + 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 2, + 6, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, 6, 0, + 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, + 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, + 2, 8, 0, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 4, 8, + 0, 0, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 2, 4, 8, 0, + 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, + 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 0, + 0, 2, 6, 8, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 0, 0, 4, + 6, 8, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 0, 0, 2, 4, 6, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, + 0, 0, 2, 10, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, + 4, 10, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 2, 4, + 10, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 6, 10, 0, 0, + 0, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, + 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 0, + 0, 4, 6, 10, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 0, 0, 2, + 4, 6, 10, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 0, 0, 8, 10, 0, + 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, + 0, 0, 4, 8, 10, 0, 0, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, + 2, 4, 8, 10, 0, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 6, 8, + 10, 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 2, 6, 8, 10, + 0, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 4, 6, 8, 10, 0, 0, + 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 0, + 0, 2, 4, 6, 8, 10, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 12, + 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, + 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, + 0, 0, 2, 4, 12, 0, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, + 6, 12, 0, 0, 0, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 2, 6, + 12, 0, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 4, 6, 12, 0, + 0, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, + 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 0, + 0, 8, 12, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 0, 0, 2, + 8, 12, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 4, 8, 12, + 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 0, 0, 2, 4, 8, 12, 0, + 0, 0, 6, 8, 12, 0, 0, 0, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, + 2, 6, 8, 12, 0, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 4, 6, + 8, 12, 0, 0, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 2, 4, 6, 8, + 12, 0, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 10, 12, 0, 0, 0, 0, + 0, 0, 0, 10, 12, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 0, + 0, 2, 10, 12, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 0, 0, 4, + 10, 12, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 0, 0, 2, 4, 10, + 12, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 0, 0, 6, 10, 12, 0, 0, + 0, 0, 2, 6, 10, 12, 0, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, + 4, 6, 10, 12, 0, 0, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 2, 4, + 6, 10, 12, 0, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 8, 10, 12, 0, + 0, 0, 0, 0, 0, 8, 10, 12, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, + 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 0, + 0, 4, 8, 10, 12, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 0, 0, 2, + 4, 8, 10, 12, 0, 0, 6, 8, 10, 12, 0, 0, 0, 0, 0, 6, 8, 10, + 12, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 0, 0, 2, 6, 8, 10, 12, + 0, 0, 4, 6, 8, 10, 12, 0, 0, 0, 0, 4, 6, 8, 10, 12, 0, 0, + 2, 4, 6, 8, 10, 12, 0, 0, 0, 2, 4, 6, 8, 10, 12, 0, 14, 0, + 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, + 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, + 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 0, + 0, 2, 4, 14, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 0, 0, 6, + 14, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 0, 0, 2, 6, 14, + 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 0, 0, 4, 6, 14, 0, 0, + 0, 0, 2, 4, 6, 14, 0, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, + 8, 14, 0, 0, 0, 0, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 2, 8, + 14, 0, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 4, 8, 14, 0, + 0, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, + 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 0, + 0, 6, 8, 14, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 0, 0, 2, + 6, 8, 14, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 0, 0, 4, 6, 8, + 14, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 0, 0, 2, 4, 6, 8, 14, + 0, 0, 10, 14, 0, 0, 0, 0, 0, 0, 0, 10, 14, 0, 0, 0, 0, 0, + 2, 10, 14, 0, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 4, 10, + 14, 0, 0, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 2, 4, 10, 14, + 0, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 6, 10, 14, 0, 0, 0, + 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 0, + 0, 2, 6, 10, 14, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 0, 0, 4, + 6, 10, 14, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 0, 0, 2, 4, 6, + 10, 14, 0, 0, 8, 10, 14, 0, 0, 0, 0, 0, 0, 8, 10, 14, 0, 0, + 0, 0, 2, 8, 10, 14, 0, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, + 4, 8, 10, 14, 0, 0, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 2, 4, + 8, 10, 14, 0, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 6, 8, 10, 14, + 0, 0, 0, 0, 0, 6, 8, 10, 14, 0, 0, 0, 2, 6, 8, 10, 14, 0, + 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 4, 6, 8, 10, 14, 0, 0, 0, + 0, 4, 6, 8, 10, 14, 0, 0, 2, 4, 6, 8, 10, 14, 0, 0, 0, 2, + 4, 6, 8, 10, 14, 0, 12, 14, 0, 0, 0, 0, 0, 0, 0, 12, 14, 0, + 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, + 0, 0, 4, 12, 14, 0, 0, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, + 2, 4, 12, 14, 0, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 6, 12, + 14, 0, 0, 0, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 2, 6, 12, 14, + 0, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 4, 6, 12, 14, 0, 0, + 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 0, + 0, 2, 4, 6, 12, 14, 0, 0, 8, 12, 14, 0, 0, 0, 0, 0, 0, 8, + 12, 14, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 0, 0, 2, 8, 12, + 14, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 0, 0, 4, 8, 12, 14, 0, + 0, 0, 2, 4, 8, 12, 14, 0, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, + 6, 8, 12, 14, 0, 0, 0, 0, 0, 6, 8, 12, 14, 0, 0, 0, 2, 6, + 8, 12, 14, 0, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 4, 6, 8, 12, + 14, 0, 0, 0, 0, 4, 6, 8, 12, 14, 0, 0, 2, 4, 6, 8, 12, 14, + 0, 0, 0, 2, 4, 6, 8, 12, 14, 0, 10, 12, 14, 0, 0, 0, 0, 0, + 0, 10, 12, 14, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 0, 0, 2, + 10, 12, 14, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 0, 0, 4, 10, 12, + 14, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 0, 0, 2, 4, 10, 12, 14, + 0, 0, 6, 10, 12, 14, 0, 0, 0, 0, 0, 6, 10, 12, 14, 0, 0, 0, + 2, 6, 10, 12, 14, 0, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 4, 6, + 10, 12, 14, 0, 0, 0, 0, 4, 6, 10, 12, 14, 0, 0, 2, 4, 6, 10, + 12, 14, 0, 0, 0, 2, 4, 6, 10, 12, 14, 0, 8, 10, 12, 14, 0, 0, + 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 0, + 0, 2, 8, 10, 12, 14, 0, 0, 4, 8, 10, 12, 14, 0, 0, 0, 0, 4, + 8, 10, 12, 14, 0, 0, 2, 4, 8, 10, 12, 14, 0, 0, 0, 2, 4, 8, + 10, 12, 14, 0, 6, 8, 10, 12, 14, 0, 0, 0, 0, 6, 8, 10, 12, 14, + 0, 0, 2, 6, 8, 10, 12, 14, 0, 0, 0, 2, 6, 8, 10, 12, 14, 0, + 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, + 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[16 * 16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, // + 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, // + 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[4 * 16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template +HWY_API Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits(hwy::SizeTag(), mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT aligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + Store(detail::Compress(v, mask_bits), d, aligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved3 + +// 128 bits +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + Full128 /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + const uint8x16x3_t triple = {v0.raw, v1.raw, v2.raw}; + vst3q_u8(unaligned, triple); +} + +// 64 bits +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + const uint8x8x3_t triple = {v0.raw, v1.raw, v2.raw}; + vst3_u8(unaligned, triple); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + alignas(16) uint8_t buf[24]; + const uint8x8x3_t triple = {v0.raw, v1.raw, v2.raw}; + vst3_u8(buf, triple); + CopyBytes(buf, unaligned); +} + +// ------------------------------ StoreInterleaved4 + +// 128 bits +HWY_API void StoreInterleaved4(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + const Vec128 v3, + Full128 /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + const uint8x16x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + vst4q_u8(unaligned, quad); +} + +// 64 bits +HWY_API void StoreInterleaved4(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + const Vec128 v3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + const uint8x8x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + vst4_u8(unaligned, quad); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template +HWY_API void StoreInterleaved4(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + const Vec128 v3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + alignas(16) uint8_t buf[32]; + const uint8x8x4_t quad = {v0.raw, v1.raw, v2.raw, v3.raw}; + vst4_u8(buf, quad); + CopyBytes(buf, unaligned); +} + +// ================================================== Operator wrapper + +// These apply to all x86_*-inl.h because there are no restrictions on V. + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_TPL +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_EVAL + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/rvv-inl.h b/third_party/highway/hwy/ops/rvv-inl.h new file mode 100644 index 000000000000..ea359a44ca83 --- /dev/null +++ b/third_party/highway/hwy/ops/rvv-inl.h @@ -0,0 +1,1762 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t::type; + +template +using TFromV = TFromD>; + +#define HWY_IF_UNSIGNED_V(V) hwy::EnableIf>()>* = nullptr +#define HWY_IF_SIGNED_V(V) \ + hwy::EnableIf>() && !IsFloat>()>* = nullptr +#define HWY_IF_FLOAT_V(V) hwy::EnableIf>()>* = nullptr + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes: (1/Nth of a register, one bit per lane) +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, NAME, OP) \ + X_MACRO(32, NAME, OP) \ + X_MACRO(16, NAME, OP) \ + X_MACRO(8, NAME, OP) \ + X_MACRO(4, NAME, OP) \ + X_MACRO(2, NAME, OP) \ + X_MACRO(1, NAME, OP) + +// For given SEW, iterate over all LMUL. Precompute SEW/LMUL => MLEN because the +// preprocessor cannot easily do it. +#define HWY_RVV_FOREACH_08(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 1, 8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 2, 4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 4, 2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 8, 1, NAME, OP) + +#define HWY_RVV_FOREACH_16(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 1, 16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 2, 8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 4, 4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 8, 2, NAME, OP) + +#define HWY_RVV_FOREACH_32(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 1, 32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 2, 16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 4, 8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 8, 4, NAME, OP) + +#define HWY_RVV_FOREACH_64(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, 1, 64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, 2, 32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, 4, 16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, 8, 8, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_08(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_16(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_32(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_64(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_08(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_16(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_32(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_64(X_MACRO, int, i, NAME, OP) + +// SEW for float: +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_16(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_32(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_64(X_MACRO, float, f, NAME, OP) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories for a given SEW: +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(CHAR, SEW, LMUL) D##CHAR##SEW##m##LMUL +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##m##LMUL##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// TODO(janwas): remove typedefs and only use HWY_RVV_V etc. directly + +// TODO(janwas): do we want fractional LMUL? (can encode as negative) +// Mixed-precision code can use LMUL 1..8 and that should be enough unless they +// need many registers. +#define HWY_SPECIALIZE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + using HWY_RVV_D(CHAR, SEW, LMUL) = \ + Simd; \ + using V##CHAR##SEW##m##LMUL = HWY_RVV_V(BASE, SEW, LMUL); \ + template <> \ + struct DFromV_t { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = Simd; \ + }; +using Vf16m1 = vfloat16m1_t; +using Vf16m2 = vfloat16m2_t; +using Vf16m4 = vfloat16m4_t; +using Vf16m8 = vfloat16m8_t; +using Df16m1 = Simd; +using Df16m2 = Simd; +using Df16m4 = Simd; +using Df16m8 = Simd; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _) +#undef HWY_SPECIALIZE + +// vector = f(d), e.g. Zero +#define HWY_RVV_RETV_ARGD(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(CHAR, SEW, LMUL) d) { \ + (void)Lanes(d); \ + return v##OP##_##CHAR##SEW##m##LMUL(); \ + } + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##m##LMUL(v); \ + } + +// vector = f(vector, scalar), e.g. detail::Add +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##m##LMUL(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(a, b); \ + } + +// ================================================== INIT + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this actually changes VL! +// vlenb is not exposed through intrinsics and vreadvl is not VLMAX. +#define HWY_RVV_LANES(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API size_t NAME(HWY_RVV_D(CHAR, SEW, LMUL) /* d */) { \ + return v##OP##SEW##m##LMUL(); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e) +#undef HWY_RVV_LANES + +// ------------------------------ Zero + +HWY_RVV_FOREACH(HWY_RVV_RETV_ARGD, Zero, zero) + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_RVV_SET(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(CHAR, SEW, LMUL) d, HWY_RVV_T(BASE, SEW) arg) { \ + (void)Lanes(d); \ + return v##OP##_##CHAR##SEW##m##LMUL(arg); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f) +#undef HWY_RVV_SET + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +HWY_RVV_FOREACH(HWY_RVV_RETV_ARGD, Undefined, undefined) +} // namespace detail + +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_RVV_CAST_NOP(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } + +// Other integers +#define HWY_RVV_CAST_UI(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API vuint8m##LMUL##_t BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##m##LMUL##_u8m##LMUL(v); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, vuint8m##LMUL##_t v) { \ + return v##OP##_v_u8m##LMUL##_##CHAR##SEW##m##LMUL(v); \ + } + +// Float: first cast to/from unsigned +#define HWY_RVV_CAST_F(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API vuint8m##LMUL##_t BitCastToByte(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_u##SEW##m##LMUL##_u8m##LMUL( \ + v##OP##_v_f##SEW##m##LMUL##_u##SEW##m##LMUL(v)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, vuint8m##LMUL##_t v) { \ + return v##OP##_v_u##SEW##m##LMUL##_f##SEW##m##LMUL( \ + v##OP##_v_u8m##LMUL##_u##SEW##m##LMUL(v)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_NOP, _, _) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_UI, _, reinterpret) +HWY_RVV_FOREACH_UI16(HWY_RVV_CAST_UI, _, reinterpret) +HWY_RVV_FOREACH_UI32(HWY_RVV_CAST_UI, _, reinterpret) +HWY_RVV_FOREACH_UI64(HWY_RVV_CAST_UI, _, reinterpret) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_F, _, reinterpret) + +#undef HWY_RVV_CAST_NOP +#undef HWY_RVV_CAST_UI +#undef HWY_RVV_CAST_F + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +namespace detail { + +template >> +HWY_API VFromD BitCastToUnsigned(V v) { + return BitCast(DU(), v); +} + +} // namespace detail + +// ------------------------------ Iota + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGD, Iota0, id_v) + +template > +HWY_API VFromD Iota0(const D /*d*/) { + Lanes(DU()); + return BitCastToUnsigned(Iota0(DU())); +} + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not ) + +template +HWY_API V Not(const V v) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, And, and_vx) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and) + +template +HWY_API V And(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +// Scalar argument plus mask. Used by VecFromMask. +#define HWY_RVV_OR_MASK(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_T(BASE, SEW) imm, \ + HWY_RVV_M(MLEN) mask, HWY_RVV_V(BASE, SEW, LMUL) maskedoff) { \ + return v##OP##_##CHAR##SEW##m##LMUL##_m(mask, maskedoff, v, imm); \ + } + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_OR_MASK, Or, or_vx) +} // namespace detail + +#undef HWY_RVV_OR_MASK + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or) + +template +HWY_API V Or(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, Xor, xor_vx) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor) + +template +HWY_API V Xor(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot + +template +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj) + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, Add, add_vx) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, Add, fadd_vf) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd) + +// ------------------------------ Sub +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub) + +// ------------------------------ SaturatedAdd + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub) + +// ------------------------------ AverageRound + +// TODO(janwas): check vxrm rounding mode +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu) + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vx_##CHAR##SEW##m##LMUL(v, kBits); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return v##OP##_vx_##CHAR##SEW##m##LMUL(v, static_cast(bits)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra) + +#undef HWY_RVV_SHIFT + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(v, bits); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(v, \ + detail::BitCastToUnsigned(bits)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ Min + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, Max, maxu_vx) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, Max, max_vx) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, Max, fmax_vf) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax) + +// ------------------------------ Mul + +HWY_RVV_FOREACH_UI16(HWY_RVV_RETV_ARGVV, Mul, mul) +HWY_RVV_FOREACH_UI32(HWY_RVV_RETV_ARGVV, Mul, mul) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul) + +// ------------------------------ MulHigh + +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh) + +// ------------------------------ Div + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv) + +// ------------------------------ ApproximateReciprocal + +// TODO(janwas): not yet supported in intrinsics +template +HWY_API V ApproximateReciprocal(const V v) { + return Set(DFromV(), 1) / v; +} +// HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frece7) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt) + +// ------------------------------ ApproximateReciprocalSqrt + +// TODO(janwas): not yet supported in intrinsics +template +HWY_API V ApproximateReciprocalSqrt(const V v) { + return ApproximateReciprocal(Sqrt(v)); +} +// HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrte7) + +// ------------------------------ MulAdd +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(add, mul, x); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SLEN 8 / LMUL 1 = 1/8th +// of all bits; SLEN 8 / LMUL 4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + (void)Lanes(DFromV()); \ + return v##OP##_vv_##CHAR##SEW##m##LMUL##_b##MLEN(a, b); \ + } + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq) + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne) + +// ------------------------------ Lt +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt) + +// ------------------------------ Gt + +template +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +// ------------------------------ Le +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle) + +#undef HWY_RVV_RETM_ARGVV + +// ------------------------------ Ge + +template +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +// ------------------------------ TestBit + +template +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return Ne(And(a, bit), Zero(DFromV())); +} + +// ------------------------------ Not + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return vm##OP##_m_b##MLEN(m); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +#undef HWY_RVV_RETM_ARGM + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return vm##OP##_mm_b##MLEN(b, a); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andnot) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_vvm_##CHAR##SEW##m##LMUL(m, no, yes); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge) + +#undef HWY_RVV_IF_THEN_ELSE +// ------------------------------ IfThenElseZero + +template +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse + +template +HWY_API V IfThenZeroElse(const M mask, const V no) { + return IfThenElse(mask, Zero(DFromV()), no); +} + +// ------------------------------ MaskFromVec + +template +HWY_API auto MaskFromVec(const V v) -> decltype(Eq(v, v)) { + return Ne(v, Zero(DFromV())); +} + +template +using MFromD = decltype(MaskFromVec(Zero(D()))); + +template +HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD would fail. + return mask; +} + +// ------------------------------ VecFromMask + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + const auto v0 = Zero(d); + return detail::Or(v0, -1, mask, v0); +} + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); +} + +// ------------------------------ ZeroIfNegative + +template +HWY_API V ZeroIfNegative(const V v) { + const auto v0 = Zero(DFromV()); + // We already have a zero constant, so avoid IfThenZeroElse. + return IfThenElse(Lt(v, v0), v0, v); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ AllFalse + +#define HWY_RVV_ALL_FALSE(MLEN, NAME, OP) \ + HWY_API bool AllFalse(const HWY_RVV_M(MLEN) m) { \ + return vfirst_m_b##MLEN(m) < 0; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_ALL_FALSE, _, _) +#undef HWY_RVV_ALL_FALSE + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(MLEN, NAME, OP) \ + HWY_API bool AllTrue(HWY_RVV_M(MLEN) m) { \ + return AllFalse(vmnot_m_b##MLEN(m)); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(MLEN, NAME, OP) \ + HWY_API size_t CountTrue(HWY_RVV_M(MLEN) m) { return vpopc_m_b##MLEN(m); } +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(CHAR, SEW, LMUL) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + (void)Lanes(d); \ + return v##OP##SEW##_v_##CHAR##SEW##m##LMUL(p); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le) +#undef HWY_RVV_LOAD + +// Partial load +template +HWY_API VFromD> Load(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ LoadU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +#define HWY_RVV_RET_ARGVDP(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(CHAR, SEW, LMUL) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + (void)Lanes(d); \ + return v##OP##SEW##_v_##CHAR##SEW##m##LMUL(p, v); \ + } +HWY_RVV_FOREACH(HWY_RVV_RET_ARGVDP, Store, se) +#undef HWY_RVV_RET_ARGVDP + +// ------------------------------ StoreU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Stream + +template +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##m##LMUL( \ + base, detail::BitCastToUnsigned(offset), v); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sx) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<2>(index)); +} + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + const VFromD> index) { + return ScatterOffset(v, d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ GatherOffset + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##m##LMUL( \ + base, detail::BitCastToUnsigned(offset)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lx) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<2>(index)); +} + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + return GatherOffset(d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b, \ + HWY_RVV_V(BASE, SEW, LMUL) c, HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const v##BASE##SEW##m##LMUL##x3_t triple = \ + vcreate_##CHAR##SEW##m##LMUL##x3(a, b, c); \ + return v##OP##e8_v_##CHAR##SEW##m##LMUL##x3(unaligned, triple); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_STORE3(uint, u, 8, 1, 8, StoreInterleaved3, sseg3) +HWY_RVV_STORE3(uint, u, 8, 2, 4, StoreInterleaved3, sseg3) + +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned) { \ + const v##BASE##SEW##m##LMUL##x4_t quad = \ + vcreate_##CHAR##SEW##m##LMUL##x4(v0, v1, v2, v3); \ + return v##OP##e8_v_##CHAR##SEW##m##LMUL##x4(aligned, quad); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_STORE4(uint, u, 8, 1, 8, StoreInterleaved4, sseg4) +HWY_RVV_STORE4(uint, u, 8, 2, 4, StoreInterleaved4, sseg4) + +#undef HWY_RVV_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo U + +HWY_API Vu16m2 PromoteTo(Du16m2 /* d */, Vu8m1 v) { return vzext_vf2_u16m2(v); } +HWY_API Vu16m4 PromoteTo(Du16m4 /* d */, Vu8m2 v) { return vzext_vf2_u16m4(v); } +HWY_API Vu16m8 PromoteTo(Du16m8 /* d */, Vu8m4 v) { return vzext_vf2_u16m8(v); } + +HWY_API Vu32m4 PromoteTo(Du32m4 /* d */, Vu8m1 v) { return vzext_vf4_u32m4(v); } +HWY_API Vu32m8 PromoteTo(Du32m8 /* d */, Vu8m2 v) { return vzext_vf4_u32m8(v); } + +HWY_API Vu32m2 PromoteTo(Du32m2 /* d */, const Vu16m1 v) { + return vzext_vf2_u32m2(v); +} +HWY_API Vu32m4 PromoteTo(Du32m4 /* d */, const Vu16m2 v) { + return vzext_vf2_u32m4(v); +} +HWY_API Vu32m8 PromoteTo(Du32m8 /* d */, const Vu16m4 v) { + return vzext_vf2_u32m8(v); +} + +HWY_API Vu64m2 PromoteTo(Du64m2 /* d */, const Vu32m1 v) { + return vzext_vf2_u64m2(v); +} +HWY_API Vu64m4 PromoteTo(Du64m4 /* d */, const Vu32m2 v) { + return vzext_vf2_u64m4(v); +} +HWY_API Vu64m8 PromoteTo(Du64m8 /* d */, const Vu32m4 v) { + return vzext_vf2_u64m8(v); +} + +template +HWY_API VFromD> PromoteTo(Simd d, + VFromD> v) { + return BitCast(d, PromoteTo(Simd(), v)); +} + +template +HWY_API VFromD> PromoteTo(Simd d, + VFromD> v) { + return BitCast(d, PromoteTo(Simd(), v)); +} + +template +HWY_API VFromD> PromoteTo(Simd d, + VFromD> v) { + return BitCast(d, PromoteTo(Simd(), v)); +} + +// ------------------------------ PromoteTo I + +HWY_API Vi16m2 PromoteTo(Di16m2 /* d */, Vi8m1 v) { return vsext_vf2_i16m2(v); } +HWY_API Vi16m4 PromoteTo(Di16m4 /* d */, Vi8m2 v) { return vsext_vf2_i16m4(v); } +HWY_API Vi16m8 PromoteTo(Di16m8 /* d */, Vi8m4 v) { return vsext_vf2_i16m8(v); } + +HWY_API Vi32m4 PromoteTo(Di32m4 /* d */, Vi8m1 v) { return vsext_vf4_i32m4(v); } +HWY_API Vi32m8 PromoteTo(Di32m8 /* d */, Vi8m2 v) { return vsext_vf4_i32m8(v); } + +HWY_API Vi32m2 PromoteTo(Di32m2 /* d */, const Vi16m1 v) { + return vsext_vf2_i32m2(v); +} +HWY_API Vi32m4 PromoteTo(Di32m4 /* d */, const Vi16m2 v) { + return vsext_vf2_i32m4(v); +} +HWY_API Vi32m8 PromoteTo(Di32m8 /* d */, const Vi16m4 v) { + return vsext_vf2_i32m8(v); +} + +HWY_API Vi64m2 PromoteTo(Di64m2 /* d */, const Vi32m1 v) { + return vsext_vf2_i64m2(v); +} +HWY_API Vi64m4 PromoteTo(Di64m4 /* d */, const Vi32m2 v) { + return vsext_vf2_i64m4(v); +} +HWY_API Vi64m8 PromoteTo(Di64m8 /* d */, const Vi32m4 v) { + return vsext_vf2_i64m8(v); +} + +// ------------------------------ PromoteTo F + +HWY_API Vf32m2 PromoteTo(Df32m2 /* d */, const Vf16m1 v) { + return vfwcvt_f_f_v_f32m2(v); +} +HWY_API Vf32m4 PromoteTo(Df32m4 /* d */, const Vf16m2 v) { + return vfwcvt_f_f_v_f32m4(v); +} +HWY_API Vf32m8 PromoteTo(Df32m8 /* d */, const Vf16m4 v) { + return vfwcvt_f_f_v_f32m8(v); +} + +HWY_API Vf64m2 PromoteTo(Df64m2 /* d */, const Vf32m1 v) { + return vfwcvt_f_f_v_f64m2(v); +} +HWY_API Vf64m4 PromoteTo(Df64m4 /* d */, const Vf32m2 v) { + return vfwcvt_f_f_v_f64m4(v); +} +HWY_API Vf64m8 PromoteTo(Df64m8 /* d */, const Vf32m4 v) { + return vfwcvt_f_f_v_f64m8(v); +} + +HWY_API Vf64m2 PromoteTo(Df64m2 /* d */, const Vi32m1 v) { + return vfwcvt_f_x_v_f64m2(v); +} +HWY_API Vf64m4 PromoteTo(Df64m4 /* d */, const Vi32m2 v) { + return vfwcvt_f_x_v_f64m4(v); +} +HWY_API Vf64m8 PromoteTo(Df64m8 /* d */, const Vi32m4 v) { + return vfwcvt_f_x_v_f64m8(v); +} + +// ------------------------------ DemoteTo U + +// First clamp negative numbers to zero to match x86 packus. +HWY_API Vu16m1 DemoteTo(Du16m1 /* d */, const Vi32m2 v) { + return vnclipu_wx_u16m1(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} +HWY_API Vu16m2 DemoteTo(Du16m2 /* d */, const Vi32m4 v) { + return vnclipu_wx_u16m2(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} +HWY_API Vu16m4 DemoteTo(Du16m4 /* d */, const Vi32m8 v) { + return vnclipu_wx_u16m4(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} + +HWY_API Vu8m1 DemoteTo(Du8m1 /* d */, const Vi32m4 v) { + return vnclipu_wx_u8m1(DemoteTo(Du16m2(), v), 0); +} +HWY_API Vu8m2 DemoteTo(Du8m2 /* d */, const Vi32m8 v) { + return vnclipu_wx_u8m2(DemoteTo(Du16m4(), v), 0); +} + +HWY_API Vu8m1 DemoteTo(Du8m1 /* d */, const Vi16m2 v) { + return vnclipu_wx_u8m1(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} +HWY_API Vu8m2 DemoteTo(Du8m2 /* d */, const Vi16m4 v) { + return vnclipu_wx_u8m2(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} +HWY_API Vu8m4 DemoteTo(Du8m4 /* d */, const Vi16m8 v) { + return vnclipu_wx_u8m4(detail::BitCastToUnsigned(detail::Max(v, 0)), 0); +} + +HWY_API Vu8m1 U8FromU32(const Vu32m4 v) { + return vnclipu_wx_u8m1(vnclipu_wx_u16m2(v, 0), 0); +} +HWY_API Vu8m2 U8FromU32(const Vu32m8 v) { + return vnclipu_wx_u8m2(vnclipu_wx_u16m4(v, 0), 0); +} + +// ------------------------------ DemoteTo I + +HWY_API Vi8m1 DemoteTo(Di8m1 /* d */, const Vi16m2 v) { + return vnclip_wx_i8m1(v, 0); +} +HWY_API Vi8m2 DemoteTo(Di8m2 /* d */, const Vi16m4 v) { + return vnclip_wx_i8m2(v, 0); +} +HWY_API Vi8m4 DemoteTo(Di8m4 /* d */, const Vi16m8 v) { + return vnclip_wx_i8m4(v, 0); +} + +HWY_API Vi16m1 DemoteTo(Di16m1 /* d */, const Vi32m2 v) { + return vnclip_wx_i16m1(v, 0); +} +HWY_API Vi16m2 DemoteTo(Di16m2 /* d */, const Vi32m4 v) { + return vnclip_wx_i16m2(v, 0); +} +HWY_API Vi16m4 DemoteTo(Di16m4 /* d */, const Vi32m8 v) { + return vnclip_wx_i16m4(v, 0); +} + +HWY_API Vi8m1 DemoteTo(Di8m1 d, const Vi32m4 v) { + return DemoteTo(d, DemoteTo(Di16m2(), v)); +} +HWY_API Vi8m2 DemoteTo(Di8m2 d, const Vi32m8 v) { + return DemoteTo(d, DemoteTo(Di16m4(), v)); +} + +// ------------------------------ DemoteTo F + +HWY_API Vf16m1 DemoteTo(Df16m1 /* d */, const Vf32m2 v) { + return vfncvt_rod_f_f_w_f16m1(v); +} +HWY_API Vf16m2 DemoteTo(Df16m2 /* d */, const Vf32m4 v) { + return vfncvt_rod_f_f_w_f16m2(v); +} +HWY_API Vf16m4 DemoteTo(Df16m4 /* d */, const Vf32m8 v) { + return vfncvt_rod_f_f_w_f16m4(v); +} + +HWY_API Vf32m1 DemoteTo(Df32m1 /* d */, const Vf64m2 v) { + return vfncvt_rod_f_f_w_f32m1(v); +} +HWY_API Vf32m2 DemoteTo(Df32m2 /* d */, const Vf64m4 v) { + return vfncvt_rod_f_f_w_f32m2(v); +} +HWY_API Vf32m4 DemoteTo(Df32m4 /* d */, const Vf64m8 v) { + return vfncvt_rod_f_f_w_f32m4(v); +} + +HWY_API Vi32m1 DemoteTo(Di32m1 /* d */, const Vf64m2 v) { + return vfncvt_rtz_x_f_w_i32m1(v); +} +HWY_API Vi32m2 DemoteTo(Di32m2 /* d */, const Vf64m4 v) { + return vfncvt_rtz_x_f_w_i32m2(v); +} +HWY_API Vi32m4 DemoteTo(Di32m4 /* d */, const Vf64m8 v) { + return vfncvt_rtz_x_f_w_i32m4(v); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(CHAR, SEW, LMUL) /* d */, HWY_RVV_V(int, SEW, LMUL) v) { \ + return vfcvt_f_x_v_f##SEW##m##LMUL(v); \ + } \ + /* Truncates (rounds toward zero). */ \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(i, SEW, LMUL) /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_rtz_x_f_v_i##SEW##m##LMUL(v); \ + } \ + /* Uses default rounding mode. */ \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_x_f_v_i##SEW##m##LMUL(v); \ + } + +// API only requires f32 but we provide f64 for internal use (otherwise, it +// seems difficult to implement Iota without a _mf2 vector half). +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _) +#undef HWY_RVV_CONVERT + +// ================================================== SWIZZLE + +// ------------------------------ Compress + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return v##OP##_vm_##CHAR##SEW##m##LMUL(mask, v, v); \ + } + +HWY_RVV_FOREACH_UI16(HWY_RVV_COMPRESS, Compress, compress) +HWY_RVV_FOREACH_UI32(HWY_RVV_COMPRESS, Compress, compress) +HWY_RVV_FOREACH_UI64(HWY_RVV_COMPRESS, Compress, compress) +HWY_RVV_FOREACH_F(HWY_RVV_COMPRESS, Compress, compress) +#undef HWY_RVV_COMPRESS + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT aligned) { + Store(Compress(v, mask), d, aligned); + return CountTrue(mask); +} + +// ------------------------------ TableLookupLanes + +template > +HWY_API VFromD SetTableIndices(D d, const TFromD* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = Lanes(d); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast>(N)); + } +#endif + return Load(DU(), idx); +} + +// <32bit are not part of Highway API, but used in Broadcast. This limits VLMAX +// to 2048! We could instead use vrgatherei16. +#define HWY_RVV_TABLE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(v, idx); \ + } + +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather) +#undef HWY_RVV_TABLE + +// ------------------------------ Shuffle01 + +template +HWY_API V Shuffle01(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + const auto idx = detail::Xor(detail::Iota0(D()), 1); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Shuffle2301 + +template +HWY_API V Shuffle2301(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const auto idx = detail::Xor(detail::Iota0(D()), 1); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Shuffle1032 + +template +HWY_API V Shuffle1032(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const auto idx = detail::Xor(detail::Iota0(D()), 2); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Shuffle0123 + +template +HWY_API V Shuffle0123(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const auto idx = detail::Xor(detail::Iota0(D()), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Shuffle2103 + +template +HWY_API V Shuffle2103(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + // This shuffle is a rotation. We can compute subtraction modulo 4 (number of + // lanes per 128-bit block) via bitwise ops. + const auto i = detail::Xor(detail::Iota0(D()), 1); + const auto lsb = detail::And(i, 1); + const auto borrow = Add(lsb, lsb); + const auto idx = Xor(i, borrow); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Shuffle0321 + +template +HWY_API V Shuffle0321(const V v) { + using D = DFromV; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + // This shuffle is a rotation. We can compute subtraction modulo 4 (number of + // lanes per 128-bit block) via bitwise ops. + const auto i = detail::Xor(detail::Iota0(D()), 3); + const auto lsb = detail::And(i, 1); + const auto borrow = Add(lsb, lsb); + const auto idx = Xor(i, borrow); + return TableLookupLanes(v, idx); +} + +// ------------------------------ TableLookupBytes + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +constexpr size_t LanesPerBlock(D) { + return 16 / sizeof(TFromD); +} + +template +HWY_API V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return detail::And(iota0, static_cast(~(LanesPerBlock(d) - 1))); +} + +} // namespace detail + +template +HWY_API V TableLookupBytes(const V v, const V idx) { + using D = DFromV; + const Repartition d8; + const auto offsets128 = detail::OffsetsOf128BitBlocks(d8, detail::Iota0(d8)); + const auto idx8 = Add(BitCast(d8, idx), offsets128); + return BitCast(D(), TableLookupLanes(BitCast(d8, v), idx8)); +} + +// ------------------------------ Broadcast + +template +HWY_API V Broadcast(const V v) { + const DFromV d; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(d)); + if (kLane != 0) { + idx = detail::Add(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_s_##CHAR##SEW##m##LMUL##_##CHAR##SEW(v); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ShiftLeftLanes + +// vector = f(vector, size_t) +#define HWY_RVV_SLIDE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, size_t lanes) { \ + return v##OP##_vx_##CHAR##SEW##m##LMUL(v, v, lanes); \ + } + +namespace detail { +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup) +} // namespace detail + +template +HWY_API V ShiftLeftLanes(const V v) { + using D = DFromV; + const RebindToSigned di; + const auto shifted = detail::SlideUp(v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di); + const auto idx_mod = detail::And(detail::Iota0(di), kLanesPerBlock - 1); + const auto clear = Lt(BitCast(di, idx_mod), Set(di, kLanes)); + return IfThenZeroElse(clear, shifted); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API V ShiftLeftBytes(const V v) { + using D = DFromV; + const Repartition d8; + Lanes(d8); + return BitCast(D(), ShiftLeftLanes(BitCast(d8, v))); +} + +// ------------------------------ ShiftRightLanes + +namespace detail { +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown) +} // namespace detail + +#undef HWY_RVV_SLIDE + +template +HWY_API V ShiftRightLanes(const V v) { + using D = DFromV; + const RebindToSigned di; + const auto shifted = detail::SlideDown(v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(di); + const auto idx_mod = detail::And(detail::Iota0(di), kLanesPerBlock - 1); + const auto keep = Lt(BitCast(di, idx_mod), Set(di, kLanesPerBlock - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes + +template +HWY_API V ShiftRightBytes(const V v) { + using D = DFromV; + const Repartition d8; + Lanes(d8); + return BitCast(D(), ShiftRightLanes(BitCast(d8, v))); +} + +// ------------------------------ OddEven + +template +HWY_API V OddEven(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + const auto is_even = Eq(detail::And(detail::Iota0(du), 1), Zero(du)); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ ConcatUpperLower + +template +HWY_API V ConcatUpperLower(const V hi, const V lo) { + const RebindToSigned> di; + const auto idx_half = Set(di, Lanes(di) / 2); + const auto is_lower_half = Lt(BitCast(di, detail::Iota0(di)), idx_half); + return IfThenElse(is_lower_half, lo, hi); +} + +// ------------------------------ ConcatLowerLower + +template +HWY_API V ConcatLowerLower(const V hi, const V lo) { + // Move lower half into upper + const auto hi_up = detail::SlideUp(hi, Lanes(DFromV()) / 2); + return ConcatUpperLower(hi_up, lo); +} + +// ------------------------------ ConcatUpperUpper + +template +HWY_API V ConcatUpperUpper(const V hi, const V lo) { + // Move upper half into lower + const auto lo_down = detail::SlideDown(lo, Lanes(DFromV()) / 2); + return ConcatUpperLower(hi, lo_down); +} + +// ------------------------------ ConcatLowerUpper + +template +HWY_API V ConcatLowerUpper(const V hi, const V lo) { + // Move half of both inputs to the other half + const auto hi_up = detail::SlideUp(hi, Lanes(DFromV()) / 2); + const auto lo_down = detail::SlideDown(lo, Lanes(DFromV()) / 2); + return ConcatUpperLower(hi_up, lo_down); +} + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(const V a, const V b) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const auto i = detail::Iota0(d); + const auto idx_mod = ShiftRight<1>(detail::And(i, kLanesPerBlock - 1)); + const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto is_even = Eq(detail::And(i, 1), Zero(du)); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +// ------------------------------ InterleaveUpper + +template +HWY_API V InterleaveUpper(const V a, const V b) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const auto i = detail::Iota0(d); + const auto idx_mod = ShiftRight<1>(detail::And(i, kLanesPerBlock - 1)); + const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto idx = detail::Add(idx_lower, kLanesPerBlock / 2); + const auto is_even = Eq(detail::And(i, 1), Zero(du)); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +// ------------------------------ ZipLower + +template +HWY_API VFromD>> ZipLower(const V a, const V b) { + RepartitionToWide> dw; + return BitCast(dw, InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper + +template +HWY_API VFromD>> ZipUpper(const V a, const V b) { + RepartitionToWide> dw; + return BitCast(dw, InterleaveUpper(a, b)); +} + +// ------------------------------ Combine + +// TODO(janwas): implement after LMUL ext/trunc +#if 0 + +template +HWY_API V Combine(const V a, const V b) { + using D = DFromV; + // double LMUL of inputs, then SlideUp with Lanes(). +} + +#endif + +// ================================================== REDUCE + +// vector = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, 1) v0) { \ + vsetvlmax_e##SEW##m##LMUL(); \ + return Set(HWY_RVV_D(CHAR, SEW, LMUL)(), \ + GetLane(v##OP##_vs_##CHAR##SEW##m##LMUL##_##CHAR##SEW##m1( \ + v0, v, v0))); \ + } + +// ------------------------------ SumOfLanes + +namespace detail { + +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredsum) + +} // namespace detail + +template +HWY_API V SumOfLanes(const V v) { + using T = TFromV; + const auto v0 = Zero(Simd()); // always m1 + return detail::RedSum(v, v0); +} + +// ------------------------------ MinOfLanes +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin) + +} // namespace detail + +template +HWY_API V MinOfLanes(const V v) { + using T = TFromV; + const Simd d1; // always m1 + const auto neutral = Set(d1, HighestValue()); + return detail::RedMin(v, neutral); +} + +// ------------------------------ MaxOfLanes +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax) + +} // namespace detail + +template +HWY_API V MaxOfLanes(const V v) { + using T = TFromV; + const Simd d1; // always m1 + const auto neutral = Set(d1, LowestValue()); + return detail::RedMax(v, neutral); +} + +#undef HWY_RVV_REDUCE + +// ================================================== Ops with dependencies + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + // TODO(janwas): set VL + const auto loaded = Load(d, p); + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + // Broadcast the first block + const auto idx = detail::And(detail::Iota0(d), kLanesPerBlock - 1); + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(MLEN, NAME, OP) \ + HWY_API size_t StoreMaskBits(HWY_RVV_M(MLEN) m, uint8_t* p) { \ + /* LMUL=1 is always enough */ \ + Simd d8; \ + const size_t num_bytes = (Lanes(d8) + MLEN - 1) / MLEN; \ + /* TODO(janwas): how to convert vbool* to vuint?*/ \ + /*Store(m, d8, p);*/ \ + (void)m; \ + (void)p; \ + return num_bytes; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, _, _) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ Neg + +template +HWY_API V Neg(const V v) { + return Sub(Zero(DFromV()), v); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, LMUL, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vv_##CHAR##SEW##m##LMUL(v, v); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn) + +// ------------------------------ Abs + +template +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff + +template +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template +HWY_API auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return Lt(Abs(v), Set(DFromV(), MantissaEnd>())); +} + +} // namespace detail + +template +HWY_API V Round(const V v) { + const DFromV df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc + +template +HWY_API V Trunc(const V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil + +template +HWY_API V Ceil(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kUp)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floor + +template +HWY_API V Floor(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kDown)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Iota + +template +HWY_API VFromD Iota(const D d, TFromD first) { + return Add(detail::Iota0(d), Set(d, first)); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + return Add(BitCast(d, detail::Iota0(du)), Set(d, first)); +} + +template +HWY_API VFromD Iota(const D d, TFromD first) { + const RebindToUnsigned du; + const RebindToSigned di; + return detail::Add(ConvertTo(d, BitCast(di, detail::Iota0(du))), first); +} + +// ------------------------------ MulEven + +// Using vwmul does not work for m8, so use mulh instead. Highway only provides +// MulHigh for 16-bit, so use a private wrapper. +namespace detail { + +HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu) +HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh) + +} // namespace detail + +template +HWY_API VFromD>> MulEven(const V a, const V b) { + const DFromV d; + Lanes(d); + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + const RepartitionToWide> dw; + return BitCast(dw, OddEven(detail::SlideUp(hi, 1), lo)); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_IF_FLOAT_V +#undef HWY_IF_SIGNED_V +#undef HWY_IF_UNSIGNED_V + +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08 +#undef HWY_RVV_FOREACH_16 +#undef HWY_RVV_FOREACH_32 +#undef HWY_RVV_FOREACH_64 +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI64 + +#undef HWY_RVV_RETV_ARGD +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV + +#undef HWY_RVV_T +#undef HWY_RVV_D +#undef HWY_RVV_V +#undef HWY_RVV_M + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h new file mode 100644 index 000000000000..3acb7ff1a823 --- /dev/null +++ b/third_party/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,1191 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include + +#include // std::min + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +class Mask1 { + using Raw = hwy::MakeUnsigned; + + public: + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? ~Raw(0) : 0; + return mask; + } + + Raw bits; +}; + +// ------------------------------ BitCast + +template +HWY_INLINE Vec1 BitCast(Sisd /* tag */, Vec1 v) { + static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); + T to; + CopyBytes(&v.raw, &to); + return Vec1(to); +} + +// ------------------------------ Set + +template +HWY_INLINE Vec1 Zero(Sisd /* tag */) { + return Vec1(T(0)); +} + +template +HWY_INLINE Vec1 Set(Sisd /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +template +HWY_INLINE Vec1 Undefined(Sisd d) { + return Zero(d); +} + +template +Vec1 Iota(const Sisd /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_INLINE Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(~BitCast(du, v).raw)); +} + +// ------------------------------ And + +template +HWY_INLINE Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_INLINE Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_INLINE Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(~BitCast(du, a).raw & BitCast(du, b).raw)); +} + +// ------------------------------ Or + +template +HWY_INLINE Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_INLINE Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_INLINE Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_INLINE Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Sisd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Sisd()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); +} + +// ------------------------------ Mask + +template +HWY_API Mask1 RebindMask(Sisd /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_INLINE Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopyBytes(&v.raw, &mask.bits); + return mask; +} + +template +Vec1 VecFromMask(const Mask1 mask) { + Vec1 v; + CopyBytes(&mask.bits, &v.raw); + return v; +} + +template +Vec1 VecFromMask(Sisd /* tag */, const Mask1 mask) { + Vec1 v; + CopyBytes(&mask.bits, &v.raw); + return v; +} + +// Returns mask ? yes : no. +template +HWY_INLINE Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_INLINE Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(0); +} + +template +HWY_INLINE Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(0) : no; +} + +template +HWY_INLINE Vec1 ZeroIfNegative(const Vec1 v) { + return v.raw < 0 ? Vec1(0) : v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + const Sisd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft (BroadcastSignBit) + +template +HWY_INLINE Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1(static_cast>(v.raw) << kBits); +} + +template +HWY_INLINE Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(v.raw >> kBits); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = BitCast(du, v).raw >> kBits; + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const TU upper = sign << (sizeof(TU) * 8 - 1 - kBits); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { + return Vec1(v.raw >> kBits); // unsigned, logical shift + } +#endif +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_INLINE Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1(static_cast>(v.raw) << bits); +} + +template +HWY_INLINE Vec1 ShiftRightSame(const Vec1 v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1(v.raw >> bits); +#else + if (IsSigned()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned; + const Sisd du; + const TU shifted = BitCast(du, v).raw >> bits; + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const TU upper = sign << (sizeof(TU) * 8 - 1 - bits); + return BitCast(Sisd(), Vec1(shifted | upper)); + } else { + return Vec1(v.raw >> bits); // unsigned, logical shift + } +#endif +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_INLINE Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_INLINE Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_INLINE Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_INLINE Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_INLINE Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_INLINE Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_INLINE Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_INLINE Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ Saturating addition + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_INLINE Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_INLINE Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); +} + +// Signed +HWY_INLINE Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_INLINE Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_INLINE Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_INLINE Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); +} + +// Signed +HWY_INLINE Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_INLINE Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +HWY_INLINE Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} +HWY_INLINE Vec1 AverageRound(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((a.raw + b.raw + 1) / 2)); +} + +// ------------------------------ Absolute value + +template +HWY_INLINE Vec1 Abs(const Vec1 a) { + const T i = a.raw; + return (i >= 0 || i == hwy::LimitsMin()) ? a : Vec1(-i); +} +HWY_INLINE Vec1 Abs(const Vec1 a) { + return Vec1(std::abs(a.raw)); +} +HWY_INLINE Vec1 Abs(const Vec1 a) { + return Vec1(std::abs(a.raw)); +} + +// ------------------------------ min/max + +template +HWY_INLINE Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_INLINE Vec1 Min(const Vec1 a, const Vec1 b) { + if (std::isnan(a.raw)) return b; + if (std::isnan(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_INLINE Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_INLINE Vec1 Max(const Vec1 a, const Vec1 b) { + if (std::isnan(a.raw)) return b; + if (std::isnan(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_INLINE Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_INLINE Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +template +HWY_INLINE Vec1 operator*(const Vec1 a, const Vec1 b) { + if (hwy::IsFloat()) { + return Vec1(static_cast(double(a.raw) * b.raw)); + } else if (hwy::IsSigned()) { + return Vec1(static_cast(int64_t(a.raw) * b.raw)); + } else { + return Vec1(static_cast(uint64_t(a.raw) * b.raw)); + } +} + +template +HWY_INLINE Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_INLINE Vec1 MulHigh(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((a.raw * b.raw) >> 16)); +} +HWY_INLINE Vec1 MulHigh(const Vec1 a, + const Vec1 b) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the result + // is the same but this way it is also defined. + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> 16)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +HWY_INLINE Vec1 MulEven(const Vec1 a, const Vec1 b) { + const int64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} +HWY_INLINE Vec1 MulEven(const Vec1 a, + const Vec1 b) { + const uint64_t a64 = a.raw; + return Vec1(a64 * b.raw); +} + +// Approximate reciprocal +HWY_INLINE Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// Absolute value of difference. +HWY_INLINE Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_INLINE Vec1 MulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return mul * x + add; +} + +template +HWY_INLINE Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_INLINE Vec1 MulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_INLINE Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_INLINE Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopyBytes<4>(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopyBytes<4>(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_INLINE Vec1 Sqrt(const Vec1 v) { + return Vec1(std::sqrt(v.raw)); +} +HWY_INLINE Vec1 Sqrt(const Vec1 v) { + return Vec1(std::sqrt(v.raw)); +} + +// ------------------------------ Floating-point rounding + +template +HWY_INLINE Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(0), v); + // Round to even + if ((rounded & 1) && std::abs(rounded - v.raw) == T(0.5)) { + return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); + } + return Vec1(static_cast(rounded)); +} + +// Round-to-nearest even. +HWY_INLINE Vec1 NearestInt(const Vec1 v) { + using T = float; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool signbit = std::signbit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= static_cast(LimitsMax()))) { + return Vec1(signbit ? LimitsMin() : LimitsMax()); + } + return Vec1(static_cast(v.raw)); + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast(v.raw + bias); + if (rounded == 0) return Vec1(0); + // Round to even + if ((rounded & 1) && std::abs(rounded - v.raw) == T(0.5)) { + return Vec1(rounded - (signbit ? -1 : 1)); + } + return Vec1(rounded); +} + +template +HWY_INLINE Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = static_cast(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(static_cast(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopyBytes(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopyBytes(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopyBytes(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopyBytes(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_INLINE Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_INLINE Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_INLINE Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_INLINE Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_INLINE Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_INLINE Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_INLINE Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_INLINE Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_INLINE Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_INLINE Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_INLINE Vec1 Load(Sisd /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopyBytes(aligned, &t); + return Vec1(t); +} + +template +HWY_INLINE Vec1 LoadU(Sisd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_INLINE Vec1 LoadDup128(Sisd d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template +HWY_INLINE void Store(const Vec1 v, Sisd /* tag */, + T* HWY_RESTRICT aligned) { + CopyBytes(&v.raw, aligned); +} + +template +HWY_INLINE void StoreU(const Vec1 v, Sisd d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +// ------------------------------ StoreInterleaved3 + +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, Sisd d, + uint8_t* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, + Sisd d, + uint8_t* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template +HWY_INLINE void Stream(const Vec1 v, Sisd d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template +HWY_INLINE void ScatterOffset(Vec1 v, Sisd d, T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + uint8_t* const base8 = reinterpret_cast(base) + offset.raw; + return Store(v, d, reinterpret_cast(base8)); +} + +template +HWY_INLINE void ScatterIndex(Vec1 v, Sisd d, T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +template +HWY_INLINE Vec1 GatherOffset(Sisd d, const T* base, + const Vec1 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + const uintptr_t addr = reinterpret_cast(base) + offset.raw; + return Load(d, reinterpret_cast(addr)); +} + +template +HWY_INLINE Vec1 GatherIndex(Sisd d, const T* HWY_RESTRICT base, + const Vec1 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Load(d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template +HWY_INLINE Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1(static_cast(from.raw)); +} + +template +HWY_INLINE Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Prevent ubsan errors when converting float to narrower integer/float + if (std::isinf(from.raw) || + std::fabs(from.raw) > static_cast(HighestValue())) { + return Vec1(std::signbit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_INLINE Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw = std::min(std::max(LimitsMin(), from.raw), + LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +static HWY_INLINE Vec1 PromoteTo(Sisd /* tag */, + const Vec1 v) { + uint16_t bits16; + CopyBytes<2>(&v.raw, &bits16); + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return Vec1(sign ? -subnormal : subnormal); + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + float out; + CopyBytes<4>(&bits32, &out); + return Vec1(out); +} + +static HWY_INLINE Vec1 DemoteTo(Sisd /* tag */, + const Vec1 v) { + uint32_t bits32; + CopyBytes<4>(&v.raw, &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + Vec1 out; + if (exp < -24) { + bits32 = 0; + CopyBytes<2>(&bits32, &out); + return out; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = (1 << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp)); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + CopyBytes<2>(&bits16, &out); + return out; +} + +template +HWY_INLINE Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax in FromT, so use double. + const double f = static_cast(from.raw); + if (std::isinf(from.raw) || + std::fabs(f) > static_cast(LimitsMax())) { + return Vec1(std::signbit(from.raw) ? LimitsMin() + : LimitsMax()); + } + return Vec1(static_cast(from.raw)); +} + +template +HWY_INLINE Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +HWY_INLINE Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ================================================== SWIZZLE + +// Unsupported: Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle*, +// UpperHalf - these require more than one lane and/or actual 128-bit vectors. + +template +HWY_INLINE T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_INLINE Vec1 LowerHalf(Vec1 v) { + return v; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_INLINE Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// indices in [0, sizeof(T)). +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 from) { + uint8_t in_bytes[sizeof(T)]; + uint8_t from_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); + CopyBytes(&from, &from_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[from_bytes[i]]; + } + T out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + int raw; +}; + +template +HWY_API Indices1 SetTableIndices(Sisd, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + HWY_DASSERT(idx[0] == 0); +#endif + return Indices1{idx[0]}; +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +// ------------------------------ Zip/unpack + +HWY_INLINE Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast((uint32_t(b.raw) << 8) + a.raw)); +} +HWY_INLINE Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint32_t(b.raw) << 16) + a.raw); +} +HWY_INLINE Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((uint64_t(b.raw) << 32) + a.raw); +} +HWY_INLINE Vec1 ZipLower(const Vec1 a, const Vec1 b) { + return Vec1(static_cast((int32_t(b.raw) << 8) + a.raw)); +} +HWY_INLINE Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((int32_t(b.raw) << 16) + a.raw); +} +HWY_INLINE Vec1 ZipLower(const Vec1 a, + const Vec1 b) { + return Vec1((int64_t(b.raw) << 32) + a.raw); +} + +// ------------------------------ Mask + +template +HWY_INLINE bool AllFalse(const Mask1 mask) { + return mask.bits == 0; +} + +template +HWY_INLINE bool AllTrue(const Mask1 mask) { + return mask.bits != 0; +} + +template +HWY_INLINE size_t StoreMaskBits(const Mask1 mask, uint8_t* p) { + *p = AllTrue(mask); + return 1; +} +template +HWY_INLINE size_t CountTrue(const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // Upper lanes are undefined, so result is the same independent of mask. + return v; +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, Sisd d, + T* HWY_RESTRICT aligned) { + Store(Compress(v, mask), d, aligned); + return CountTrue(mask); +} + +// ------------------------------ Reductions + +// Sum of all lanes, i.e. the only one. +template +HWY_INLINE Vec1 SumOfLanes(const Vec1 v0) { + return v0; +} +template +HWY_INLINE Vec1 MinOfLanes(const Vec1 v) { + return v; +} +template +HWY_INLINE Vec1 MaxOfLanes(const Vec1 v) { + return v; +} + +// ================================================== Operator wrapper + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/highway/hwy/ops/set_macros-inl.h new file mode 100644 index 000000000000..d5ce5e568f34 --- /dev/null +++ b/third_party/highway/hwy/ops/set_macros-inl.h @@ -0,0 +1,226 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "hwy/targets.h" + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_LANES + +#undef HWY_CAP_INTEGER64 +#undef HWY_CAP_FLOAT64 +#undef HWY_CAP_GE256 +#undef HWY_CAP_GE512 + +#undef HWY_TARGET_STR + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. +//----------------------------------------------------------------------------- +// SSE4 +#if HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR "sse2,ssse3,sse4.1" + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR "avx,avx2,f16c" +#else +#define HWY_TARGET_STR "avx,avx2,bmi,bmi2,fma,f16c" +#endif + +//----------------------------------------------------------------------------- +// AVX3 +#elif HWY_TARGET == HWY_AVX3 + +#define HWY_ALIGN alignas(64) +#define HWY_LANES(T) (64 / sizeof(T)) + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 1 + +#define HWY_NAMESPACE N_AVX3 + +// Must include AVX2 because an AVX3 test may call AVX2 functions (e.g. when +// converting to half-vectors). HWY_DISABLE_BMI2_FMA is not relevant because if +// we have AVX3, we should also have BMI2/FMA. +#define HWY_TARGET_STR \ + "avx,avx2,bmi,bmi2,fma,f16c,avx512f,avx512vl,avx512dq,avx512bw" + +//----------------------------------------------------------------------------- +// PPC8 +#elif HWY_TARGET == HWY_PPC8 + +#define HWY_ALIGN alignas(16) +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_PPC8 + +#define HWY_TARGET_STR "altivec,vsx" + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET == HWY_NEON + +#define HWY_ALIGN alignas(16) +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_ARCH_ARM_A64 +#define HWY_CAP_FLOAT64 1 +#else +#define HWY_CAP_FLOAT64 0 +#endif + +#define HWY_NAMESPACE N_NEON + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_CAP_INTEGER64 0 +#define HWY_CAP_FLOAT64 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// Arbitrary constant, not the actual lane count! Large enough that we can +// mul/div by 8 for LMUL. Value matches kMaxVectorSize, see base.h. +#define HWY_LANES(T) (4096 / sizeof(T)) + + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_RVV + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +// (rv64gcv is not a valid target) + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_LANES(T) 1 + +#define HWY_CAP_INTEGER64 1 +#define HWY_CAP_FLOAT64 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif + +// DEPRECATED +#undef HWY_GATHER_LANES +#define HWY_GATHER_LANES(T) HWY_LANES(T) diff --git a/third_party/highway/hwy/ops/shared-inl.h b/third_party/highway/hwy/ops/shared-inl.h new file mode 100644 index 000000000000..9e8560af8ed8 --- /dev/null +++ b/third_party/highway/hwy/ops/shared-inl.h @@ -0,0 +1,125 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +#include + +// Separate header because foreach_target.h re-enables its include guard. +#include "hwy/ops/set_macros-inl.h" + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// SIMD operations are implemented as overloaded functions selected using a +// "descriptor" D := Simd. T is the lane type, N a number of lanes >= 1 +// (always a power of two). Users generally do not choose N directly, but +// instead use HWY_FULL(T[, LMUL]) (the largest available size). N is not +// necessarily the actual number of lanes, which is returned by Lanes(D()). +// +// Only HWY_FULL(T) and N <= 16 / sizeof(T) are guaranteed to be available - the +// latter are useful if >128 bit vectors are unnecessary or undesirable. +template +struct Simd { + constexpr Simd() = default; + using T = Lane; + static_assert((N & (N - 1)) == 0 && N != 0, "N must be a power of two"); + + // Widening/narrowing ops change the number of lanes and/or their type. + // To initialize such vectors, we need the corresponding descriptor types: + + // PromoteTo/DemoteTo() with another lane type, but same number of lanes. + template + using Rebind = Simd; + + // MulEven() with another lane type, but same total size. + // Round up to correctly handle scalars with N=1. + template + using Repartition = + Simd; + + // LowerHalf() with the same lane type, but half the lanes. + // Round up to correctly handle scalars with N=1. + using Half = Simd; + + // Combine() with the same lane type, but twice the lanes. + using Twice = Simd; +}; + +template +using TFromD = typename D::T; + +// Descriptor for the same number of lanes as D, but with the LaneType T. +template +using Rebind = typename D::template Rebind; + +template +using RebindToSigned = Rebind>, D>; +template +using RebindToUnsigned = Rebind>, D>; +template +using RebindToFloat = Rebind>, D>; + +// Descriptor for the same total size as D, but with the LaneType T. +template +using Repartition = typename D::template Repartition; + +template +using RepartitionToWide = Repartition>, D>; +template +using RepartitionToNarrow = Repartition>, D>; + +// Descriptor for the same lane type as D, but half the lanes. +template +using Half = typename D::Half; + +// Descriptor for the same lane type as D, but twice the lanes. +template +using Twice = typename D::Twice; + +// Same as base.h macros but with a Simd argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(TFromD) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_LANE_SIZE(TFromD, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_LANE_SIZE(TFromD, bytes) + +// Compile-time-constant, (typically but not guaranteed) an upper bound on the +// number of lanes. +// Prefer instead using Lanes() and dynamic allocation, or Rebind, or +// `#if HWY_CAP_GE*`. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(Simd) { + return N; +} + +// Targets with non-constexpr Lanes define this themselves. +#if HWY_TARGET != HWY_RVV + +// (Potentially) non-constant actual size of the vector at runtime, subject to +// the limit imposed by the Simd. Useful for advancing loop counters. +template +HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd) { + return N; +} + +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/highway/hwy/ops/wasm_128-inl.h new file mode 100644 index 000000000000..6efe4273eb8f --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_128-inl.h @@ -0,0 +1,3008 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128 { + using type = __f32x4; +}; + +template +using Full128 = Simd; + +template +class Vec128 { + using Raw = typename Raw128::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Integer: FF..FF or 0. Float: MSB, all other bits undefined - see README. +template +class Mask128 { + using Raw = typename Raw128::type; + + public: + Raw raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +HWY_API __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_API __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_API __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template +HWY_API Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; + +template +HWY_API Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_i32x4_splat(0)}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{wasm_f32x4_splat(0.0f)}; +} + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{wasm_i8x16_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const uint16_t t) { + return Vec128{wasm_i16x8_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const uint32_t t) { + return Vec128{wasm_i32x4_splat(t)}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{wasm_i8x16_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{wasm_i16x8_splat(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{wasm_i32x4_splat(t)}; +} + +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{wasm_f32x4_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +} + +// ------------------------------ Saturating addition + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_add_saturate(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_add_saturate(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add_saturate(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add_saturate(a.raw, b.raw)}; +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_sub_saturate(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_sub_saturate(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub_saturate(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub_saturate(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i8x16_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i16x8_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i32x4_abs(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f32x4_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const Simd di; + const Simd du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Shift lanes by same variable #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u32x4_shr(v.raw, bits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, (0xFF << bits) & 0xFF); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const Simd di; + const Simd du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Minimum + +// Unsigned +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + alignas(16) float min[4]; + min[0] = + std::min(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); + min[1] = + std::min(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); + return Vec128{wasm_v128_load(min)}; + // TODO(janwas): new op? + // return Vec128{wasm_u64x2_min(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + alignas(16) float min[4]; + min[0] = + std::min(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); + min[1] = + std::min(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); + return Vec128{wasm_v128_load(min)}; + // TODO(janwas): new op? (also do not yet have wasm_u64x2_make) + // return Vec128{wasm_i64x2_min(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_min(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + alignas(16) float max[4]; + max[0] = + std::max(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); + max[1] = + std::max(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); + return Vec128{wasm_v128_load(max)}; + // TODO(janwas): new op? (also do not yet have wasm_u64x2_make) + // return Vec128{wasm_u64x2_max(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + alignas(16) float max[4]; + max[0] = + std::max(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); + max[1] = + std::max(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); + return Vec128{wasm_v128_load(max)}; + // TODO(janwas): new op? (also do not yet have wasm_u64x2_make) + // return Vec128{wasm_i64x2_max(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_max(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_i32x4_widen_low_u16x8(a.raw); + const auto ah = wasm_i32x4_widen_high_u16x8(a.raw); + const auto bl = wasm_i32x4_widen_low_u16x8(b.raw); + const auto bh = wasm_i32x4_widen_high_u16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_v16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto al = wasm_i32x4_widen_low_i16x8(a.raw); + const auto ah = wasm_i32x4_widen_high_i16x8(a.raw); + const auto bl = wasm_i32x4_widen_low_i16x8(b.raw); + const auto bh = wasm_i32x4_widen_high_i16x8(b.raw); + const auto l = wasm_i32x4_mul(al, bl); + const auto h = wasm_i32x4_mul(ah, bh); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_v16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + // TODO(eustas): replace, when implemented in WASM. + const auto kEvenMask = wasm_i32x4_make(0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(Simd())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i8x16_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i16x8_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i32x4_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + // TODO(eustas): replace, when implemented in WASM. + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / v; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f32x4_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec128 one = Vec128{wasm_f32x4_splat(1.0f)}; + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + // IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not + // yet have an instruction for that (f32x4.nearest is not implemented). We + // rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const Simd df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_API Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + // TODO(eustas): is it f32x4.trunc? (not implemented yet) + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_INLINE Vec128 Ceil(const Vec128 v) { + // TODO(eustas): is it f32x4.ceil? (not implemented yet) + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +// Toward -infinity, aka floor +template +HWY_INLINE Vec128 Floor(const Vec128 v) { + // TODO(eustas): is it f32x4.floor? (not implemented yet) + const Simd df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128{m.raw}; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Strict inequality + +// Signed/float > +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const Simd d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper half is less than or greater, this is the answer. + const auto m_gt = a32 < b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_v32x4_shuffle(m_gt, m_gt, 2, 2, 0, 0); + const auto lo_gt = And(m_eq, lo_in_hi); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{wasm_v32x4_shuffle(gt, gt, 3, 3, 1, 1)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_le(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec128 Not(Vec128 v) { + return Vec128{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd()), sign)); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return VecFromMask(Simd(), v < Zero(Simd())); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(Simd /* tag */, Mask128 v) { + return Vec128{v.raw}; +} + +// DEPRECATED +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(Simd(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const Simd d; + const auto zero = Zero(d); + return IfThenElse(Mask128{(v > zero).raw}, v, zero); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const Simd d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const Simd d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const Simd d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const Simd d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{wasm_v128_load(aligned)}; +} + +// Partial load. +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + Vec128 v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial store. +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); +} + +HWY_API void Store(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { + *p = wasm_f32x4_extract_lane(v.raw, 0); +} + +// StoreU == Store. +template +HWY_API void StoreU(Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter (Store) + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ Extract lane + +// Gets the single value stored in a vector/part. +template +HWY_API uint8_t GetLane(const Vec128 v) { + return wasm_i8x16_extract_lane(v.raw, 0); +} +template +HWY_API int8_t GetLane(const Vec128 v) { + return wasm_i8x16_extract_lane(v.raw, 0); +} +template +HWY_API uint16_t GetLane(const Vec128 v) { + return wasm_i16x8_extract_lane(v.raw, 0); +} +template +HWY_API int16_t GetLane(const Vec128 v) { + return wasm_i16x8_extract_lane(v.raw, 0); +} +template +HWY_API uint32_t GetLane(const Vec128 v) { + return wasm_i32x4_extract_lane(v.raw, 0); +} +template +HWY_API int32_t GetLane(const Vec128 v) { + return wasm_i32x4_extract_lane(v.raw, 0); +} +template +HWY_API float GetLane(const Vec128 v) { + return wasm_f32x4_extract_lane(v.raw, 0); +} + +// ------------------------------ Extract half + +// Returns upper/lower half of a vector. +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// These copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec128 UpperHalf(Vec128 v) { + // TODO(eustas): use swizzle? + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} +template <> +HWY_INLINE Vec128 UpperHalf(Vec128 v) { + // TODO(eustas): use swizzle? + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// ------------------------------ Shift vector by constant #bytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 0, 1, 2, 3, 4, 5, 6)}; + + case 10: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5)}; + + case 11: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4)}; + + case 12: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3)}; + + case 13: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, 2)}; + + case 14: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1)}; + + case 15: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0)}; + } + return Vec128{zero}; +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16)}; + + case 2: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 16)}; + + case 3: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 16, 16)}; + + case 4: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 16, 16, 16)}; + + case 5: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 16, 16, 16, 16)}; + + case 6: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16, 16, 16, 16)}; + + case 7: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16, 16, 16, 16)}; + + case 8: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16, 16, 16, 16)}; + + case 9: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 10: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 11: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 12: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 13: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 14: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + + case 15: + return Vec128{wasm_v8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16)}; + } + return Vec128{zero}; +} + +template +HWY_API Vec128 ShiftRightLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ Extract from 2x 128-bit at constant offset + +// Extracts 128 bits from by skipping the least-significant kBytes. +template +HWY_API Vec128 CombineShiftRightBytes(const Vec128 hi, + const Vec128 lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16)}; + + case 2: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21)}; + + case 7: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22)}; + + case 8: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23)}; + + case 9: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24)}; + + case 10: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25)}; + + case 11: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26)}; + + case 12: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27)}; + + case 13: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28)}; + + case 14: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29)}; + + case 15: + return Vec128{wasm_v8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30)}; + } + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_v16x8_shuffle( + v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_v32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_v16x8_shuffle( + v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_v32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_v32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { +// Not yet available in all engines, see +// https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md +// V8 implementation of this had a bug, fixed on 2021-04-03: +// https://chromium-review.googlesource.com/c/v8/v8/+/2822951 +#if 0 + return Vec128{wasm_v8x16_swizzle(bytes.raw, from.raw)}; +#else + alignas(16) uint8_t control[16]; + alignas(16) uint8_t input[16]; + alignas(16) uint8_t output[16]; + wasm_v128_store(control, from.raw); + wasm_v128_store(input, bytes.raw); + for (size_t i = 0; i < 16; ++i) { + output[i] = control[i] < 16 ? input[control[i]] : 0; + } + return Vec128{wasm_v128_load(output)}; +#endif +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_v64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_v64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{wasm_v64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{wasm_v32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __v128_u raw; +}; + +template +HWY_API Indices128 SetTableIndices(Full128, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = 16 / sizeof(T); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast(N)); + } +#endif + + const Full128 d8; + alignas(16) uint8_t control[16]; // = Lanes() + for (size_t idx_byte = 0; idx_byte < 16; ++idx_byte) { + const size_t idx_lane = idx_byte / sizeof(T); + const size_t mod = idx_byte % sizeof(T); + control[idx_byte] = idx[idx_lane] * sizeof(T) + mod; + } + return Indices128{Load(d8, control).raw}; +} + +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} + +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} + +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + const Full128 di; + const Full128 df; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +// ------------------------------ Zip lanes + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. + +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_v16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} + +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_v16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} + +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, + 10, 26, 11, 27, 12, 28, 13, + 29, 14, 30, 15, 31)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_v16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} + +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, + 10, 26, 11, 27, 12, 28, 13, + 29, 14, 30, 15, 31)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_v16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} + +// ------------------------------ Interleave lanes + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(const Vec128 a, const Vec128 b) { + return Vec128{ZipLower(a, b).raw}; +} +template <> +HWY_INLINE Vec128 InterleaveLower( + const Vec128 a, const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template <> +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template <> +HWY_INLINE Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template +HWY_API Vec128 InterleaveUpper(const Vec128 a, const Vec128 b) { + return Vec128{ZipUpper(a, b).raw}; +} +template <> +HWY_INLINE Vec128 InterleaveUpper( + const Vec128 a, const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template <> +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template <> +HWY_INLINE Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +// ------------------------------ Blocks + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(const Vec128 hi, const Vec128 lo) { + return Vec128{wasm_v64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec128 ConcatUpperUpper(const Vec128 hi, const Vec128 lo) { + return Vec128{wasm_v64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API Vec128 ConcatLowerUpper(const Vec128 hi, const Vec128 lo) { + return CombineShiftRightBytes<8>(hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec128 ConcatUpperLower(const Vec128 hi, const Vec128 lo) { + return Vec128{wasm_v64x2_shuffle(lo.raw, hi.raw, 0, 3)}; +} + +// ------------------------------ Odd/even lanes + +namespace { + +template +HWY_API Vec128 odd_even_impl(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const Full128 d; + const Full128 d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_API Vec128 odd_even_impl(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_API Vec128 odd_even_impl(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +// TODO(eustas): implement +// template +// HWY_API Vec128 odd_even_impl(hwy::SizeTag<8> /* tag */, +// const Vec128 a, +// const Vec128 b) + +} // namespace + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return odd_even_impl(hwy::SizeTag(), a, b); +} +template <> +HWY_INLINE Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_v32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_widen_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_widen_low_u16x8(wasm_i16x8_widen_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_widen_low_u8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_widen_low_u16x8(wasm_i16x8_widen_low_u8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_widen_low_u16x8(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_widen_low_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_widen_low_i8x16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{ + wasm_i32x4_widen_low_i16x8(wasm_i16x8_widen_low_i8x16(v.raw))}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_widen_low_i16x8(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd df, + const Vec128 v) { + // TODO(janwas): use https://github.com/WebAssembly/simd/pull/383 + alignas(16) int32_t lanes[4]; + Store(v, Simd(), lanes); + alignas(16) double lanes64[2]; + lanes64[0] = lanes[0]; + lanes64[1] = N >= 2 ? lanes[1] : 0.0; + return Load(df, lanes64); +} + +template +HWY_INLINE Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + const Simd di32; + const Simd du32; + const Simd df32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd di, + const Vec128 v) { + // TODO(janwas): use https://github.com/WebAssembly/simd/pull/383 + alignas(16) double lanes64[2]; + Store(v, Simd(), lanes64); + alignas(16) int32_t lanes[4] = {static_cast(lanes64[0])}; + if (N >= 2) lanes[1] = static_cast(lanes64[1]); + return Load(di, lanes); +} + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const Simd di; + const Simd du; + const Simd du16; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128{DemoteTo(du16, bits16).raw}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_f32x4_convert_i32x4(v.raw)}; +} +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{wasm_i32x4_trunc_saturate_f32x4(v.raw)}; +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return ConvertTo(Simd(), Round(v)); +} + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return (hi + lo); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const __i16x8 zero = wasm_i16x8_splat(0); + const Mask128 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return BitsFromMask(hwy::SizeTag<1>(), mask8); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1] | lanes[2] | lanes[3]; +} + +// Returns the lowest N bits for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); +} + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +template +HWY_API uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +template +HWY_API size_t CountTrue(hwy::SizeTag<1> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_API size_t CountTrue(hwy::SizeTag<2> tag, const Mask128 m) { + return PopCount(BitsFromMask(tag, m)); +} + +template +HWY_API size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} + +} // namespace detail + +template +HWY_INLINE size_t StoreMaskBits(const Mask128 mask, uint8_t* p) { + const uint64_t bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes(&bits, p); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Mask128 m) { + return detail::CountTrue(hwy::SizeTag(), m); +} + +// Partial vector +template +HWY_API size_t CountTrue(const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return CountTrue(Mask128{AndNot(mask, m).raw}); +} + +// Full vector, type-independent +template +HWY_API bool AllFalse(const Mask128 m) { +#if 0 + // Casting followed by wasm_i8x16_any_true results in wasm error: + // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 + const auto v8 = BitCast(Full128(), VecFromMask(Full128(), m)); + return !wasm_i8x16_any_true(v8.raw); +#else + return (wasm_i64x2_extract_lane(m.raw, 0) | + wasm_i64x2_extract_lane(m.raw, 1)) == 0; +#endif +} + +// Full vector, type-dependent +namespace detail { +template +HWY_API bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_API bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_API bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +// Partial vectors + +template +HWY_API bool AllFalse(const Mask128 m) { + // Ensure all undefined bytes are 0. + const Mask128 mask{detail::BytesAbove()}; + return AllFalse(Mask128{AndNot(mask, m).raw}); +} + +template +HWY_API bool AllTrue(const Mask128 m) { + // Ensure all undefined bytes are FF. + const Mask128 mask{detail::BytesAbove()}; + return AllTrue(Mask128{Or(mask, m).raw}); +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 Idx16x8FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, + 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 2, + 6, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, 6, 0, + 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, + 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, + 2, 8, 0, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 4, 8, + 0, 0, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 2, 4, 8, 0, + 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, + 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 0, + 0, 2, 6, 8, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 0, 0, 4, + 6, 8, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 0, 0, 2, 4, 6, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, + 0, 0, 2, 10, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, + 4, 10, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 2, 4, + 10, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 6, 10, 0, 0, + 0, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, + 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 0, + 0, 4, 6, 10, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 0, 0, 2, + 4, 6, 10, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 0, 0, 8, 10, 0, + 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, + 0, 0, 4, 8, 10, 0, 0, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, + 2, 4, 8, 10, 0, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 6, 8, + 10, 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 2, 6, 8, 10, + 0, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 4, 6, 8, 10, 0, 0, + 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 0, + 0, 2, 4, 6, 8, 10, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 12, + 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, + 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, + 0, 0, 2, 4, 12, 0, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, + 6, 12, 0, 0, 0, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 2, 6, + 12, 0, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 4, 6, 12, 0, + 0, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, + 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 0, + 0, 8, 12, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 0, 0, 2, + 8, 12, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 4, 8, 12, + 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 0, 0, 2, 4, 8, 12, 0, + 0, 0, 6, 8, 12, 0, 0, 0, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, + 2, 6, 8, 12, 0, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 4, 6, + 8, 12, 0, 0, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 2, 4, 6, 8, + 12, 0, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 10, 12, 0, 0, 0, 0, + 0, 0, 0, 10, 12, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 0, + 0, 2, 10, 12, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 0, 0, 4, + 10, 12, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 0, 0, 2, 4, 10, + 12, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 0, 0, 6, 10, 12, 0, 0, + 0, 0, 2, 6, 10, 12, 0, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, + 4, 6, 10, 12, 0, 0, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 2, 4, + 6, 10, 12, 0, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 8, 10, 12, 0, + 0, 0, 0, 0, 0, 8, 10, 12, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, + 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 0, + 0, 4, 8, 10, 12, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 0, 0, 2, + 4, 8, 10, 12, 0, 0, 6, 8, 10, 12, 0, 0, 0, 0, 0, 6, 8, 10, + 12, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 0, 0, 2, 6, 8, 10, 12, + 0, 0, 4, 6, 8, 10, 12, 0, 0, 0, 0, 4, 6, 8, 10, 12, 0, 0, + 2, 4, 6, 8, 10, 12, 0, 0, 0, 2, 4, 6, 8, 10, 12, 0, 14, 0, + 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, + 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, + 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 0, + 0, 2, 4, 14, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 0, 0, 6, + 14, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 0, 0, 2, 6, 14, + 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 0, 0, 4, 6, 14, 0, 0, + 0, 0, 2, 4, 6, 14, 0, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, + 8, 14, 0, 0, 0, 0, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 2, 8, + 14, 0, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 4, 8, 14, 0, + 0, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, + 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 0, + 0, 6, 8, 14, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 0, 0, 2, + 6, 8, 14, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 0, 0, 4, 6, 8, + 14, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 0, 0, 2, 4, 6, 8, 14, + 0, 0, 10, 14, 0, 0, 0, 0, 0, 0, 0, 10, 14, 0, 0, 0, 0, 0, + 2, 10, 14, 0, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 4, 10, + 14, 0, 0, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 2, 4, 10, 14, + 0, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 6, 10, 14, 0, 0, 0, + 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 0, + 0, 2, 6, 10, 14, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 0, 0, 4, + 6, 10, 14, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 0, 0, 2, 4, 6, + 10, 14, 0, 0, 8, 10, 14, 0, 0, 0, 0, 0, 0, 8, 10, 14, 0, 0, + 0, 0, 2, 8, 10, 14, 0, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, + 4, 8, 10, 14, 0, 0, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 2, 4, + 8, 10, 14, 0, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 6, 8, 10, 14, + 0, 0, 0, 0, 0, 6, 8, 10, 14, 0, 0, 0, 2, 6, 8, 10, 14, 0, + 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 4, 6, 8, 10, 14, 0, 0, 0, + 0, 4, 6, 8, 10, 14, 0, 0, 2, 4, 6, 8, 10, 14, 0, 0, 0, 2, + 4, 6, 8, 10, 14, 0, 12, 14, 0, 0, 0, 0, 0, 0, 0, 12, 14, 0, + 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, + 0, 0, 4, 12, 14, 0, 0, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, + 2, 4, 12, 14, 0, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 6, 12, + 14, 0, 0, 0, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 2, 6, 12, 14, + 0, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 4, 6, 12, 14, 0, 0, + 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 0, + 0, 2, 4, 6, 12, 14, 0, 0, 8, 12, 14, 0, 0, 0, 0, 0, 0, 8, + 12, 14, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 0, 0, 2, 8, 12, + 14, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 0, 0, 4, 8, 12, 14, 0, + 0, 0, 2, 4, 8, 12, 14, 0, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, + 6, 8, 12, 14, 0, 0, 0, 0, 0, 6, 8, 12, 14, 0, 0, 0, 2, 6, + 8, 12, 14, 0, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 4, 6, 8, 12, + 14, 0, 0, 0, 0, 4, 6, 8, 12, 14, 0, 0, 2, 4, 6, 8, 12, 14, + 0, 0, 0, 2, 4, 6, 8, 12, 14, 0, 10, 12, 14, 0, 0, 0, 0, 0, + 0, 10, 12, 14, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 0, 0, 2, + 10, 12, 14, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 0, 0, 4, 10, 12, + 14, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 0, 0, 2, 4, 10, 12, 14, + 0, 0, 6, 10, 12, 14, 0, 0, 0, 0, 0, 6, 10, 12, 14, 0, 0, 0, + 2, 6, 10, 12, 14, 0, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 4, 6, + 10, 12, 14, 0, 0, 0, 0, 4, 6, 10, 12, 14, 0, 0, 2, 4, 6, 10, + 12, 14, 0, 0, 0, 2, 4, 6, 10, 12, 14, 0, 8, 10, 12, 14, 0, 0, + 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 0, + 0, 2, 8, 10, 12, 14, 0, 0, 4, 8, 10, 12, 14, 0, 0, 0, 0, 4, + 8, 10, 12, 14, 0, 0, 2, 4, 8, 10, 12, 14, 0, 0, 0, 2, 4, 8, + 10, 12, 14, 0, 6, 8, 10, 12, 14, 0, 0, 0, 0, 6, 8, 10, 12, 14, + 0, 0, 2, 6, 8, 10, 12, 14, 0, 0, 0, 2, 6, 8, 10, 12, 14, 0, + 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, + 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 Idx32x4FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[16 * 16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, // + 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, // + 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 + +template +HWY_INLINE Vec128 Idx64x2FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[4 * 16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +#endif + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_API Vec128 Compress(hwy::SizeTag<2> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx16x8FromBits(mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_API Vec128 Compress(hwy::SizeTag<4> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx32x4FromBits(mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +#if HWY_CAP_INTEGER64 || HWY_CAP_FLOAT64 + +template +HWY_API Vec128 Compress(hwy::SizeTag<8> /*tag*/, + Vec128 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx64x2FromBits(mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +#endif + +} // namespace detail + +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + return detail::Compress(hwy::SizeTag(), v, + detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT aligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + Store(detail::Compress(hwy::SizeTag(), v, mask_bits), d, aligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +// 128 bits +HWY_API void StoreInterleaved3(const Vec128 a, const Vec128 b, + const Vec128 c, Full128 d, + uint8_t* HWY_RESTRICT unaligned) { + const auto k5 = Set(d, 5); + const auto k6 = Set(d, 6); + + // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = Load(d, tbl_r0); + const auto shuf_g0 = Load(d, tbl_g0); // cannot reuse r0 due to 5 in MSB + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(a, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(b, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(c, shuf_b0); // .4..3..2..1..0.. + const auto int0 = r0 | g0 | b0; + StoreU(int0, d, unaligned + 0 * 16); + + // Second vector: g10,r10, bgr[9:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. + const auto r1 = TableLookupBytes(a, shuf_r1); + const auto g1 = TableLookupBytes(b, shuf_g1); + const auto b1 = TableLookupBytes(c, shuf_b1); + const auto int1 = r1 | g1 | b1; + StoreU(int1, d, unaligned + 1 * 16); + + // Third vector: bgr[15:11], b10 + const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. + const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. + const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A + const auto r2 = TableLookupBytes(a, shuf_r2); + const auto g2 = TableLookupBytes(b, shuf_g2); + const auto b2 = TableLookupBytes(c, shuf_b2); + const auto int2 = r2 | g2 | b2; + StoreU(int2, d, unaligned + 2 * 16); +} + +// 64 bits +HWY_API void StoreInterleaved3(const Vec128 a, + const Vec128 b, + const Vec128 c, Simd d, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and first result. + const Full128 d_full; + const auto k5 = Set(d_full, 5); + const auto k6 = Set(d_full, 6); + + const Vec128 full_a{a.raw}; + const Vec128 full_b{b.raw}; + const Vec128 full_c{c.raw}; + + // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = Load(d_full, tbl_r0); + const auto shuf_g0 = Load(d_full, tbl_g0); // cannot reuse r0 due to 5 in MSB + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(full_a, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(full_b, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(full_c, shuf_b0); // .4..3..2..1..0.. + const auto int0 = r0 | g0 | b0; + StoreU(int0, d_full, unaligned + 0 * 16); + + // Second (HALF) vector: bgr[7:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // ..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // .7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // 7..6..5. + const auto r1 = TableLookupBytes(full_a, shuf_r1); + const auto g1 = TableLookupBytes(full_b, shuf_g1); + const auto b1 = TableLookupBytes(full_c, shuf_b1); + const decltype(Zero(d)) int1{(r1 | g1 | b1).raw}; + StoreU(int1, d, unaligned + 1 * 16); +} + +// <= 32 bits +template +HWY_API void StoreInterleaved3(const Vec128 a, + const Vec128 b, + const Vec128 c, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 d_full; + + const Vec128 full_a{a.raw}; + const Vec128 full_b{b.raw}; + const Vec128 full_c{c.raw}; + + // Shuffle (a,b,c) vector bytes to bgr[3:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, // + 0x80, 0x80, 0x80, 0x80}; + const auto shuf_r0 = Load(d_full, tbl_r0); + const auto shuf_g0 = CombineShiftRightBytes<15>(shuf_r0, shuf_r0); + const auto shuf_b0 = CombineShiftRightBytes<14>(shuf_r0, shuf_r0); + const auto r0 = TableLookupBytes(full_a, shuf_r0); // ......3..2..1..0 + const auto g0 = TableLookupBytes(full_b, shuf_g0); // .....3..2..1..0. + const auto b0 = TableLookupBytes(full_c, shuf_b0); // ....3..2..1..0.. + const auto int0 = r0 | g0 | b0; + alignas(16) uint8_t buf[16]; + StoreU(int0, d_full, buf); + CopyBytes(buf, unaligned); +} + +// ------------------------------ StoreInterleaved4 + +// 128 bits +HWY_API void StoreInterleaved4(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + const Vec128 v3, Full128 d, + uint8_t* HWY_RESTRICT unaligned) { + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto ba8 = ZipUpper(v0, v1); + const auto dc8 = ZipUpper(v2, v3); + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + const auto dcba_4 = ZipUpper(ba0, dc0); // d..a7 d..a4 + const auto dcba_8 = ZipLower(ba8, dc8); // d..aB d..a8 + const auto dcba_C = ZipUpper(ba8, dc8); // d..aF d..aC + StoreU(BitCast(d, dcba_0), d, unaligned + 0 * 16); + StoreU(BitCast(d, dcba_4), d, unaligned + 1 * 16); + StoreU(BitCast(d, dcba_8), d, unaligned + 2 * 16); + StoreU(BitCast(d, dcba_C), d, unaligned + 3 * 16); +} + +// 64 bits +HWY_API void StoreInterleaved4(const Vec128 in0, + const Vec128 in1, + const Vec128 in2, + const Vec128 in3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Vec128 v0{in0.raw}; + const Vec128 v1{in1.raw}; + const Vec128 v2{in2.raw}; + const Vec128 v3{in3.raw}; + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + const auto dcba_4 = ZipUpper(ba0, dc0); // d..a7 d..a4 + const Full128 d_full; + StoreU(BitCast(d_full, dcba_0), d_full, unaligned + 0 * 16); + StoreU(BitCast(d_full, dcba_4), d_full, unaligned + 1 * 16); +} + +// <= 32 bits +template +HWY_API void StoreInterleaved4(const Vec128 in0, + const Vec128 in1, + const Vec128 in2, + const Vec128 in3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Vec128 v0{in0.raw}; + const Vec128 v1{in1.raw}; + const Vec128 v2{in2.raw}; + const Vec128 v3{in3.raw}; + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b3 a3 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d3 c3 .. d0 c0 + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + alignas(16) uint8_t buf[16]; + const Full128 d_full; + StoreU(BitCast(d_full, dcba_0), d_full, buf); + CopyBytes<4 * N>(buf, unaligned); +} + +// ------------------------------ Reductions + +namespace detail { + +// For u32/i32/f32. +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// For u64/i64/f64. +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the sum in each lane. +template +HWY_API Vec128 SumOfLanes(const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ================================================== Operator wrapper + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 000000000000..2b7ba2f2aa96 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,3694 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +#include +#include // SSE4 +#include +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +// Clang 3.9 generates VINSERTF128 instead of the desired VBROADCASTF128, +// which would free up port5. However, inline assembly isn't supported on +// MSVC, results in incorrect output on GCC 8.3, and raises "invalid output size +// for constraint" errors on Clang (https://gcc.godbolt.org/z/-Jt_-F), hence we +// disable it. +#ifndef HWY_LOADDUP_ASM +#define HWY_LOADDUP_ASM 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +template +using Full128 = Simd; + +template +class Vec128 { + using Raw = typename Raw128::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Integer: FF..FF or 0. Float: MSB, all other bits undefined - see README. +template +class Mask128 { + using Raw = typename Raw128::type; + + public: + Raw raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +HWY_API __m128i BitCastToInteger(__m128i v) { return v; } +HWY_API __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_API __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +template +HWY_API Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_API Vec128 BitCastFromByte(Simd /* tag */, + Vec128 v) { + return Vec128{BitCastFromInteger128()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 BitCast(Simd d, + Vec128 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector/part. +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_si128()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(Simd /* tag */) { + return Vec128{_mm_setzero_pd()}; +} + +// Returns a vector/part with all lanes set to "t". +template +HWY_API Vec128 Set(Simd /* tag */, const uint8_t t) { + return Vec128{_mm_set1_epi8(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const uint16_t t) { + return Vec128{_mm_set1_epi16(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const uint32_t t) { + return Vec128{_mm_set1_epi32(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const uint64_t t) { + return Vec128{_mm_set1_epi64x(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int8_t t) { + return Vec128{_mm_set1_epi8(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int16_t t) { + return Vec128{_mm_set1_epi16(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int32_t t) { + return Vec128{_mm_set1_epi32(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const int64_t t) { + return Vec128{_mm_set1_epi64x(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const float t) { + return Vec128{_mm_set1_ps(t)}; +} +template +HWY_API Vec128 Set(Simd /* tag */, const double t) { + return Vec128{_mm_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec128 Undefined(Simd /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec128{_mm_undefined_si128()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_ps()}; +} +template +HWY_API Vec128 Undefined(Simd /* tag */) { + return Vec128{_mm_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. +template +HWY_API uint8_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw) & 0xFF; +} +template +HWY_API int8_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw) & 0xFF; +} +template +HWY_API uint16_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw) & 0xFFFF; +} +template +HWY_API int16_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw) & 0xFFFF; +} +template +HWY_API uint32_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw); +} +template +HWY_API int32_t GetLane(const Vec128 v) { + return _mm_cvtsi128_si32(v.raw); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API uint64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) uint64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template +HWY_API int64_t GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + alignas(16) int64_t lanes[2]; + Store(v, Simd(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_si128(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{_mm_andnot_si128(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_si128(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not + +template +HWY_API Vec128 Not(const Vec128 v) { + using TU = MakeUnsigned; +#if HWY_TARGET == HWY_AVX3 + const __m128i vu = BitCast(Simd(), v).raw; + return BitCast(Simd(), + Vec128{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(Simd(), Vec128{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Simd d; + const auto msb = SignBit(d); + +#if HWY_TARGET == HWY_AVX3 + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m128i out = _mm_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { +#if HWY_TARGET == HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(Simd()), sign)); +#endif +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +template +HWY_API Vec128 VecFromMask(const Simd /* tag */, + const Mask128 v) { + return Vec128{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(const Mask128 mask, + const Vec128 yes, + const Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(Simd(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(Simd(), mask), no); +} + +template +HWY_API Vec128 ZeroIfNegative(Vec128 v) { + const Simd d; + return IfThenElse(MaskFromVec(v), Zero(d), v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask128 RebindMask(Simd /*tag*/, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(Simd(), VecFromMask(d, m))); +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi8(b.raw, a.raw)}; +} +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi16(b.raw, a.raw)}; +} +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi32(b.raw, a.raw)}; +} +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmplt_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmplt_pd(a.raw, b.raw)}; +} + +// Signed/float > +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_SSE4 // SSE4.1 + // If the upper half is less than or greater, this is the answer. + const __m128i m_gt = _mm_cmpgt_epi32(a.raw, b.raw); + + // Otherwise, the lower half decides. + const __m128i m_eq = _mm_cmpeq_epi32(a.raw, b.raw); + const __m128i lo_in_hi = _mm_shuffle_epi32(m_gt, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i lo_gt = _mm_and_si128(m_eq, lo_in_hi); + + const __m128i gt = _mm_or_si128(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{_mm_shuffle_epi32(gt, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmple_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmple_pd(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ Saturating addition + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (reaches breakpoint) + const auto zero = Zero(Simd()); + return Vec128{_mm_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi16(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi32(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(Simd(), mask); +} +template +HWY_API Vec128 Abs(const Vec128 v) { + const Vec128 mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(Simd(), mask); +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const Simd di; + const Simd du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return VecFromMask(v < Zero(Simd())); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 + return VecFromMask(v < Zero(Simd())); +#else + // Efficient Gt() requires SSE4.2 but we only have SSE4.1. BLENDVPD requires + // two constants and domain crossing. 32-bit compare only requires Zero() + // plus a shuffle to replicate the upper 32 bits. + const Simd d32; + const auto sign = BitCast(d32, v) < Zero(d32); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srai_epi64(v.raw, kBits)}; +#else + const Simd di; + const Simd du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, (0xFF << bits) & 0xFF); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const Simd d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Simd di; + const Simd du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const Simd di; + const Simd du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(Simd())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Zero(Simd()) - v; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_sd(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +// Absolute value of difference. +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(const Vec128 mul, + const Vec128 x, + const Vec128 add) { +#if HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(const Vec128 mul, + const Vec128 x, + const Vec128 sub) { +#if HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(const Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + const Simd du; + const Simd di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + const Simd du; + const Simd di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + + +// ================================================== MEMORY + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API Vec128 Load(Full128 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{_mm_load_si128(reinterpret_cast(aligned))}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +HWY_API Vec128 Load(Full128 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API Vec128 LoadU(Full128 /* tag */, const T* HWY_RESTRICT p) { + return Vec128{_mm_loadu_si128(reinterpret_cast(p))}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +HWY_API Vec128 LoadU(Full128 /* tag */, + const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API Vec128 Load(Simd /* tag */, + const T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); + return Vec128{v}; +#else + return Vec128{ + _mm_loadl_epi64(reinterpret_cast(p))}; +#endif +} + +HWY_API Vec128 Load(Simd /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); + return Vec128{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec128{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +HWY_API Vec128 Load(Simd /* tag */, + const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); + return Vec128{v}; +#else + return Vec128{_mm_load_sd(p)}; +#endif +} + +HWY_API Vec128 Load(Simd /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); + return Vec128{v}; +#else + return Vec128{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API Vec128 Load(Simd /* tag */, const T* HWY_RESTRICT p) { + constexpr size_t kSize = sizeof(T) * N; +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes(p, &v); + return Vec128{v}; +#else + // TODO(janwas): load_ss? + int32_t bits; + CopyBytes(p, &bits); + return Vec128{_mm_cvtsi32_si128(bits)}; +#endif +} + +// For < 128 bit, LoadU == Load. +template +HWY_API Vec128 LoadU(Simd d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API Vec128 LoadDup128(Simd d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec128 v, Full128 /* tag */, T* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec128 v, Full128 /* tag */, + double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); +#else + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(Vec128 v, Simd /* tag */, T* HWY_RESTRICT p) { + CopyBytes(&v, p); +} +HWY_API void Store(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(const Vec128 v, Simd d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(Vec128 v, Simd /* tag */, + T* HWY_RESTRICT aligned) { + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(const Vec128 v, Simd /* tag */, + double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(google-runtime-int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET == HWY_AVX3 +namespace detail { + +template +HWY_API void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_epi32(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_API void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_API void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128 v, + Simd /* tag */, T* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_epi64(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, 8); + } +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template +HWY_INLINE void ScatterOffset(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 4) { + _mm_i32scatter_ps(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(Vec128 v, Simd /* tag */, + float* HWY_RESTRICT base, + const Vec128 index) { + if (N == 4) { + _mm_i32scatter_ps(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, 4); + } +} + +template +HWY_INLINE void ScatterOffset(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 offset) { + if (N == 2) { + _mm_i64scatter_pd(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, offset.raw, v.raw, 1); + } +} +template +HWY_INLINE void ScatterIndex(Vec128 v, Simd /* tag */, + double* HWY_RESTRICT base, + const Vec128 index) { + if (N == 2) { + _mm_i64scatter_pd(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, 8); + } +} +#else // HWY_TARGET == HWY_AVX3 + +template +HWY_API void ScatterOffset(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec128 v, Simd d, T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET == HWY_SSE4 + +template +HWY_API Vec128 GatherOffset(const Simd d, + const T* HWY_RESTRICT base, + const Vec128 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template +HWY_API Vec128 GatherIndex(const Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +#else + +namespace detail { + +template +HWY_API Vec128 GatherOffset(hwy::SizeTag<4> /* tag */, Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(hwy::SizeTag<4> /* tag */, Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_API Vec128 GatherOffset(hwy::SizeTag<8> /* tag */, Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(hwy::SizeTag<8> /* tag */, Simd /* d */, + const T* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec128 GatherOffset(Simd d, const T* HWY_RESTRICT base, + const Vec128 offset) { + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec128 GatherIndex(Simd d, const T* HWY_RESTRICT base, + const Vec128 index) { + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i32gather_ps(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const float* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i32gather_ps(base, index.raw, 4)}; +} + +template +HWY_API Vec128 GatherOffset(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 offset) { + return Vec128{_mm_i64gather_pd(base, offset.raw, 1)}; +} +template +HWY_API Vec128 GatherIndex(Simd /* tag */, + const double* HWY_RESTRICT base, + const Vec128 index) { + return Vec128{_mm_i64gather_pd(base, index.raw, 8)}; +} + +#endif // HWY_TARGET != HWY_SSE4 + +// ================================================== SWIZZLE + +// ------------------------------ Extract half + +// Returns upper/lower half of a vector. +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// These copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API Vec128 UpperHalf(Vec128 v) { + return Vec128{_mm_unpackhi_epi64(v.raw, v.raw)}; +} +template <> +HWY_INLINE Vec128 UpperHalf(Vec128 v) { + return Vec128{_mm_movehl_ps(v.raw, v.raw)}; +} +template <> +HWY_INLINE Vec128 UpperHalf(Vec128 v) { + return Vec128{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// ------------------------------ Shift vector by constant #bytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec128 ShiftLeftBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128{_mm_slli_si128(v.raw, kBytes)}; +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec128 ShiftRightBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128{_mm_srli_si128(v.raw, kBytes)}; +} + +template +HWY_API Vec128 ShiftRightLanes(const Vec128 v) { + const Simd d8; + const Simd d; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ Extract from 2x 128-bit at constant offset + +// Extracts 128 bits from by skipping the least-significant kBytes. +template +HWY_API Vec128 CombineShiftRightBytes(const Vec128 hi, + const Vec128 lo) { + const Full128 d8; + const Vec128 extracted_bytes{ + _mm_alignr_epi8(BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}; + return BitCast(Full128(), extracted_bytes); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128{_mm_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{_mm_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template +HWY_API Indices128 SetTableIndices(Full128, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = 16 / sizeof(T); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast(N)); + } +#endif + + const Full128 d8; + alignas(16) uint8_t control[16]; + for (size_t idx_byte = 0; idx_byte < 16; ++idx_byte) { + const size_t idx_lane = idx_byte / sizeof(T); + const size_t mod = idx_byte % sizeof(T); + control[idx_byte] = static_cast(idx[idx_lane] * sizeof(T) + mod); + } + return Indices128{Load(d8, control).raw}; +} + +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} +HWY_API Vec128 TableLookupLanes(const Vec128 v, + const Indices128 idx) { + const Full128 di; + const Full128 df; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +// ------------------------------ Interleave lanes + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec128 InterleaveUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +// ------------------------------ Zip lanes + +// Same as interleave_*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. + +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipLower(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 ZipUpper(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_unpackhi_epi32(a.raw, b.raw)}; +} + +// ------------------------------ Blocks + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec128 ConcatLowerLower(const Vec128 hi, const Vec128 lo) { + const Full128 d64; + return BitCast(Full128(), + InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec128 ConcatUpperUpper(const Vec128 hi, const Vec128 lo) { + const Full128 d64; + return BitCast(Full128(), + InterleaveUpper(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API Vec128 ConcatLowerUpper(const Vec128 hi, const Vec128 lo) { + return CombineShiftRightBytes<8>(hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec128 ConcatUpperLower(const Vec128 hi, const Vec128 lo) { + return Vec128{_mm_blend_epi16(hi.raw, lo.raw, 0x0F)}; +} +template <> +HWY_INLINE Vec128 ConcatUpperLower(const Vec128 hi, + const Vec128 lo) { + return Vec128{_mm_blend_ps(hi.raw, lo.raw, 3)}; +} +template <> +HWY_INLINE Vec128 ConcatUpperLower(const Vec128 hi, + const Vec128 lo) { + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +} + +// ------------------------------ OddEven (IfThenElse) + +namespace detail { + +template +HWY_API Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const Full128 d; + const Full128 d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_API Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x55)}; +} +template +HWY_API Vec128 OddEven(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x33)}; +} +template +HWY_API Vec128 OddEven(hwy::SizeTag<8> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{_mm_blend_epi16(a.raw, b.raw, 0x0F)}; +} + +} // namespace detail + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template <> +HWY_INLINE Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +} + +template <> +HWY_INLINE Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_blend_pd(a.raw, b.raw, 1)}; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +#if HWY_TARGET != HWY_AVX3 +namespace detail { + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_API Vec128, N> Pow2(const Vec128 v) { + const Simd d; + const Repartition df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(zero, upper); + const auto f1 = ZipUpper(zero, upper); + // See comment below. + const Vec128 bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; + const Vec128 bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + return Vec128, N>{_mm_packus_epi32(bits0.raw, bits1.raw)}; +} + +// Same, for 32-bit shifts. +template +HWY_API Vec128, N> Pow2(const Vec128 v) { + const Simd d; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. cvt instead of cvtt should be equivalent, but avoids test + // failure under GCC 10.2.1. + return Vec128, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; +} + +} // namespace detail +#endif // HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#else + return v * detail::Pow2(bits); +#endif +} + +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + return v * detail::Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} + +template +HWY_API Vec128 operator<<(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const __m128i out0 = _mm_sll_epi64(v.raw, bits.raw); + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const __m128i out1 = _mm_sll_epi64(v.raw, bits1); + return Vec128{_mm_blend_epi16(out0, out1, 0xF0)}; +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 operator<<(const Vec128 v, const Vec128 bits) { + const Simd di; + const Simd, N> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for the SSE4 target or 16-bit. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#else + const Simd d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const Simd d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = MulEven(in31, mul31); // 3 ? 1 ? + // OddEven is defined below, avoid the dependency. + const Vec128 out{_mm_blend_epi16(out31.raw, out20.raw, 0x33)}; + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const __m128i out0 = _mm_srl_epi64(v.raw, bits.raw); + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const __m128i out1 = _mm_srl_epi64(v.raw, bits1); + return Vec128{_mm_blend_epi16(out0, out1, 0xF0)}; +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} + +#if HWY_TARGET != HWY_AVX3 +namespace detail { + +// Also used in x86_256-inl.h. +template +HWY_API V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} + +} // namespace detail +#endif // HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +template +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd(), v, bits); +#endif +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu8_epi16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu8_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu8_epi16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu8_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu16_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu16_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi8_epi16(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi8_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi16_epi32(v.raw)}; +} +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_epi64(v.raw)}; +} + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if defined(MEMORY_SANITIZER) && \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSE4 + const Simd di32; + const Simd du32; + const Simd df32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + return Vec128{_mm_cvtph_ps(v.raw)}; +#endif +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API Vec128 PromoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packus_epi32(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i u16 = _mm_packus_epi32(v.raw, v.raw); + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_packs_epi16(v.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { +#if HWY_TARGET == HWY_SSE4 + const Simd di; + const Simd du; + const Simd du16; + const Simd df16; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + return Vec128{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// For well-defined float->int demotion in all x86_*-inl.h. + +template +HWY_API auto ClampF64ToI32Max(Simd d, decltype(Zero(d)) v) + -> decltype(Zero(d)) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +// For ConvertTo float->int of same size, clamping before conversion would +// change the result because the max integer value is not exactly representable. +// Instead detect the overflow result after conversion and fix it. +template , N>> +HWY_API auto FixConversionOverflow(Simd di, + decltype(Zero(DF())) original, + decltype(Zero(di).raw) converted_raw) + -> decltype(Zero(di)) { + // Combinations of original and output sign: + // --: normal <0 or -huge_val to 80..00: OK + // -+: -0 to 0 : OK + // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF + // ++: normal >0 : OK + const auto converted = decltype(Zero(di)){converted_raw}; + const auto sign_wrong = AndNot(BitCast(di, original), converted); + return BitCast(di, Xor(converted, BroadcastSignBit(sign_wrong))); +} + +} // namespace detail + +template +HWY_INLINE Vec128 DemoteTo(Simd /* tag */, + const Vec128 v) { + const auto clamped = detail::ClampF64ToI32Max(Simd(), v); + return Vec128{_mm_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const Simd d32; + const Simd d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +} + +// ------------------------------ Convert integer <=> floating point + +template +HWY_API Vec128 ConvertTo(Simd /* tag */, + const Vec128 v) { + return Vec128{_mm_cvtepi32_ps(v.raw)}; +} + +template +HWY_API Vec128 ConvertTo(Simd dd, + const Vec128 v) { +#if HWY_TARGET == HWY_AVX3 + (void)dd; + return Vec128{_mm_cvtepi64_pd(v.raw)}; +#else + alignas(16) int64_t lanes_i[2]; + Store(v, Simd(), lanes_i); + alignas(16) double lanes_d[2]; + for (size_t i = 0; i < N; ++i) { + lanes_d[i] = static_cast(lanes_i[i]); + } + return Load(dd, lanes_d); +#endif +} + +// Truncates (rounds toward zero). +template +HWY_API Vec128 ConvertTo(const Simd di, + const Vec128 v) { + return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); +} + +template +HWY_API Vec128 ConvertTo(Simd di, + const Vec128 v) { +#if HWY_TARGET == HWY_AVX3 + return detail::FixConversionOverflow(di, v, _mm_cvttpd_epi64(v.raw)); +#else + alignas(16) double lanes_d[2]; + Store(v, Simd(), lanes_d); + alignas(16) int64_t lanes_i[2]; + for (size_t i = 0; i < N; ++i) { + if (lanes_d[i] >= static_cast(LimitsMax())) { + lanes_i[i] = LimitsMax(); + } else if (lanes_d[i] <= static_cast(LimitsMin())) { + lanes_i[i] = LimitsMin(); + } else { + lanes_i[i] = static_cast(lanes_d[i]); + } + } + return Load(di, lanes_i); +#endif +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const Simd di; + return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); +} + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec128 Iota(const Simd d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Mask + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int bits) { + return static_cast(static_cast(bits)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128 mask) { + const Simd d; + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128 mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_ps(sign_bits.raw)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128 mask) { + const Simd d; + const Simd df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_pd(sign_bits.raw)); +} + +// Returns the lowest N of the _mm_movemask* bits. +template +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); +} + +template +HWY_API uint64_t BitsFromMask(const Mask128 mask) { + return OnlyActive(BitsFromMask(hwy::SizeTag(), mask)); +} + +} // namespace detail + +template +HWY_INLINE size_t StoreMaskBits(const Mask128 mask, uint8_t* p) { + const uint64_t bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7)/8; + CopyBytes(&bits, p); + return kNumBytes; +} + +template +HWY_API bool AllFalse(const Mask128 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Mask128 mask) { + constexpr uint64_t kAllBits = + detail::OnlyActive((1ull << (16 / sizeof(T))) - 1); + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Mask128 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 Idx16x8FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, + 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 2, + 6, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, 6, 0, + 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, + 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, + 2, 8, 0, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 4, 8, + 0, 0, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 2, 4, 8, 0, + 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, + 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 0, + 0, 2, 6, 8, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 0, 0, 4, + 6, 8, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 0, 0, 2, 4, 6, + 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, + 0, 0, 2, 10, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, + 4, 10, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 2, 4, + 10, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 6, 10, 0, 0, + 0, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, + 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 0, + 0, 4, 6, 10, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 0, 0, 2, + 4, 6, 10, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 0, 0, 8, 10, 0, + 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, + 0, 0, 4, 8, 10, 0, 0, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, + 2, 4, 8, 10, 0, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 6, 8, + 10, 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 2, 6, 8, 10, + 0, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 4, 6, 8, 10, 0, 0, + 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 0, + 0, 2, 4, 6, 8, 10, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 12, + 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, + 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, + 0, 0, 2, 4, 12, 0, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, + 6, 12, 0, 0, 0, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 2, 6, + 12, 0, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 4, 6, 12, 0, + 0, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, + 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 0, + 0, 8, 12, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 0, 0, 2, + 8, 12, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 4, 8, 12, + 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 0, 0, 2, 4, 8, 12, 0, + 0, 0, 6, 8, 12, 0, 0, 0, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, + 2, 6, 8, 12, 0, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 4, 6, + 8, 12, 0, 0, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 2, 4, 6, 8, + 12, 0, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 10, 12, 0, 0, 0, 0, + 0, 0, 0, 10, 12, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 0, + 0, 2, 10, 12, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 0, 0, 4, + 10, 12, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 0, 0, 2, 4, 10, + 12, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 0, 0, 6, 10, 12, 0, 0, + 0, 0, 2, 6, 10, 12, 0, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, + 4, 6, 10, 12, 0, 0, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 2, 4, + 6, 10, 12, 0, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 8, 10, 12, 0, + 0, 0, 0, 0, 0, 8, 10, 12, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, + 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 0, + 0, 4, 8, 10, 12, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 0, 0, 2, + 4, 8, 10, 12, 0, 0, 6, 8, 10, 12, 0, 0, 0, 0, 0, 6, 8, 10, + 12, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 0, 0, 2, 6, 8, 10, 12, + 0, 0, 4, 6, 8, 10, 12, 0, 0, 0, 0, 4, 6, 8, 10, 12, 0, 0, + 2, 4, 6, 8, 10, 12, 0, 0, 0, 2, 4, 6, 8, 10, 12, 0, 14, 0, + 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, + 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, + 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 0, + 0, 2, 4, 14, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 0, 0, 6, + 14, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 0, 0, 2, 6, 14, + 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 0, 0, 4, 6, 14, 0, 0, + 0, 0, 2, 4, 6, 14, 0, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, + 8, 14, 0, 0, 0, 0, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 2, 8, + 14, 0, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 4, 8, 14, 0, + 0, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, + 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 0, + 0, 6, 8, 14, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 0, 0, 2, + 6, 8, 14, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 0, 0, 4, 6, 8, + 14, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 0, 0, 2, 4, 6, 8, 14, + 0, 0, 10, 14, 0, 0, 0, 0, 0, 0, 0, 10, 14, 0, 0, 0, 0, 0, + 2, 10, 14, 0, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 4, 10, + 14, 0, 0, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 2, 4, 10, 14, + 0, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 6, 10, 14, 0, 0, 0, + 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 0, + 0, 2, 6, 10, 14, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 0, 0, 4, + 6, 10, 14, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 0, 0, 2, 4, 6, + 10, 14, 0, 0, 8, 10, 14, 0, 0, 0, 0, 0, 0, 8, 10, 14, 0, 0, + 0, 0, 2, 8, 10, 14, 0, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, + 4, 8, 10, 14, 0, 0, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 2, 4, + 8, 10, 14, 0, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 6, 8, 10, 14, + 0, 0, 0, 0, 0, 6, 8, 10, 14, 0, 0, 0, 2, 6, 8, 10, 14, 0, + 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 4, 6, 8, 10, 14, 0, 0, 0, + 0, 4, 6, 8, 10, 14, 0, 0, 2, 4, 6, 8, 10, 14, 0, 0, 0, 2, + 4, 6, 8, 10, 14, 0, 12, 14, 0, 0, 0, 0, 0, 0, 0, 12, 14, 0, + 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, + 0, 0, 4, 12, 14, 0, 0, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, + 2, 4, 12, 14, 0, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 6, 12, + 14, 0, 0, 0, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 2, 6, 12, 14, + 0, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 4, 6, 12, 14, 0, 0, + 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 0, + 0, 2, 4, 6, 12, 14, 0, 0, 8, 12, 14, 0, 0, 0, 0, 0, 0, 8, + 12, 14, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 0, 0, 2, 8, 12, + 14, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 0, 0, 4, 8, 12, 14, 0, + 0, 0, 2, 4, 8, 12, 14, 0, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, + 6, 8, 12, 14, 0, 0, 0, 0, 0, 6, 8, 12, 14, 0, 0, 0, 2, 6, + 8, 12, 14, 0, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 4, 6, 8, 12, + 14, 0, 0, 0, 0, 4, 6, 8, 12, 14, 0, 0, 2, 4, 6, 8, 12, 14, + 0, 0, 0, 2, 4, 6, 8, 12, 14, 0, 10, 12, 14, 0, 0, 0, 0, 0, + 0, 10, 12, 14, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 0, 0, 2, + 10, 12, 14, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 0, 0, 4, 10, 12, + 14, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 0, 0, 2, 4, 10, 12, 14, + 0, 0, 6, 10, 12, 14, 0, 0, 0, 0, 0, 6, 10, 12, 14, 0, 0, 0, + 2, 6, 10, 12, 14, 0, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 4, 6, + 10, 12, 14, 0, 0, 0, 0, 4, 6, 10, 12, 14, 0, 0, 2, 4, 6, 10, + 12, 14, 0, 0, 0, 2, 4, 6, 10, 12, 14, 0, 8, 10, 12, 14, 0, 0, + 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 0, + 0, 2, 8, 10, 12, 14, 0, 0, 4, 8, 10, 12, 14, 0, 0, 0, 0, 4, + 8, 10, 12, 14, 0, 0, 2, 4, 8, 10, 12, 14, 0, 0, 0, 2, 4, 8, + 10, 12, 14, 0, 6, 8, 10, 12, 14, 0, 0, 0, 0, 6, 8, 10, 12, 14, + 0, 0, 2, 6, 8, 10, 12, 14, 0, 0, 0, 2, 6, 8, 10, 12, 14, 0, + 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, + 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 Idx32x4FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[16 * 16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, // + 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, // + 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 Idx64x2FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t packed_array[4 * 16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); +} + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_API Vec128 Compress(hwy::SizeTag<2> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + const auto idx = detail::Idx16x8FromBits(mask_bits); + using D = Simd; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_API Vec128 Compress(hwy::SizeTag<4> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + using D = Simd; + using TI = MakeSigned; + const Rebind di; +#if HWY_TARGET == HWY_AVX3 + return BitCast(D(), Vec128{_mm_maskz_compress_epi32( + mask_bits, BitCast(di, v).raw)}); +#else + const auto idx = detail::Idx32x4FromBits(mask_bits); + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +#endif +} + +template +HWY_API Vec128 Compress(hwy::SizeTag<8> /*tag*/, Vec128 v, + const uint64_t mask_bits) { + using D = Simd; + using TI = MakeSigned; + const Rebind di; +#if HWY_TARGET == HWY_AVX3 + return BitCast(D(), Vec128{_mm_maskz_compress_epi64( + mask_bits, BitCast(di, v).raw)}); +#else + const auto idx = detail::Idx64x2FromBits(mask_bits); + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +#endif +} + +} // namespace detail + +template +HWY_API Vec128 Compress(Vec128 v, const Mask128 mask) { + return detail::Compress(hwy::SizeTag(), v, + detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec128 v, const Mask128 mask, + Simd d, T* HWY_RESTRICT aligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + Store(detail::Compress(hwy::SizeTag(), v, mask_bits), d, aligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +// 128 bits +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, Full128 d, + uint8_t* HWY_RESTRICT unaligned) { + const auto k5 = Set(d, 5); + const auto k6 = Set(d, 6); + + // Shuffle (v0,v1,v2) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = Load(d, tbl_r0); + const auto shuf_g0 = Load(d, tbl_g0); // cannot reuse r0 due to 5 in MSB + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(v0, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(v1, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(v2, shuf_b0); // .4..3..2..1..0.. + const auto int0 = r0 | g0 | b0; + StoreU(int0, d, unaligned + 0 * 16); + + // Second vector: g10,r10, bgr[9:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. + const auto r1 = TableLookupBytes(v0, shuf_r1); + const auto g1 = TableLookupBytes(v1, shuf_g1); + const auto b1 = TableLookupBytes(v2, shuf_b1); + const auto int1 = r1 | g1 | b1; + StoreU(int1, d, unaligned + 1 * 16); + + // Third vector: bgr[15:11], b10 + const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. + const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. + const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A + const auto r2 = TableLookupBytes(v0, shuf_r2); + const auto g2 = TableLookupBytes(v1, shuf_g2); + const auto b2 = TableLookupBytes(v2, shuf_b2); + const auto int2 = r2 | g2 | b2; + StoreU(int2, d, unaligned + 2 * 16); +} + +// 64 bits +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, Simd d, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and first result. + const Full128 d_full; + const auto k5 = Set(d_full, 5); + const auto k6 = Set(d_full, 6); + + const Vec128 full_a{v0.raw}; + const Vec128 full_b{v1.raw}; + const Vec128 full_c{v2.raw}; + + // Shuffle (v0,v1,v2) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = Load(d_full, tbl_r0); + const auto shuf_g0 = Load(d_full, tbl_g0); // cannot reuse r0 due to 5 in MSB + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(full_a, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(full_b, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(full_c, shuf_b0); // .4..3..2..1..0.. + const auto int0 = r0 | g0 | b0; + StoreU(int0, d_full, unaligned + 0 * 16); + + // Second (HALF) vector: bgr[7:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // ..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // .7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // 7..6..5. + const auto r1 = TableLookupBytes(full_a, shuf_r1); + const auto g1 = TableLookupBytes(full_b, shuf_g1); + const auto b1 = TableLookupBytes(full_c, shuf_b1); + const decltype(Zero(d)) int1{(r1 | g1 | b1).raw}; + StoreU(int1, d, unaligned + 1 * 16); +} + +// <= 32 bits +template +HWY_API void StoreInterleaved3(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 d_full; + + const Vec128 full_a{v0.raw}; + const Vec128 full_b{v1.raw}; + const Vec128 full_c{v2.raw}; + + // Shuffle (v0,v1,v2) vector bytes to bgr[3:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, // + 0x80, 0x80, 0x80, 0x80}; + const auto shuf_r0 = Load(d_full, tbl_r0); + const auto shuf_g0 = CombineShiftRightBytes<15>(shuf_r0, shuf_r0); + const auto shuf_b0 = CombineShiftRightBytes<14>(shuf_r0, shuf_r0); + const auto r0 = TableLookupBytes(full_a, shuf_r0); // ......3..2..1..0 + const auto g0 = TableLookupBytes(full_b, shuf_g0); // .....3..2..1..0. + const auto b0 = TableLookupBytes(full_c, shuf_b0); // ....3..2..1..0.. + const auto int0 = r0 | g0 | b0; + alignas(16) uint8_t buf[16]; + StoreU(int0, d_full, buf); + CopyBytes(buf, unaligned); +} + +// ------------------------------ StoreInterleaved4 + +// 128 bits +HWY_API void StoreInterleaved4(const Vec128 v0, + const Vec128 v1, + const Vec128 v2, + const Vec128 v3, Full128 d, + uint8_t* HWY_RESTRICT unaligned) { + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto ba8 = ZipUpper(v0, v1); + const auto dc8 = ZipUpper(v2, v3); + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + const auto dcba_4 = ZipUpper(ba0, dc0); // d..a7 d..a4 + const auto dcba_8 = ZipLower(ba8, dc8); // d..aB d..a8 + const auto dcba_C = ZipUpper(ba8, dc8); // d..aF d..aC + StoreU(BitCast(d, dcba_0), d, unaligned + 0 * 16); + StoreU(BitCast(d, dcba_4), d, unaligned + 1 * 16); + StoreU(BitCast(d, dcba_8), d, unaligned + 2 * 16); + StoreU(BitCast(d, dcba_C), d, unaligned + 3 * 16); +} + +// 64 bits +HWY_API void StoreInterleaved4(const Vec128 in0, + const Vec128 in1, + const Vec128 in2, + const Vec128 in3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Vec128 v0{in0.raw}; + const Vec128 v1{in1.raw}; + const Vec128 v2{in2.raw}; + const Vec128 v3{in3.raw}; + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + const auto dcba_4 = ZipUpper(ba0, dc0); // d..a7 d..a4 + const Full128 d_full; + StoreU(BitCast(d_full, dcba_0), d_full, unaligned + 0 * 16); + StoreU(BitCast(d_full, dcba_4), d_full, unaligned + 1 * 16); +} + +// <= 32 bits +template +HWY_API void StoreInterleaved4(const Vec128 in0, + const Vec128 in1, + const Vec128 in2, + const Vec128 in3, + Simd /*tag*/, + uint8_t* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Vec128 v0{in0.raw}; + const Vec128 v1{in1.raw}; + const Vec128 v2{in2.raw}; + const Vec128 v3{in3.raw}; + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b3 a3 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d3 c3 .. d0 c0 + const auto dcba_0 = ZipLower(ba0, dc0); // d..a3 d..a0 + alignas(16) uint8_t buf[16]; + const Full128 d_full; + StoreU(BitCast(d_full, dcba_0), d_full, buf); + CopyBytes<4 * N>(buf, unaligned); +} + +// ------------------------------ Reductions + +namespace detail { + +// For u32/i32/f32. +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = v3210 + v1032; + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Min(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128 v3210) { + const Vec128 v1032 = Shuffle1032(v3210); + const Vec128 v31_20_31_20 = Max(v3210, v1032); + const Vec128 v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// For u64/i64/f64. +template +HWY_API Vec128 SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_API Vec128 MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_API Vec128 MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128 v10) { + const Vec128 v01 = Shuffle01(v10); + return Max(v10, v01); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the sum in each lane. +template +HWY_API Vec128 SumOfLanes(const Vec128 v) { + return detail::SumOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MinOfLanes(const Vec128 v) { + return detail::MinOfLanes(hwy::SizeTag(), v); +} +template +HWY_API Vec128 MaxOfLanes(const Vec128 v) { + return detail::MaxOfLanes(hwy::SizeTag(), v); +} + +// ================================================== Operator wrapper + +// These apply to all x86_*-inl.h because there are no restrictions on V. + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/highway/hwy/ops/x86_256-inl.h new file mode 100644 index 000000000000..265632a174e1 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_256-inl.h @@ -0,0 +1,2907 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +#include // AVX2+ +#if defined(_MSC_VER) && defined(__clang__) +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +#include +#include +#include +#include +#endif + +#include +#include + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct Raw256 { + using type = __m256i; +}; +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +template +using Full256 = Simd; + +template +class Vec256 { + using Raw = typename Raw256::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Integer: FF..FF or 0. Float: MSB, all other bits undefined - see README. +template +class Mask256 { + using Raw = typename Raw256::type; + + public: + Raw raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +HWY_API __m256i BitCastToInteger(__m256i v) { return v; } +HWY_API __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_API __m256i BitCastToInteger(__m256d v) { return _mm256_castpd_si256(v); } + +template +HWY_API Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template +HWY_API Vec256 BitCastFromByte(Full256 /* tag */, Vec256 v) { + return Vec256{BitCastFromInteger256()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 BitCast(Full256 d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_si256()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_ps()}; +} +HWY_API Vec256 Zero(Full256 /* tag */) { + return Vec256{_mm256_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { + return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { + return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { + return Vec256{_mm256_set1_epi32(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { + return Vec256{ + _mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { + return Vec256{_mm256_set1_epi8(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { + return Vec256{_mm256_set1_epi16(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { + return Vec256{_mm256_set1_epi32(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { + return Vec256{_mm256_set1_epi64x(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const float t) { + return Vec256{_mm256_set1_ps(t)}; +} +HWY_API Vec256 Set(Full256 /* tag */, const double t) { + return Vec256{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec256 Undefined(Full256 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec256{_mm256_undefined_si256()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_ps()}; +} +HWY_API Vec256 Undefined(Full256 /* tag */) { + return Vec256{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256 And(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_si256(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(const Vec256 not_mask, + const Vec256 mask) { + return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_si256(a.raw, b.raw)}; +} + +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not + +template +HWY_API Vec256 Not(const Vec256 v) { + using TU = MakeUnsigned; +#if HWY_TARGET == HWY_AVX3 + const __m256i vu = BitCast(Full256(), v).raw; + return BitCast(Full256(), + Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(Full256(), Vec256{_mm256_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full256 d; + const auto msb = SignBit(d); + +#if HWY_TARGET == HWY_AVX3 + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m256i out = _mm256_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { +#if HWY_TARGET == HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(Full256()), sign)); +#endif +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { + return Vec256{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(const Mask256 mask, + const Vec256 yes, + const Vec256 no) { + return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(Full256(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(Full256(), mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + const auto zero = Zero(Full256()); + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + const Full256 d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API Mask256 RebindMask(Full256 d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(d_to, VecFromMask(Full256(), m))); +} + +// ------------------------------ Equality + +// Unsigned +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} +HWY_API Mask256 operator==(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Strict inequality + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +// Signed/float < +HWY_API Mask256 operator<(const Vec256 a, + const Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) < + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(b.raw, a.raw)}; +#endif +} +HWY_API Mask256 operator<(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(b.raw, a.raw)}; +} +HWY_API Mask256 operator<(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(b.raw, a.raw)}; +} +HWY_API Mask256 operator<(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(b.raw, a.raw)}; +} +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_LT_OQ)}; +} +HWY_API Mask256 operator<(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_LT_OQ)}; +} + +// Signed/float > +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 operator>(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +// Float <= >= +HWY_API Mask256 operator<=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_LE_OQ)}; +} +HWY_API Mask256 operator<=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_LE_OQ)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_pd(a.raw, b.raw)}; +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ Saturating addition + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi8(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi32(v.raw)}; +} + +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(Full256(), mask); +} +HWY_API Vec256 Abs(const Vec256 v) { + const Vec256 mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(Full256(), mask); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{_mm256_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return VecFromMask(v < Zero(Full256())); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { +#if HWY_TARGET == HWY_AVX2 + return VecFromMask(v < Zero(Full256())); +#else + return Vec256{_mm256_srai_epi64(v.raw, 63)}; +#endif +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_srai_epi64(v.raw, kBits)}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, (0xFF << bits) & 0xFF); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, 0xFF >> bits); +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Negate + +template +HWY_API Vec256 Neg(const Vec256 v) { + return Xor(v, SignBit(Full256())); +} + +template +HWY_API Vec256 Neg(const Vec256 v) { + return Zero(Full256()) - v; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { + return Vec256{_mm256_rcp_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(const Vec256 mul, + const Vec256 x, + const Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 NegMulSub(const Vec256 mul, + const Vec256 x, + const Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256 Sqrt(const Vec256 v) { + return Vec256{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { + return Vec256{_mm256_rsqrt_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Round(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Trunc(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Ceil(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Floor(const Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256{ + _mm256_load_si256(reinterpret_cast(aligned))}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ps(aligned)}; +} +HWY_API Vec256 Load(Full256 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_pd(aligned)}; +} + +template +HWY_API Vec256 LoadU(Full256 /* tag */, const T* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_si256(reinterpret_cast(p))}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ps(p)}; +} +HWY_API Vec256 LoadU(Full256 /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_pd(p)}; +} + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec256 LoadDup128(Full256 /* tag */, const T* HWY_RESTRICT p) { +#if HWY_LOADDUP_ASM + __m256i out; + asm("vbroadcasti128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec256{out}; +#else + return Vec256{_mm256_broadcastsi128_si256(LoadU(Full128(), p).raw)}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const float* const HWY_RESTRICT p) { +#if HWY_LOADDUP_ASM + __m256 out; + asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec256{out}; +#else + return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; +#endif +} +HWY_API Vec256 LoadDup128(Full256 /* tag */, + const double* const HWY_RESTRICT p) { +#if HWY_LOADDUP_ASM + __m256d out; + asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec256{out}; +#else + return Vec256{ + _mm256_broadcast_pd(reinterpret_cast(p))}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(Vec256 v, Full256 /* tag */, + T* HWY_RESTRICT aligned) { + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec256 v, Full256 /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +#if HWY_TARGET == HWY_AVX3 +namespace detail { + +template +HWY_API void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_API void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_API void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256 v, + Full256 /* tag */, T* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template <> +HWY_INLINE void ScatterOffset(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} +template <> +HWY_INLINE void ScatterIndex(Vec256 v, Full256 /* tag */, + float* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template <> +HWY_INLINE void ScatterOffset(Vec256 v, + Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} +template <> +HWY_INLINE void ScatterIndex(Vec256 v, + Full256 /* tag */, + double* HWY_RESTRICT base, + const Vec256 index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +#else + +template +HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Simd(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template +HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Simd(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather + +namespace detail { + +template +HWY_API Vec256 GatherOffset(hwy::SizeTag<4> /* tag */, Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_API Vec256 GatherIndex(hwy::SizeTag<4> /* tag */, Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), index.raw, 4)}; +} + +template +HWY_API Vec256 GatherOffset(hwy::SizeTag<8> /* tag */, Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), offset.raw, 1)}; +} +template +HWY_API Vec256 GatherIndex(hwy::SizeTag<8> /* tag */, Full256 /* tag */, + const T* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), index.raw, 8)}; +} + +} // namespace detail + +template +HWY_API Vec256 GatherOffset(Full256 d, const T* HWY_RESTRICT base, + const Vec256 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec256 GatherIndex(Full256 d, const T* HWY_RESTRICT base, + const Vec256 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template <> +HWY_INLINE Vec256 GatherOffset(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; +} +template <> +HWY_INLINE Vec256 GatherIndex(Full256 /* tag */, + const float* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; +} + +template <> +HWY_INLINE Vec256 GatherOffset(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 offset) { + return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; +} +template <> +HWY_INLINE Vec256 GatherIndex(Full256 /* tag */, + const double* HWY_RESTRICT base, + const Vec256 index) { + return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ Extract half + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return Vec128{_mm256_castsi256_si128(v.raw)}; +} +template <> +HWY_INLINE Vec128 LowerHalf(Vec256 v) { + return Vec128{_mm256_castps256_ps128(v.raw)}; +} +template <> +HWY_INLINE Vec128 LowerHalf(Vec256 v) { + return Vec128{_mm256_castpd256_pd128(v.raw)}; +} + +template +HWY_API Vec128 UpperHalf(Vec256 v) { + return Vec128{_mm256_extracti128_si256(v.raw, 1)}; +} +template <> +HWY_INLINE Vec128 UpperHalf(Vec256 v) { + return Vec128{_mm256_extractf128_ps(v.raw, 1)}; +} +template <> +HWY_INLINE Vec128 UpperHalf(Vec256 v) { + return Vec128{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on GCC until 10.1. For older GCC, we can +// still obtain the desired code thanks to pattern recognition; note that the +// expensive insert instruction is not actually generated, see +// https://gcc.godbolt.org/z/1MKGaP. + +template +HWY_API Vec256 ZeroExtendVector(Vec128 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec256{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#else + return Vec256{_mm256_zextsi128_si256(lo.raw)}; +#endif +} +template <> +HWY_INLINE Vec256 ZeroExtendVector(Vec128 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#else + return Vec256{_mm256_zextps128_ps256(lo.raw)}; +#endif +} +template <> +HWY_INLINE Vec256 ZeroExtendVector(Vec128 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#else + return Vec256{_mm256_zextpd128_pd256(lo.raw)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec256 Combine(Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(lo); + return Vec256{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; +} +template <> +HWY_INLINE Vec256 Combine(Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(lo); + return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +template <> +HWY_INLINE Vec256 Combine(Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(lo); + return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ Shift vector by constant #bytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec256 ShiftLeftBytes(const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return Vec256{_mm256_slli_si256(v.raw, kBytes)}; +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + const Full256 d8; + const Full256 d; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec256 ShiftRightBytes(const Vec256 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return Vec256{_mm256_srli_si256(v.raw, kBytes)}; +} + +template +HWY_API Vec256 ShiftRightLanes(const Vec256 v) { + const Full256 d8; + const Full256 d; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ Extract from 2x 128-bit at constant offset + +// Extracts 128 bits from by skipping the least-significant kBytes. +template +HWY_API Vec256 CombineShiftRightBytes(const Vec256 hi, + const Vec256 lo) { + const Full256 d8; + const Vec256 extracted_bytes{ + _mm256_alignr_epi8(BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}; + return BitCast(Full256(), extracted_bytes); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +template +HWY_API Indices256 SetTableIndices(const Full256, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = 32 / sizeof(T); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast(N)); + } +#endif + return Indices256{LoadU(Full256(), idx).raw}; +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +// ------------------------------ Interleave lanes + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +// ------------------------------ Zip lanes + +// Same as interleave_*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. + +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 ZipLower(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 ZipUpper(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} + +// ------------------------------ Blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency. _mm256_permute2x128_si256 is +// slow on Zen1 (8 uops); we can avoid it for LowerLower and UpperLower, and on +// UpperUpper at the cost of one extra cycle/instruction. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec256 ConcatLowerLower(const Vec256 hi, const Vec256 lo) { + return Vec256{_mm256_inserti128_si256(lo.raw, LowerHalf(hi).raw, 1)}; +} +template <> +HWY_INLINE Vec256 ConcatLowerLower(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(hi).raw, 1)}; +} +template <> +HWY_INLINE Vec256 ConcatLowerLower(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec256 ConcatLowerUpper(const Vec256 hi, const Vec256 lo) { + return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; +} +template <> +HWY_INLINE Vec256 ConcatLowerUpper(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +template <> +HWY_INLINE Vec256 ConcatLowerUpper(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec256 ConcatUpperLower(const Vec256 hi, const Vec256 lo) { + return Vec256{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; +} +template <> +HWY_INLINE Vec256 ConcatUpperLower(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +template <> +HWY_INLINE Vec256 ConcatUpperLower(const Vec256 hi, + const Vec256 lo) { + return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec256 ConcatUpperUpper(const Vec256 hi, const Vec256 lo) { + return ConcatUpperLower(hi, ZeroExtendVector(UpperHalf(lo))); +} + +// ------------------------------ Odd/even lanes + +namespace detail { + +template +HWY_API Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, + const Vec256 b) { + const Full256 d; + const Full256 d8; + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); +} +template +HWY_API Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; +} +template +HWY_API Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} +template +HWY_API Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +} // namespace detail + +template +HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template <> +HWY_INLINE Vec256 OddEven(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +template <> +HWY_INLINE Vec256 OddEven(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, + const Vec256 from) { + return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ Shl (Mul, ZipLower) + +#if HWY_TARGET != HWY_AVX3 +namespace detail { + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_API Vec256> Pow2(const Vec256 v) { + const Full256 d; + const Full256 df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(zero, upper); + const auto f1 = ZipUpper(zero, upper); + // Do not use ConvertTo because it checks for overflow, which is redundant + // because we only care about v in [0, 16). + const Vec256 bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; + const Vec256 bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; + return Vec256>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; +} + +} // namespace detail +#endif // HWY_TARGET != HWY_AVX3 + +HWY_API Vec256 operator<<(const Vec256 v, + const Vec256 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return v * detail::Pow2(bits); +#endif +} + +HWY_API Vec256 operator<<(const Vec256 v, + const Vec256 bits) { + return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator<<(const Vec256 v, + const Vec256 bits) { + return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec256 operator<<(const Vec256 v, const Vec256 bits) { + const Full256 di; + const Full256> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + const Full256 d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), v, out); +#endif +} + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { + return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { + return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { + return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(const Vec256 v, + const Vec256 bits) { +#if HWY_TARGET == HWY_AVX3 + return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256(), v, bits); +#endif +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtph_ps(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtps_pd(v.raw)}; +} + +HWY_API Vec256 PromoteTo(Full256 /* tag */, + const Vec128 v) { + return Vec256{_mm256_cvtepi32_pd(v.raw)}; +} + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec256 PromoteTo(Full256 /* tag */, + Vec128 v) { + return Vec256{_mm256_cvtepi32_epi64(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec256 v) { + const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); + const __m128i u16 = _mm256_castsi256_si128(u16_concat); + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); + return Vec128{_mm_packus_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Simd /* tag */, + const Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return Vec128{_mm_packs_epi16(i16, i16)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return Vec128{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + return Vec128{_mm256_cvtpd_ps(v.raw)}; +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec256 v) { + const auto clamped = detail::ClampF64ToI32Max(Full256(), v); + return Vec128{_mm256_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(quad); + const auto pair = LowerHalf(lo | hi); + return BitCast(Simd(), pair); +} + +// ------------------------------ Convert integer <=> floating point + +HWY_API Vec256 ConvertTo(Full256 /* tag */, + const Vec256 v) { + return Vec256{_mm256_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec256 ConvertTo(Full256 dd, const Vec256 v) { +#if HWY_TARGET == HWY_AVX3 + (void)dd; + return Vec256{_mm256_cvtepi64_pd(v.raw)}; +#else + alignas(32) int64_t lanes_i[4]; + Store(v, Full256(), lanes_i); + alignas(32) double lanes_d[4]; + for (size_t i = 0; i < 4; ++i) { + lanes_d[i] = static_cast(lanes_i[i]); + } + return Load(dd, lanes_d); +#endif +} + +// Truncates (rounds toward zero). +HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { + return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); +} + +HWY_API Vec256 ConvertTo(Full256 di, const Vec256 v) { +#if HWY_TARGET == HWY_AVX3 + return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); +#else + alignas(32) double lanes_d[4]; + Store(v, Full256(), lanes_d); + alignas(32) int64_t lanes_i[4]; + for (size_t i = 0; i < 4; ++i) { + if (lanes_d[i] >= static_cast(LimitsMax())) { + lanes_i[i] = LimitsMax(); + } else if (lanes_d[i] <= static_cast(LimitsMin())) { + lanes_i[i] = LimitsMin(); + } else { + lanes_i[i] = static_cast(lanes_d[i]); + } + } + return Load(di, lanes_i); +#endif +} + +HWY_API Vec256 NearestInt(const Vec256 v) { + const Full256 di; + return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); +} + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec256 Iota(const Full256 d, const T2 first) { + HWY_ALIGN T lanes[32 / sizeof(T)]; + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { + const Full256 d; + const Full256 d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast(_mm256_movemask_epi8(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_ARCH_X86_64 + const uint64_t sign_bits8 = BitsFromMask(hwy::SizeTag<1>(), mask); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u64(sign_bits8, 0xAAAAAAAAull); +#else + // Slow workaround for 32-bit builds, which lack _pext_u64. + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); + return static_cast(_mm256_movemask_epi8(compressed)); + +#endif +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_ps(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + const Full256 d; + const Full256 df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_pd(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(const Mask256 mask) { + return BitsFromMask(hwy::SizeTag(), mask); +} + +} // namespace detail + +template +HWY_INLINE size_t StoreMaskBits(const Mask256 mask, uint8_t* p) { + const uint64_t bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (4 + sizeof(T) - 1) / sizeof(T); + CopyBytes(&bits, p); + return kNumBytes; +} + +template +HWY_API bool AllFalse(const Mask256 mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template +HWY_API bool AllTrue(const Mask256 mask) { + constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; + return detail::BitsFromMask(mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(const Mask256 mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +// ------------------------------ Compress + +namespace detail { + +HWY_INLINE Vec256 Idx32x8FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Full256 d32; + + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + 0x00000000, 0x00000000, 0x00000001, 0x00000010, 0x00000002, 0x00000020, + 0x00000021, 0x00000210, 0x00000003, 0x00000030, 0x00000031, 0x00000310, + 0x00000032, 0x00000320, 0x00000321, 0x00003210, 0x00000004, 0x00000040, + 0x00000041, 0x00000410, 0x00000042, 0x00000420, 0x00000421, 0x00004210, + 0x00000043, 0x00000430, 0x00000431, 0x00004310, 0x00000432, 0x00004320, + 0x00004321, 0x00043210, 0x00000005, 0x00000050, 0x00000051, 0x00000510, + 0x00000052, 0x00000520, 0x00000521, 0x00005210, 0x00000053, 0x00000530, + 0x00000531, 0x00005310, 0x00000532, 0x00005320, 0x00005321, 0x00053210, + 0x00000054, 0x00000540, 0x00000541, 0x00005410, 0x00000542, 0x00005420, + 0x00005421, 0x00054210, 0x00000543, 0x00005430, 0x00005431, 0x00054310, + 0x00005432, 0x00054320, 0x00054321, 0x00543210, 0x00000006, 0x00000060, + 0x00000061, 0x00000610, 0x00000062, 0x00000620, 0x00000621, 0x00006210, + 0x00000063, 0x00000630, 0x00000631, 0x00006310, 0x00000632, 0x00006320, + 0x00006321, 0x00063210, 0x00000064, 0x00000640, 0x00000641, 0x00006410, + 0x00000642, 0x00006420, 0x00006421, 0x00064210, 0x00000643, 0x00006430, + 0x00006431, 0x00064310, 0x00006432, 0x00064320, 0x00064321, 0x00643210, + 0x00000065, 0x00000650, 0x00000651, 0x00006510, 0x00000652, 0x00006520, + 0x00006521, 0x00065210, 0x00000653, 0x00006530, 0x00006531, 0x00065310, + 0x00006532, 0x00065320, 0x00065321, 0x00653210, 0x00000654, 0x00006540, + 0x00006541, 0x00065410, 0x00006542, 0x00065420, 0x00065421, 0x00654210, + 0x00006543, 0x00065430, 0x00065431, 0x00654310, 0x00065432, 0x00654320, + 0x00654321, 0x06543210, 0x00000007, 0x00000070, 0x00000071, 0x00000710, + 0x00000072, 0x00000720, 0x00000721, 0x00007210, 0x00000073, 0x00000730, + 0x00000731, 0x00007310, 0x00000732, 0x00007320, 0x00007321, 0x00073210, + 0x00000074, 0x00000740, 0x00000741, 0x00007410, 0x00000742, 0x00007420, + 0x00007421, 0x00074210, 0x00000743, 0x00007430, 0x00007431, 0x00074310, + 0x00007432, 0x00074320, 0x00074321, 0x00743210, 0x00000075, 0x00000750, + 0x00000751, 0x00007510, 0x00000752, 0x00007520, 0x00007521, 0x00075210, + 0x00000753, 0x00007530, 0x00007531, 0x00075310, 0x00007532, 0x00075320, + 0x00075321, 0x00753210, 0x00000754, 0x00007540, 0x00007541, 0x00075410, + 0x00007542, 0x00075420, 0x00075421, 0x00754210, 0x00007543, 0x00075430, + 0x00075431, 0x00754310, 0x00075432, 0x00754320, 0x00754321, 0x07543210, + 0x00000076, 0x00000760, 0x00000761, 0x00007610, 0x00000762, 0x00007620, + 0x00007621, 0x00076210, 0x00000763, 0x00007630, 0x00007631, 0x00076310, + 0x00007632, 0x00076320, 0x00076321, 0x00763210, 0x00000764, 0x00007640, + 0x00007641, 0x00076410, 0x00007642, 0x00076420, 0x00076421, 0x00764210, + 0x00007643, 0x00076430, 0x00076431, 0x00764310, 0x00076432, 0x00764320, + 0x00764321, 0x07643210, 0x00000765, 0x00007650, 0x00007651, 0x00076510, + 0x00007652, 0x00076520, 0x00076521, 0x00765210, 0x00007653, 0x00076530, + 0x00076531, 0x00765310, 0x00076532, 0x00765320, 0x00765321, 0x07653210, + 0x00007654, 0x00076540, 0x00076541, 0x00765410, 0x00076542, 0x00765420, + 0x00765421, 0x07654210, 0x00076543, 0x00765430, 0x00765431, 0x07654310, + 0x00765432, 0x07654320, 0x07654321, 0x76543210}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +HWY_INLINE Vec256 Idx64x4FromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + const Full256 d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t packed_array[16 * 8] = { + 0, 1, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 0, 1, 0, 1, 0, 1, // + 2, 3, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 0, 1, 0, 1, // + 4, 5, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 4, 5, 0, 1, 0, 1, // + 2, 3, 4, 5, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 4, 5, 0, 1, // + 6, 7, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 6, 7, 0, 1, 0, 1, // + 2, 3, 6, 7, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 6, 7, 0, 1, // + 4, 5, 6, 7, 0, 1, 0, 1, /**/ 0, 1, 4, 5, 6, 7, 0, 1, + 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d32, packed_array + 8 * mask_bits); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_API Vec256 Compress(hwy::SizeTag<4> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + const auto vu = BitCast(Full256(), v); +#if HWY_TARGET == HWY_AVX3 + const __m256i ret = + _mm256_maskz_compress_epi32(static_cast<__mmask8>(mask_bits), vu.raw); +#else + const Vec256 idx = detail::Idx32x8FromBits(mask_bits); + const __m256i ret = _mm256_permutevar8x32_epi32(vu.raw, idx.raw); +#endif + return BitCast(Full256(), Vec256{ret}); +} + +template +HWY_API Vec256 Compress(hwy::SizeTag<8> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + const auto vu = BitCast(Full256(), v); +#if HWY_TARGET == HWY_AVX3 + const __m256i ret = + _mm256_maskz_compress_epi64(static_cast<__mmask8>(mask_bits), vu.raw); +#else + const Vec256 idx = detail::Idx64x4FromBits(mask_bits); + const __m256i ret = _mm256_permutevar8x32_epi32(vu.raw, idx.raw); +#endif + return BitCast(Full256(), Vec256{ret}); +} + +// Otherwise, defined in x86_512-inl.h so it can use wider vectors. +#if HWY_TARGET != HWY_AVX3 + +// LUTs are infeasible for 2^16 possible masks. Promoting to 32-bit and using +// the native Compress is probably more efficient than 2 LUTs. +template +HWY_API Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + using D = Full256; + const Rebind du; + const Repartition dw; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const auto promoted0 = PromoteTo(dw, LowerHalf(vu16)); + const auto promoted1 = PromoteTo(dw, UpperHalf(vu16)); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = Compress(hwy::SizeTag<4>(), promoted0, mask_bits0); + const auto compressed1 = Compress(hwy::SizeTag<4>(), promoted1, mask_bits1); + + const Half dh; + const auto demoted0 = ZeroExtendVector(DemoteTo(dh, compressed0)); + const auto demoted1 = ZeroExtendVector(DemoteTo(dh, compressed1)); + + const size_t count0 = PopCount(mask_bits0); + // Now combine by shifting demoted1 up. AVX2 lacks VPERMW, so start with + // VPERMD for shifting at 4 byte granularity. + alignas(32) constexpr int32_t iota4[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7}; + const auto indices = SetTableIndices(dw, iota4 + 8 - count0 / 2); + const auto shift1_multiple4 = + BitCast(du, TableLookupLanes(BitCast(dw, demoted1), indices)); + + // Whole-register unconditional shift by 2 bytes. + // TODO(janwas): slow on AMD, use 2 shifts + permq + OR instead? + const __m256i lo_zz = _mm256_permute2x128_si256(shift1_multiple4.raw, + shift1_multiple4.raw, 0x08); + const auto shift1_multiple2 = + Vec256{_mm256_alignr_epi8(shift1_multiple4.raw, lo_zz, 14)}; + + // Make the shift conditional on the lower bit of count0. + const auto m_odd = TestBit(Set(du, count0), Set(du, 1)); + const auto shifted1 = IfThenElse(m_odd, shift1_multiple2, shift1_multiple4); + + // Blend the lower and shifted upper parts. + constexpr uint16_t on = 0xFFFF; + alignas(32) constexpr uint16_t lower_lanes[32] = {HWY_REP4(on), HWY_REP4(on), + HWY_REP4(on), HWY_REP4(on)}; + const auto m_lower = MaskFromVec(LoadU(du, lower_lanes + 16 - count0)); + return BitCast(D(), IfThenElse(m_lower, demoted0, shifted1)); +} + +#endif // HWY_TARGET != HWY_AVX3 + +} // namespace detail + +// Otherwise, defined in x86_512-inl.h after detail::Compress. +#if HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { + return detail::Compress(hwy::SizeTag(), v, + detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, Full256 d, + T* HWY_RESTRICT aligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + // NOTE: it is tempting to split inputs into two halves for 16-bit lanes, but + // using StoreU to concatenate the results would cause page faults if + // `aligned` is the last valid vector. Instead rely on in-register splicing. + Store(detail::Compress(hwy::SizeTag(), v, mask_bits), d, aligned); + return PopCount(mask_bits); +} + +#endif // HWY_TARGET != HWY_AVX3 + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes, ConcatUpperLower) + +HWY_API void StoreInterleaved3(const Vec256 v0, + const Vec256 v1, + const Vec256 v2, Full256 d, + uint8_t* HWY_RESTRICT unaligned) { + const auto k5 = Set(d, 5); + const auto k6 = Set(d, 6); + + // Shuffle (v0,v1,v2) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = LoadDup128(d, tbl_r0); + const auto shuf_g0 = LoadDup128(d, tbl_g0); // cannot reuse r0 due to 5 + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(v0, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(v1, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(v2, shuf_b0); // .4..3..2..1..0.. + const auto interleaved_10_00 = r0 | g0 | b0; + + // Second vector: g10,r10, bgr[9:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. + const auto r1 = TableLookupBytes(v0, shuf_r1); + const auto g1 = TableLookupBytes(v1, shuf_g1); + const auto b1 = TableLookupBytes(v2, shuf_b1); + const auto interleaved_15_05 = r1 | g1 | b1; + + // We want to write the lower halves of the interleaved vectors, then the + // upper halves. We could obtain 10_05 and 15_0A via ConcatUpperLower, but + // that would require two ununaligned stores. For the lower halves, we can + // merge two 128-bit stores for the same swizzling cost: + const auto out0 = ConcatLowerLower(interleaved_15_05, interleaved_10_00); + StoreU(out0, d, unaligned + 0 * 32); + + // Third vector: bgr[15:11], b10 + const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. + const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. + const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A + const auto r2 = TableLookupBytes(v0, shuf_r2); + const auto g2 = TableLookupBytes(v1, shuf_g2); + const auto b2 = TableLookupBytes(v2, shuf_b2); + const auto interleaved_1A_0A = r2 | g2 | b2; + + const auto out1 = ConcatUpperLower(interleaved_10_00, interleaved_1A_0A); + StoreU(out1, d, unaligned + 1 * 32); + + const auto out2 = ConcatUpperUpper(interleaved_1A_0A, interleaved_15_05); + StoreU(out2, d, unaligned + 2 * 32); +} + +// ------------------------------ StoreInterleaved4 + +HWY_API void StoreInterleaved4(const Vec256 v0, + const Vec256 v1, + const Vec256 v2, + const Vec256 v3, Full256 d, + uint8_t* HWY_RESTRICT unaligned) { + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto ba8 = ZipUpper(v0, v1); + const auto dc8 = ZipUpper(v2, v3); + const auto dcba_0 = ZipLower(ba0, dc0); // d..a13 d..a10 | d..a03 d..a00 + const auto dcba_4 = ZipUpper(ba0, dc0); // d..a17 d..a14 | d..a07 d..a04 + const auto dcba_8 = ZipLower(ba8, dc8); // d..a1B d..a18 | d..a0B d..a08 + const auto dcba_C = ZipUpper(ba8, dc8); // d..a1F d..a1C | d..a0F d..a0C + // Write lower halves, then upper. vperm2i128 is slow on Zen1 but we can + // efficiently combine two lower halves into 256 bits: + const auto out0 = BitCast(d, ConcatLowerLower(dcba_4, dcba_0)); + const auto out1 = BitCast(d, ConcatLowerLower(dcba_C, dcba_8)); + StoreU(out0, d, unaligned + 0 * 32); + StoreU(out1, d, unaligned + 1 * 32); + const auto out2 = BitCast(d, ConcatUpperUpper(dcba_4, dcba_0)); + const auto out3 = BitCast(d, ConcatUpperUpper(dcba_C, dcba_8)); + StoreU(out2, d, unaligned + 2 * 32); + StoreU(out3, d, unaligned + 3 * 32); +} + +// ------------------------------ Reductions + +namespace detail { + +// Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. +// Same logic as x86/128.h, but with Vec256 arguments. +template +HWY_API Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = v3210 + v1032; + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template +HWY_API Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Min(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Max(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +template +HWY_API Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return v10 + v01; +} +template +HWY_API Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Min(v10, v01); +} +template +HWY_API Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { + const auto v01 = Shuffle01(v10); + return Max(v10, v01); +} + +} // namespace detail + +// Supported for {uif}32x8, {uif}64x4. Returns the sum in each lane. +template +HWY_API Vec256 SumOfLanes(const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(vHL, vHL); + return detail::SumOfLanes(hwy::SizeTag(), vLH + vHL); +} +template +HWY_API Vec256 MinOfLanes(const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(vHL, vHL); + return detail::MinOfLanes(hwy::SizeTag(), Min(vLH, vHL)); +} +template +HWY_API Vec256 MaxOfLanes(const Vec256 vHL) { + const Vec256 vLH = ConcatLowerUpper(vHL, vHL); + return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/highway/hwy/ops/x86_512-inl.h new file mode 100644 index 000000000000..9f1595e59ee6 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_512-inl.h @@ -0,0 +1,3050 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +#include // AVX2+ +#if defined(_MSC_VER) && defined(__clang__) +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct Raw512 { + using type = __m512i; +}; +template <> +struct Raw512 { + using type = __m512; +}; +template <> +struct Raw512 { + using type = __m512d; +}; + +template +using Full512 = Simd; + +template +class Vec512 { + using Raw = typename Raw512::type; + + public: + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Template arg: sizeof(lane type) +template +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +// Mask register: one bit per lane. +template +class Mask512 { + using Raw = typename RawMask512::type; + + public: + Raw raw; +}; + +// ------------------------------ BitCast + +namespace detail { + +HWY_API __m512i BitCastToInteger(__m512i v) { return v; } +HWY_API __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_API __m512i BitCastToInteger(__m512d v) { return _mm512_castpd_si512(v); } + +template +HWY_API Vec512 BitCastToByte(Vec512 v) { + return Vec512{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template +HWY_API Vec512 BitCastFromByte(Full512 /* tag */, Vec512 v) { + return Vec512{BitCastFromInteger512()(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 BitCast(Full512 d, Vec512 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_si512()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_ps()}; +} +HWY_API Vec512 Zero(Full512 /* tag */) { + return Vec512{_mm512_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec512 Set(Full512 /* tag */, const uint8_t t) { + return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint16_t t) { + return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint32_t t) { + return Vec512{_mm512_set1_epi32(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const uint64_t t) { + return Vec512{ + _mm512_set1_epi64(static_cast(t))}; // NOLINT +} +HWY_API Vec512 Set(Full512 /* tag */, const int8_t t) { + return Vec512{_mm512_set1_epi8(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const int16_t t) { + return Vec512{_mm512_set1_epi16(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const int32_t t) { + return Vec512{_mm512_set1_epi32(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const int64_t t) { + return Vec512{_mm512_set1_epi64(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const float t) { + return Vec512{_mm512_set1_ps(t)}; +} +HWY_API Vec512 Set(Full512 /* tag */, const double t) { + return Vec512{_mm512_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec512 Undefined(Full512 /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512{_mm512_undefined_epi32()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_ps()}; +} +HWY_API Vec512 Undefined(Full512 /* tag */) { + return Vec512{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec512 Not(const Vec512 v) { + using TU = MakeUnsigned; + const __m512i vu = BitCast(Full512(), v).raw; + return BitCast(Full512(), + Vec512{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); +} + +// ------------------------------ And + +template +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { + return Vec512{_mm512_andnot_si512(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_si512(a.raw, b.raw)}; +} + +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { + return And(a, b); +} + +template +HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { + return Or(a, b); +} + +template +HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template +HWY_API Vec512 CopySign(const Vec512 magn, const Vec512 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + + const Full512 d; + const auto msb = SignBit(d); + + const Rebind, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m512i out = _mm512_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +} + +template +HWY_API Vec512 CopySignToAbs(const Vec512 abs, const Vec512 sign) { + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ------------------------------ Select/blend + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_API Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, const Mask512 mask, + const Vec512 yes, const Vec512 no) { + return Vec512{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, const Mask512 mask, + const Vec512 yes, const Vec512 no) { + return Vec512{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, const Mask512 mask, + const Vec512 yes, const Vec512 no) { + return Vec512{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, const Mask512 mask, + const Vec512 yes, const Vec512 no) { + return Vec512{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +template <> +HWY_INLINE Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +template <> +HWY_INLINE Vec512 IfThenElse(const Mask512 mask, + const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_API Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_API Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +template <> +HWY_INLINE Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +template <> +HWY_INLINE Vec512 IfThenElseZero(const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_API Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_API Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_API Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_API Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +template <> +HWY_INLINE Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +template <> +HWY_INLINE Vec512 IfThenZeroElse(const Mask512 mask, + const Vec512 no) { + return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec512 ZeroIfNegative(const Vec512 v) { + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ Saturating addition + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 AverageRound(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi8(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi32(v.raw)}; +} + +// These aren't native instructions, they also involve AND with constant. +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_ps(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_pd(v.raw)}; +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 d8; + // Use raw instead of BitCast to support N=1. + const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, (0xFF << bits) & 0xFF); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, 0xFF >> bits); +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { + return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const Full512 di; + const Full512 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Shl + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { + const Full512 di; + const Full512> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512 MulHigh(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 MulHigh(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512 MulEven(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 MulEven(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Xor(v, SignBit(Full512())); +} + +template +HWY_API Vec512 Neg(const Vec512 v) { + return Zero(Full512()) - v; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator/(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp14_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec512 AbsDiff(const Vec512 a, const Vec512 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 NegMulAdd(const Vec512 mul, + const Vec512 x, + const Vec512 add) { + return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 NegMulSub(const Vec512 mul, + const Vec512 x, + const Vec512 sub) { + return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec512 ApproximateReciprocalSqrt(const Vec512 v) { + return Vec512{_mm512_rsqrt14_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Round(const Vec512 v) { + return Vec512{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Trunc(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Ceil(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Floor(const Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API Mask512 RebindMask(Full512 /*tag*/, Mask512 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask512{m.raw}; +} + +namespace detail { + +template +HWY_API Mask512 TestBit(hwy::SizeTag<1> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_API Mask512 TestBit(hwy::SizeTag<2> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_API Mask512 TestBit(hwy::SizeTag<4> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_API Mask512 TestBit(hwy::SizeTag<8> /*tag*/, const Vec512 v, + const Vec512 bit) { + return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +// Unsigned +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +// Signed +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +// Float +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} +HWY_API Mask512 operator==(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +HWY_API Mask512 operator<(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(b.raw, a.raw)}; +} +HWY_API Mask512 operator<(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(b.raw, a.raw)}; +} +HWY_API Mask512 operator<(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(b.raw, a.raw)}; +} +HWY_API Mask512 operator<(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(b.raw, a.raw)}; +} +HWY_API Mask512 operator<(const Vec512 a, const Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_LT_OQ)}; +} +HWY_API Mask512 operator<(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_LT_OQ)}; +} + +// Signed/float > +HWY_API Mask512 operator>(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(const Vec512 a, const Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512 operator>(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +// Float <= >= +HWY_API Mask512 operator<=(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_LE_OQ)}; +} +HWY_API Mask512 operator<=(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_LE_OQ)}; +} +HWY_API Mask512 operator>=(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512 operator>=(const Vec512 a, + const Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_API Mask512 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi8_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi16_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi32_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512 v) { + return Mask512{_mm512_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} +HWY_API Mask512 MaskFromVec(const Vec512 v) { + return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi8(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi16(v.raw)}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; +} + +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512 VecFromMask(const Mask512 v) { + return Vec512{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; +} + +template +HWY_API Vec512 VecFromMask(Full512 /* tag */, const Mask512 v) { + return VecFromMask(v); +} + +// ------------------------------ Mask logical + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) && \ + (HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC >= 700 || \ + HWY_COMPILER_CLANG >= 800) +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif + +namespace detail { + +template +HWY_API Mask512 Not(hwy::SizeTag<1> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask64(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_API Mask512 Not(hwy::SizeTag<2> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask32(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_API Mask512 Not(hwy::SizeTag<4> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask16(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFFFF)}; +#endif +} +template +HWY_API Mask512 Not(hwy::SizeTag<8> /*tag*/, const Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask8(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFF)}; +#endif +} + +template +HWY_API Mask512 And(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_API Mask512 And(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_API Mask512 And(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} +template +HWY_API Mask512 And(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} + +template +HWY_API Mask512 AndNot(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_API Mask512 AndNot(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_API Mask512 AndNot(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} +template +HWY_API Mask512 AndNot(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} + +template +HWY_API Mask512 Or(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_API Mask512 Or(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_API Mask512 Or(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} +template +HWY_API Mask512 Or(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} + +template +HWY_API Mask512 Xor(hwy::SizeTag<1> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_API Mask512 Xor(hwy::SizeTag<2> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_API Mask512 Xor(hwy::SizeTag<4> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} +template +HWY_API Mask512 Xor(hwy::SizeTag<8> /*tag*/, const Mask512 a, + const Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask512 Not(const Mask512 m) { + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask512 And(const Mask512 a, Mask512 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 AndNot(const Mask512 a, Mask512 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Or(const Mask512 a, Mask512 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Xor(const Mask512 a, Mask512 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return VecFromMask(v < Zero(Full512())); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512 BroadcastSignBit(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, 63)}; +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API Vec512 Load(Full512 /* tag */, const T* HWY_RESTRICT aligned) { + return Vec512{ + _mm512_load_si512(reinterpret_cast(aligned))}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ps(aligned)}; +} +HWY_API Vec512 Load(Full512 /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_pd(aligned)}; +} + +template +HWY_API Vec512 LoadU(Full512 /* tag */, const T* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_si512(reinterpret_cast(p))}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ps(p)}; +} +HWY_API Vec512 LoadU(Full512 /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_pd(p)}; +} + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const T* const HWY_RESTRICT p) { + // Clang 3.9 generates VINSERTF128 which is slower, but inline assembly leads + // to "invalid output size for constraint" without -mavx512: + // https://gcc.godbolt.org/z/-Jt_-F +#if HWY_LOADDUP_ASM + __m512i out; + asm("vbroadcasti128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec512{out}; +#else + const auto x4 = LoadU(Full128(), p); + return Vec512{_mm512_broadcast_i32x4(x4.raw)}; +#endif +} +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const float* const HWY_RESTRICT p) { +#if HWY_LOADDUP_ASM + __m512 out; + asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec512{out}; +#else + const __m128 x4 = _mm_loadu_ps(p); + return Vec512{_mm512_broadcast_f32x4(x4)}; +#endif +} + +HWY_API Vec512 LoadDup128(Full512 /* tag */, + const double* const HWY_RESTRICT p) { +#if HWY_LOADDUP_ASM + __m512d out; + asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); + return Vec512{out}; +#else + const __m128d x2 = _mm_loadu_pd(p); + return Vec512{_mm512_broadcast_f64x2(x2)}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec512 v, Full512 /* tag */, + double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec512 v, Full512, + double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec512 v, Full512, + double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +namespace detail { + +template +HWY_API void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template +HWY_API void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template +HWY_API void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512 v, + Full512 /* tag */, T* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); +} +template +HWY_API void ScatterIndex(Vec512 v, Full512 d, T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); +} + +template <> +HWY_INLINE void ScatterOffset(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} +template <> +HWY_INLINE void ScatterIndex(Vec512 v, Full512 /* tag */, + float* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template <> +HWY_INLINE void ScatterOffset(Vec512 v, + Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} +template <> +HWY_INLINE void ScatterIndex(Vec512 v, + Full512 /* tag */, + double* HWY_RESTRICT base, + const Vec512 index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template +HWY_API Vec512 GatherOffset(hwy::SizeTag<4> /* tag */, Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_epi32(offset.raw, base, 1)}; +} +template +HWY_API Vec512 GatherIndex(hwy::SizeTag<4> /* tag */, Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_epi32(index.raw, base, 4)}; +} + +template +HWY_API Vec512 GatherOffset(hwy::SizeTag<8> /* tag */, Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_epi64(offset.raw, base, 1)}; +} +template +HWY_API Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, Full512 /* tag */, + const T* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_epi64(index.raw, base, 8)}; +} + +} // namespace detail + +template +HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, + const Vec512 offset) { +static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag(), d, base, offset); +} +template +HWY_API Vec512 GatherIndex(Full512 d, const T* HWY_RESTRICT base, + const Vec512 index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag(), d, base, index); +} + +template <> +HWY_INLINE Vec512 GatherOffset(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i32gather_ps(offset.raw, base, 1)}; +} +template <> +HWY_INLINE Vec512 GatherIndex(Full512 /* tag */, + const float* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i32gather_ps(index.raw, base, 4)}; +} + +template <> +HWY_INLINE Vec512 GatherOffset(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 offset) { + return Vec512{_mm512_i64gather_pd(offset.raw, base, 1)}; +} +template <> +HWY_INLINE Vec512 GatherIndex(Full512 /* tag */, + const double* HWY_RESTRICT base, + const Vec512 index) { + return Vec512{_mm512_i64gather_pd(index.raw, base, 8)}; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec512 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ Extract half + +template +HWY_API Vec256 LowerHalf(Vec512 v) { + return Vec256{_mm512_castsi512_si256(v.raw)}; +} +template <> +HWY_INLINE Vec256 LowerHalf(Vec512 v) { + return Vec256{_mm512_castps512_ps256(v.raw)}; +} +template <> +HWY_INLINE Vec256 LowerHalf(Vec512 v) { + return Vec256{_mm512_castpd512_pd256(v.raw)}; +} + +template +HWY_API Vec256 UpperHalf(Vec512 v) { + return Vec256{_mm512_extracti32x8_epi32(v.raw, 1)}; +} +template <> +HWY_INLINE Vec256 UpperHalf(Vec512 v) { + return Vec256{_mm512_extractf32x8_ps(v.raw, 1)}; +} +template <> +HWY_INLINE Vec256 UpperHalf(Vec512 v) { + return Vec256{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm512_castsi256_si512 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (EVEX encoded +// instructions have that effect), a compiler could decide to optimize out code +// that relies on this. +// +// The newer _mm512_zextsi256_si512 intrinsic fixes this by specifying the +// zeroing, but it is not available on GCC until 10.1. For older GCC, we can +// still obtain the desired code thanks to pattern recognition; note that the +// expensive insert instruction is not actually generated, see +// https://gcc.godbolt.org/z/1MKGaP. + +template +HWY_API Vec512 ZeroExtendVector(Vec256 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec512{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; +#else + return Vec512{_mm512_zextsi256_si512(lo.raw)}; +#endif +} +template <> +HWY_INLINE Vec512 ZeroExtendVector(Vec256 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec512{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; +#else + return Vec512{_mm512_zextps256_ps512(lo.raw)}; +#endif +} +template <> +HWY_INLINE Vec512 ZeroExtendVector(Vec256 lo) { +#if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) + return Vec512{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; +#else + return Vec512{_mm512_zextpd256_pd512(lo.raw)}; +#endif +} + +// ------------------------------ Combine + +template +HWY_API Vec512 Combine(Vec256 hi, Vec256 lo) { + const auto lo512 = ZeroExtendVector(lo); + return Vec512{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; +} +template <> +HWY_INLINE Vec512 Combine(Vec256 hi, Vec256 lo) { + const auto lo512 = ZeroExtendVector(lo); + return Vec512{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; +} +template <> +HWY_INLINE Vec512 Combine(Vec256 hi, Vec256 lo) { + const auto lo512 = ZeroExtendVector(lo); + return Vec512{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; +} + +// ------------------------------ Shift vector by constant #bytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API Vec512 ShiftLeftBytes(const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +template +HWY_API Vec512 ShiftLeftLanes(const Vec512 v) { + const Full512 d8; + const Full512 d; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API Vec512 ShiftRightBytes(const Vec512 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +template +HWY_API Vec512 ShiftRightLanes(const Vec512 v) { + const Full512 d8; + const Full512 d; + return BitCast(d, ShiftRightBytes(BitCast(d8, v))); +} + +// ------------------------------ Extract from 2x 128-bit at constant offset + +// Extracts 128 bits from by skipping the least-significant kBytes. +template +HWY_API Vec512 CombineShiftRightBytes(const Vec512 hi, + const Vec512 lo) { + const Full512 d8; + const Vec512 extracted_bytes{ + _mm512_alignr_epi8(BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}; + return BitCast(Full512(), extracted_bytes); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Signed +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Float +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// Swap 64-bit halves +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices512 { + __m512i raw; +}; + +template +HWY_API Indices512 SetTableIndices(const Full512, const int32_t* idx) { +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) + const size_t N = 64 / sizeof(T); + for (size_t i = 0; i < N; ++i) { + HWY_DASSERT(0 <= idx[i] && idx[i] < static_cast(N)); + } +#endif + return Indices512{LoadU(Full512(), idx).raw}; +} + +HWY_API Vec512 TableLookupLanes(const Vec512 v, + const Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} +HWY_API Vec512 TableLookupLanes(const Vec512 v, + const Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} +HWY_API Vec512 TableLookupLanes(const Vec512 v, + const Indices512 idx) { + return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +// ------------------------------ Interleave lanes + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +// ------------------------------ Zip lanes + +// Same as interleave_*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. + +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} + +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 ZipLower(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} + +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} + +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 ZipUpper(const Vec512 a, + const Vec512 b) { + return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API Vec512 ConcatLowerLower(const Vec512 hi, const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +template <> +HWY_INLINE Vec512 ConcatLowerLower(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +template <> +HWY_INLINE Vec512 ConcatLowerLower(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API Vec512 ConcatUpperUpper(const Vec512 hi, const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +template <> +HWY_INLINE Vec512 ConcatUpperUpper(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +template <> +HWY_INLINE Vec512 ConcatUpperUpper(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API Vec512 ConcatLowerUpper(const Vec512 hi, const Vec512 lo) { + return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, 0x4E)}; +} +template <> +HWY_INLINE Vec512 ConcatLowerUpper(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, 0x4E)}; +} +template <> +HWY_INLINE Vec512 ConcatLowerUpper(const Vec512 hi, + const Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, 0x4E)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API Vec512 ConcatUpperLower(const Vec512 hi, const Vec512 lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + return Vec512{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; +} +template <> +HWY_INLINE Vec512 ConcatUpperLower(const Vec512 hi, + const Vec512 lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return Vec512{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +template <> +HWY_INLINE Vec512 ConcatUpperLower(const Vec512 hi, + const Vec512 lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ Odd/even lanes + +template +HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); +} + +// ------------------------------ Shuffle bytes with variable indices + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec512 TableLookupBytes(const Vec512 bytes, + const Vec512 from) { + return Vec512{_mm512_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +HWY_API Vec512 PromoteTo(Full512 /* tag */, + const Vec256 v) { + return Vec512{_mm512_cvtph_ps(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtps_pd(v.raw)}; +} + +HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { + return Vec512{_mm512_cvtepi32_pd(v.raw)}; +} + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec128 v) { + return Vec512{_mm512_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec512 PromoteTo(Full512 /* tag */, + Vec256 v) { + return Vec512{_mm512_cvtepi32_epi64(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const Vec512 i16{ + _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; + const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128 DemoteTo(Full128 /* tag */, + const Vec512 v) { + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, + 0, 4, 8, 12, 0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512(), kLanes); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512(), kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + return Vec256{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + return Vec256{_mm512_cvtpd_ps(v.raw)}; +} + +HWY_API Vec256 DemoteTo(Full256 /* tag */, + const Vec512 v) { + const auto clamped = detail::ClampF64ToI32Max(Full512(), v); + return Vec256{_mm512_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec512 v) { + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Convert integer <=> floating point + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec512 ConvertTo(Full512 /* tag */, + const Vec512 v) { + return Vec512{_mm512_cvtepi64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +HWY_API Vec512 ConvertTo(Full512 d, const Vec512 v) { + return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); +} +HWY_API Vec512 ConvertTo(Full512 di, const Vec512 v) { + return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); +} + +HWY_API Vec512 NearestInt(const Vec512 v) { + const Full512 di; + return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); +} + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template +Vec512 Iota(const Full512 d, const T2 first) { + HWY_ALIGN T lanes[64 / sizeof(T)]; + for (size_t i = 0; i < 64 / sizeof(T); ++i) { + lanes[i] = static_cast(first + static_cast(i)); + } + return Load(d, lanes); +} + +// ------------------------------ Mask + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_API bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(v.raw, v.raw); +#else + return v.raw == 0; +#endif +} +template +HWY_API bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(v.raw, v.raw); +#else + return v.raw == 0; +#endif +} +template +HWY_API bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(v.raw, v.raw); +#else + return v.raw == 0; +#endif +} +template +HWY_API bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(v.raw, v.raw); +#else + return v.raw == 0; +#endif +} + +} // namespace detail + +template +HWY_API bool AllFalse(const Mask512 v) { + return detail::AllFalse(hwy::SizeTag(), v); +} + +namespace detail { + +template +HWY_API bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(v.raw, v.raw); +#else + return v.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template +HWY_API bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(v.raw, v.raw); +#else + return v.raw == 0xFFFFFFFFull; +#endif +} +template +HWY_API bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(v.raw, v.raw); +#else + return v.raw == 0xFFFFull; +#endif +} +template +HWY_API bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 v) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(v.raw, v.raw); +#else + return v.raw == 0xFFull; +#endif +} + +} // namespace detail + +template +HWY_API bool AllTrue(const Mask512 v) { + return detail::AllTrue(hwy::SizeTag(), v); +} + +template +HWY_INLINE size_t StoreMaskBits(const Mask512 mask, uint8_t* p) { + const size_t kNumBytes = 8 / sizeof(T); + CopyBytes(&mask.raw, p); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(const Mask512 mask) { + return PopCount(mask.raw); +} + +// ------------------------------ Compress + +HWY_API Vec512 Compress(Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_API Vec512 Compress(Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec512 Compress(Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi64(mask.raw, v.raw)}; +} +HWY_API Vec512 Compress(Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi64(mask.raw, v.raw)}; +} + +HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { + return Vec512{_mm512_maskz_compress_ps(mask.raw, v.raw)}; +} + +HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { + return Vec512{_mm512_maskz_compress_pd(mask.raw, v.raw)}; +} + +namespace detail { + +// Ignore IDE redefinition error for these two functions: if this header is +// included, then the functions weren't actually defined in x86_256-inl.h. +template +HWY_API Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, + const uint64_t mask_bits) { + using D = Full256; + const Rebind du; + const Rebind dw; // 512-bit, not 256! + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Mask512 mask{static_cast<__mmask16>(mask_bits)}; + return BitCast(D(), DemoteTo(du, Compress(PromoteTo(dw, vu16), mask))); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { + return detail::Compress(hwy::SizeTag(), v, + detail::BitsFromMask(mask)); +} + +// Expands to 32-bit, compresses, concatenate demoted halves. +template +HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { + using D = Full512; + const Rebind du; + const Repartition dw; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const auto promoted0 = PromoteTo(dw, LowerHalf(vu16)); + const auto promoted1 = PromoteTo(dw, UpperHalf(vu16)); + + const Mask512 mask0{static_cast<__mmask16>(mask.raw & 0xFFFF)}; + const Mask512 mask1{static_cast<__mmask16>(mask.raw >> 16)}; + const auto compressed0 = Compress(promoted0, mask0); + const auto compressed1 = Compress(promoted1, mask1); + + const Half dh; + const auto demoted0 = ZeroExtendVector(DemoteTo(dh, compressed0)); + const auto demoted1 = ZeroExtendVector(DemoteTo(dh, compressed1)); + + // Concatenate into single vector by shifting upper with writemask. + const size_t num0 = CountTrue(mask0); + const __mmask32 m_upper = ~((1u << num0) - 1); + alignas(64) uint16_t iota[64] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + const auto idx = LoadU(du, iota + 32 - num0); + return Vec512{_mm512_mask_permutexvar_epi16(demoted0.raw, m_upper, idx.raw, + demoted1.raw)}; +} + +// ------------------------------ CompressStore + +template +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, Full256 d, + T* HWY_RESTRICT aligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + Store(detail::Compress(hwy::SizeTag(), v, mask_bits), d, aligned); + return PopCount(mask_bits); +} + +template +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, Full512 d, + T* HWY_RESTRICT aligned) { + // NOTE: it is tempting to split inputs into two halves for 16-bit lanes, but + // using StoreU to concatenate the results would cause page faults if + // `aligned` is the last valid vector. Instead rely on in-register splicing. + Store(Compress(v, mask), d, aligned); + return CountTrue(mask); +} + +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + uint32_t* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_epi32(aligned, mask.raw, v.raw); + return CountTrue(mask); +} +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + int32_t* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_epi32(aligned, mask.raw, v.raw); + return CountTrue(mask); +} + +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + uint64_t* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_epi64(aligned, mask.raw, v.raw); + return CountTrue(mask); +} +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + int64_t* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_epi64(aligned, mask.raw, v.raw); + return CountTrue(mask); +} + +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_ps(aligned, mask.raw, v.raw); + return CountTrue(mask); +} + +HWY_API size_t CompressStore(Vec512 v, const Mask512 mask, + Full512 /* tag */, + double* HWY_RESTRICT aligned) { + _mm512_mask_compressstoreu_pd(aligned, mask.raw, v.raw); + return CountTrue(mask); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +HWY_API void StoreInterleaved3(const Vec512 a, const Vec512 b, + const Vec512 c, Full512 d, + uint8_t* HWY_RESTRICT unaligned) { + const auto k5 = Set(d, 5); + const auto k6 = Set(d, 6); + + // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. + // 0x80 so lanes to be filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_r0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_g0[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + const auto shuf_r0 = LoadDup128(d, tbl_r0); + const auto shuf_g0 = LoadDup128(d, tbl_g0); // cannot reuse r0 due to 5 + const auto shuf_b0 = CombineShiftRightBytes<15>(shuf_g0, shuf_g0); + const auto r0 = TableLookupBytes(a, shuf_r0); // 5..4..3..2..1..0 + const auto g0 = TableLookupBytes(b, shuf_g0); // ..4..3..2..1..0. + const auto b0 = TableLookupBytes(c, shuf_b0); // .4..3..2..1..0.. + const auto i = (r0 | g0 | b0).raw; // low byte in each 128bit: 30 20 10 00 + + // Second vector: g10,r10, bgr[9:6], b5,g5 + const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. + const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 + const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. + const auto r1 = TableLookupBytes(a, shuf_r1); + const auto g1 = TableLookupBytes(b, shuf_g1); + const auto b1 = TableLookupBytes(c, shuf_b1); + const auto j = (r1 | g1 | b1).raw; // low byte in each 128bit: 35 25 15 05 + + // Third vector: bgr[15:11], b10 + const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. + const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. + const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A + const auto r2 = TableLookupBytes(a, shuf_r2); + const auto g2 = TableLookupBytes(b, shuf_g2); + const auto b2 = TableLookupBytes(c, shuf_b2); + const auto k = (r2 | g2 | b2).raw; // low byte in each 128bit: 3A 2A 1A 0A + + // To obtain 10 0A 05 00 in one vector, transpose "rows" into "columns". + const auto k3_k0_i3_i0 = _mm512_shuffle_i64x2(i, k, _MM_SHUFFLE(3, 0, 3, 0)); + const auto i1_i2_j0_j1 = _mm512_shuffle_i64x2(j, i, _MM_SHUFFLE(1, 2, 0, 1)); + const auto j2_j3_k1_k2 = _mm512_shuffle_i64x2(k, j, _MM_SHUFFLE(2, 3, 1, 2)); + + // Alternating order, most-significant 128 bits from the second arg. + const __mmask8 m = 0xCC; + const auto i1_k0_j0_i0 = _mm512_mask_blend_epi64(m, k3_k0_i3_i0, i1_i2_j0_j1); + const auto j2_i2_k1_j1 = _mm512_mask_blend_epi64(m, i1_i2_j0_j1, j2_j3_k1_k2); + const auto k3_j3_i3_k2 = _mm512_mask_blend_epi64(m, j2_j3_k1_k2, k3_k0_i3_i0); + + StoreU(Vec512{i1_k0_j0_i0}, d, unaligned + 0 * 64); // 10 0A 05 00 + StoreU(Vec512{j2_i2_k1_j1}, d, unaligned + 1 * 64); // 25 20 1A 15 + StoreU(Vec512{k3_j3_i3_k2}, d, unaligned + 2 * 64); // 3A 35 30 2A +} + +// ------------------------------ StoreInterleaved4 + +HWY_API void StoreInterleaved4(const Vec512 v0, + const Vec512 v1, + const Vec512 v2, + const Vec512 v3, Full512 d, + uint8_t* HWY_RESTRICT unaligned) { + // let a,b,c,d denote v0..3. + const auto ba0 = ZipLower(v0, v1); // b7 a7 .. b0 a0 + const auto dc0 = ZipLower(v2, v3); // d7 c7 .. d0 c0 + const auto ba8 = ZipUpper(v0, v1); + const auto dc8 = ZipUpper(v2, v3); + const auto i = ZipLower(ba0, dc0).raw; // 4x128bit: d..a3 d..a0 + const auto j = ZipUpper(ba0, dc0).raw; // 4x128bit: d..a7 d..a4 + const auto k = ZipLower(ba8, dc8).raw; // 4x128bit: d..aB d..a8 + const auto l = ZipUpper(ba8, dc8).raw; // 4x128bit: d..aF d..aC + // 128-bit blocks were independent until now; transpose 4x4. + const auto j1_j0_i1_i0 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(1, 0, 1, 0)); + const auto l1_l0_k1_k0 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(1, 0, 1, 0)); + const auto j3_j2_i3_i2 = _mm512_shuffle_i64x2(i, j, _MM_SHUFFLE(3, 2, 3, 2)); + const auto l3_l2_k3_k2 = _mm512_shuffle_i64x2(k, l, _MM_SHUFFLE(3, 2, 3, 2)); + constexpr int k20 = _MM_SHUFFLE(2, 0, 2, 0); + constexpr int k31 = _MM_SHUFFLE(3, 1, 3, 1); + const auto l0_k0_j0_i0 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k20); + const auto l1_k1_j1_i1 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k31); + const auto l2_k2_j2_i2 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k20); + const auto l3_k3_j3_i3 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k31); + StoreU(Vec512{l0_k0_j0_i0}, d, unaligned + 0 * 64); + StoreU(Vec512{l1_k1_j1_i1}, d, unaligned + 1 * 64); + StoreU(Vec512{l2_k2_j2_i2}, d, unaligned + 2 * 64); + StoreU(Vec512{l3_k3_j3_i3}, d, unaligned + 3 * 64); +} + +// ------------------------------ Reductions + +// Returns the sum in each lane. +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_add_epi32(v.raw)); +} +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_add_epi64(v.raw)); +} +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return BitCast(Full512(), + SumOfLanes(BitCast(Full512(), v))); +} +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return BitCast(Full512(), + SumOfLanes(BitCast(Full512(), v))); +} +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_add_ps(v.raw)); +} +HWY_API Vec512 SumOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_add_pd(v.raw)); +} + +// Returns the minimum in each lane. +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_epi32(v.raw)); +} +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_epi64(v.raw)); +} +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_epu32(v.raw)); +} +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_epu64(v.raw)); +} +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_ps(v.raw)); +} +HWY_API Vec512 MinOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_min_pd(v.raw)); +} + +// Returns the maximum in each lane. +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_epi32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_epi64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_epu32(v.raw)); +} +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_epu64(v.raw)); +} +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_ps(v.raw)); +} +HWY_API Vec512 MaxOfLanes(const Vec512 v) { + return Set(Full512(), _mm512_reduce_max_pd(v.raw)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/targets.cc b/third_party/highway/hwy/targets.cc new file mode 100644 index 000000000000..287c49732d41 --- /dev/null +++ b/third_party/highway/hwy/targets.cc @@ -0,0 +1,286 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#include +#include +#include + +#include +#include + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +#if HWY_ARCH_X86 +#include +#if HWY_COMPILER_MSVC +#include +#else // HWY_COMPILER_MSVC +#include +#endif // HWY_COMPILER_MSVC +#endif + +namespace hwy { +namespace { + +#if HWY_ARCH_X86 + +bool IsBitSet(const uint32_t reg, const int index) { + return (reg & (1U << index)) != 0; +} + +// Calls CPUID instruction with eax=level and ecx=count and returns the result +// in abcd array where abcd = {eax, ebx, ecx, edx} (hence the name abcd). +void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else // HWY_COMPILER_MSVC + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif // HWY_COMPILER_MSVC +} + +// Returns the lower 32 bits of extended control register 0. +// Requires CPU support for "OSXSAVE" (see below). +uint32_t ReadXCR0() { +#if HWY_COMPILER_MSVC + return static_cast(_xgetbv(0)); +#else // HWY_COMPILER_MSVC + uint32_t xcr0, xcr0_high; + const uint32_t index = 0; + asm volatile(".byte 0x0F, 0x01, 0xD0" + : "=a"(xcr0), "=d"(xcr0_high) + : "c"(index)); + return xcr0; +#endif // HWY_COMPILER_MSVC +} + +#endif // HWY_ARCH_X86 + +// Not function-local => no compiler-generated locking. +std::atomic supported_{0}; // Not yet initialized + +// When running tests, this value can be set to the mocked supported targets +// mask. Only written to from a single thread before the test starts. +uint32_t supported_targets_for_test_ = 0; + +// Mask of targets disabled at runtime with DisableTargets. +uint32_t supported_mask_{std::numeric_limits::max()}; + +#if HWY_ARCH_X86 +// Bits indicating which instruction set extensions are supported. +constexpr uint32_t kSSE = 1 << 0; +constexpr uint32_t kSSE2 = 1 << 1; +constexpr uint32_t kSSE3 = 1 << 2; +constexpr uint32_t kSSSE3 = 1 << 3; +constexpr uint32_t kSSE41 = 1 << 4; +constexpr uint32_t kSSE42 = 1 << 5; +constexpr uint32_t kGroupSSE4 = kSSE | kSSE2 | kSSE3 | kSSSE3 | kSSE41 | kSSE42; + +constexpr uint32_t kAVX = 1u << 6; +constexpr uint32_t kAVX2 = 1u << 7; +constexpr uint32_t kFMA = 1u << 8; +constexpr uint32_t kLZCNT = 1u << 9; +constexpr uint32_t kBMI = 1u << 10; +constexpr uint32_t kBMI2 = 1u << 11; + +// We normally assume BMI/BMI2/FMA are available if AVX2 is. This allows us to +// use BZHI and (compiler-generated) MULX. However, VirtualBox lacks them +// [https://www.virtualbox.org/ticket/15471]. Thus we provide the option of +// avoiding using and requiring these so AVX2 can still be used. +#ifdef HWY_DISABLE_BMI2_FMA +constexpr uint32_t kGroupAVX2 = kAVX | kAVX2 | kLZCNT; +#else +constexpr uint32_t kGroupAVX2 = kAVX | kAVX2 | kFMA | kLZCNT | kBMI | kBMI2; +#endif + +constexpr uint32_t kAVX512F = 1u << 12; +constexpr uint32_t kAVX512VL = 1u << 13; +constexpr uint32_t kAVX512DQ = 1u << 14; +constexpr uint32_t kAVX512BW = 1u << 15; +constexpr uint32_t kGroupAVX3 = kAVX512F | kAVX512VL | kAVX512DQ | kAVX512BW; +#endif + +} // namespace + +HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[2000]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // defined(*_SANITIZER) + fflush(stderr); + +#if HWY_COMPILER_MSVC + abort(); // Compile error without this due to HWY_NORETURN. +#else + __builtin_trap(); +#endif +} + +void DisableTargets(uint32_t disabled_targets) { + supported_mask_ = ~(disabled_targets & ~uint32_t(HWY_ENABLED_BASELINE)); + // We can call Update() here to initialize the mask but that will trigger a + // call to SupportedTargets() which we use in tests to tell whether any of the + // highway dynamic dispatch functions were used. + chosen_target.DeInit(); +} + +void SetSupportedTargetsForTest(uint32_t targets) { + // Reset the cached supported_ value to 0 to force a re-evaluation in the + // next call to SupportedTargets() which will use the mocked value set here + // if not zero. + supported_.store(0, std::memory_order_release); + supported_targets_for_test_ = targets; + chosen_target.DeInit(); +} + +bool SupportedTargetsCalledForTest() { + return supported_.load(std::memory_order_acquire) != 0; +} + +uint32_t SupportedTargets() { + uint32_t bits = supported_.load(std::memory_order_acquire); + // Already initialized? + if (HWY_LIKELY(bits != 0)) { + return bits & supported_mask_; + } + + // When running tests, this allows to mock the current supported targets. + if (HWY_UNLIKELY(supported_targets_for_test_ != 0)) { + // Store the value to signal that this was used. + supported_.store(supported_targets_for_test_, std::memory_order_release); + return supported_targets_for_test_ & supported_mask_; + } + + bits = HWY_SCALAR; + +#if HWY_ARCH_X86 + uint32_t flags = 0; + uint32_t abcd[4]; + + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + + // Standard feature flags + Cpuid(1, 0, abcd); + flags |= IsBitSet(abcd[3], 25) ? kSSE : 0; + flags |= IsBitSet(abcd[3], 26) ? kSSE2 : 0; + flags |= IsBitSet(abcd[2], 0) ? kSSE3 : 0; + flags |= IsBitSet(abcd[2], 9) ? kSSSE3 : 0; + flags |= IsBitSet(abcd[2], 19) ? kSSE41 : 0; + flags |= IsBitSet(abcd[2], 20) ? kSSE42 : 0; + flags |= IsBitSet(abcd[2], 12) ? kFMA : 0; + flags |= IsBitSet(abcd[2], 28) ? kAVX : 0; + const bool has_osxsave = IsBitSet(abcd[2], 27); + + // Extended feature flags + Cpuid(0x80000001U, 0, abcd); + flags |= IsBitSet(abcd[2], 5) ? kLZCNT : 0; + + // Extended features + if (max_level >= 7) { + Cpuid(7, 0, abcd); + flags |= IsBitSet(abcd[1], 3) ? kBMI : 0; + flags |= IsBitSet(abcd[1], 5) ? kAVX2 : 0; + flags |= IsBitSet(abcd[1], 8) ? kBMI2 : 0; + + flags |= IsBitSet(abcd[1], 16) ? kAVX512F : 0; + flags |= IsBitSet(abcd[1], 17) ? kAVX512DQ : 0; + flags |= IsBitSet(abcd[1], 30) ? kAVX512BW : 0; + flags |= IsBitSet(abcd[1], 31) ? kAVX512VL : 0; + } + + // Verify OS support for XSAVE, without which XMM/YMM registers are not + // preserved across context switches and are not safe to use. + if (has_osxsave) { + const uint32_t xcr0 = ReadXCR0(); + // XMM + if (!IsBitSet(xcr0, 1)) { + flags = 0; + } + // YMM + if (!IsBitSet(xcr0, 2)) { + flags &= ~kGroupAVX2; + } + // ZMM + opmask + if ((xcr0 & 0x70) != 0x70) { + flags &= ~kGroupAVX3; + } + } + + // Set target bit(s) if all their group's flags are all set. + if ((flags & kGroupAVX3) == kGroupAVX3) { + bits |= HWY_AVX3; + } + if ((flags & kGroupAVX2) == kGroupAVX2) { + bits |= HWY_AVX2; + } + if ((flags & kGroupSSE4) == kGroupSSE4) { + bits |= HWY_SSE4; + } +#else + // TODO(janwas): detect for other platforms + bits = HWY_ENABLED_BASELINE; +#endif // HWY_ARCH_X86 + + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + fprintf(stderr, "WARNING: CPU supports %zx but software requires %x\n", + size_t(bits), HWY_ENABLED_BASELINE); + } + + supported_.store(bits, std::memory_order_release); + return bits & supported_mask_; +} + +// Declared in targets.h +ChosenTarget chosen_target; + +void ChosenTarget::Update() { + // The supported variable contains the current CPU supported targets shifted + // to the location expected by the ChosenTarget mask. We enabled SCALAR + // regardless of whether it was compiled since it is also used as the + // fallback mechanism to the baseline target. + uint32_t supported = HWY_CHOSEN_TARGET_SHIFT(hwy::SupportedTargets()) | + HWY_CHOSEN_TARGET_MASK_SCALAR; + mask_.store(supported); +} + +} // namespace hwy diff --git a/third_party/highway/hwy/targets.h b/third_party/highway/hwy/targets.h new file mode 100644 index 000000000000..bf2665aead68 --- /dev/null +++ b/third_party/highway/hwy/targets.h @@ -0,0 +1,491 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TARGETS_H_ +#define HIGHWAY_HWY_TARGETS_H_ + +#include + +// For SIMD module implementations and their callers. Defines which targets to +// generate and call. + +#include "hwy/base.h" + +//------------------------------------------------------------------------------ +// Optional configuration + +// See ../quick_reference.md for documentation of these macros. + +// Uncomment to override the default baseline determined from predefined macros: +// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR) + +// Uncomment to override the default blocklist: +// #define HWY_BROKEN_TARGETS HWY_AVX3 + +// Uncomment to definitely avoid generating those target(s): +// #define HWY_DISABLED_TARGETS HWY_SSE4 + +// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating +// AVX2 target for VMs which support AVX2 but not the other instruction sets) +// #define HWY_DISABLE_BMI2_FMA + +//------------------------------------------------------------------------------ +// Targets + +// Unique bit value for each target. A lower value is "better" (e.g. more lanes) +// than a higher value within the same group/platform - see HWY_STATIC_TARGET. +// +// All values are unconditionally defined so we can test HWY_TARGETS without +// first checking the HWY_ARCH_*. +// +// The C99 preprocessor evaluates #if expressions using intmax_t types, so we +// can use 32-bit literals. + +// 1,2,4: reserved +#define HWY_AVX3 8 +#define HWY_AVX2 16 +// 32: reserved for AVX +#define HWY_SSE4 64 +// 0x80, 0x100, 0x200: reserved for SSSE3, SSE3, SSE2 + +// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for +// dynamic dispatch. All x86 target bits must be lower or equal to +// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use +// HWY_MAX_DYNAMIC_TARGETS in total. +#define HWY_HIGHEST_TARGET_BIT_X86 9 + +// 0x400, 0x800, 0x1000 reserved for SVE, SVE2, Helium +#define HWY_NEON 0x2000 + +#define HWY_HIGHEST_TARGET_BIT_ARM 13 + +// 0x4000, 0x8000 reserved +#define HWY_PPC8 0x10000 // v2.07 or 3 +// 0x20000, 0x40000 reserved for prior VSX/AltiVec + +#define HWY_HIGHEST_TARGET_BIT_PPC 18 + +// 0x80000 reserved +#define HWY_WASM 0x100000 + +#define HWY_HIGHEST_TARGET_BIT_WASM 20 + +// 0x200000, 0x400000, 0x800000 reserved + +#define HWY_RVV 0x1000000 + +#define HWY_HIGHEST_TARGET_BIT_RVV 24 + +// 0x2000000, 0x4000000, 0x8000000, 0x10000000 reserved + +#define HWY_SCALAR 0x20000000 +// Cannot use higher values, otherwise HWY_TARGETS computation might overflow. + +//------------------------------------------------------------------------------ +// Set default blocklists + +// Disabled means excluded from enabled at user's request. A separate config +// macro allows disabling without deactivating the blocklist below. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS 0 +#endif + +// Broken means excluded from enabled due to known compiler issues. Allow the +// user to override this blocklist without any guarantee of success. +#ifndef HWY_BROKEN_TARGETS + +// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid +// SSE4 codegen (msan failure), so disable all those targets. +#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +// TODO: Disable all non-scalar targets for every build target once we have +// clang-7 enabled in our builders. +#ifdef MEMORY_SANITIZER +#define HWY_BROKEN_TARGETS (HWY_SSE4 | HWY_AVX2 | HWY_AVX3) +#else +#define HWY_BROKEN_TARGETS 0 +#endif +// This entails a major speed reduction, so warn unless the user explicitly +// opts in to scalar-only. +#if !defined(HWY_COMPILE_ONLY_SCALAR) +#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.") +#endif + +// MSVC, or 32-bit may fail to compile AVX2/3. +#elif HWY_COMPILER_MSVC != 0 || HWY_ARCH_X86_32 +#define HWY_BROKEN_TARGETS (HWY_AVX2 | HWY_AVX3) +#pragma message("Disabling AVX2/3 due to known issues with MSVC/32-bit builds") + +#else +#define HWY_BROKEN_TARGETS 0 +#endif + +#endif // HWY_BROKEN_TARGETS + +// Enabled means not disabled nor blocklisted. +#define HWY_ENABLED(targets) \ + ((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS))) + +//------------------------------------------------------------------------------ +// Detect baseline targets using predefined macros + +// Baseline means the targets for which the compiler is allowed to generate +// instructions, implying the target CPU would have to support them. Do not use +// this directly because it does not take the blocklist into account. Allow the +// user to override this without any guarantee of success. +#ifndef HWY_BASELINE_TARGETS + +#ifdef __wasm_simd128__ +#define HWY_BASELINE_WASM HWY_WASM +#else +#define HWY_BASELINE_WASM 0 +#endif + +#ifdef __VSX__ +#define HWY_BASELINE_PPC8 HWY_PPC8 +#else +#define HWY_BASELINE_PPC8 0 +#endif + +// GCC 4.5.4 only defines the former; 5.4 defines both. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define HWY_BASELINE_NEON HWY_NEON +#else +#define HWY_BASELINE_NEON 0 +#endif + +#ifdef __SSE4_1__ +#define HWY_BASELINE_SSE4 HWY_SSE4 +#else +#define HWY_BASELINE_SSE4 0 +#endif + +#ifdef __AVX2__ +#define HWY_BASELINE_AVX2 HWY_AVX2 +#else +#define HWY_BASELINE_AVX2 0 +#endif + +#ifdef __AVX512F__ +#define HWY_BASELINE_AVX3 HWY_AVX3 +#else +#define HWY_BASELINE_AVX3 0 +#endif + +#ifdef __riscv_vector +#define HWY_BASELINE_RVV HWY_RVV +#else +#define HWY_BASELINE_RVV 0 +#endif + +#define HWY_BASELINE_TARGETS \ + (HWY_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | HWY_BASELINE_NEON | \ + HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \ + HWY_BASELINE_RVV) + +#endif // HWY_BASELINE_TARGETS + +//------------------------------------------------------------------------------ +// Choose target for static dispatch + +#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS) +#if HWY_ENABLED_BASELINE == 0 +#error "At least one baseline target must be defined and enabled" +#endif + +// Best baseline, used for static dispatch. This is the least-significant 1-bit +// within HWY_ENABLED_BASELINE and lower bit values imply "better". +#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE) + +// Start by assuming static dispatch. If we later use dynamic dispatch, this +// will be defined to other targets during the multiple-inclusion, and finally +// return to the initial value. Defining this outside begin/end_target ensures +// inl headers successfully compile by themselves (required by Bazel). +#define HWY_TARGET HWY_STATIC_TARGET + +//------------------------------------------------------------------------------ +// Choose targets for dynamic dispatch according to one of four policies + +#if (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_STATIC) + \ + defined(HWY_COMPILE_ALL_ATTAINABLE)) > 1 +#error "Invalid config: can only define a single policy for targets" +#endif + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to autovectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_SCALAR | HWY_SSE4 | HWY_AVX2 | HWY_AVX3) +#else +#define HWY_ATTAINABLE_TARGETS HWY_ENABLED_BASELINE +#endif + +// 1) For older compilers: disable all SIMD (could also set HWY_DISABLED_TARGETS +// to ~HWY_SCALAR, but this is more explicit). +#if defined(HWY_COMPILE_ONLY_SCALAR) +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_SCALAR // override baseline +#define HWY_TARGETS HWY_SCALAR + +// 2) For forcing static dispatch without code changes (removing HWY_EXPORT) +#elif defined(HWY_COMPILE_ONLY_STATIC) +#define HWY_TARGETS HWY_STATIC_TARGET + +// 3) For tests: include all attainable targets (in particular: scalar) +#elif defined(HWY_COMPILE_ALL_ATTAINABLE) +#define HWY_TARGETS HWY_ATTAINABLE_TARGETS + +// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by +// excluding superseded targets, in particular scalar. +#else + +#define HWY_TARGETS (HWY_ATTAINABLE_TARGETS & (2 * HWY_STATIC_TARGET - 1)) + +#endif // target policy + +// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being +// one of the dynamic targets. This also implies HWY_TARGETS != 0 and +// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0. +#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 +#error "Logic error: best baseline should be included in dynamic targets" +#endif + +//------------------------------------------------------------------------------ + +namespace hwy { + +// Returns (cached) bitfield of enabled targets that are supported on this CPU. +// Implemented in supported_targets.cc; unconditionally compiled to support the +// use case of binary-only distributions. The HWY_SUPPORTED_TARGETS wrapper may +// allow eliding calls to this function. +uint32_t SupportedTargets(); + +// Disable from runtime dispatch the mask of compiled in targets. Targets that +// were not enabled at compile time are ignored. This function is useful to +// disable a target supported by the CPU that is known to have bugs or when a +// lower target is desired. For this reason, attempts to disable targets which +// are in HWY_ENABLED_BASELINE have no effect so SupportedTargets() always +// returns at least the baseline target. +void DisableTargets(uint32_t disabled_targets); + +// Single target: reduce code size by eliding the call and conditional branches +// inside Choose*() functions. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 +#define HWY_SUPPORTED_TARGETS HWY_TARGETS +#else +#define HWY_SUPPORTED_TARGETS hwy::SupportedTargets() +#endif + +// Set the mock mask of CPU supported targets instead of the actual CPU +// supported targets computed in SupportedTargets(). The return value of +// SupportedTargets() will still be affected by the DisabledTargets() mask +// regardless of this mock, to prevent accidentally adding targets that are +// known to be buggy in the current CPU. Call with a mask of 0 to disable the +// mock and use the actual CPU supported targets instead. +void SetSupportedTargetsForTest(uint32_t targets); + +// Returns whether the SupportedTargets() function was called since the last +// SetSupportedTargetsForTest() call. +bool SupportedTargetsCalledForTest(); + +// Return the list of targets in HWY_TARGETS supported by the CPU as a list of +// individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list +// is affected by the current SetSupportedTargetsForTest() mock if any. +HWY_INLINE std::vector SupportedAndGeneratedTargets() { + std::vector ret; + for (uint32_t targets = SupportedTargets() & HWY_TARGETS; targets != 0; + targets = targets & (targets - 1)) { + uint32_t current_target = targets & ~(targets - 1); + ret.push_back(current_target); + } + return ret; +} + +static inline HWY_MAYBE_UNUSED const char* TargetName(uint32_t target) { + switch (target) { +#if HWY_ARCH_X86 + case HWY_SSE4: + return "SSE4"; + case HWY_AVX2: + return "AVX2"; + case HWY_AVX3: + return "AVX3"; +#endif + +#if HWY_ARCH_ARM + case HWY_NEON: + return "Neon"; +#endif + +#if HWY_ARCH_PPC + case HWY_PPC8: + return "Power8"; +#endif + +#if HWY_ARCH_WASM + case HWY_WASM: + return "Wasm"; +#endif + +#if HWY_ARCH_RVV + case HWY_RVV: + return "RVV"; +#endif + + case HWY_SCALAR: + return "Scalar"; + + default: + return "?"; + } +} + +// The maximum number of dynamic targets on any architecture is defined by +// HWY_MAX_DYNAMIC_TARGETS and depends on the arch. + +// For the ChosenTarget mask and index we use a different bit arrangement than +// in the HWY_TARGETS mask. Only the targets involved in the current +// architecture are used in this mask, and therefore only the least significant +// (HWY_MAX_DYNAMIC_TARGETS + 2) bits of the uint32_t mask are used. The least +// significant bit is set when the mask is not initialized, the next +// HWY_MAX_DYNAMIC_TARGETS more significant bits are a range of bits from the +// HWY_TARGETS or SupportedTargets() mask for the given architecture shifted to +// that position and the next more significant bit is used for the scalar +// target. Because of this we need to define equivalent values for HWY_TARGETS +// in this representation. +// This mask representation allows to use ctz() on this mask and obtain a small +// number that's used as an index of the table for dynamic dispatch. In this +// way the first entry is used when the mask is uninitialized, the following +// HWY_MAX_DYNAMIC_TARGETS are for dynamic dispatch and the last one is for +// scalar. + +// The HWY_SCALAR bit in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_SCALAR (1u << (HWY_MAX_DYNAMIC_TARGETS + 1)) + +// Converts from a HWY_TARGETS mask to a ChosenTarget mask format for the +// current architecture. +#define HWY_CHOSEN_TARGET_SHIFT(X) \ + ((((X) >> (HWY_HIGHEST_TARGET_BIT + 1 - HWY_MAX_DYNAMIC_TARGETS)) & \ + ((1u << HWY_MAX_DYNAMIC_TARGETS) - 1)) \ + << 1) + +// The HWY_TARGETS mask in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_TARGETS \ + (HWY_CHOSEN_TARGET_SHIFT(HWY_TARGETS) | HWY_CHOSEN_TARGET_MASK_SCALAR | 1u) + +#if HWY_ARCH_X86 +// Maximum number of dynamic targets, changing this value is an ABI incompatible +// change +#define HWY_MAX_DYNAMIC_TARGETS 10 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_X86 +// These must match the order in which the HWY_TARGETS are defined +// starting by the least significant (HWY_HIGHEST_TARGET_BIT + 1 - +// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly +// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry +// corresponds to the best target. Don't include a "," at the end of the list. +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \ + HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \ + nullptr, /* AVX */ \ + HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \ + nullptr, /* SSSE3 */ \ + nullptr, /* SSE3 */ \ + nullptr /* SSE2 */ + +#endif // HWY_ARCH_X86 + +#if HWY_ARCH_ARM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 4 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON(func_name) /* NEON */ + +#endif // HWY_ARCH_ARM + +#if HWY_ARCH_PPC +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 5 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + nullptr, /* VSX */ \ + nullptr /* AltiVec */ + +#endif // HWY_ARCH_PPC + +#if HWY_ARCH_WASM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 4 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_WASM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_WASM(func_name) /* WASM */ + +#endif // HWY_ARCH_WASM + +#if HWY_ARCH_RVV +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 4 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_RVV(func_name) /* RVV */ + +#endif // HWY_ARCH_RVV + +struct ChosenTarget { + public: + // Update the ChosenTarget mask based on the current CPU supported + // targets. + void Update(); + + // Reset the ChosenTarget to the uninitialized state. + void DeInit() { mask_.store(1); } + + // Whether the ChosenTarget was initialized. This is useful to know whether + // any HWY_DYNAMIC_DISPATCH function was called. + bool IsInitialized() const { return mask_.load() != 1; } + + // Return the index in the dynamic dispatch table to be used by the current + // CPU. Note that this method must be in the header file so it uses the value + // of HWY_CHOSEN_TARGET_MASK_TARGETS defined in the translation unit that + // calls it, which may be different from others. This allows to only consider + // those targets that were actually compiled in this module. + size_t HWY_INLINE GetIndex() const { + return hwy::Num0BitsBelowLS1Bit_Nonzero32(mask_.load() & + HWY_CHOSEN_TARGET_MASK_TARGETS); + } + + private: + // Initialized to 1 so GetChosenTargetIndex() returns 0. + std::atomic mask_{1}; +}; + +extern ChosenTarget chosen_target; + +} // namespace hwy + +#endif // HIGHWAY_HWY_TARGETS_H_ diff --git a/third_party/highway/hwy/targets_test.cc b/third_party/highway/hwy/targets_test.cc new file mode 100644 index 000000000000..4cb9291d15f9 --- /dev/null +++ b/third_party/highway/hwy/targets_test.cc @@ -0,0 +1,102 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#include "hwy/tests/test_util-inl.h" + +namespace fake { + +#define DECLARE_FUNCTION(TGT) \ + namespace N_##TGT { \ + uint32_t FakeFunction(int) { return HWY_##TGT; } \ + } + +DECLARE_FUNCTION(AVX3) +DECLARE_FUNCTION(AVX2) +DECLARE_FUNCTION(SSE4) +DECLARE_FUNCTION(NEON) +DECLARE_FUNCTION(PPC8) +DECLARE_FUNCTION(WASM) +DECLARE_FUNCTION(RVV) +DECLARE_FUNCTION(SCALAR) + +HWY_EXPORT(FakeFunction); + +void CheckFakeFunction() { +#define CHECK_ARRAY_ENTRY(TGT) \ + if ((HWY_TARGETS & HWY_##TGT) != 0) { \ + hwy::SetSupportedTargetsForTest(HWY_##TGT); \ + /* Calling Update() first to make &HWY_DYNAMIC_DISPATCH() return */ \ + /* the pointer to the already cached function. */ \ + hwy::chosen_target.Update(); \ + EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ + /* Calling DeInit() will test that the initializer function */ \ + /* also calls the right function. */ \ + hwy::chosen_target.DeInit(); \ + EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ + /* Second call uses the cached value from the previous call. */ \ + EXPECT_EQ(uint32_t(HWY_##TGT), HWY_DYNAMIC_DISPATCH(FakeFunction)(42)); \ + } + CHECK_ARRAY_ENTRY(AVX3) + CHECK_ARRAY_ENTRY(AVX2) + CHECK_ARRAY_ENTRY(SSE4) + CHECK_ARRAY_ENTRY(NEON) + CHECK_ARRAY_ENTRY(PPC8) + CHECK_ARRAY_ENTRY(WASM) + CHECK_ARRAY_ENTRY(RVV) + CHECK_ARRAY_ENTRY(SCALAR) +#undef CHECK_ARRAY_ENTRY +} + +} // namespace fake + +namespace hwy { + +class HwyTargetsTest : public testing::Test { + protected: + void TearDown() override { + SetSupportedTargetsForTest(0); + DisableTargets(0); // Reset the mask. + } +}; + +// Test that the order in the HWY_EXPORT static array matches the expected +// value of the target bits. This is only checked for the targets that are +// enabled in the current compilation. +TEST_F(HwyTargetsTest, ChosenTargetOrderTest) { fake::CheckFakeFunction(); } + +TEST_F(HwyTargetsTest, DisabledTargetsTest) { + DisableTargets(~0u); + // Check that the baseline can't be disabled. + HWY_ASSERT(HWY_ENABLED_BASELINE == SupportedTargets()); + + DisableTargets(0); // Reset the mask. + uint32_t current_targets = SupportedTargets(); + if ((current_targets & ~HWY_ENABLED_BASELINE) == 0) { + // We can't test anything else if the only compiled target is the baseline. + return; + } + // Get the lowest bit in the mask (the best target) and disable that one. + uint32_t lowest_target = current_targets & (~current_targets + 1); + // The lowest target shouldn't be one in the baseline. + HWY_ASSERT((lowest_target & ~HWY_ENABLED_BASELINE) != 0); + DisableTargets(lowest_target); + + // Check that the other targets are still enabled. + HWY_ASSERT((lowest_target ^ current_targets) == SupportedTargets()); + DisableTargets(0); // Reset the mask. +} + +} // namespace hwy diff --git a/third_party/highway/hwy/tests/arithmetic_test.cc b/third_party/highway/hwy/tests/arithmetic_test.cc new file mode 100644 index 000000000000..02e4cbde5629 --- /dev/null +++ b/third_party/highway/hwy/tests/arithmetic_test.cc @@ -0,0 +1,1259 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/arithmetic_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestPlusMinus { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, T(2)); + const auto v3 = Iota(d, T(3)); + const auto v4 = Iota(d, T(4)); + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast((2 + i) + (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, lanes.get(), v2 + v3); + HWY_ASSERT_VEC_EQ(d, Set(d, 2), Sub(v4, v2)); + + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast((2 + i) + (4 + i)); + } + auto sum = v2; + sum = Add(sum, v4); // sum == 6,8.. + HWY_ASSERT_VEC_EQ(d, Load(d, lanes.get()), sum); + + sum = Sub(sum, v4); + HWY_ASSERT_VEC_EQ(d, v2, sum); + } +}; + +HWY_NOINLINE void TestAllPlusMinus() { + ForAllTypes(ForPartialVectors()); +} + +struct TestUnsignedSaturatingArithmetic { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 1); + const auto vm = Set(d, LimitsMax()); + + HWY_ASSERT_VEC_EQ(d, Add(v0, v0), SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vi), SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vm), SaturatedAdd(v0, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vi, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vm, vm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vm)); + HWY_ASSERT_VEC_EQ(d, Sub(vm, vi), SaturatedSub(vm, vi)); + } +}; + +struct TestSignedSaturatingArithmetic { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vpm = Set(d, LimitsMax()); + // Ensure all lanes are positive, even if Iota wraps around + const auto vi = Or(And(Iota(d, 0), vpm), Set(d, 1)); + const auto vn = Sub(v0, vi); + const auto vnm = Set(d, LimitsMin()); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Gt(vi, v0)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Lt(vn, v0)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(v0, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vi, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vpm, vpm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Sub(v0, vi), SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vn, SaturatedSub(vn, v0)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vi)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vpm)); + } +}; + +HWY_NOINLINE void TestAllSaturatingArithmetic() { + const ForPartialVectors test_unsigned; + test_unsigned(uint8_t()); + test_unsigned(uint16_t()); + + const ForPartialVectors test_signed; + test_signed(int8_t()); + test_signed(int16_t()); +} + +struct TestAverage { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto v2 = Set(d, T(2)); + + HWY_ASSERT_VEC_EQ(d, v0, AverageRound(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v0, v1)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v1, v1)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v2, v2)); + } +}; + +HWY_NOINLINE void TestAllAverage() { + const ForPartialVectors test; + test(uint8_t()); + test(uint16_t()); +} + +struct TestAbs { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vpm = Set(d, LimitsMax()); + const auto vnm = Set(d, LimitsMin()); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vpm, Abs(vpm)); + HWY_ASSERT_VEC_EQ(d, vnm, Abs(vnm)); + } +}; + +struct TestFloatAbs { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vp2 = Set(d, T(0.01)); + const auto vn2 = Set(d, T(-0.01)); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vp2)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vn2)); + } +}; + +HWY_NOINLINE void TestAllAbs() { + const ForPartialVectors test; + test(int8_t()); + test(int16_t()); + test(int32_t()); + + const ForPartialVectors test_float; + test_float(float()); +#if HWY_CAP_FLOAT64 + test_float(double()); +#endif +} + +template +struct TestLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftLeft<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftLeftSame(values, 0)); + + // 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, kMaxShift)); + } +}; + +template +struct TestVariableLeftShifts { + template + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestVariableLeftShifts()(t, d); + } + + using TI = MakeSigned; + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shl(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, v1)); + + // Same: max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, small_shifts)); + + // Variable: large + for (size_t i = 0; i < N; ++i) { + expected[i] = T(TU(1) << (kMaxShift - (i & kMaxShift))); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(v1, large_shifts)); + } +}; + +struct TestUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, kMaxShift)); + } +}; + +struct TestVariableUnsignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, 0); + + const T kMax = LimitsMax(); + const auto max = Set(d, kMax); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shr(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, v1)); + + // Same: max + HWY_ASSERT_VEC_EQ(d, v0, Shr(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + expected[i] = T(i) >> (i & kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, small_shifts)); + + // Variable: Large + for (size_t i = 0; i < N; ++i) { + expected[i] = kMax >> (kMaxShift - (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(max, large_shifts)); + } +}; + +template +T RightShiftNegative(T val) { + // C++ shifts are implementation-defined for negative numbers, and we have + // seen divisions replaced with shifts, so resort to bit operations. + using TU = hwy::MakeUnsigned; + TU bits; + CopyBytes(&val, &bits); + + const TU shifted = bits >> kAmount; + + const TU all = ~TU(0); + const size_t num_zero = sizeof(TU) * 8 - 1 - kAmount; + const TU sign_extended = static_cast((all << num_zero) & LimitsMax()); + + bits = shifted | sign_extended; + CopyBytes(&bits, &val); + return val; +} + +class TestSignedRightShifts { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto values = Iota(d, 0) & Set(d, kMax); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(values)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(values, kMaxShift)); + + // Even negative value + Test<0>(kMin, d, __LINE__); + Test<1>(kMin, d, __LINE__); + Test<2>(kMin, d, __LINE__); + Test(kMin, d, __LINE__); + + const T odd = static_cast(kMin + 1); + Test<0>(odd, d, __LINE__); + Test<1>(odd, d, __LINE__); + Test<2>(odd, d, __LINE__); + Test(odd, d, __LINE__); + } + + private: + template + void Test(T val, D d, int line) { + const auto expected = Set(d, RightShiftNegative(val)); + const auto in = Set(d, val); + const char* file = __FILE__; + AssertVecEqual(d, expected, ShiftRight(in), file, line); + AssertVecEqual(d, expected, ShiftRightSame(in, kAmount), file, line); + } +}; + +struct TestVariableSignedRightShifts { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + constexpr T kMin = LimitsMin(); + constexpr T kMax = LimitsMax(); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto positive = Iota(d, 0) & Set(d, kMax); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, positive, ShiftRight<0>(positive)); + HWY_ASSERT_VEC_EQ(d, positive, ShiftRightSame(positive, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(positive)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(positive, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight(positive)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(positive, kMaxShift)); + + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + const auto negative = Iota(d, kMin); + + // Test varying negative to shift + for (size_t i = 0; i < N; ++i) { + expected[i] = RightShiftNegative<1>(static_cast(kMin + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(negative, Set(d, 1))); + + // Shift MSB right by small amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = i & kMaxShift; + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopyBytes(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), small_shifts)); + + // Shift MSB right by large amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = kMaxShift - (i & kMaxShift); + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopyBytes(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), large_shifts)); + } +}; + +HWY_NOINLINE void TestAllShifts() { + ForUnsignedTypes(ForPartialVectors>()); + ForSignedTypes(ForPartialVectors>()); + ForUnsignedTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); +} + +HWY_NOINLINE void TestAllVariableShifts() { + const ForPartialVectors> shl_u; + const ForPartialVectors> shl_s; + const ForPartialVectors shr_u; + const ForPartialVectors shr_s; + + shl_u(uint16_t()); + shr_u(uint16_t()); + + shl_u(uint32_t()); + shr_u(uint32_t()); + + shl_s(int16_t()); + shr_s(int16_t()); + + shl_s(int32_t()); + shr_s(int32_t()); + +#if HWY_CAP_INTEGER64 + shl_u(uint64_t()); + shr_u(uint64_t()); + + shl_s(int64_t()); + shr_s(int64_t()); +#endif +} + +struct TestUnsignedMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v0, Min(v1, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v0)); + + const auto vmin = Set(d, LimitsMin()); + const auto vmax = Set(d, LimitsMax()); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestSignedMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + const auto v_neg = Sub(Zero(d), v1); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, LimitsMin()); + const auto vmax = Set(d, LimitsMax()); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestFloatMinMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const auto v_neg = Iota(d, -T(Lanes(d))); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, T(-1E30)); + const auto vmax = Set(d, T(1E30)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +HWY_NOINLINE void TestAllMinMax() { + ForUnsignedTypes(ForPartialVectors()); + ForSignedTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +struct TestUnsignedMul { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vj = Iota(d, 3); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), vi * vi); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vj)); + + const T max = LimitsMax(); + const auto vmax = Set(d, max); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(vmax, v1)); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(v1, vmax)); + + const size_t bits = sizeof(T) * 8; + const uint64_t mask = (1ull << bits) - 1; + const T max2 = (uint64_t(max) * max) & mask; + HWY_ASSERT_VEC_EQ(d, Set(d, max2), Mul(vmax, vmax)); + } +}; + +struct TestSignedMul { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vn = Iota(d, -T(N)); // no i8 supported, so no wraparound + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vi)); + + for (int i = 0; i < static_cast(N); ++i) { + expected[i] = static_cast((-T(N) + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vn, vi)); + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vn)); + } +}; + +HWY_NOINLINE void TestAllMul() { + const ForPartialVectors test_unsigned; + // No u8. + test_unsigned(uint16_t()); + test_unsigned(uint32_t()); + // No u64. + + const ForPartialVectors test_signed; + // No i8. + test_signed(int16_t()); + test_signed(int32_t()); + // No i64. +} + +struct TestMulHigh { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide; + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + auto expected_lanes = AllocateAligned(N); + + const auto vi = Iota(d, 1); + const auto vni = Iota(d, -T(N)); // no i8 supported, so no wraparound + + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(vi, v0)); + + // Large positive squared + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = LimitsMax() >> i; + expected_lanes[i] = (Wide(in_lanes[i]) * in_lanes[i]) >> 16; + } + auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, v)); + + // Large positive * small positive + for (int i = 0; i < static_cast(N); ++i) { + expected_lanes[i] = static_cast((Wide(in_lanes[i]) * T(1 + i)) >> 16); + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vi)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vi, v)); + + // Large positive * small negative + for (size_t i = 0; i < N; ++i) { + expected_lanes[i] = (Wide(in_lanes[i]) * T(i - N)) >> 16; + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vni)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vni, v)); + } +}; + +HWY_NOINLINE void TestAllMulHigh() { + ForPartialVectors test; + test(int16_t()); + test(uint16_t()); +} + +struct TestMulEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide; + const Repartition d2; + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d2, Zero(d2), MulEven(v0, v0)); + + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + auto expected = AllocateAligned(Lanes(d2)); + for (size_t i = 0; i < N; i += 2) { + in_lanes[i + 0] = LimitsMax() >> i; + if (N != 1) { + in_lanes[i + 1] = 1; // unused + } + expected[i / 2] = Wide(in_lanes[i + 0]) * in_lanes[i + 0]; + } + + const auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d2, expected.get(), MulEven(v, v)); + } +}; + +HWY_NOINLINE void TestAllMulEven() { + ForPartialVectors test; + test(int32_t()); + test(uint32_t()); +} + +struct TestMulAdd { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto k0 = Zero(d); + const auto kNeg0 = Set(d, T(-0.0)); + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + HWY_ASSERT_VEC_EQ(d, k0, MulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, k0, NegMulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(v1, k0, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v1, Neg(v2), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 2) * (i + 2) + (i + 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v2, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = -T(i + 2) * (i + 2) + (1 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v2, v2, v1)); + + HWY_ASSERT_VEC_EQ(d, k0, MulSub(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, kNeg0, NegMulSub(k0, k0, k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = -T(i + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(k0), v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v1, Neg(k0), v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v1), v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v2, Neg(v1), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast((i + 2) * (i + 2) - (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v2), v2, v1)); + } +}; + +HWY_NOINLINE void TestAllMulAdd() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestDiv { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto v1 = Set(d, T(1)); + + // Unchanged after division by 1. + HWY_ASSERT_VEC_EQ(d, v, Div(v, v1)); + + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = (T(i) - 2) / T(2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Div(v, Set(d, T(2)))); + } +}; + +HWY_NOINLINE void TestAllDiv() { ForFloatTypes(ForPartialVectors()); } + +struct TestApproximateReciprocal { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto nonzero = IfThenElse(Eq(v, Zero(d)), Set(d, T(1)), v); + const size_t N = Lanes(d); + auto input = AllocateAligned(N); + Store(nonzero, d, input.get()); + + auto actual = AllocateAligned(N); + Store(ApproximateReciprocal(nonzero), d, actual.get()); + + double max_l1 = 0.0; + for (size_t i = 0; i < N; ++i) { + max_l1 = std::max(max_l1, std::abs((1.0 / input[i]) - actual[i])); + } + const double max_rel = max_l1 / std::abs(1.0 / input[N - 1]); + printf("max err %f\n", max_rel); + + HWY_ASSERT(max_rel < 0.002); + } +}; + +HWY_NOINLINE void TestAllApproximateReciprocal() { + ForPartialVectors()(float()); +} + +struct TestSquareRoot { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vi = Iota(d, 0); + HWY_ASSERT_VEC_EQ(d, vi, Sqrt(vi * vi)); + } +}; + +HWY_NOINLINE void TestAllSquareRoot() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestReciprocalSquareRoot { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Set(d, 123.0f); + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(ApproximateReciprocalSqrt(v), d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + float err = lanes[i] - 0.090166f; + if (err < 0.0f) err = -err; + HWY_ASSERT(err < 1E-4f); + } + } +}; + +HWY_NOINLINE void TestAllReciprocalSquareRoot() { + ForPartialVectors()(float()); +} + +template +AlignedFreeUniquePtr RoundTestCases(T /*unused*/, D d, size_t& padded) { + const T eps = std::numeric_limits::epsilon(); + const T test_cases[] = { + // +/- 1 + T(1), T(-1), + // +/- 0 + T(0), T(-0), + // near 0 + T(0.4), T(-0.4), + // +/- integer + T(4), T(-32), + // positive near limit + MantissaEnd() - T(1.5), MantissaEnd() + T(1.5), + // negative near limit + -MantissaEnd() - T(1.5), -MantissaEnd() + T(1.5), + // +/- huge (but still fits in float) + T(1E34), T(-1E35), + // positive tiebreak + T(1.5), T(2.5), + // negative tiebreak + T(-1.5), T(-2.5), + // positive +/- delta + T(2.0001), T(3.9999), + // negative +/- delta + T(-999.9999), T(-998.0001), + // positive +/- epsilon + T(1) + eps, T(1) - eps, + // negative +/- epsilon + T(-1) + eps, T(-1) - eps, + // +/- infinity + std::numeric_limits::infinity(), -std::numeric_limits::infinity(), + // qNaN + GetLane(NaN(d))}; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, T(0)); + return in; +} + +struct TestRound { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + // Avoid [std::]round, which does not round to nearest *even*. + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = nearbyint(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Round(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllRound() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestNearestInt { + template + HWY_NOINLINE void operator()(TF tf, const DF df) { + using TI = MakeSigned; + const RebindToSigned di; + + size_t padded; + auto in = RoundTestCases(tf, df, padded); + auto expected = AllocateAligned(padded); + + constexpr double max = static_cast(LimitsMax()); + for (size_t i = 0; i < padded; ++i) { + if (std::isnan(in[i])) { + // We replace NaN with 0 below (no_nan) + expected[i] = 0; + } else if (std::isinf(in[i]) || double(std::abs(in[i])) >= max) { + // Avoid undefined result for lrintf + expected[i] = std::signbit(in[i]) ? LimitsMin() : LimitsMax(); + } else { + expected[i] = lrintf(in[i]); + } + } + for (size_t i = 0; i < padded; i += Lanes(df)) { + const auto v = Load(df, &in[i]); + const auto no_nan = IfThenElse(Eq(v, v), v, Zero(df)); + HWY_ASSERT_VEC_EQ(di, &expected[i], NearestInt(no_nan)); + } + } +}; + +HWY_NOINLINE void TestAllNearestInt() { + ForPartialVectors()(float()); +} + +struct TestTrunc { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = trunc(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Trunc(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllTrunc() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestCeil { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::ceil(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Ceil(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllCeil() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestFloor { + template + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::floor(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Floor(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllFloor() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestSumOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + // Lane i = bit i, higher lanes 0 + double sum = 0.0; + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 0; + sum += static_cast(in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), SumOfLanes(Load(d, in_lanes.get()))); + + // Lane i = i (iota) to include upper lanes + sum = 0.0; + for (size_t i = 0; i < N; ++i) { + sum += static_cast(i); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), SumOfLanes(Iota(d, 0))); + } +}; + +HWY_NOINLINE void TestAllSumOfLanes() { + // Only full vectors because lanes in partial vectors are undefined. + const ForFullVectors sum; + + // No u8/u16/i8/i16. + sum(uint32_t()); + sum(int32_t()); + +#if HWY_CAP_INTEGER64 + sum(uint64_t()); + sum(int64_t()); +#endif + + ForFloatTypes(sum); +} + +struct TestMinOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + // Lane i = bit i, higher lanes = 2 (not the minimum) + T min = HighestValue(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 2; + min = std::min(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(Load(d, in_lanes.get()))); + + // Lane i = N - i to include upper lanes + min = HighestValue(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast(N - i); // no 8-bit T so no wraparound + min = std::min(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(Load(d, in_lanes.get()))); + } +}; + +struct TestMaxOfLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + + T max = LowestValue(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast(1ull << i) : 0; + max = std::max(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(Load(d, in_lanes.get()))); + + // Lane i = i to include upper lanes + max = LowestValue(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast(i); // no 8-bit T so no wraparound + max = std::max(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(Load(d, in_lanes.get()))); + } +}; + +HWY_NOINLINE void TestAllMinMaxOfLanes() { + // Only full vectors because lanes in partial vectors are undefined. + const ForFullVectors min; + const ForFullVectors max; + + // No u8/u16/i8/i16. + min(uint32_t()); + max(uint32_t()); + min(int32_t()); + max(int32_t()); + +#if HWY_CAP_INTEGER64 + min(uint64_t()); + max(uint64_t()); + min(int64_t()); + max(int64_t()); +#endif + + ForFloatTypes(min); + ForFloatTypes(max); +} + +struct TestAbsDiff { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes_a = AllocateAligned(N); + auto in_lanes_b = AllocateAligned(N); + auto out_lanes = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + in_lanes_a[i] = static_cast((i ^ 1u) << i); + in_lanes_b[i] = static_cast(i << i); + out_lanes[i] = std::abs(in_lanes_a[i] - in_lanes_b[i]); + } + const auto a = Load(d, in_lanes_a.get()); + const auto b = Load(d, in_lanes_b.get()); + const auto expected = Load(d, out_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(a, b)); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(b, a)); + } +}; + +HWY_NOINLINE void TestAllAbsDiff() { + ForPartialVectors()(float()); +} + +struct TestNeg { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vn = Set(d, T(-3)); + const auto vp = Set(d, T(3)); + HWY_ASSERT_VEC_EQ(d, v0, Neg(v0)); + HWY_ASSERT_VEC_EQ(d, vp, Neg(vn)); + HWY_ASSERT_VEC_EQ(d, vn, Neg(vp)); + } +}; + +HWY_NOINLINE void TestAllNeg() { + ForSignedTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyArithmeticTest); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllPlusMinus); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSaturatingArithmetic); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllShifts); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllVariableShifts); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAverage); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbs); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMul); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMulHigh); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMulEven); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMulAdd); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllDiv); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllApproximateReciprocal); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllReciprocalSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSumOfLanes); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMaxOfLanes); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllRound); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllNearestInt); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllTrunc); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllCeil); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllFloor); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbsDiff); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllNeg); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/combine_test.cc b/third_party/highway/hwy/tests/combine_test.cc new file mode 100644 index 000000000000..4f7942f67cc5 --- /dev/null +++ b/third_party/highway/hwy/tests/combine_test.cc @@ -0,0 +1,287 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/combine_test.cc" +#include "hwy/foreach_target.h" + +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Not yet implemented +#if HWY_TARGET != HWY_RVV + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLowerHalf { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half d2; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + const auto v = Iota(d, 1); + Store(LowerHalf(v), d2, lanes.get()); + size_t i = 0; + for (; i < Lanes(d2); ++i) { + HWY_ASSERT_EQ(T(1 + i), lanes[i]); + } + // Other half remains unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + } + } +}; + +struct TestLowerQuarter { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half> d4; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + const auto v = Iota(d, 1); + const auto lo = LowerHalf(LowerHalf(v)); + Store(lo, d4, lanes.get()); + size_t i = 0; + for (; i < Lanes(d4); ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes[i]); + } + // Upper 3/4 remain unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + } + } +}; + +HWY_NOINLINE void TestAllLowerHalf() { + constexpr size_t kDiv = 1; + ForAllTypes(ForPartialVectors()); + ForAllTypes(ForPartialVectors()); +} + +struct TestUpperHalf { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define UpperHalf. +#if HWY_TARGET != HWY_SCALAR + const Half d2; + + const auto v = Iota(d, 1); + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + + Store(UpperHalf(v), d2, lanes.get()); + size_t i = 0; + for (; i < Lanes(d2); ++i) { + HWY_ASSERT_EQ(T(Lanes(d2) + 1 + i), lanes[i]); + } + // Other half remains unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllUpperHalf() { + ForAllTypes(ForGE128Vectors()); +} + +struct TestZeroExtendVector { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_CAP_GE256 + const Twice d2; + + const auto v = Iota(d, 1); + const size_t N2 = Lanes(d2); + auto lanes = AllocateAligned(N2); + Store(v, d, &lanes[0]); + Store(v, d, &lanes[N2 / 2]); + + const auto ext = ZeroExtendVector(v); + Store(ext, d2, lanes.get()); + + size_t i = 0; + // Lower half is unchanged + for (; i < N2 / 2; ++i) { + HWY_ASSERT_EQ(T(1 + i), lanes[i]); + } + // Upper half is zero + for (; i < N2; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllZeroExtendVector() { + ForAllTypes(ForExtendableVectors()); +} + +struct TestCombine { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_CAP_GE256 + const Twice d2; + const size_t N2 = Lanes(d2); + auto lanes = AllocateAligned(N2); + + const auto lo = Iota(d, 1); + const auto hi = Iota(d, N2 / 2 + 1); + const auto combined = Combine(hi, lo); + Store(combined, d2, lanes.get()); + + const auto expected = Iota(d2, 1); + HWY_ASSERT_VEC_EQ(d2, expected, combined); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllCombine() { + ForAllTypes(ForExtendableVectors()); +} + + +template +struct TestCombineShiftRightBytesR { + template + HWY_NOINLINE void operator()(T t, D d) { +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const Repartition d8; + const size_t N8 = Lanes(d8); + const auto lo = BitCast(d, Iota(d8, 1)); + const auto hi = BitCast(d, Iota(d8, 1 + N8)); + + auto expected = AllocateAligned(Lanes(d)); + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + const size_t kBlockSize = 16; + for (size_t i = 0; i < N8; ++i) { + const size_t block = i / kBlockSize; + const size_t lane = i % kBlockSize; + const size_t first_lo = block * kBlockSize; + const size_t idx = lane + kBytes; + const size_t offset = (idx < kBlockSize) ? 0 : N8 - kBlockSize; + const bool at_end = idx >= 2 * kBlockSize; + expected_bytes[i] = + at_end ? 0 : static_cast(first_lo + idx + 1 + offset); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), + CombineShiftRightBytes(hi, lo)); + + TestCombineShiftRightBytesR()(t, d); +#else + (void)t; + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +template +struct TestCombineShiftRightLanesR { + template + HWY_NOINLINE void operator()(T t, D d) { +// Scalar does not define CombineShiftRightBytes (needed for *Lanes). +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const Repartition d8; + const size_t N8 = Lanes(d8); + const auto lo = BitCast(d, Iota(d8, 1)); + const auto hi = BitCast(d, Iota(d8, 1 + N8)); + + auto expected = AllocateAligned(Lanes(d)); + + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + const size_t kBlockSize = 16; + for (size_t i = 0; i < N8; ++i) { + const size_t block = i / kBlockSize; + const size_t lane = i % kBlockSize; + const size_t first_lo = block * kBlockSize; + const size_t idx = lane + kLanes * sizeof(T); + const size_t offset = (idx < kBlockSize) ? 0 : N8 - kBlockSize; + const bool at_end = idx >= 2 * kBlockSize; + expected_bytes[i] = + at_end ? 0 : static_cast(first_lo + idx + 1 + offset); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), + CombineShiftRightLanes(hi, lo)); + + TestCombineShiftRightBytesR()(t, d); +#else + (void)t; + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +template <> +struct TestCombineShiftRightBytesR<0> { + template + void operator()(T /*unused*/, D /*unused*/) {} +}; + +template <> +struct TestCombineShiftRightLanesR<0> { + template + void operator()(T /*unused*/, D /*unused*/) {} +}; + +struct TestCombineShiftRight { + template + HWY_NOINLINE void operator()(T t, D d) { + TestCombineShiftRightBytesR<15>()(t, d); + TestCombineShiftRightLanesR<16 / sizeof(T) - 1>()(t, d); + } +}; + +HWY_NOINLINE void TestAllCombineShiftRight() { + ForAllTypes(ForGE128Vectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyCombineTest); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllLowerHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllUpperHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllZeroExtendVector); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllCombine); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllCombineShiftRight); +} // namespace hwy +#endif + +#else +int main(int, char**) { return 0; } +#endif // HWY_TARGET != HWY_RVV diff --git a/third_party/highway/hwy/tests/compare_test.cc b/third_party/highway/hwy/tests/compare_test.cc new file mode 100644 index 000000000000..9e7803b87aca --- /dev/null +++ b/third_party/highway/hwy/tests/compare_test.cc @@ -0,0 +1,217 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memset + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/compare_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// All types. +struct TestMask { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + + std::fill(lanes.get(), lanes.get() + N, T(0)); + const auto actual_false = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), actual_false); + + memset(lanes.get(), 0xFF, N * sizeof(T)); + const auto actual_true = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), actual_true); + } +}; + +HWY_NOINLINE void TestAllMask() { ForAllTypes(ForPartialVectors()); } + +// All types. +struct TestEquality { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, 2); + const auto v2b = Iota(d, 2); + const auto v3 = Iota(d, 3); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq(v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2b)); + } +}; + +HWY_NOINLINE void TestAllEquality() { + ForAllTypes(ForPartialVectors()); +} + +// a > b should be true, verify that for Gt/Lt and with swapped args. +template +void EnsureGreater(D d, TFromD a, TFromD b, const char* file, int line) { + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + const auto va = Set(d, a); + const auto vb = Set(d, b); + AssertMaskEqual(d, mask_true, Gt(va, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, vb), file, line); + + // Swapped order + AssertMaskEqual(d, mask_false, Gt(vb, va), file, line); + AssertMaskEqual(d, mask_true, Lt(vb, va), file, line); + + // Also ensure irreflexive + AssertMaskEqual(d, mask_false, Gt(va, va), file, line); + AssertMaskEqual(d, mask_false, Gt(vb, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, va), file, line); + AssertMaskEqual(d, mask_false, Lt(vb, vb), file, line); +} + +#define HWY_ENSURE_GREATER(d, a, b) EnsureGreater(d, a, b, __FILE__, __LINE__) + +struct TestStrictInt { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T min = LimitsMin(); + const T max = LimitsMax(); + const auto v0 = Zero(d); + const auto v2 = And(Iota(d, T(2)), Set(d, 127)); // 0..127 + const auto vn = Neg(v2) - Set(d, 1); // -1..-128 + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, max, max / 2); + HWY_ENSURE_GREATER(d, max, 1); + HWY_ENSURE_GREATER(d, max, 0); + HWY_ENSURE_GREATER(d, max, -1); + HWY_ENSURE_GREATER(d, max, min); + HWY_ENSURE_GREATER(d, 0, min); + HWY_ENSURE_GREATER(d, min / 2, min); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +HWY_NOINLINE void TestAllStrictInt() { + ForSignedTypes(ForExtendableVectors()); +} + +struct TestStrictFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T huge_neg = -1E35; + const T huge_pos = 1E36; + const auto v0 = Zero(d); + const auto v2 = Iota(d, T(2)); + const auto vn = Neg(v2); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, huge_pos, 1); + HWY_ENSURE_GREATER(d, huge_pos, 0); + HWY_ENSURE_GREATER(d, huge_pos, -1); + HWY_ENSURE_GREATER(d, huge_pos, huge_neg); + HWY_ENSURE_GREATER(d, 0, huge_neg); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +HWY_NOINLINE void TestAllStrictFloat() { + ForFloatTypes(ForExtendableVectors()); +} + +struct TestWeakFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, 2); + const auto vn = Iota(d, -T(Lanes(d))); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, vn)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Le(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ge(vn, v2)); + } +}; + +HWY_NOINLINE void TestAllWeakFloat() { + ForFloatTypes(ForPartialVectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyCompareTest); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMask); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEquality); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictInt); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllWeakFloat); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/convert_test.cc b/third_party/highway/hwy/tests/convert_test.cc new file mode 100644 index 000000000000..870955fcafbe --- /dev/null +++ b/third_party/highway/hwy/tests/convert_test.cc @@ -0,0 +1,568 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/convert_test.cc" +#include "hwy/foreach_target.h" + +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Cast and ensure bytes are the same. Called directly from TestAllBitCast or +// via TestBitCastFrom. +template +struct TestBitCast { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Repartition dto; + HWY_ASSERT_EQ(Lanes(d) * sizeof(T), Lanes(dto) * sizeof(ToT)); + const auto vf = Iota(d, 1); + const auto vt = BitCast(dto, vf); + // Must return the same bits + auto from_lanes = AllocateAligned(Lanes(d)); + auto to_lanes = AllocateAligned(Lanes(dto)); + Store(vf, d, from_lanes.get()); + Store(vt, dto, to_lanes.get()); + HWY_ASSERT( + BytesEqual(from_lanes.get(), to_lanes.get(), Lanes(d) * sizeof(T))); + } +}; + +// From D to all types. +struct TestBitCastFrom { + template + HWY_NOINLINE void operator()(T t, D d) { + TestBitCast()(t, d); + TestBitCast()(t, d); + TestBitCast()(t, d); +#if HWY_CAP_INTEGER64 + TestBitCast()(t, d); +#endif + TestBitCast()(t, d); + TestBitCast()(t, d); + TestBitCast()(t, d); +#if HWY_CAP_INTEGER64 + TestBitCast()(t, d); +#endif + TestBitCast()(t, d); +#if HWY_CAP_FLOAT64 + TestBitCast()(t, d); +#endif + } +}; + +HWY_NOINLINE void TestAllBitCast() { + // For HWY_SCALAR and partial vectors, we can only cast to same-sized types: + // the former can't partition its single lane, and the latter can be smaller + // than a destination type. + const ForPartialVectors> to_u8; + to_u8(uint8_t()); + to_u8(int8_t()); + + const ForPartialVectors> to_i8; + to_i8(uint8_t()); + to_i8(int8_t()); + + const ForPartialVectors> to_u16; + to_u16(uint16_t()); + to_u16(int16_t()); + + const ForPartialVectors> to_i16; + to_i16(uint16_t()); + to_i16(int16_t()); + + const ForPartialVectors> to_u32; + to_u32(uint32_t()); + to_u32(int32_t()); + to_u32(float()); + + const ForPartialVectors> to_i32; + to_i32(uint32_t()); + to_i32(int32_t()); + to_i32(float()); + +#if HWY_CAP_INTEGER64 + const ForPartialVectors> to_u64; + to_u64(uint64_t()); + to_u64(int64_t()); +#if HWY_CAP_FLOAT64 + to_u64(double()); +#endif + + const ForPartialVectors> to_i64; + to_i64(uint64_t()); + to_i64(int64_t()); +#if HWY_CAP_FLOAT64 + to_i64(double()); +#endif +#endif // HWY_CAP_INTEGER64 + + const ForPartialVectors> to_float; + to_float(uint32_t()); + to_float(int32_t()); + to_float(float()); + +#if HWY_CAP_FLOAT64 + const ForPartialVectors> to_double; + to_double(double()); +#if HWY_CAP_INTEGER64 + to_double(uint64_t()); + to_double(int64_t()); +#endif // HWY_CAP_INTEGER64 +#endif // HWY_CAP_FLOAT64 + + // For non-scalar vectors, we can cast all types to all. + ForAllTypes(ForGE128Vectors()); +} + +template +struct TestPromoteTo { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(sizeof(T) < sizeof(ToT), "Input type must be narrower"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + RandomState rng; + for (size_t rep = 0; rep < 200; ++rep) { + for (size_t i = 0; i < N; ++i) { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(T)); + expected[i] = from[i]; + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + PromoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllPromoteTo() { + const ForPartialVectors, 2> to_u16div2; + to_u16div2(uint8_t()); + + const ForPartialVectors, 4> to_u32div4; + to_u32div4(uint8_t()); + + const ForPartialVectors, 2> to_u32div2; + to_u32div2(uint16_t()); + + const ForPartialVectors, 2> to_i16div2; + to_i16div2(uint8_t()); + to_i16div2(int8_t()); + + const ForPartialVectors, 2> to_i32div2; + to_i32div2(uint16_t()); + to_i32div2(int16_t()); + + const ForPartialVectors, 4> to_i32div4; + to_i32div4(uint8_t()); + to_i32div4(int8_t()); + + // Must test f16 separately because we can only load/store/convert them. + +#if HWY_CAP_INTEGER64 + const ForPartialVectors, 2> to_u64div2; + to_u64div2(uint32_t()); + + const ForPartialVectors, 2> to_i64div2; + to_i64div2(int32_t()); +#endif + +#if HWY_CAP_FLOAT64 + const ForPartialVectors, 2> to_f64div2; + to_f64div2(int32_t()); + to_f64div2(float()); +#endif +} + +template +bool IsFinite(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template +bool IsFinite(T /*unused*/) { + return true; +} + +template +struct TestDemoteTo { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(!IsFloat(), "Use TestDemoteToFloat for float output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // Narrower range in the wider type, for clamping before we cast + const T min = LimitsMin(); + const T max = LimitsMax(); + + RandomState rng; + for (size_t rep = 0; rep < 1000; ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(T)); + } while (!IsFinite(from[i])); + expected[i] = static_cast(std::min(std::max(min, from[i]), max)); + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + DemoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToInt() { + ForDemoteVectors, 2>()(int16_t()); + ForDemoteVectors, 4>()(int32_t()); + + ForDemoteVectors, 2>()(int16_t()); + ForDemoteVectors, 4>()(int32_t()); + + const ForDemoteVectors, 2> to_u16; + to_u16(int32_t()); + + const ForDemoteVectors, 2> to_i16; + to_i16(int32_t()); +} + +HWY_NOINLINE void TestAllDemoteToMixed() { +#if HWY_CAP_FLOAT64 + const ForDemoteVectors, 2> to_i32; + to_i32(double()); +#endif +} + +template +struct TestDemoteToFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + // For floats, we clamp differently and cannot call LimitsMin. + static_assert(IsFloat(), "Use TestDemoteTo for integer output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + + RandomState rng; + for (size_t rep = 0; rep < 1000; ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(T)); + } while (!IsFinite(from[i])); + const T magn = std::abs(from[i]); + const T max_abs = HighestValue(); + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + const T clipped = copysign(std::min(magn, max_abs), from[i]); + expected[i] = static_cast(clipped); + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + DemoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToFloat() { + // Must test f16 separately because we can only load/store/convert them. + +#if HWY_CAP_FLOAT64 + const ForDemoteVectors, 2> to_float; + to_float(double()); +#endif +} + +template +AlignedFreeUniquePtr F16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // +/- 1 + 1.0f, -1.0f, + // +/- 0 + 0.0f, -0.0f, + // near 0 + 0.25f, -0.25f, + // +/- integer + 4.0f, -32.0f, + // positive near limit + 65472.0f, 65504.0f, + // negative near limit + -65472.0f, -65504.0f, + // positive +/- delta + 2.00390625f, 3.99609375f, + // negative +/- delta + -2.00390625f, -3.99609375f, + // No infinity/NaN - implementation-defined due to ARM. + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned(padded); + auto expected = AllocateAligned(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, 0.0f); + return in; +} + +struct TestF16 { + template + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { + size_t padded; + auto in = F16TestCases(d32, padded); + using TF16 = float16_t; + const Rebind d16; + const size_t N = Lanes(d32); // same count for f16 + auto temp16 = AllocateAligned(N); + + for (size_t i = 0; i < padded; i += N) { + const auto loaded = Load(d32, &in[i]); + Store(DemoteTo(d16, loaded), d16, temp16.get()); + HWY_ASSERT_VEC_EQ(d32, loaded, PromoteTo(d32, Load(d16, temp16.get()))); + } + } +}; + +HWY_NOINLINE void TestAllF16() { ForDemoteVectors()(float()); } + +struct TestConvertU8 { + template + HWY_NOINLINE void operator()(T /*unused*/, const D du32) { + const Rebind du8; + auto lanes8 = AllocateAligned(Lanes(du8)); + Store(Iota(du8, 0), du8, lanes8.get()); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0), U8FromU32(Iota(du32, 0))); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0x7F), U8FromU32(Iota(du32, 0x7F))); + } +}; + +HWY_NOINLINE void TestAllConvertU8() { + ForDemoteVectors()(uint32_t()); +} + +// Separate function to attempt to work around a compiler bug on ARM: when this +// is merged with TestIntFromFloat, outputs match a previous Iota(-(N+1)) input. +struct TestIntFromFloatHuge { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + // Still does not work, although ARMv7 manual says that float->int + // saturates, i.e. chooses the nearest representable value. +#if HWY_TARGET != HWY_NEON + using TI = MakeSigned; + const Rebind di; + + // Huge positive (lvalue works around GCC bug, tested with 10.2.1, where + // the expected i32 value is otherwise 0x80..00). + const auto expected_max = Set(di, LimitsMax()); + HWY_ASSERT_VEC_EQ(di, expected_max, ConvertTo(di, Set(df, TF(1E20)))); + + // Huge negative (also lvalue for safety, but GCC bug was not triggered) + const auto expected_min = Set(di, LimitsMin()); + HWY_ASSERT_VEC_EQ(di, expected_min, ConvertTo(di, Set(df, TF(-1E20)))); +#else + (void)df; +#endif + } +}; + +struct TestIntFromFloat { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = MakeSigned; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), ConvertTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), ConvertTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), ConvertTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), ConvertTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + ConvertTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + ConvertTo(di, Iota(df, -TF(N + 1) - eps))); + + // TF does not have enough precision to represent TI. + const double min = static_cast(LimitsMin()); + const double max = static_cast(LimitsMax()); + + // Also check random values. + auto from = AllocateAligned(N); + auto expected = AllocateAligned(N); + RandomState rng; + for (size_t rep = 0; rep < 1000; ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + memcpy(&from[i], &bits, sizeof(TF)); + } while (!std::isfinite(from[i])); + if (from[i] >= max) { + expected[i] = LimitsMax(); + } else if (from[i] <= min) { + expected[i] = LimitsMin(); + } else { + expected[i] = static_cast(from[i]); + } + } + + HWY_ASSERT_VEC_EQ(di, expected.get(), + ConvertTo(di, Load(df, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllIntFromFloat() { + ForFloatTypes(ForPartialVectors()); + ForFloatTypes(ForPartialVectors()); +} + +struct TestFloatFromInt { + template + HWY_NOINLINE void operator()(TI /*unused*/, const DI di) { + using TF = MakeFloat; + const Rebind df; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), ConvertTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), ConvertTo(df, Iota(di, -TI(N)))); + + // Max positive + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax())), + ConvertTo(df, Set(di, LimitsMax()))); + + // Min negative + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin())), + ConvertTo(df, Set(di, LimitsMin()))); + } +}; + +HWY_NOINLINE void TestAllFloatFromInt() { + ForPartialVectors()(int32_t()); +#if HWY_CAP_FLOAT64 && HWY_CAP_INTEGER64 + ForPartialVectors()(int64_t()); +#endif +} + +struct TestI32F64 { + template + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), DemoteTo(di, Iota(df, TF(4.0)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), DemoteTo(di, Iota(df, -TF(N)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), PromoteTo(df, Iota(di, -TI(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), DemoteTo(di, Iota(df, TF(2.001)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(2.0)), PromoteTo(df, Iota(di, TI(2)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), DemoteTo(di, Iota(df, TF(3.9999)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + const TF eps = static_cast(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + DemoteTo(di, Iota(df, -TF(N + 1) + eps))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-4.0)), PromoteTo(df, Iota(di, TI(-4)))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + DemoteTo(di, Iota(df, -TF(N + 1) - eps))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-2.0)), PromoteTo(df, Iota(di, TI(-2)))); + + // Huge positive float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMax()), + DemoteTo(di, Set(df, TF(1E12)))); + + // Huge negative float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMin()), + DemoteTo(di, Set(df, TF(-1E12)))); + + // Max positive int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax())), + PromoteTo(df, Set(di, LimitsMax()))); + + // Min negative int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin())), + PromoteTo(df, Set(di, LimitsMin()))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_CAP_FLOAT64 + ForDemoteVectors()(double()); +#endif +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyConvertTest); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBitCast); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllPromoteTo); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToInt); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToMixed); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllDemoteToFloat); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllF16); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllConvertU8); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllIntFromFloat); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromInt); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllI32F64); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/list_targets.cc b/third_party/highway/hwy/tests/list_targets.cc new file mode 100644 index 000000000000..4b0cdcedd2c3 --- /dev/null +++ b/third_party/highway/hwy/tests/list_targets.cc @@ -0,0 +1,34 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Simple tool to print the list of targets that were compiled in when building +// this tool. + +#include + +#include "hwy/highway.h" + +void PrintTargets(const char* msg, uint32_t targets) { + fprintf(stderr, "%s", msg); + for (unsigned x = targets; x != 0; x = x & (x - 1)) { + fprintf(stderr, " %s", hwy::TargetName(x & (~x + 1))); + } + fprintf(stderr, "\n"); +} + +int main() { + PrintTargets("Compiled HWY_TARGETS:", HWY_TARGETS); + PrintTargets("HWY_BASELINE_TARGETS:", HWY_BASELINE_TARGETS); + return 0; +} diff --git a/third_party/highway/hwy/tests/logical_test.cc b/third_party/highway/hwy/tests/logical_test.cc new file mode 100644 index 000000000000..a82795a53268 --- /dev/null +++ b/third_party/highway/hwy/tests/logical_test.cc @@ -0,0 +1,691 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include // memcmp + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/logical_test.cc" +#include "hwy/foreach_target.h" + +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLogicalInteger { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 0); + const auto ones = VecFromMask(d, Eq(v0, v0)); + const auto v1 = Set(d, 1); + const auto vnot1 = Set(d, ~T(1)); + + HWY_ASSERT_VEC_EQ(d, v0, Not(ones)); + HWY_ASSERT_VEC_EQ(d, ones, Not(v0)); + HWY_ASSERT_VEC_EQ(d, v1, Not(vnot1)); + HWY_ASSERT_VEC_EQ(d, vnot1, Not(v1)); + + HWY_ASSERT_VEC_EQ(d, v0, And(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, And(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, And(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Or(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Xor(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, AndNot(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + + auto v = vi; + v = And(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = And(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + v = Or(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = Or(v, v0); + HWY_ASSERT_VEC_EQ(d, vi, v); + + v = Xor(v, vi); + HWY_ASSERT_VEC_EQ(d, v0, v); + v = Xor(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + } +}; + +HWY_NOINLINE void TestAllLogicalInteger() { + ForIntegerTypes(ForPartialVectors()); +} + +struct TestLogicalFloat { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 0); + + HWY_ASSERT_VEC_EQ(d, v0, And(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, And(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, And(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Or(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Xor(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, AndNot(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + + auto v = vi; + v = And(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = And(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + v = Or(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = Or(v, v0); + HWY_ASSERT_VEC_EQ(d, vi, v); + + v = Xor(v, vi); + HWY_ASSERT_VEC_EQ(d, v0, v); + v = Xor(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + } +}; + +HWY_NOINLINE void TestAllLogicalFloat() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestCopySign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero remains zero regardless of sign + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vn)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vn)); + + // Positive input, positive sign => unchanged + HWY_ASSERT_VEC_EQ(d, vp, CopySign(vp, vp)); + HWY_ASSERT_VEC_EQ(d, vp, CopySignToAbs(vp, vp)); + + // Positive input, negative sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySign(vp, vn)); + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySignToAbs(vp, vn)); + + // Negative input, negative sign => unchanged + HWY_ASSERT_VEC_EQ(d, vn, CopySign(vn, vn)); + + // Negative input, positive sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vn), CopySign(vn, vp)); + } +}; + +HWY_NOINLINE void TestAllCopySign() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestIfThenElse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const size_t N = Lanes(d); + auto in1 = AllocateAligned(N); + auto in2 = AllocateAligned(N); + auto mask_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + + // NOTE: reverse polarity (mask is true iff lane == 0) because we cannot + // reliably compare against all bits set (NaN for float types). + const T off = 1; + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < 50; ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast(Random32(&rng)); + in2[i] = static_cast(Random32(&rng)); + mask_lanes[i] = (Random32(&rng) & 1024) ? off : T(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto mask = Eq(Load(d, mask_lanes.get()), Zero(d)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = (mask_lanes[i] == off) ? in2[i] : in1[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElse(mask, v1, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = mask_lanes[i] ? T(0) : in1[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElseZero(mask, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = mask_lanes[i] ? in2[i] : T(0); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenZeroElse(mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfThenElse() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskVec { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const size_t N = Lanes(d); + auto mask_lanes = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + mask_lanes[i] = static_cast(Random32(&rng) & 1); + } + + const auto mask = RebindMask(d, Eq(Load(d, mask_lanes.get()), Zero(d))); + HWY_ASSERT_MASK_EQ(d, mask, MaskFromVec(VecFromMask(d, mask))); + } + } +}; + +HWY_NOINLINE void TestAllMaskVec() { + const ForPartialVectors test; + + test(uint16_t()); + test(int16_t()); + // TODO(janwas): float16_t - cannot compare yet + + test(uint32_t()); + test(int32_t()); + test(float()); + +#if HWY_CAP_INTEGER64 + test(uint64_t()); + test(int64_t()); +#endif +#if HWY_CAP_FLOAT64 + test(double()); +#endif +} + +struct TestCompress { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TU = MakeUnsigned; + const Rebind du; + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + auto mask_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + auto actual = AllocateAligned(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < 100; ++rep) { + size_t expected_pos = 0; + for (size_t i = 0; i < N; ++i) { + const uint64_t bits = Random32(&rng); + in_lanes[i] = T(); // cannot initialize float16_t directly. + CopyBytes(&bits, &in_lanes[i]); + mask_lanes[i] = static_cast(Random32(&rng) & 1); + if (mask_lanes[i] == 0) { // Zero means true (easier to compare) + expected[expected_pos++] = in_lanes[i]; + } + } + + const auto in = Load(d, in_lanes.get()); + const auto mask = RebindMask(d, Eq(Load(du, mask_lanes.get()), Zero(du))); + + Store(Compress(in, mask), d, actual.get()); + // Upper lanes are undefined. + for (size_t i = 0; i < expected_pos; ++i) { + HWY_ASSERT(memcmp(&actual[i], &expected[i], sizeof(T)) == 0); + } + + // Also check CompressStore in the same way. + memset(actual.get(), 0, N * sizeof(T)); + const size_t num_written = CompressStore(in, mask, d, actual.get()); + HWY_ASSERT_EQ(expected_pos, num_written); + for (size_t i = 0; i < expected_pos; ++i) { + HWY_ASSERT(memcmp(&actual[i], &expected[i], sizeof(T)) == 0); + } + } + } +}; + +#if 0 +namespace detail { // for code folding +void PrintCompress16x8Tables() { + constexpr size_t N = 8; // 128-bit SIMD + for (uint64_t code = 0; code < 1ull << N; ++code) { + std::array indices{0}; + size_t pos = 0; + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + + // Doubled (for converting lane to byte indices) + for (size_t i = 0; i < N; ++i) { + printf("%d,", 2 * indices[i]); + } + } + printf("\n"); +} + +// Compressed to nibbles +void PrintCompress32x8Tables() { + constexpr size_t N = 8; // AVX2 + for (uint64_t code = 0; code < 1ull << N; ++code) { + std::array indices{0}; + size_t pos = 0; + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < 16); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << 32)); + printf("0x%08x,", static_cast(packed)); + } + printf("\n"); +} + +// Pairs of 32-bit lane indices +void PrintCompress64x4Tables() { + constexpr size_t N = 4; // AVX2 + for (uint64_t code = 0; code < 1ull << N; ++code) { + std::array indices{0}; + size_t pos = 0; + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + + for (size_t i = 0; i < N; ++i) { + printf("%d,%d,", 2 * indices[i], 2 * indices[i] + 1); + } + } + printf("\n"); +} + +// 4-tuple of byte indices +void PrintCompress32x4Tables() { + using T = uint32_t; + constexpr size_t N = 4; // SSE4 + for (uint64_t code = 0; code < 1ull << N; ++code) { + std::array indices{0}; + size_t pos = 0; + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%zu,", sizeof(T) * indices[i] + idx_byte); + } + } + } + printf("\n"); +} + +// 8-tuple of byte indices +void PrintCompress64x2Tables() { + using T = uint64_t; + constexpr size_t N = 2; // SSE4 + for (uint64_t code = 0; code < 1ull << N; ++code) { + std::array indices{0}; + size_t pos = 0; + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%zu,", sizeof(T) * indices[i] + idx_byte); + } + } + } + printf("\n"); +} +} // namespace detail +#endif + +HWY_NOINLINE void TestAllCompress() { + // detail::PrintCompress32x8Tables(); + // detail::PrintCompress64x4Tables(); + // detail::PrintCompress32x4Tables(); + // detail::PrintCompress64x2Tables(); + // detail::PrintCompress16x8Tables(); + + const ForPartialVectors test; + + test(uint16_t()); + test(int16_t()); + test(float16_t()); + + test(uint32_t()); + test(int32_t()); + test(float()); + +#if HWY_CAP_INTEGER64 + test(uint64_t()); + test(int64_t()); +#endif +#if HWY_CAP_FLOAT64 + test(double()); +#endif +} + +struct TestZeroIfNegative { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(v0)); + HWY_ASSERT_VEC_EQ(d, vp, ZeroIfNegative(vp)); + + // Negative are all replaced with zero + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(vn)); + } +}; + +HWY_NOINLINE void TestAllZeroIfNegative() { + ForFloatTypes(ForPartialVectors()); +} + +struct TestBroadcastSignBit { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto s0 = Zero(d); + const auto s1 = Set(d, -1); // all bit set + const auto vpos = And(Iota(d, 0), Set(d, LimitsMax())); + const auto vneg = s1 - vpos; + + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(vpos)); + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(Set(d, LimitsMax()))); + + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(vneg)); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin()))); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin() / 2))); + } +}; + +HWY_NOINLINE void TestAllBroadcastSignBit() { + ForSignedTypes(ForPartialVectors()); +} + +struct TestTestBit { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t kNumBits = sizeof(T) * 8; + for (size_t i = 0; i < kNumBits; ++i) { + const auto bit1 = Set(d, 1ull << i); + const auto bit2 = Set(d, 1ull << ((i + 1) % kNumBits)); + const auto bit3 = Set(d, 1ull << ((i + 2) % kNumBits)); + const auto bits12 = Or(bit1, bit2); + const auto bits23 = Or(bit2, bit3); + HWY_ASSERT(AllTrue(TestBit(bit1, bit1))); + HWY_ASSERT(AllTrue(TestBit(bits12, bit1))); + HWY_ASSERT(AllTrue(TestBit(bits12, bit2))); + + HWY_ASSERT(AllFalse(TestBit(bits12, bit3))); + HWY_ASSERT(AllFalse(TestBit(bits23, bit1))); + HWY_ASSERT(AllFalse(TestBit(bit1, bit2))); + HWY_ASSERT(AllFalse(TestBit(bit2, bit1))); + HWY_ASSERT(AllFalse(TestBit(bit1, bit3))); + HWY_ASSERT(AllFalse(TestBit(bit3, bit1))); + HWY_ASSERT(AllFalse(TestBit(bit2, bit3))); + HWY_ASSERT(AllFalse(TestBit(bit3, bit2))); + } + } +}; + +HWY_NOINLINE void TestAllTestBit() { + ForIntegerTypes(ForFullVectors()); +} + +struct TestAllTrueFalse { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto zero = Zero(d); + auto v = zero; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + + HWY_ASSERT(AllTrue(Eq(v, zero))); + HWY_ASSERT(!AllFalse(Eq(v, zero))); + + // Single lane implies AllFalse = !AllTrue. Otherwise, there are multiple + // lanes and one is nonzero. + const bool expected_all_false = (N != 1); + + // Set each lane to nonzero and back to zero + for (size_t i = 0; i < N; ++i) { + lanes[i] = T(1); + v = Load(d, lanes.get()); + HWY_ASSERT(!AllTrue(Eq(v, zero))); + HWY_ASSERT(expected_all_false ^ AllFalse(Eq(v, zero))); + + lanes[i] = T(-1); + v = Load(d, lanes.get()); + HWY_ASSERT(!AllTrue(Eq(v, zero))); + HWY_ASSERT(expected_all_false ^ AllFalse(Eq(v, zero))); + + // Reset to all zero + lanes[i] = T(0); + v = Load(d, lanes.get()); + HWY_ASSERT(AllTrue(Eq(v, zero))); + HWY_ASSERT(!AllFalse(Eq(v, zero))); + } + } +}; + +HWY_NOINLINE void TestAllAllTrueFalse() { + ForAllTypes(ForPartialVectors()); +} + +class TestStoreMaskBits { + public: + template + HWY_NOINLINE void operator()(T /*t*/, D d) { + // TODO(janwas): remove once implemented (cast or vse1) +#if HWY_TARGET != HWY_RVV + RandomState rng; + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + const size_t expected_bytes = (N + 7) / 8; + auto bits = AllocateAligned(expected_bytes); + + for (size_t rep = 0; rep < 100; ++rep) { + // Generate random mask pattern. + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast((rng() & 1024) ? 1 : 0); + } + const auto mask = Load(d, lanes.get()) == Zero(d); + + const size_t bytes_written = StoreMaskBits(mask, bits.get()); + + HWY_ASSERT_EQ(expected_bytes, bytes_written); + size_t i = 0; + // Stored bits must match original mask + for (; i < N; ++i) { + const bool bit = (bits[i / 8] & (1 << (i % 8))) != 0; + HWY_ASSERT_EQ(bit, lanes[i] == 0); + } + // Any partial bits in the last byte must be zero + for (; i < 8 * bytes_written; ++i) { + const int bit = (bits[i / 8] & (1 << (i % 8))); + HWY_ASSERT_EQ(bit, 0); + } + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllStoreMaskBits() { + ForAllTypes(ForPartialVectors()); +} + +struct TestCountTrue { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = std::min(N, size_t(10)); + + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(1)); + + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + // Number of zeros written = number of mask lanes that are true. + size_t expected = 0; + for (size_t i = 0; i < max_lanes; ++i) { + lanes[i] = T(1); + if (code & (1ull << i)) { + ++expected; + lanes[i] = T(0); + } + } + + const auto mask = Eq(Load(d, lanes.get()), Zero(d)); + const size_t actual = CountTrue(mask); + HWY_ASSERT_EQ(expected, actual); + } + } +}; + +HWY_NOINLINE void TestAllCountTrue() { + ForAllTypes(ForPartialVectors()); +} + +struct TestLogicalMask { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto m0 = MaskFalse(d); + const auto m_all = MaskTrue(d); + + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + std::fill(lanes.get(), lanes.get() + N, T(1)); + + HWY_ASSERT_MASK_EQ(d, m0, Not(m_all)); + HWY_ASSERT_MASK_EQ(d, m_all, Not(m0)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = std::min(N, size_t(6)); + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + for (size_t i = 0; i < max_lanes; ++i) { + lanes[i] = T(1); + if (code & (1ull << i)) { + lanes[i] = T(0); + } + } + + const auto m = Eq(Load(d, lanes.get()), Zero(d)); + + HWY_ASSERT_MASK_EQ(d, m0, Xor(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m_all, m)); + + HWY_ASSERT_MASK_EQ(d, m, Or(m, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m_all, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m_all)); + HWY_ASSERT_MASK_EQ(d, m, AndNot(m0, m)); + } + } +}; + +HWY_NOINLINE void TestAllLogicalMask() { + ForAllTypes(ForFullVectors()); +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyLogicalTest); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalInteger); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalFloat); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCopySign); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllIfThenElse); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskVec); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCompress); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllZeroIfNegative); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllBroadcastSignBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllTrueFalse); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllStoreMaskBits); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCountTrue); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogicalMask); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/memory_test.cc b/third_party/highway/hwy/tests/memory_test.cc new file mode 100644 index 000000000000..8c6a51e61050 --- /dev/null +++ b/third_party/highway/hwy/tests/memory_test.cc @@ -0,0 +1,413 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/memory_test.cc" +#include "hwy/cache_control.h" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLoadStore { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto hi = Iota(d, 1 + N); + const auto lo = Iota(d, 1); + auto lanes = AllocateAligned(2 * N); + Store(hi, d, &lanes[N]); + Store(lo, d, &lanes[0]); + + // Aligned load + const auto lo2 = Load(d, &lanes[0]); + HWY_ASSERT_VEC_EQ(d, lo2, lo); + + // Aligned store + auto lanes2 = AllocateAligned(2 * N); + Store(lo2, d, &lanes2[0]); + Store(hi, d, &lanes2[N]); + for (size_t i = 0; i < 2 * N; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + + // Unaligned load + const auto vu = LoadU(d, &lanes[1]); + auto lanes3 = AllocateAligned(N); + Store(vu, d, lanes3.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT_EQ(T(i + 2), lanes3[i]); + } + + // Unaligned store + StoreU(lo2, d, &lanes2[N / 2]); + size_t i = 0; + for (; i < N / 2; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + for (; i < 3 * N / 2; ++i) { + HWY_ASSERT_EQ(T(i - N / 2 + 1), lanes2[i]); + } + // Subsequent values remain unchanged. + for (; i < 2 * N; ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes2[i]); + } + } +}; + +HWY_NOINLINE void TestAllLoadStore() { + ForAllTypes(ForPartialVectors()); +} + +struct TestStoreInterleaved3 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned(3 * N); + for (size_t i = 0; i < 3 * N; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned(4 * N); + auto actual_aligned = AllocateAligned(4 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[3 * i + 0] = bytes[0 * N + i]; + expected[3 * i + 1] = bytes[1 * N + i]; + expected[3 * i + 2] = bytes[2 * N + i]; + // Ensure we do not write more than 3*N bytes + expected[3 * N + i] = actual[3 * N + i] = 0; + } + StoreInterleaved3(in0, in1, in2, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 4 * N, &pos)) { + Print(d, "in0", in0, pos / 3); + Print(d, "in1", in1, pos / 3); + Print(d, "in2", in2, pos / 3); + const size_t i = pos - pos % 3; + fprintf(stderr, "interleaved %d %d %d %d %d %d\n", actual[i], + actual[i + 1], actual[i + 2], actual[i + 3], actual[i + 4], + actual[i + 5]); + HWY_ASSERT(false); + } + } + } +}; + +HWY_NOINLINE void TestAllStoreInterleaved3() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors test; +#else + const ForPartialVectors test; +#endif + test(uint8_t()); +} + +struct TestStoreInterleaved4 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned(4 * N); + for (size_t i = 0; i < 4 * N; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + const auto in3 = Load(d, &bytes[3 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned(5 * N); + auto actual_aligned = AllocateAligned(5 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[4 * i + 0] = bytes[0 * N + i]; + expected[4 * i + 1] = bytes[1 * N + i]; + expected[4 * i + 2] = bytes[2 * N + i]; + expected[4 * i + 3] = bytes[3 * N + i]; + // Ensure we do not write more than 4*N bytes + expected[4 * N + i] = actual[4 * N + i] = 0; + } + StoreInterleaved4(in0, in1, in2, in3, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 5 * N, &pos)) { + Print(d, "in0", in0, pos / 4); + Print(d, "in1", in1, pos / 4); + Print(d, "in2", in2, pos / 4); + Print(d, "in3", in3, pos / 4); + const size_t i = pos; + fprintf(stderr, "interleaved %d %d %d %d %d %d %d %d\n", actual[i], + actual[i + 1], actual[i + 2], actual[i + 3], actual[i + 4], + actual[i + 5], actual[i + 6], actual[i + 7]); + HWY_ASSERT(false); + } + } + } +}; + +HWY_NOINLINE void TestAllStoreInterleaved4() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors test; +#else + const ForPartialVectors test; +#endif + test(uint8_t()); +} + +struct TestLoadDup128 { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define LoadDup128. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + constexpr size_t N128 = 16 / sizeof(T); + alignas(16) T lanes[N128]; + for (size_t i = 0; i < N128; ++i) { + lanes[i] = static_cast(1 + i); + } + const auto v = LoadDup128(d, lanes); + const size_t N = Lanes(d); + auto out = AllocateAligned(N); + Store(v, d, out.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT_EQ(T(i % N128 + 1), out[i]); + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllLoadDup128() { + ForAllTypes(ForGE128Vectors()); +} + +struct TestStream { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + const size_t affected_bytes = + (Lanes(d) * sizeof(T) + HWY_STREAM_MULTIPLE - 1) & + ~size_t(HWY_STREAM_MULTIPLE - 1); + const size_t affected_lanes = affected_bytes / sizeof(T); + auto out = AllocateAligned(2 * affected_lanes); + std::fill(out.get(), out.get() + 2 * affected_lanes, T(0)); + + Stream(v, d, out.get()); + StoreFence(); + const auto actual = Load(d, out.get()); + HWY_ASSERT_VEC_EQ(d, v, actual); + // Ensure Stream didn't modify more memory than expected + for (size_t i = affected_lanes; i < 2 * affected_lanes; ++i) { + HWY_ASSERT_EQ(T(0), out[i]); + } + } +}; + +HWY_NOINLINE void TestAllStream() { + const ForPartialVectors test; + // No u8,u16. + test(uint32_t()); + test(uint64_t()); + // No i8,i16. + test(int32_t()); + test(int64_t()); + ForFloatTypes(test); +} + +// Assumes little-endian byte order! +struct TestScatter { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned; + + const size_t N = Lanes(d); + const size_t range = 4 * N; // number of items to scatter + const size_t max_bytes = range * sizeof(T); // upper bound on offset + + RandomState rng; + + // Data to be scattered + auto bytes = AllocateAligned(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + const auto data = Load(d, reinterpret_cast(bytes.get())); + + // Scatter into these regions, ensure vector results match scalar + auto expected = AllocateAligned(range); + auto actual = AllocateAligned(range); + + const Rebind d_offsets; + auto offsets = AllocateAligned(N); // or indices + + for (size_t rep = 0; rep < 100; ++rep) { + // Byte offsets + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + offsets[i] = + static_cast(Random32(&rng) % (max_bytes - sizeof(T))); + CopyBytes( + bytes.get() + i * sizeof(T), + reinterpret_cast(expected.get()) + offsets[i]); + } + const auto voffsets = Load(d_offsets, offsets.get()); + ScatterOffset(data, d, actual.get(), voffsets); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Offsets", voffsets); + HWY_ASSERT(false); + } + + // Indices + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + offsets[i] = static_cast(Random32(&rng) % range); + CopyBytes(bytes.get() + i * sizeof(T), + &expected[offsets[i]]); + } + const auto vindices = Load(d_offsets, offsets.get()); + ScatterIndex(data, d, actual.get(), vindices); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Indices", vindices); + HWY_ASSERT(false); + } + } + } +}; + +HWY_NOINLINE void TestAllScatter() { + // No u8,u16,i8,i16. + const ForPartialVectors test; + test(uint32_t()); + test(int32_t()); + +#if HWY_CAP_INTEGER64 + test(uint64_t()); + test(int64_t()); +#endif + + ForFloatTypes(test); +} + +struct TestGather { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned; + + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be gathered from + const size_t max_bytes = 4 * N * sizeof(T); // upper bound on offset + auto bytes = AllocateAligned(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast(Random32(&rng) & 0xFF); + } + + auto expected = AllocateAligned(N); + auto offsets = AllocateAligned(N); + auto indices = AllocateAligned(N); + + for (size_t rep = 0; rep < 100; ++rep) { + // Offsets + for (size_t i = 0; i < N; ++i) { + offsets[i] = + static_cast(Random32(&rng) % (max_bytes - sizeof(T))); + CopyBytes(bytes.get() + offsets[i], &expected[i]); + } + + const Rebind d_offset; + const T* base = reinterpret_cast(bytes.get()); + auto actual = GatherOffset(d, base, Load(d_offset, offsets.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + + // Indices + for (size_t i = 0; i < N; ++i) { + indices[i] = + static_cast(Random32(&rng) % (max_bytes / sizeof(T))); + CopyBytes(base + indices[i], &expected[i]); + } + actual = GatherIndex(d, base, Load(d_offset, indices.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + } + } +}; + +HWY_NOINLINE void TestAllGather() { + // No u8,u16,i8,i16. + const ForPartialVectors test; + test(uint32_t()); + test(int32_t()); + +#if HWY_CAP_INTEGER64 + test(uint64_t()); + test(int64_t()); +#endif + ForFloatTypes(test); +} + +HWY_NOINLINE void TestAllCache() { + LoadFence(); + StoreFence(); + int test = 0; + Prefetch(&test); + FlushCacheline(&test); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwyMemoryTest); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadStore); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreInterleaved3); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreInterleaved4); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadDup128); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStream); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllScatter); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllGather); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllCache); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/swizzle_test.cc b/third_party/highway/hwy/tests/swizzle_test.cc new file mode 100644 index 000000000000..1ec7c5299851 --- /dev/null +++ b/third_party/highway/hwy/tests/swizzle_test.cc @@ -0,0 +1,642 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/swizzle_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestShiftBytes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Bytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const Repartition du8; + const size_t N8 = Lanes(du8); + + // Zero remains zero + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(v0)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(v0)); + + // Zero after shifting out the high/low byte + auto bytes = AllocateAligned(N8); + std::fill(bytes.get(), bytes.get() + N8, 0); + bytes[N8 - 1] = 0x7F; + const auto vhi = BitCast(d, Load(du8, bytes.get())); + bytes[N8 - 1] = 0; + bytes[0] = 0x7F; + const auto vlo = BitCast(d, Load(du8, bytes.get())); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(vhi)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(vlo)); + + // Check expected result with Iota + const size_t N = Lanes(d); + auto in = AllocateAligned(N); + const uint8_t* in_bytes = reinterpret_cast(in.get()); + const auto v = BitCast(d, Iota(du8, 1)); + Store(v, d, in.get()); + + auto expected = AllocateAligned(N); + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + const size_t kBlockSize = HWY_MIN(N8, 16); + for (size_t block = 0; block < N8; block += kBlockSize) { + expected_bytes[block] = 0; + memcpy(expected_bytes + block + 1, in_bytes + block, kBlockSize - 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftBytes<1>(v)); + + for (size_t block = 0; block < N8; block += kBlockSize) { + memcpy(expected_bytes + block, in_bytes + block + 1, kBlockSize - 1); + expected_bytes[block + kBlockSize - 1] = 0; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightBytes<1>(v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftBytes() { + ForIntegerTypes(ForGE128Vectors()); +} + +struct TestShiftLanes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Lanes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const auto v = Iota(d, T(1)); + const size_t N = Lanes(d); + auto expected = AllocateAligned(N); + + HWY_ASSERT_VEC_EQ(d, v, ShiftLeftLanes<0>(v)); + HWY_ASSERT_VEC_EQ(d, v, ShiftRightLanes<0>(v)); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + + for (size_t i = 0; i < N; ++i) { + expected[i] = (i % kLanesPerBlock) == 0 ? T(0) : T(i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftLanes<1>(v)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = + (i % kLanesPerBlock) == (kLanesPerBlock - 1) ? T(0) : T(2 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightLanes<1>(v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftLanes() { + ForAllTypes(ForGE128Vectors()); +} + +template +struct TestBroadcastR { + HWY_NOINLINE void operator()() const { +// TODO(janwas): fix failure +#if HWY_TARGET != HWY_WASM + using T = typename D::T; + const D d; + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned(N); + std::fill(in_lanes.get(), in_lanes.get() + N, T(0)); + const size_t blockN = HWY_MIN(N * sizeof(T), 16) / sizeof(T); + // Need to set within each 128-bit block + for (size_t block = 0; block < N; block += blockN) { + in_lanes[block + kLane] = static_cast(block + 1); + } + const auto in = Load(d, in_lanes.get()); + auto expected = AllocateAligned(N); + for (size_t block = 0; block < N; block += blockN) { + for (size_t i = 0; i < blockN; ++i) { + expected[block + i] = T(block + 1); + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Broadcast(in)); + + TestBroadcastR()(); +#endif + } +}; + +template +struct TestBroadcastR { + void operator()() const {} +}; + +struct TestBroadcast { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + TestBroadcastR()(); + } +}; + +HWY_NOINLINE void TestAllBroadcast() { + const ForPartialVectors test; + // No u8. + test(uint16_t()); + test(uint32_t()); +#if HWY_CAP_INTEGER64 + test(uint64_t()); +#endif + + // No i8. + test(int16_t()); + test(int32_t()); +#if HWY_CAP_INTEGER64 + test(int64_t()); +#endif + + ForFloatTypes(test); +} + +struct TestTableLookupBytes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + const size_t N = Lanes(d); + const size_t N8 = Lanes(Repartition()); + auto in_bytes = AllocateAligned(N8); + for (size_t i = 0; i < N8; ++i) { + in_bytes[i] = Random32(&rng) & 0xFF; + } + const auto in = + BitCast(d, Load(d, reinterpret_cast(in_bytes.get()))); + + // Enough test data; for larger vectors, upper lanes will be zero. + const uint8_t index_bytes_source[64] = { + // Same index as source, multiple outputs from same input, + // unused input (9), ascending/descending and nonconsecutive neighbors. + 0, 2, 1, 2, 15, 12, 13, 14, 6, 7, 8, 5, 4, 3, 10, 11, + 11, 10, 3, 4, 5, 8, 7, 6, 14, 13, 12, 15, 2, 1, 2, 0, + 4, 3, 2, 2, 5, 6, 7, 7, 15, 15, 15, 15, 15, 15, 0, 1}; + auto index_bytes = AllocateAligned(N8); + for (size_t i = 0; i < N8; ++i) { + index_bytes[i] = (i < 64) ? index_bytes_source[i] : 0; + // Avoid undefined results / asan error for scalar by capping indices. + if (index_bytes[i] >= N * sizeof(T)) { + index_bytes[i] = static_cast(N * sizeof(T) - 1); + } + } + const auto indices = Load(d, reinterpret_cast(index_bytes.get())); + auto expected = AllocateAligned(N); + uint8_t* expected_bytes = reinterpret_cast(expected.get()); + + // Byte indices wrap around + const size_t mod = HWY_MIN(N8, 256); + for (size_t block = 0; block < N8; block += 16) { + for (size_t i = 0; i < 16 && (block + i) < N8; ++i) { + const uint8_t index = index_bytes[block + i]; + expected_bytes[block + i] = in_bytes[(block + index) % mod]; + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), TableLookupBytes(in, indices)); + } +}; + +HWY_NOINLINE void TestAllTableLookupBytes() { + ForIntegerTypes(ForPartialVectors()); +} +struct TestTableLookupLanes { +#if HWY_TARGET == HWY_RVV + using Index = uint32_t; +#else + using Index = int32_t; +#endif + + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto idx = AllocateAligned(N); + std::fill(idx.get(), idx.get() + N, Index(0)); + auto expected = AllocateAligned(N); + const auto v = Iota(d, 1); + + if (N <= 8) { // Test all permutations + for (size_t i0 = 0; i0 < N; ++i0) { + idx[0] = static_cast(i0); + for (size_t i1 = 0; i1 < N; ++i1) { + idx[1] = static_cast(i1); + for (size_t i2 = 0; i2 < N; ++i2) { + idx[2] = static_cast(i2); + for (size_t i3 = 0; i3 < N; ++i3) { + idx[3] = static_cast(i3); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque = SetTableIndices(d, idx.get()); + const auto actual = TableLookupLanes(v, opaque); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + } + } + } + } + } else { + // Too many permutations to test exhaustively; choose one with repeated + // and cross-block indices and ensure indices do not exceed #lanes. + // For larger vectors, upper lanes will be zero. + HWY_ALIGN Index idx_source[16] = {1, 3, 2, 2, 8, 1, 7, 6, + 15, 14, 14, 15, 4, 9, 8, 5}; + for (size_t i = 0; i < N; ++i) { + idx[i] = (i < 16) ? idx_source[i] : 0; + // Avoid undefined results / asan error for scalar by capping indices. + if (idx[i] >= static_cast(N)) { + idx[i] = static_cast(N - 1); + } + expected[i] = static_cast(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque = SetTableIndices(d, idx.get()); + const auto actual = TableLookupLanes(v, opaque); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllTableLookupLanes() { + const ForFullVectors test; + test(uint32_t()); + test(int32_t()); + test(float()); +} + +struct TestInterleave { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned; + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast(2 * i + 0); + odd_lanes[i] = static_cast(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = 16 / sizeof(T); + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + const size_t index = (i % blockN) + block * 2 * blockN; + expected[i] = static_cast(index & LimitsMax()); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveLower(even, odd)); + + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + expected[i] = T((i % blockN) + block * 2 * blockN + blockN); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveUpper(even, odd)); + } +}; + +HWY_NOINLINE void TestAllInterleave() { + // Not supported by HWY_SCALAR: Interleave(f32, f32) would return f32x2. + ForAllTypes(ForGE128Vectors()); +} + +struct TestZipLower { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using WideT = MakeWide; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned() == IsSigned(), "Must have same sign"); + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast(2 * i + 0); + odd_lanes[i] = static_cast(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const Repartition dw; + auto expected = AllocateAligned(Lanes(dw)); + const WideT blockN = static_cast(16 / sizeof(WideT)); + for (size_t i = 0; i < Lanes(dw); ++i) { + const size_t block = i / blockN; + // Value of least-significant lane in lo-vector. + const WideT lo = + static_cast(2 * (i % blockN) + 4 * block * blockN); + const WideT kBits = static_cast(sizeof(T) * 8); + expected[i] = + static_cast((static_cast(lo + 1) << kBits) + lo); + } + HWY_ASSERT_VEC_EQ(dw, expected.get(), ZipLower(even, odd)); + } +}; + +struct TestZipUpper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using WideT = MakeWide; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned() == IsSigned(), "Must have same sign"); + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned(N); + auto odd_lanes = AllocateAligned(N); + for (size_t i = 0; i < Lanes(d); ++i) { + even_lanes[i] = static_cast(2 * i + 0); + odd_lanes[i] = static_cast(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const Repartition dw; + auto expected = AllocateAligned(Lanes(dw)); + + constexpr WideT blockN = static_cast(16 / sizeof(WideT)); + for (size_t i = 0; i < Lanes(dw); ++i) { + const size_t block = i / blockN; + const WideT lo = + static_cast(2 * (i % blockN) + 4 * block * blockN); + const WideT kBits = static_cast(sizeof(T) * 8); + expected[i] = static_cast( + (static_cast(lo + 2 * blockN + 1) << kBits) + lo + 2 * blockN); + } + HWY_ASSERT_VEC_EQ(dw, expected.get(), ZipUpper(even, odd)); + } +}; + +HWY_NOINLINE void TestAllZip() { + const ForPartialVectors lower_unsigned; + // TODO(janwas): fix +#if HWY_TARGET != HWY_RVV + lower_unsigned(uint8_t()); +#endif + lower_unsigned(uint16_t()); +#if HWY_CAP_INTEGER64 + lower_unsigned(uint32_t()); // generates u64 +#endif + + const ForPartialVectors lower_signed; +#if HWY_TARGET != HWY_RVV + lower_signed(int8_t()); +#endif + lower_signed(int16_t()); +#if HWY_CAP_INTEGER64 + lower_signed(int32_t()); // generates i64 +#endif + + const ForGE128Vectors upper_unsigned; +#if HWY_TARGET != HWY_RVV + upper_unsigned(uint8_t()); +#endif + upper_unsigned(uint16_t()); +#if HWY_CAP_INTEGER64 + upper_unsigned(uint32_t()); // generates u64 +#endif + + const ForGE128Vectors upper_signed; +#if HWY_TARGET != HWY_RVV + upper_signed(int8_t()); +#endif + upper_signed(int16_t()); +#if HWY_CAP_INTEGER64 + upper_signed(int32_t()); // generates i64 +#endif + + // No float - concatenating f32 does not result in a f64 +} + +class TestSpecialShuffle32 { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + +#define VERIFY_LANES_32(d, v, i3, i2, i1, i0) \ + VerifyLanes32((d), (v), (i3), (i2), (i1), (i0), __FILE__, __LINE__) + + VERIFY_LANES_32(d, Shuffle2301(v), 2, 3, 0, 1); + VERIFY_LANES_32(d, Shuffle1032(v), 1, 0, 3, 2); + VERIFY_LANES_32(d, Shuffle0321(v), 0, 3, 2, 1); + VERIFY_LANES_32(d, Shuffle2103(v), 2, 1, 0, 3); + VERIFY_LANES_32(d, Shuffle0123(v), 0, 1, 2, 3); + +#undef VERIFY_LANES_32 + } + + private: + template + HWY_NOINLINE void VerifyLanes32(D d, V v, const int i3, const int i2, + const int i1, const int i0, + const char* filename, const int line) { + using T = typename D::T; + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(v, d, lanes.get()); + const std::string name = TypeName(lanes[0], N); + constexpr size_t kBlockN = 16 / sizeof(T); + for (int block = 0; block < static_cast(N); block += kBlockN) { + AssertEqual(T(block + i3), lanes[block + 3], name, filename, line); + AssertEqual(T(block + i2), lanes[block + 2], name, filename, line); + AssertEqual(T(block + i1), lanes[block + 1], name, filename, line); + AssertEqual(T(block + i0), lanes[block + 0], name, filename, line); + } + } +}; + +class TestSpecialShuffle64 { + public: + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + VerifyLanes64(d, Shuffle01(v), 0, 1, __FILE__, __LINE__); + } + + private: + template + HWY_NOINLINE void VerifyLanes64(D d, V v, const int i1, const int i0, + const char* filename, const int line) { + using T = typename D::T; + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(v, d, lanes.get()); + const std::string name = TypeName(lanes[0], N); + constexpr size_t kBlockN = 16 / sizeof(T); + for (int block = 0; block < static_cast(N); block += kBlockN) { + AssertEqual(T(block + i1), lanes[block + 1], name, filename, line); + AssertEqual(T(block + i0), lanes[block + 0], name, filename, line); + } + } +}; + +HWY_NOINLINE void TestAllSpecialShuffles() { + const ForGE128Vectors test32; + test32(uint32_t()); + test32(int32_t()); + test32(float()); + +#if HWY_CAP_INTEGER64 + const ForGE128Vectors test64; + test64(uint64_t()); + test64(int64_t()); +#endif + +#if HWY_CAP_FLOAT64 + const ForGE128Vectors test_d; + test_d(double()); +#endif +} + +struct TestConcatHalves { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // TODO(janwas): fix +#if HWY_TARGET != HWY_RVV + // Construct inputs such that interleaved halves == iota. + const auto expected = Iota(d, 1); + + const size_t N = Lanes(d); + auto lo = AllocateAligned(N); + auto hi = AllocateAligned(N); + size_t i; + for (i = 0; i < N / 2; ++i) { + lo[i] = static_cast(1 + i); + hi[i] = static_cast(lo[i] + T(N) / 2); + } + for (; i < N; ++i) { + lo[i] = hi[i] = 0; + } + + HWY_ASSERT_VEC_EQ(d, expected, + ConcatLowerLower(Load(d, hi.get()), Load(d, lo.get()))); + + // Same for high blocks. + for (i = 0; i < N / 2; ++i) { + lo[i] = hi[i] = 0; + } + for (; i < N; ++i) { + lo[i] = static_cast(1 + i - N / 2); + hi[i] = static_cast(lo[i] + T(N) / 2); + } + + HWY_ASSERT_VEC_EQ(d, expected, + ConcatUpperUpper(Load(d, hi.get()), Load(d, lo.get()))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllConcatHalves() { + ForAllTypes(ForGE128Vectors()); +} + +struct TestConcatLowerUpper { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // TODO(janwas): fix +#if HWY_TARGET != HWY_RVV + const size_t N = Lanes(d); + // Middle part of Iota(1) == Iota(1 + N / 2). + const auto lo = Iota(d, 1); + const auto hi = Iota(d, 1 + N); + HWY_ASSERT_VEC_EQ(d, Iota(d, 1 + N / 2), ConcatLowerUpper(hi, lo)); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllConcatLowerUpper() { + ForAllTypes(ForGE128Vectors()); +} + +struct TestConcatUpperLower { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto lo = Iota(d, 1); + const auto hi = Iota(d, 1 + N); + auto expected = AllocateAligned(N); + size_t i = 0; + for (; i < N / 2; ++i) { + expected[i] = static_cast(1 + i); + } + for (; i < N; ++i) { + expected[i] = static_cast(1 + i + N); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatUpperLower(hi, lo)); + } +}; + +HWY_NOINLINE void TestAllConcatUpperLower() { + ForAllTypes(ForGE128Vectors()); +} + +struct TestOddEven { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto even = Iota(d, 1); + const auto odd = Iota(d, 1 + N); + auto expected = AllocateAligned(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast(1 + i + ((i & 1) ? N : 0)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), OddEven(odd, even)); + } +}; + +HWY_NOINLINE void TestAllOddEven() { + ForAllTypes(ForGE128Vectors()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(HwySwizzleTest); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllShiftBytes); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllShiftLanes); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllBroadcast); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllTableLookupBytes); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllTableLookupLanes); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllInterleave); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllZip); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllSpecialShuffles); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllConcatHalves); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllConcatLowerUpper); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllConcatUpperLower); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEven); +} // namespace hwy +#endif diff --git a/third_party/highway/hwy/tests/test_util-inl.h b/third_party/highway/hwy/tests/test_util-inl.h new file mode 100644 index 000000000000..aa898c5c461a --- /dev/null +++ b/third_party/highway/hwy/tests/test_util-inl.h @@ -0,0 +1,580 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for non-SIMD portion of this header. +#ifndef HWY_TESTS_TEST_UTIL_H_ +#define HWY_TESTS_TEST_UTIL_H_ + +// Helper functions for use by *_test.cc. + +#include +#include +#include +#include + +#include +#include +#include // std::forward + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/highway.h" + +#include "gtest/gtest.h" + +namespace hwy { + +// The maximum vector size used in tests when defining test data. DEPRECATED. +constexpr size_t kTestMaxVectorSize = 64; + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +// Helper class to run parametric tests using the hwy target as parameter. To +// use this define the following in your test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// TEST_P(MyTestSuite, MyTest) { ... } +class TestWithParamTarget : public testing::TestWithParam { + protected: + void SetUp() override { SetSupportedTargetsForTest(GetParam()); } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(SupportedTargetsCalledForTest()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } +}; + +// Function to convert the test parameter of a TestWithParamTarget for +// displaying it in the gtest test name. +static inline std::string TestParamTargetName( + const testing::TestParamInfo& info) { + return TargetName(info.param); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + ::hwy::TestParamTargetName) + +// Helper class similar to TestWithParamTarget to run parametric tests that +// depend on the target and another parametric test. If you need to use multiple +// extra parameters use a std::tuple<> of them and ::testing::Generate(...) as +// the generator. To use this class define the following in your test: +// class MyTestSuite : public TestWithParamTargetT { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(MyTestSuite, ::testing::Range(0, 9)); +// TEST_P(MyTestSuite, MyTest) { ... GetParam() .... } +template +class TestWithParamTargetAndT + : public ::testing::TestWithParam> { + public: + // Expose the parametric type here so it can be used by the + // HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T macro. + using HwyParamType = T; + + protected: + void SetUp() override { + SetSupportedTargetsForTest(std::get<0>( + ::testing::TestWithParam>::GetParam())); + } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(SupportedTargetsCalledForTest()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } + + T GetParam() { + return std::get<1>( + ::testing::TestWithParam>::GetParam()); + } +}; + +template +std::string TestParamTargetNameAndT( + const testing::TestParamInfo>& info) { + return std::string(TargetName(std::get<0>(info.param))) + "_" + + ::testing::PrintToString(std::get<1>(info.param)); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(suite, generator) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + ::testing::Combine( \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + generator), \ + ::hwy::TestParamTargetNameAndT) + +// Helper macro to export a function and define a test that tests it. This is +// equivalent to do a HWY_EXPORT of a void(void) function and run it in a test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// HWY_EXPORT_AND_TEST_P(MyTestSuite, MyTest); +#define HWY_EXPORT_AND_TEST_P(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_EXPORT_AND_TEST_P_T(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(GetParam()); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_BEFORE_TEST(suite) \ + class suite : public hwy::TestWithParamTarget {}; \ + HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite); \ + static_assert(true, "For requiring trailing semicolon") + +// 64-bit random generator (Xorshift128+). Much smaller state than std::mt19937, +// which triggers a compiler bug. +class RandomState { + public: + explicit RandomState(const uint64_t seed = 0x123456789ull) { + s0_ = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_ = SplitMix64(s0_); + } + + HWY_INLINE uint64_t operator()() { + uint64_t s1 = s0_; + const uint64_t s0 = s1_; + const uint64_t bits = s1 + s0; + s0_ = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_ = s1; + return bits; + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + uint64_t s0_; + uint64_t s1_; +}; + +static HWY_INLINE uint32_t Random32(RandomState* rng) { + return static_cast((*rng)()); +} + +// Prevents the compiler from eliding the computations that led to "output". +// Works by indicating to the compiler that "output" is being read and modified. +// The +r constraint avoids unnecessary writes to memory, but only works for +// built-in types. +template +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC + (void)output; +#else // HWY_COMPILER_MSVC + asm volatile("" : "+r"(output) : : "memory"); +#endif // HWY_COMPILER_MSVC +} + +// Returns a name for the vector/part/scalar. The type prefix is u/i/f for +// unsigned/signed/floating point, followed by the number of bits per lane; +// then 'x' followed by the number of lanes. Example: u8x16. This is useful for +// understanding which instantiation of a generic test failed. +template +static inline std::string TypeName(T /*unused*/, size_t N) { + const char prefix = IsFloat() ? 'f' : (IsSigned() ? 'i' : 'u'); + char name[64]; + // Omit the xN suffix for scalars. + if (N == 1) { + snprintf(name, sizeof(name), "%c%zu", prefix, sizeof(T) * 8); + } else { + snprintf(name, sizeof(name), "%c%zux%zu", prefix, sizeof(T) * 8, N); + } + return name; +} + +// String comparison + +template +inline bool BytesEqual(const T1* p1, const T2* p2, const size_t size, + size_t* pos = nullptr) { + const uint8_t* bytes1 = reinterpret_cast(p1); + const uint8_t* bytes2 = reinterpret_cast(p2); + for (size_t i = 0; i < size; ++i) { + if (bytes1[i] != bytes2[i]) { + fprintf(stderr, "Mismatch at byte %zu of %zu: %d != %d (%s, %s)\n", i, + size, bytes1[i], bytes2[i], TypeName(T1(), 1).c_str(), + TypeName(T2(), 1).c_str()); + if (pos != nullptr) { + *pos = i; + } + return false; + } + } + return true; +} + +inline bool StringsEqual(const char* s1, const char* s2) { + while (*s1 == *s2++) { + if (*s1++ == '\0') return true; + } + return false; +} + +} // namespace hwy + +#endif // HWY_TESTS_TEST_UTIL_H_ + +// Per-target include guard +#if defined(HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#undef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#else +#define HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Prints lanes around `lane`, in memory order. +template +HWY_NOINLINE void Print(const D d, const char* caption, const Vec v, + intptr_t lane = 0) { + using T = TFromD; + const size_t N = Lanes(d); + auto lanes = AllocateAligned(N); + Store(v, d, lanes.get()); + const size_t begin = static_cast(std::max(0, lane - 2)); + const size_t end = std::min(begin + 7, N); + fprintf(stderr, "%s %s [%zu+ ->]:\n ", TypeName(T(), N).c_str(), caption, + begin); + for (size_t i = begin; i < end; ++i) { + fprintf(stderr, "%g,", double(lanes[i])); + } + if (begin >= end) fprintf(stderr, "(out of bounds)"); + fprintf(stderr, "\n"); +} + +static HWY_NORETURN HWY_NOINLINE void NotifyFailure( + const char* filename, const int line, const char* type_name, + const size_t lane, const char* expected, const char* actual) { + hwy::Abort(filename, line, + "%s, %s lane %zu mismatch: expected '%s', got '%s'.\n", + hwy::TargetName(HWY_TARGET), type_name, lane, expected, actual); +} + +template +inline Out BitCast(const In& in) { + static_assert(sizeof(Out) == sizeof(In), ""); + Out out; + CopyBytes(&in, &out); + return out; +} + +// Computes the difference in units of last place between x and y. +template +MakeUnsigned ComputeUlpDelta(TF x, TF y) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + using TU = MakeUnsigned; + + // Handle -0 == 0 and infinities. + if (x == y) return 0; + + // Consider "equal" if both are NaN, so we can verify an expected NaN. + // Needs a special case because there are many possible NaN representations. + if (std::isnan(x) && std::isnan(y)) return 0; + + // NOTE: no need to check for differing signs; they will result in large + // differences, which is fine, and we avoid overflow. + + const TU ux = BitCast(x); + const TU uy = BitCast(y); + // Avoid unsigned->signed cast: 2's complement is only guaranteed by C++20. + return std::max(ux, uy) - std::min(ux, uy); +} + +template +HWY_NOINLINE bool IsEqual(const T expected, const T actual) { + return expected == actual; +} + +template +HWY_NOINLINE bool IsEqual(const T expected, const T actual) { + return ComputeUlpDelta(expected, actual) <= 1; +} + +// Compare non-vector, non-string T. +template +HWY_NOINLINE void AssertEqual(const T expected, const T actual, + const std::string& type_name, + const char* filename = "", const int line = -1, + const size_t lane = 0) { + if (!IsEqual(expected, actual)) { + char expected_str[100]; + snprintf(expected_str, sizeof(expected_str), "%g", double(expected)); + char actual_str[100]; + snprintf(actual_str, sizeof(actual_str), "%g", double(actual)); + NotifyFailure(filename, line, type_name.c_str(), lane, expected_str, + actual_str); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void AssertStringEqual( + const char* expected, const char* actual, const char* filename = "", + const int line = -1, const size_t lane = 0) { + if (!hwy::StringsEqual(expected, actual)) { + NotifyFailure(filename, line, "string", lane, expected, actual); + } +} + +// Compare expected vector to vector. +template +HWY_NOINLINE void AssertVecEqual(D d, const V expected, const V actual, + const char* filename, const int line) { + using T = TFromD; + const size_t N = Lanes(d); + auto expected_lanes = AllocateAligned(N); + auto actual_lanes = AllocateAligned(N); + Store(expected, d, expected_lanes.get()); + Store(actual, d, actual_lanes.get()); + for (size_t i = 0; i < N; ++i) { + if (!IsEqual(expected_lanes[i], actual_lanes[i])) { + fprintf(stderr, "\n\n"); + Print(d, "expect", expected, i); + Print(d, "actual", actual, i); + + char expected_str[100]; + snprintf(expected_str, sizeof(expected_str), "%g", + double(expected_lanes[i])); + char actual_str[100]; + snprintf(actual_str, sizeof(actual_str), "%g", double(actual_lanes[i])); + + NotifyFailure(filename, line, hwy::TypeName(T(), N).c_str(), i, + expected_str, actual_str); + } + } +} + +// Compare expected lanes to vector. +template +HWY_NOINLINE void AssertVecEqual(D d, const TFromD* expected, Vec actual, + const char* filename, int line) { + AssertVecEqual(d, LoadU(d, expected), actual, filename, line); +} + +template +HWY_NOINLINE void AssertMaskEqual(D d, Mask a, Mask b, + const char* filename, int line) { + AssertVecEqual(d, VecFromMask(d, a), VecFromMask(d, b), filename, line); + + const std::string type_name = TypeName(TFromD(), Lanes(d)); + AssertEqual(CountTrue(a), CountTrue(b), type_name, filename, line, 0); + AssertEqual(AllTrue(a), AllTrue(b), type_name, filename, line, 0); + AssertEqual(AllFalse(a), AllFalse(b), type_name, filename, line, 0); + + // TODO(janwas): StoreMaskBits +} + +template +HWY_NOINLINE Mask MaskTrue(const D d) { + const auto v0 = Zero(d); + return Eq(v0, v0); +} + +template +HWY_NOINLINE Mask MaskFalse(const D d) { + // Lt is only for signed types and we cannot yet cast mask types. + return Eq(Zero(d), Set(d, 1)); +} + +#ifndef HWY_ASSERT_EQ + +#define HWY_ASSERT_EQ(expected, actual) \ + AssertEqual(expected, actual, hwy::TypeName(expected, 1), __FILE__, __LINE__) + +#define HWY_ASSERT_STRING_EQ(expected, actual) \ + AssertStringEqual(expected, actual, __FILE__, __LINE__) + +#define HWY_ASSERT_VEC_EQ(d, expected, actual) \ + AssertVecEqual(d, expected, actual, __FILE__, __LINE__) + +#define HWY_ASSERT_MASK_EQ(d, expected, actual) \ + AssertMaskEqual(d, expected, actual, __FILE__, __LINE__) + +#endif // HWY_ASSERT_EQ + +// Helpers for instantiating tests with combinations of lane types / counts. + +// For all powers of two in [kMinLanes, N * kMinLanes] (so that recursion stops +// at N == 0) +template +struct ForeachSizeR { + static void Do() { + static_assert(N != 0, "End of recursion"); + Test()(T(), Simd()); + ForeachSizeR::Do(); + } +}; + +// Base case to stop the recursion. +template +struct ForeachSizeR { + static void Do() {} +}; + +// These adapters may be called directly, or via For*Types: + +// Calls Test for all powers of two in [kMinLanes, HWY_LANES(T) / kDivLanes]. +template +struct ForPartialVectors { + template + void operator()(T /*unused*/) const { +#if HWY_TARGET == HWY_RVV + // Only m1..8 for now, can ignore kMaxLanes because HWY_*_LANES are full. + ForeachSizeR::Do(); +#else + ForeachSizeR::Do(); +#endif + } +}; + +// Calls Test for all vectors that can be demoted log2(kFactor) times. +template +struct ForDemoteVectors { + template + void operator()(T /*unused*/) const { +#if HWY_TARGET == HWY_RVV + // Only m1..8 for now. + ForeachSizeR::Do(); +#else + ForeachSizeR::Do(); +#endif + } +}; + +// Calls Test for all powers of two in [128 bits, max bits]. +template +struct ForGE128Vectors { + template + void operator()(T /*unused*/) const { +#if HWY_TARGET == HWY_RVV + ForeachSizeR::Do(); +#else + ForeachSizeR::Do(); + +#endif + } +}; + +// Calls Test for all vectors that can be expanded by kFactor. +template +struct ForExtendableVectors { + template + void operator()(T /*unused*/) const { +#if HWY_TARGET == HWY_RVV + ForeachSizeR::Do(); +#else + ForeachSizeR::Do(); +#endif + } +}; + +// Calls Test for full vectors only. +template +struct ForFullVectors { + template + void operator()(T t) const { +#if HWY_TARGET == HWY_RVV + ForeachSizeR::Do(); + (void)t; +#else + Test()(t, HWY_FULL(T)()); +#endif + } +}; + +// Type lists to shorten call sites: + +template +void ForSignedTypes(const Func& func) { + func(int8_t()); + func(int16_t()); + func(int32_t()); +#if HWY_CAP_INTEGER64 + func(int64_t()); +#endif +} + +template +void ForUnsignedTypes(const Func& func) { + func(uint8_t()); + func(uint16_t()); + func(uint32_t()); +#if HWY_CAP_INTEGER64 + func(uint64_t()); +#endif +} + +template +void ForIntegerTypes(const Func& func) { + ForSignedTypes(func); + ForUnsignedTypes(func); +} + +template +void ForFloatTypes(const Func& func) { + func(float()); +#if HWY_CAP_FLOAT64 + func(double()); +#endif +} + +template +void ForAllTypes(const Func& func) { + ForIntegerTypes(func); + ForFloatTypes(func); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/third_party/highway/hwy/tests/test_util_test.cc b/third_party/highway/hwy/tests/test_util_test.cc new file mode 100644 index 000000000000..b0f5edf52afe --- /dev/null +++ b/third_party/highway/hwy/tests/test_util_test.cc @@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/test_util_test.cc" +#include "hwy/foreach_target.h" +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestName { + template + HWY_NOINLINE void operator()(T t, D d) { + char num[10]; + std::string expected = IsFloat() ? "f" : (IsSigned() ? "i" : "u"); + snprintf(num, sizeof(num), "%zu", sizeof(T) * 8); + expected += num; + + const size_t N = Lanes(d); + if (N != 1) { + expected += 'x'; + snprintf(num, sizeof(num), "%zu", N); + expected += num; + } + const std::string actual = TypeName(t, N); + if (expected != actual) { + NotifyFailure(__FILE__, __LINE__, expected.c_str(), 0, expected.c_str(), + actual.c_str()); + } + } +}; + +HWY_NOINLINE void TestAllName() { ForAllTypes(ForPartialVectors()); } + +struct TestEqualInteger { + template + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT(IsEqual(T(0), T(0))); + HWY_ASSERT(IsEqual(T(1), T(1))); + HWY_ASSERT(IsEqual(T(-1), T(-1))); + HWY_ASSERT(IsEqual(LimitsMin(), LimitsMin())); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LimitsMin(), LimitsMax())); + HWY_ASSERT(!IsEqual(LimitsMax(), LimitsMin())); + } +}; + +struct TestEqualFloat { + template + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT(IsEqual(T(0), T(0))); + HWY_ASSERT(IsEqual(T(1), T(1))); + HWY_ASSERT(IsEqual(T(-1), T(-1))); + HWY_ASSERT(IsEqual(MantissaEnd(), MantissaEnd())); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LowestValue(), HighestValue())); + HWY_ASSERT(!IsEqual(HighestValue(), LowestValue())); + } +}; + +HWY_NOINLINE void TestAllEqual() { + ForIntegerTypes(TestEqualInteger()); + ForFloatTypes(TestEqualFloat()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_BEFORE_TEST(TestUtilTest); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllName); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllEqual); +} // namespace hwy +#endif diff --git a/third_party/highway/libhwy-test.pc.in b/third_party/highway/libhwy-test.pc.in new file mode 100644 index 000000000000..827bb8e91b6d --- /dev/null +++ b/third_party/highway/libhwy-test.pc.in @@ -0,0 +1,8 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libhwy-test +Description: Efficient and performance-portable SIMD wrapper, test helpers. +Requires: gtest +Version: @HWY_LIBRARY_VERSION@ +Cflags: -I${includedir} diff --git a/third_party/highway/libhwy.pc.in b/third_party/highway/libhwy.pc.in new file mode 100644 index 000000000000..2ada0e847cbf --- /dev/null +++ b/third_party/highway/libhwy.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libhwy +Description: Efficient and performance-portable SIMD wrapper +Version: @HWY_LIBRARY_VERSION@ +Libs: -L${libdir} -lhwy +Cflags: -I${includedir} diff --git a/third_party/highway/run_tests.bat b/third_party/highway/run_tests.bat new file mode 100644 index 000000000000..d38081046881 --- /dev/null +++ b/third_party/highway/run_tests.bat @@ -0,0 +1,20 @@ +@echo off +REM Switch directory of this batch file +cd %~dp0 + +if not exist build mkdir build + +cd build +cmake .. -G Ninja || goto error +ninja || goto error +ctest -j || goto error + +cd .. +echo Success +goto end + +:error +echo Failure +exit /b 1 + +:end diff --git a/third_party/highway/run_tests.sh b/third_party/highway/run_tests.sh new file mode 100644 index 000000000000..1c772cd5c9ed --- /dev/null +++ b/third_party/highway/run_tests.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Switch to directory of this script +MYDIR=$(dirname $(realpath "$0")) +cd "${MYDIR}" + +# Exit if anything fails +set -e + +mkdir -p build +cd build +cmake .. +make -j +ctest -j +echo Success diff --git a/third_party/jpeg-xl/.clang-format b/third_party/jpeg-xl/.clang-format new file mode 100644 index 000000000000..a61b61c56972 --- /dev/null +++ b/third_party/jpeg-xl/.clang-format @@ -0,0 +1,4 @@ +BasedOnStyle: Google +IncludeCategories: + - Regex: '^ +# instead . +# - modernize-return-braced-init-list: this often doesn't improve readability. +# - modernize-use-auto: is too aggressive towards using auto. +# - modernize-use-default-member-init: with a mix of constructors and default +# member initialization this can be confusing if enforced. +# - modernize-use-trailing-return-type: does not improve readability when used +# systematically. +# - modernize-use-using: typedefs are ok. +# +# - readability-else-after-return: It doesn't always improve readability. +# - readability-static-accessed-through-instance +# It is often more useful and readable to access a constant of a passed +# variable (like d.N) instead of using the type of the variable that could be +# long and complex. +# - readability-uppercase-literal-suffix: we write 1.0f, not 1.0F. + +Checks: >- + bugprone-*, + clang-*, + -clang-diagnostic-unused-command-line-argument, + google-*, + modernize-*, + performance-*, + readability-*, + -google-readability-todo, + -modernize-deprecated-headers, + -modernize-return-braced-init-list, + -modernize-use-auto, + -modernize-use-default-member-init, + -modernize-use-trailing-return-type, + -modernize-use-using, + -readability-else-after-return, + -readability-function-cognitive-complexity, + -readability-static-accessed-through-instance, + -readability-uppercase-literal-suffix, + + +WarningsAsErrors: >- + bugprone-argument-comment, + bugprone-macro-parentheses, + bugprone-suspicious-string-compare, + bugprone-use-after-move, + clang-*, + clang-analyzer-*, + -clang-diagnostic-unused-command-line-argument, + google-build-using-namespace, + google-explicit-constructor, + google-readability-braces-around-statements, + google-readability-namespace-comments, + modernize-use-override, + readability-inconsistent-declaration-parameter-name + +# We are only interested in the headers from this projects, excluding +# third_party/ and build/. +HeaderFilterRegex: '^.*/(lib|tools)/.*\.h$' + +CheckOptions: + - key: readability-braces-around-statements.ShortStatementLines + value: '2' + - key: google-readability-braces-around-statements.ShortStatementLines + value: '2' + - key: readability-implicit-bool-conversion.AllowPointerConditions + value: '1' + - key: readability-implicit-bool-conversion.AllowIntegerConditions + value: '1' diff --git a/third_party/jpeg-xl/CMakeLists.txt b/third_party/jpeg-xl/CMakeLists.txt new file mode 100644 index 000000000000..c16fe5c647ef --- /dev/null +++ b/third_party/jpeg-xl/CMakeLists.txt @@ -0,0 +1,334 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Ubuntu bionic ships with cmake 3.10. +cmake_minimum_required(VERSION 3.10) + +# Honor VISIBILITY_INLINES_HIDDEN on all types of targets. +if(POLICY CMP0063) + cmake_policy(SET CMP0063 NEW) +endif() +# Pass CMAKE_EXE_LINKER_FLAGS to CC and CXX compilers when testing if they work. +if(POLICY CMP0065) + cmake_policy(SET CMP0065 NEW) +endif() + +# Set PIE flags for POSITION_INDEPENDENT_CODE targets, added in 3.14. +if(POLICY CMP0083) + cmake_policy(SET CMP0083 NEW) +endif() + +project(JPEGXL LANGUAGES C CXX) + +include(CheckCXXSourceCompiles) +check_cxx_source_compiles( + "int main() { + #if !defined(__EMSCRIPTEN__) + static_assert(false, \"__EMSCRIPTEN__ is not defined\"); + #endif + return 0; + }" + JPEGXL_EMSCRIPTEN +) + +message(STATUS "CMAKE_SYSTEM_PROCESSOR is ${CMAKE_SYSTEM_PROCESSOR}") +include(CheckCXXCompilerFlag) +check_cxx_compiler_flag("-fsanitize=fuzzer-no-link" CXX_FUZZERS_SUPPORTED) +check_cxx_compiler_flag("-Xclang -mconstructor-aliases" CXX_CONSTRUCTOR_ALIASES_SUPPORTED) + +# Enabled PIE binaries by default if supported. +include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) +if(CHECK_PIE_SUPPORTED) + check_pie_supported(LANGUAGES CXX) + if(CMAKE_CXX_LINK_PIE_SUPPORTED) + set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) + endif() +endif() + +### Project build options: +if(${CXX_FUZZERS_SUPPORTED}) + # Enabled by default except on arm64, Windows and Apple builds. + set(ENABLE_FUZZERS_DEFAULT true) +endif() +find_package(PkgConfig) +if(NOT APPLE AND NOT WIN32 AND NOT HAIKU AND ${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") + pkg_check_modules(TCMallocMinimalVersionCheck QUIET IMPORTED_TARGET + libtcmalloc_minimal) + if(TCMallocMinimalVersionCheck_FOUND AND + NOT TCMallocMinimalVersionCheck_VERSION VERSION_EQUAL 2.8.0) + # Enabled by default except on Windows and Apple builds for + # tcmalloc != 2.8.0. tcmalloc 2.8.1 already has a fix for this issue. + set(ENABLE_TCMALLOC_DEFAULT true) + else() + message(STATUS + "tcmalloc version ${TCMallocMinimalVersionCheck_VERSION} -- " + "tcmalloc 2.8.0 disabled due to " + "https://github.com/gperftools/gperftools/issues/1204") + endif() +endif() + +set(WARNINGS_AS_ERRORS_DEFAULT false) + +set(JPEGXL_ENABLE_FUZZERS ${ENABLE_FUZZERS_DEFAULT} CACHE BOOL + "Build JPEGXL fuzzer targets.") +set(JPEGXL_ENABLE_DEVTOOLS false CACHE BOOL + "Build JPEGXL developer tools.") +set(JPEGXL_ENABLE_MANPAGES true CACHE BOOL + "Build and install man pages for the command-line tools.") +set(JPEGXL_ENABLE_BENCHMARK true CACHE BOOL + "Build JPEGXL benchmark tools.") +set(JPEGXL_ENABLE_EXAMPLES true CACHE BOOL + "Build JPEGXL library usage examples.") +set(JPEGXL_ENABLE_SJPEG true CACHE BOOL + "Build JPEGXL with support for encoding with sjpeg.") +set(JPEGXL_ENABLE_OPENEXR true CACHE BOOL + "Build JPEGXL with support for OpenEXR if available.") +set(JPEGXL_ENABLE_SKCMS true CACHE BOOL + "Build with skcms instead of lcms2.") +set(JPEGXL_ENABLE_VIEWERS false CACHE BOOL + "Build JPEGXL viewer tools for evaluation.") +set(JPEGXL_ENABLE_TCMALLOC ${ENABLE_TCMALLOC_DEFAULT} CACHE BOOL + "Build JPEGXL using gperftools (tcmalloc) allocator.") +set(JPEGXL_ENABLE_PLUGINS false CACHE BOOL + "Build third-party plugings to support JPEG XL in other applications.") +set(JPEGXL_ENABLE_COVERAGE false CACHE BOOL + "Enable code coverage tracking for libjxl. This also enables debug and disables optimizations.") +set(JPEGXL_STATIC false CACHE BOOL + "Build tools as static binaries.") +set(JPEGXL_WARNINGS_AS_ERRORS ${WARNINGS_AS_ERRORS_DEFAULT} CACHE BOOL + "Treat warnings as errors during compilation.") +set(JPEGXL_DEP_LICENSE_DIR "" CACHE STRING + "Directory where to search for system dependencies \"copyright\" files.") +set(JPEGXL_FORCE_NEON false CACHE BOOL + "Set flags to enable NEON in arm if not enabled by your toolchain.") + + +# Force system dependencies. +set(JPEGXL_FORCE_SYSTEM_GTEST false CACHE BOOL + "Force using system installed googletest (gtest/gmock) instead of third_party/googletest source.") +set(JPEGXL_FORCE_SYSTEM_BROTLI false CACHE BOOL + "Force using system installed brotli instead of third_party/brotli source.") +set(JPEGXL_FORCE_SYSTEM_HWY false CACHE BOOL + "Force using system installed highway (libhwy-dev) instead of third_party/highway source.") + +# Check minimum compiler versions. Older compilers are not supported and fail +# with hard to understand errors. +if (NOT ${CMAKE_C_COMPILER_ID} STREQUAL ${CMAKE_CXX_COMPILER_ID}) + message(FATAL_ERROR "Different C/C++ compilers set: " + "${CMAKE_C_COMPILER_ID} vs ${CMAKE_CXX_COMPILER_ID}") +endif() +if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + # Android NDK's toolchain.cmake fakes the clang version in + # CMAKE_CXX_COMPILER_VERSION with an incorrect number, so ignore this. + if (NOT ${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION} MATCHES "clang" + AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 6) + message(FATAL_ERROR + "Minimum Clang version required is Clang 6, please update.") + endif() +elseif (${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 7) + message(FATAL_ERROR + "Minimum GCC version required is 7, please update.") + endif() +endif() + +message(STATUS + "Compiled IDs C:${CMAKE_C_COMPILER_ID}, C++:${CMAKE_CXX_COMPILER_ID}") + +# CMAKE_EXPORT_COMPILE_COMMANDS is used to generate the compilation database +# used by clang-tidy. +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(JPEGXL_STATIC) + set(CMAKE_FIND_LIBRARY_SUFFIXES .a) + set(BUILD_SHARED_LIBS 0) + set(CMAKE_EXE_LINKER_FLAGS + "${CMAKE_EXE_LINKER_FLAGS} -static -static-libgcc -static-libstdc++") + if (MINGW) + # In MINGW libstdc++ uses pthreads directly. When building statically a + # program (regardless of whether the source code uses pthread or not) the + # toolchain will add stdc++ and pthread to the linking step but stdc++ will + # be linked statically while pthread will be linked dynamically. + # To avoid this and have pthread statically linked with need to pass it in + # the command line with "-Wl,-Bstatic -lpthread -Wl,-Bdynamic" but the + # linker will discard it if not used by anything else up to that point in + # the linker command line. If the program or any dependency don't use + # pthread directly -lpthread is discarded and libstdc++ (added by the + # toolchain later) will then use the dynamic version. For this we also need + # to pass -lstdc++ explicitly before -lpthread. For pure C programs -lstdc++ + # will be discarded anyway. + # This adds these flags as dependencies for *all* targets. Adding this to + # CMAKE_EXE_LINKER_FLAGS instead would cause them to be included before any + # object files and therefore discarded. + link_libraries(-Wl,-Bstatic -lstdc++ -lpthread -Wl,-Bdynamic) + endif() # MINGW +endif() # JPEGXL_STATIC + +if (MSVC) +# TODO(janwas): add flags +else () + +# Global compiler flags for all targets here and in subdirectories. +add_definitions( + # Avoid changing the binary based on the current time and date. + -D__DATE__="redacted" + -D__TIMESTAMP__="redacted" + -D__TIME__="redacted" +) + +if("${JPEGXL_ENABLE_FUZZERS}" OR "${JPEGXL_ENABLE_COVERAGE}") + add_definitions( + -DJXL_ENABLE_FUZZERS + ) +endif() # JPEGXL_ENABLE_FUZZERS + +# In CMake before 3.12 it is problematic to pass repeated flags like -Xclang. +# For this reason we place them in CMAKE_CXX_FLAGS instead. +# See https://gitlab.kitware.com/cmake/cmake/issues/15826 + +# Machine flags. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funwind-tables") +if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xclang -mrelax-all") +endif() +if ("${CXX_CONSTRUCTOR_ALIASES_SUPPORTED}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xclang -mconstructor-aliases") +endif() + +if(WIN32) +# Not supported by clang-cl, but frame pointers are default on Windows +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") +endif() + +# CPU flags - remove once we have NEON dynamic dispatch + +# TODO(janwas): this also matches M1, but only ARMv7 is intended/needed. +if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") +if(JPEGXL_FORCE_NEON) +# GCC requires these flags, otherwise __ARM_NEON is undefined. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \ + -mfpu=neon-vfpv4 -mfloat-abi=hard") +endif() +endif() + +# Force build with optimizations in release mode. +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2") + +add_compile_options( + # Ignore this to allow redefining __DATE__ and others. + -Wno-builtin-macro-redefined + + # Global warning settings. + -Wall +) + +if (JPEGXL_WARNINGS_AS_ERRORS) +add_compile_options(-Werror) +endif () +endif () # !MSVC + +include(GNUInstallDirs) + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +add_subdirectory(third_party) + +set(THREADS_PREFER_PTHREAD_FLAG YES) +find_package(Threads REQUIRED) + +# Copy the JXL license file to the output build directory. +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.jpeg-xl COPYONLY) + +# Enable tests regardless of where are they defined. +enable_testing() +include(CTest) + +# Libraries. +add_subdirectory(lib) + +if(BUILD_TESTING) +# Script to run tests over the source code in bash. +find_program (BASH_PROGRAM bash) +if(BASH_PROGRAM) + add_test( + NAME bash_test + COMMAND ${BASH_PROGRAM} ${CMAKE_CURRENT_SOURCE_DIR}/bash_test.sh) +endif() +endif() # BUILD_TESTING + +# Documentation generated by Doxygen +find_package(Doxygen) +if(DOXYGEN_FOUND) +set(DOXYGEN_GENERATE_HTML "YES") +set(DOXYGEN_GENERATE_XML "NO") +set(DOXYGEN_STRIP_FROM_PATH "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(DOXYGEN_USE_MDFILE_AS_MAINPAGE "README.md") +set(DOXYGEN_WARN_AS_ERROR "YES") +doxygen_add_docs(doc + "${CMAKE_CURRENT_SOURCE_DIR}/lib/include/jxl" + "${CMAKE_CURRENT_SOURCE_DIR}/doc/api.txt" + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + COMMENT "Generating C API documentation") +else() +# Create a "doc" target for compatibility since "doc" is not otherwise added to +# the build when doxygen is not installed. +add_custom_target(doc false + COMMENT "Error: Can't generate doc since Doxygen not installed.") +endif() # DOXYGEN_FOUND + +if(JPEGXL_ENABLE_MANPAGES) +find_package(Python COMPONENTS Interpreter) +if(Python_Interpreter_FOUND) + find_program(ASCIIDOC a2x) +endif() +if(NOT Python_Interpreter_FOUND OR "${ASCIIDOC}" STREQUAL "ASCIIDOC-NOTFOUND") + message(WARNING "asciidoc was not found, the man pages will not be installed.") +else() + set(MANPAGE_FILES "") + set(MANPAGES "") + foreach(PAGE IN ITEMS cjxl djxl) + # Invoking the Python interpreter ourselves instead of running the a2x binary + # directly is necessary on MSYS2, otherwise it is run through cmd.exe which + # does not recognize it. + add_custom_command( + OUTPUT "${PAGE}.1" + COMMAND Python::Interpreter + ARGS "${ASCIIDOC}" + --format manpage --destination-dir="${CMAKE_CURRENT_BINARY_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/doc/man/${PAGE}.txt" + MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/doc/man/${PAGE}.txt") + list(APPEND MANPAGE_FILES "${CMAKE_CURRENT_BINARY_DIR}/${PAGE}.1") + list(APPEND MANPAGES "${PAGE}.1") + endforeach() + add_custom_target(manpages ALL DEPENDS ${MANPAGES}) + install(FILES ${MANPAGE_FILES} DESTINATION share/man/man1) +endif() +endif() + +# Example usage code. +if (${JPEGXL_ENABLE_EXAMPLES}) +add_subdirectory(examples) +endif () + +# Plugins for third-party software +if (${JPEGXL_ENABLE_PLUGINS}) +add_subdirectory(plugins) +endif () + +# Binary tools +add_subdirectory(tools) diff --git a/third_party/jpeg-xl/CONTRIBUTING.md b/third_party/jpeg-xl/CONTRIBUTING.md new file mode 100644 index 000000000000..3e3c9f9685db --- /dev/null +++ b/third_party/jpeg-xl/CONTRIBUTING.md @@ -0,0 +1,10 @@ +# How to Contribute + +We are currently unable to accept patches to this project, but we'd very much +appreciate if you open a Gitlab issue for any bug reports/feature requests. +We will be happy to investigate and hopefully fix any issues. + +# Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/third_party/jpeg-xl/CONTRIBUTORS b/third_party/jpeg-xl/CONTRIBUTORS new file mode 100644 index 000000000000..848096f9219b --- /dev/null +++ b/third_party/jpeg-xl/CONTRIBUTORS @@ -0,0 +1,23 @@ +# This files lists individuals who made significant contributions to the JPEG XL +# code base, such as design, adding features, performing experiments, ... +# Small changes such as a small bugfix or fixing spelling errors are not +# included. If you'd like to be included in this file thanks to a significant +# contribution, feel free to send a pull request changing this file. +Alex Deymo +Alexander Rhatushnyak +Evgenii Kliuchnikov +Iulia-Maria Comșa +Jan Wassenberg +Jon Sneyers +Jyrki Alakuijala +Krzysztof Potempa +Lode Vandevenne +Luca Versari +Martin Bruse +Moritz Firsching +Renata Khasanova +Robert Obryk +Sami Boukortt +Sebastian Gomez-Gonzalez +Thomas Fischbacher +Zoltan Szabadka diff --git a/third_party/jpeg-xl/LICENSE b/third_party/jpeg-xl/LICENSE new file mode 100644 index 000000000000..6b0b1270ff0c --- /dev/null +++ b/third_party/jpeg-xl/LICENSE @@ -0,0 +1,203 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/third_party/jpeg-xl/README.Haiku.md b/third_party/jpeg-xl/README.Haiku.md new file mode 100644 index 000000000000..88f24e91e2c2 --- /dev/null +++ b/third_party/jpeg-xl/README.Haiku.md @@ -0,0 +1,20 @@ +## Disclaimer + +Haiku builds are not officially supported, i.e. the build might not work at all, +some tests may fail and some sub-projects are excluded from build. + +This manual outlines Haiku-specific setup. For general building and testing +instructions see "[README](README.md)" and +"[Building and Testing changes](doc/building_and_testing.md)". + +## Dependencies + +```shell +pkgman install llvm9_clang ninja cmake doxygen libjpeg_turbo_devel +``` + +## Building + +```shell +TEST_STACK_LIMIT=none CMAKE_FLAGS="-I/boot/system/develop/tools/lib/gcc/x86_64-unknown-haiku/8.3.0/include/c++ -I/boot/system/develop/tools/lib/gcc/x86_64-unknown-haiku/8.3.0/include/c++/x86_64-unknown-haiku" CMAKE_SHARED_LINKER_FLAGS="-shared -Xlinker -soname=libjpegxl.so -lpthread" ./ci.sh opt +``` diff --git a/third_party/jpeg-xl/README.OSX.md b/third_party/jpeg-xl/README.OSX.md new file mode 100644 index 000000000000..8c6dc5a397e4 --- /dev/null +++ b/third_party/jpeg-xl/README.OSX.md @@ -0,0 +1,41 @@ +## Disclaimer + +OSX builds have "best effort" support, i.e. build might not work at all, some +tests may fail and some sub-projects are excluded from build. + +This manual outlines OSX specific setup. For general building and testing +instructions see "[README](README.md)" and +"[Building and Testing changes](doc/building_and_testing.md)". + +[Homebrew](https://brew.sh/) is a popular package manager. JPEG XL library and +binaries could be installed using it: + +```bash +brew install jpeg-xl +``` + +## Dependencies + +Make sure that `brew doctor` does not report serious problems and up-to-date +version of XCode is installed. + +Installing (actually, building) `clang` might take a couple hours. + +```bash +brew install llvm +``` + +```bash +brew install coreutils cmake giflib jpeg-turbo libpng ninja zlib +``` + +Before building the project check that `which clang` is +`/usr/local/opt/llvm/bin/clang`, not the one provided by XCode. If not, update +`PATH` environment variable. + +Also, setting `CMAKE_PREFIX_PATH` might be necessary for correct include paths +resolving, e.g.: + +```bash +export CMAKE_PREFIX_PATH=`brew --prefix giflib`:`brew --prefix jpeg-turbo`:`brew --prefix libpng`:`brew --prefix zlib` +``` \ No newline at end of file diff --git a/third_party/jpeg-xl/README.md b/third_party/jpeg-xl/README.md new file mode 100644 index 000000000000..5c386617564f --- /dev/null +++ b/third_party/jpeg-xl/README.md @@ -0,0 +1,224 @@ +# JPEG XL reference implementation + +JXL logo + +This repository contains a reference implementation of JPEG XL (encoder and +decoder), called `libjxl`. + +JPEG XL is in the final stages of standardization and its codestream format is +frozen. + +The libraries API, command line options and tools in this repository are subject +to change, however files encoded with `cjxl` conform to the JPEG XL format +specification and can be decoded with current and future `djxl` decoders or +`libjxl` decoding library. + +## Quick start guide + +For more details and other workflows see the "Advanced guide" below. + +### Checking out the code + +```bash +git clone https://gitlab.com/wg1/jpeg-xl.git --recursive +``` + +This repository uses git submodules to handle some third party dependencies +under `third_party/`, that's why is important to pass `--recursive`. If you +didn't check out with `--recursive`, or any submodule has changed, run: +`git submodule update --init --recursive`. + +Important: If you downloaded a zip file or tarball from the web interface you +won't get the needed submodules and the code will not compile. You can download +these external dependencies from source running `./deps.sh`. The git workflow +described above is recommended instead. + +### Installing dependencies + +Required dependencies for compiling the code, in a Debian/Ubuntu based +distribution run: + +```bash +sudo apt install cmake pkg-config libbrotli-dev +``` + +Optional dependencies for supporting other formats in the `cjxl`/`djxl` tools, +in a Debian/Ubuntu based distribution run: + +```bash +sudo apt install libgif-dev libjpeg-dev libopenexr-dev libpng-dev libwebp-dev +``` + +We recommend using a recent Clang compiler (version 7 or newer), for that +install clang and set `CC` and `CXX` variables. For example, with clang-7: + +```bash +sudo apt install clang-7 +export CC=clang-7 CXX=clang++-7 +``` + +### Building + +```bash +cd jpeg-xl +mkdir build +cd build +cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF .. +cmake --build . -- -j$(nproc) +``` + +The encoder/decoder tools will be available in the `build/tools` directory. + +### Installing + +```bash +sudo cmake --install . +``` + +### Basic encoder/decoder + +To encode a source image to JPEG XL with default settings: + +```bash +build/tools/cjxl input.png output.jxl +``` + +For more settings run `build/tools/cjxl --help` or for a full list of options +run `build/tools/cjxl -v -v --help`. + +To decode a JPEG XL file run: + +```bash +build/tools/djxl input.jxl output.png +``` + +When possible `cjxl`/`djxl` are able to read/write the following +image formats: .exr, .gif, .jpeg/.jpg, .pfm, .pgm/.ppm, .pgx, .png. + +### Benchmarking + +For speed benchmarks on single images in single or multi-threaded decoding +`djxl` can print decoding speed information. See `djxl --help` for details +on the decoding options and note that the output image is optional for +benchmarking purposes. + +For a more comprehensive comparison of compression density between multiple +options see "Benchmarking with benchmark_xl" section below. + +## Advanced guide + +### Building with Docker + +We build a common environment based on Debian/Ubuntu using Docker. Other +systems may have different combinations of versions and dependencies that +have not been tested and may not work. For those cases we recommend using the +Docker environment as explained in the +[step by step guide](doc/developing_in_docker.md). + +### Building JPEG XL for developers + +For experienced developers, we also provide build instructions for an [up to +date Debian-based Linux](doc/developing_in_debian.md) and [64-bit +Windows](doc/developing_in_windows.md). If you encounter any difficulties, +please use Docker instead. + +## Benchmarking with benchmark_xl + +We recommend `build/tools/benchmark_xl` as a convenient method for reading +images or image sequences, encoding them using various codecs (jpeg jxl png +webp), decoding the result, and computing objective quality metrics. An example +invocation is: + +```bash +build/tools/benchmark_xl --input "/path/*.png" --codec jxl:wombat:d1,jxl:cheetah:d2 +``` + +Multiple comma-separated codecs are allowed. The characters after : are +parameters for the codec, separated by colons, in this case specifying maximum +target psychovisual distances of 1 and 2 (higher implies lower quality) and +the encoder effort (see below). Other common parameters are `r0.5` (target +bitrate 0.5 bits per pixel) and `q92` (quality 92, on a scale of 0-100, where +higher is better). The `jxl` codec supports the following additional parameters: + +Speed: `falcon`, `cheetah`, `hare`, `wombat`, `squirrel`, `kitten`, `tortoise` +control the encoder effort in ascending order. This also affects memory usage: +using lower effort will typically reduce memory consumption during encoding. + +* `falcon` disables all of the following tools. +* `cheetah` enables coefficient reordering, context clustering, and heuristics + for selecting DCT sizes and quantization steps. +* `hare` enables Gaborish filtering, chroma from luma, and an initial estimate + of quantization steps. +* `wombat` enables error diffusion quantization and full DCT size selection + heuristics. +* `squirrel` (default) enables dots, patches, and spline detection, and full + context clustering. +* `kitten` optimizes the adaptive quantization for a psychovisual metric. +* `tortoise` enables a more thorough adaptive quantization search. + +Mode: JPEG XL has two modes. The default is Var-DCT mode, which is suitable for +lossy compression. The other mode is Modular mode, which is suitable for lossless +compression. Modular mode can also do lossy compression (e.g. `jxl:m:q50`). + +* `m` activates modular mode. + +Other arguments to benchmark_xl include: + +* `--save_compressed`: save codestreams to `output_dir`. +* `--save_decompressed`: save decompressed outputs to `output_dir`. +* `--output_extension`: selects the format used to output decoded images. +* `--num_threads`: number of codec instances that will independently + encode/decode images, or 0. +* `--inner_threads`: how many threads each instance should use for parallel + encoding/decoding, or 0. +* `--encode_reps`/`--decode_reps`: how many times to repeat encoding/decoding + each image, for more consistent measurements (we recommend 10). + +The benchmark output begins with a header: + +``` +Compr Input Compr Compr Compr Decomp Butteraugli +Method Pixels Size BPP # MP/s MP/s Distance Error p norm BPP*pnorm Errors +``` + +`ComprMethod` lists each each comma-separated codec. `InputPixels` is the number +of pixels in the input image. `ComprSize` is the codestream size in bytes and +`ComprBPP` the bitrate. `Compr MP/s` and `Decomp MP/s` are the +compress/decompress throughput, in units of Megapixels/second. +`Butteraugli Distance` indicates the maximum psychovisual error in the decoded +image (larger is worse). `Error p norm` is a similar summary of the psychovisual +error, but closer to an average, giving less weight to small low-quality +regions. `BPP*pnorm` is the product of `ComprBPP` and `Error p norm`, which is a +figure of merit for the codec (lower is better). `Errors` is nonzero if errors +occurred while loading or encoding/decoding the image. + +## License + +This software is available under Apache 2.0 license which can be found in the +[LICENSE](LICENSE) file. + +## Additional documentation + +### Codec description + +* [Introductory paper](https://www.spiedigitallibrary.org/proceedings/Download?fullDOI=10.1117%2F12.2529237) (open-access) +* [XL Overview](doc/xl_overview.md) - a brief introduction to the source code modules +* [JPEG XL white paper](http://ds.jpeg.org/whitepapers/jpeg-xl-whitepaper.pdf) +* [JPEG XL website](https://jpeg.org/jpegxl/) +* [Jon's JXL info page](https://sneyers.info/jxl/) + +### Development process +* [Docker setup - **start here**](doc/developing_in_docker.md) +* [Building on Debian](doc/developing_in_debian.md) - for experts only +* [Building on Windows](doc/developing_in_windows.md) - for experts only +* [More information on testing/build options](doc/building_and_testing.md) +* [Git guide for JPEG XL](doc/developing_in_gitlab.md) - for developers only +* [Building Web Assembly artifacts](doc/building_wasm.md) + +### Contact + +If you encounter a bug or other issue with the software, please open an Issue here. + +There is a [subreddit about JPEG XL](https://www.reddit.com/r/jpegxl/), and +informal chatting with developers and early adopters of `libjxl` can be done on the +[JPEG XL Discord server](https://discord.gg/DqkQgDRTFu). diff --git a/third_party/jpeg-xl/SECURITY.md b/third_party/jpeg-xl/SECURITY.md new file mode 100644 index 000000000000..e6616d18c0c9 --- /dev/null +++ b/third_party/jpeg-xl/SECURITY.md @@ -0,0 +1,37 @@ +# Security and Vulnerability Policy for JPEG XL + +The current focus of the reference implementation is to provide a vehicle for +evaluating the JPEG XL codec compression density, quality, features and its +actual performance on different platforms. With this focus in mind we provide +source code releases with improvements on performance and quality so developers +can evaluate the codec. + +At this time, **we don't provide security and vulnerability support** for any +of these releases. This means that the source code may contain bugs, including +security bugs, that may be added or fixed between releases and will **not** be +individually documented. All of these +[releases](https://gitlab.com/wg1/jpeg-xl/-/releases) include the following +note to that effect: + +* Note: This release is for evaluation purposes and may contain bugs, including + security bugs, that will *not* be individually documented when fixed. Always + prefer to use the latest release. Please provide feedback and report bugs + [here](https://gitlab.com/wg1/jpeg-xl/-/issues). + +To be clear, this means that because a release doesn't mention any CVE it +doesn't mean that no security issues in previous versions were fixed. You should +assume that any previous release contains security issues if that's a concern +for your use case. + +This however doesn't impede you from evaluating the codec with your own trusted +inputs, such as `.jxl` you encoded yourself, or when taking appropriate measures +for your application like sandboxing if processing untrusted inputs. + +## Future plans + +To help our users and developers integrating this implementation into their +software we plan to provide support for security and vulnerability tracking of +this implementation in the future. + +When we can provide such support we will update this Policy with the details and +expectations and clearly mention that fact in the release notes. diff --git a/third_party/jpeg-xl/bash_test.sh b/third_party/jpeg-xl/bash_test.sh new file mode 100644 index 000000000000..14869d63ce12 --- /dev/null +++ b/third_party/jpeg-xl/bash_test.sh @@ -0,0 +1,229 @@ +#!/bin/bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests implemented in bash. These typically will run checks about the source +# code rather than the compiled one. + +MYDIR=$(dirname $(realpath "$0")) + +set -u + +test_includes() { + local ret=0 + local f + for f in $(git ls-files | grep -E '(\.cc|\.cpp|\.h)$'); do + # Check that the public files (in lib/include/ directory) don't use the full + # path to the public header since users of the library will include the + # library as: #include "jxl/foobar.h". + if [[ "${f#lib/include/}" != "${f}" ]]; then + if grep -i -H -n -E '#include\s*[<"]lib/include/jxl' "$f" >&2; then + echo "Don't add \"include/\" to the include path of public headers." >&2 + ret=1 + fi + fi + + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + + # Check that local files don't use the full path to third_party/ + # directory since the installed versions will not have that path. + # Add an exception for third_party/dirent.h. + if grep -v -F 'third_party/dirent.h' "$f" | \ + grep -i -H -n -E '#include\s*[<"]third_party/' >&2 && + [[ $ret -eq 0 ]]; then + cat >&2 <&2 + ret=1 + fi + done + return ${ret} +} + +test_copyright() { + local ret=0 + local f + for f in $(git ls-files | grep -E '(\.cc|\.cpp|\.h|\.sh|\.m|\.py)$'); do + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + if ! head -n 10 "$f" | + grep -F 'Copyright (c) the JPEG XL Project' >/dev/null ; then + echo "$f: Missing Copyright blob near the top of the file." >&2 + ret=1 + fi + fi + done + return ${ret} +} + +# Check that "dec_" code doesn't depend on "enc_" headers. +test_dec_enc_deps() { + local ret=0 + local f + for f in $(git ls-files | grep -E '/dec_'); do + if [[ "${f#third_party/}" == "$f" ]]; then + # $f is not in third_party/ + if grep -n -H -E "#include.*/enc_" "$f" >&2; then + echo "$f: Don't include \"enc_*\" files from \"dec_*\" files." >&2 + ret=1 + fi + fi + done + return ${ret} +} + +# Check for git merge conflict markers. +test_merge_conflict() { + local ret=0 + TEXT_FILES='(\.cc|\.cpp|\.h|\.sh|\.m|\.py|\.md|\.txt|\.cmake)$' + for f in $(git ls-files | grep -E "${TEXT_FILES}"); do + if grep -E '^<<<<<<< ' "$f"; then + echo "$f: Found git merge conflict marker. Please resolve." >&2 + ret=1 + fi + done + return ${ret} +} + +# Check that the library and the package have the same version. This prevents +# accidentally having them out of sync. +get_version() { + local varname=$1 + local line=$(grep -F "set(${varname} " lib/CMakeLists.txt | head -n 1) + [[ -n "${line}" ]] + line="${line#set(${varname} }" + line="${line%)}" + echo "${line}" +} + +test_version() { + local major=$(get_version JPEGXL_MAJOR_VERSION) + local minor=$(get_version JPEGXL_MINOR_VERSION) + local patch=$(get_version JPEGXL_PATCH_VERSION) + # Check that the version is not empty + if [[ -z "${major}${minor}${patch}" ]]; then + echo "Couldn't parse version from CMakeLists.txt" >&2 + return 1 + fi + local pkg_version=$(head -n 1 debian/changelog) + # Get only the part between the first "jpeg-xl (" and the following ")". + pkg_version="${pkg_version#jpeg-xl (}" + pkg_version="${pkg_version%%)*}" + if [[ -z "${pkg_version}" ]]; then + echo "Couldn't parse version from debian package" >&2 + return 1 + fi + + local lib_version="${major}.${minor}.${patch}" + lib_version="${lib_version%.0}" + if [[ "${pkg_version}" != "${lib_version}"* ]]; then + echo "Debian package version (${pkg_version}) doesn't match library" \ + "version (${lib_version})." >&2 + return 1 + fi + return 0 +} + +# Check that the SHA versions in deps.sh matches the git submodules. +test_deps_version() { + while IFS= read -r line; do + if [[ "${line:0:10}" != "[submodule" ]]; then + continue + fi + line="${line#[submodule \"}" + line="${line%\"]}" + local varname=$(tr '[:lower:]' '[:upper:]' <<< "${line}") + varname="${varname/\//_}" + if ! grep -F "${varname}=" deps.sh >/dev/null; then + # Ignoring submodule not in deps.sh + continue + fi + local deps_sha=$(grep -F "${varname}=" deps.sh | cut -f 2 -d '"') + [[ -n "${deps_sha}" ]] + local git_sha=$(git ls-tree -r HEAD "${line}" | cut -f 1 | cut -f 3 -d ' ') + if [[ "${deps_sha}" != "${git_sha}" ]]; then + cat >&2 </dev/null; then + cat >&2 </dev/null 2>/dev/null; then + echo "Not a git checkout, skipping bash_test" + return 0 + fi + + IFS=$'\n' + for f in $(declare -F); do + local test_name=$(echo "$f" | cut -f 3 -d ' ') + # Runs all the local bash functions that start with "test_". + if [[ "${test_name}" == test_* ]]; then + echo "Test ${test_name}: Start" + if ${test_name}; then + echo "Test ${test_name}: PASS" + else + echo "Test ${test_name}: FAIL" + ret=1 + fi + fi + done + return ${ret} +} + +main "$@" diff --git a/third_party/jpeg-xl/ci.sh b/third_party/jpeg-xl/ci.sh new file mode 100644 index 000000000000..7e42aeccd92b --- /dev/null +++ b/third_party/jpeg-xl/ci.sh @@ -0,0 +1,1330 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous integration helper module. This module is meant to be called from +# the .gitlab-ci.yml file during the continuous integration build, as well as +# from the command line for developers. + +set -eu + +OS=`uname -s` + +MYDIR=$(dirname $(realpath "$0")) + +### Environment parameters: +TEST_STACK_LIMIT="${TEST_STACK_LIMIT:-128}" +CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-RelWithDebInfo} +CMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH:-} +CMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER:-} +CMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER:-} +CMAKE_MAKE_PROGRAM=${CMAKE_MAKE_PROGRAM:-} +SKIP_TEST="${SKIP_TEST:-0}" +BUILD_TARGET="${BUILD_TARGET:-}" +ENABLE_WASM_SIMD="${ENABLE_WASM_SIMD:-0}" +if [[ -n "${BUILD_TARGET}" ]]; then + BUILD_DIR="${BUILD_DIR:-${MYDIR}/build-${BUILD_TARGET%%-*}}" +else + BUILD_DIR="${BUILD_DIR:-${MYDIR}/build}" +fi +# Whether we should post a message in the MR when the build fails. +POST_MESSAGE_ON_ERROR="${POST_MESSAGE_ON_ERROR:-1}" + +# Set default compilers to clang if not already set +export CC=${CC:-clang} +export CXX=${CXX:-clang++} + +# Time limit for the "fuzz" command in seconds (0 means no limit). +FUZZER_MAX_TIME="${FUZZER_MAX_TIME:-0}" + +SANITIZER="none" + +if [[ "${BUILD_TARGET}" == wasm* ]]; then + # Check that environment is setup for the WASM build target. + if [[ -z "${EMSCRIPTEN}" ]]; then + echo "'EMSCRIPTEN' is not defined. Use 'emconfigure' wrapper to setup WASM build environment" >&2 + return 1 + fi + # Remove the side-effect of "emconfigure" wrapper - it considers NodeJS environment. + unset EMMAKEN_JUST_CONFIGURE + EMS_TOOLCHAIN_FILE="${EMSCRIPTEN}/cmake/Modules/Platform/Emscripten.cmake" + if [[ -f "${EMS_TOOLCHAIN_FILE}" ]]; then + CMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE:-${EMS_TOOLCHAIN_FILE}} + else + echo "Warning: EMSCRIPTEN CMake module not found" >&2 + fi + CMAKE_CROSSCOMPILING_EMULATOR="${MYDIR}/js-wasm-wrapper.sh" +fi + +if [[ "${BUILD_TARGET%%-*}" == "x86_64" || + "${BUILD_TARGET%%-*}" == "i686" ]]; then + # Default to building all targets, even if compiler baseline is SSE4 + HWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS:-HWY_SCALAR} +else + HWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS:-} +fi + +# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS +CMAKE_FLAGS=${CMAKE_FLAGS:-} +CMAKE_C_FLAGS="${CMAKE_C_FLAGS:-} ${CMAKE_FLAGS}" +CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS:-} ${CMAKE_FLAGS}" + +CMAKE_CROSSCOMPILING_EMULATOR=${CMAKE_CROSSCOMPILING_EMULATOR:-} +CMAKE_EXE_LINKER_FLAGS=${CMAKE_EXE_LINKER_FLAGS:-} +CMAKE_FIND_ROOT_PATH=${CMAKE_FIND_ROOT_PATH:-} +CMAKE_MODULE_LINKER_FLAGS=${CMAKE_MODULE_LINKER_FLAGS:-} +CMAKE_SHARED_LINKER_FLAGS=${CMAKE_SHARED_LINKER_FLAGS:-} +CMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE:-} + +if [[ "${ENABLE_WASM_SIMD}" -ne "0" ]]; then + CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -msimd128" + CMAKE_C_FLAGS="${CMAKE_C_FLAGS} -msimd128" + CMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS} -msimd128" +fi + +if [[ ! -z "${HWY_BASELINE_TARGETS}" ]]; then + CMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -DHWY_BASELINE_TARGETS=${HWY_BASELINE_TARGETS}" +fi + +# Version inferred from the CI variables. +CI_COMMIT_SHA=${CI_COMMIT_SHA:-} +JPEGXL_VERSION=${JPEGXL_VERSION:-${CI_COMMIT_SHA:0:8}} + +# Benchmark parameters +STORE_IMAGES=${STORE_IMAGES:-1} +BENCHMARK_CORPORA="${MYDIR}/third_party/corpora" + +# Local flags passed to sanitizers. +UBSAN_FLAGS=( + -fsanitize=alignment + -fsanitize=bool + -fsanitize=bounds + -fsanitize=builtin + -fsanitize=enum + -fsanitize=float-cast-overflow + -fsanitize=float-divide-by-zero + -fsanitize=integer-divide-by-zero + -fsanitize=null + -fsanitize=object-size + -fsanitize=pointer-overflow + -fsanitize=return + -fsanitize=returns-nonnull-attribute + -fsanitize=shift-base + -fsanitize=shift-exponent + -fsanitize=unreachable + -fsanitize=vla-bound + + -fno-sanitize-recover=undefined + # Brunsli uses unaligned accesses to uint32_t, so alignment is just a warning. + -fsanitize-recover=alignment +) +# -fsanitize=function doesn't work on aarch64 and arm. +if [[ "${BUILD_TARGET%%-*}" != "aarch64" && + "${BUILD_TARGET%%-*}" != "arm" ]]; then + UBSAN_FLAGS+=( + -fsanitize=function + ) +fi +if [[ "${BUILD_TARGET%%-*}" != "arm" ]]; then + UBSAN_FLAGS+=( + -fsanitize=signed-integer-overflow + ) +fi + +CLANG_TIDY_BIN=$(which clang-tidy-6.0 clang-tidy-7 clang-tidy-8 clang-tidy | head -n 1) +# Default to "cat" if "colordiff" is not installed or if stdout is not a tty. +if [[ -t 1 ]]; then + COLORDIFF_BIN=$(which colordiff cat | head -n 1) +else + COLORDIFF_BIN="cat" +fi +FIND_BIN=$(which gfind find | head -n 1) +# "false" will disable wine64 when not installed. This won't allow +# cross-compiling. +WINE_BIN=$(which wine64 false | head -n 1) + +CLANG_VERSION="${CLANG_VERSION:-}" +# Detect the clang version suffix and store it in CLANG_VERSION. For example, +# "6.0" for clang 6 or "7" for clang 7. +detect_clang_version() { + if [[ -n "${CLANG_VERSION}" ]]; then + return 0 + fi + local clang_version=$("${CC:-clang}" --version | head -n1) + clang_version=${clang_version#"Debian "} + local llvm_tag + case "${clang_version}" in + "clang version 6."*) + CLANG_VERSION="6.0" + ;; + "clang version "*) + # Any other clang version uses just the major version number. + local suffix="${clang_version#clang version }" + CLANG_VERSION="${suffix%%.*}" + ;; + "emcc"*) + # We can't use asan or msan in the emcc case. + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac +} + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} + +# Executed on exit. +on_exit() { + local retcode="$1" + # Always cleanup the CLEANUP_FILES. + cleanup + + # Post a message in the MR when requested with POST_MESSAGE_ON_ERROR but only + # if the run failed and we are not running from a MR pipeline. + if [[ ${retcode} -ne 0 && -n "${CI_BUILD_NAME:-}" && + -n "${POST_MESSAGE_ON_ERROR}" && -z "${CI_MERGE_REQUEST_ID:-}" && + "${CI_BUILD_REF_NAME}" = "master" ]]; then + load_mr_vars_from_commit + { set +xeu; } 2>/dev/null + local message="**Run ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} failed.** + +Check the output of the job at ${CI_JOB_URL:-} to see if this was your problem. +If it was, please rollback this change or fix the problem ASAP, broken builds +slow down development. Check if the error already existed in the previous build +as well. + +Pipeline: ${CI_PIPELINE_URL} + +Previous build commit: ${CI_COMMIT_BEFORE_SHA} +" + cmd_post_mr_comment "${message}" + fi +} + +trap 'retcode=$?; { set +x; } 2>/dev/null; on_exit ${retcode}' INT TERM EXIT + + +# These variables are populated when calling merge_request_commits(). + +# The current hash at the top of the current branch or merge request branch (if +# running from a merge request pipeline). +MR_HEAD_SHA="" +# The common ancestor between the current commit and the tracked branch, such +# as master. This includes a list +MR_ANCESTOR_SHA="" + +# Populate MR_HEAD_SHA and MR_ANCESTOR_SHA. +merge_request_commits() { + { set +x; } 2>/dev/null + # CI_BUILD_REF is the reference currently being build in the CI workflow. + MR_HEAD_SHA=$(git -C "${MYDIR}" rev-parse -q "${CI_BUILD_REF:-HEAD}") + if [[ -z "${CI_MERGE_REQUEST_IID:-}" ]]; then + # We are in a local branch, not a merge request. + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q HEAD@{upstream} || true) + else + # Merge request pipeline in CI. In this case the upstream is called "origin" + # but it refers to the forked project that's the source of the merge + # request. We need to get the target of the merge request, for which we need + # to query that repository using our CI_JOB_TOKEN. + echo "machine gitlab.com login gitlab-ci-token password ${CI_JOB_TOKEN}" \ + >> "${HOME}/.netrc" + git -C "${MYDIR}" fetch "${CI_MERGE_REQUEST_PROJECT_URL}" \ + "${CI_MERGE_REQUEST_TARGET_BRANCH_NAME}" + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q FETCH_HEAD) + fi + if [[ -z "${MR_ANCESTOR_SHA}" ]]; then + echo "Warning, not tracking any branch, using the last commit in HEAD.">&2 + # This prints the return value with just HEAD. + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" rev-parse -q "${MR_HEAD_SHA}^") + else + MR_ANCESTOR_SHA=$(git -C "${MYDIR}" merge-base --all \ + "${MR_ANCESTOR_SHA}" "${MR_HEAD_SHA}") + fi + set -x +} + +# Load the MR iid from the landed commit message when running not from a +# merge request workflow. This is useful to post back results at the merge +# request when running pipelines from master. +load_mr_vars_from_commit() { + { set +x; } 2>/dev/null + if [[ -z "${CI_MERGE_REQUEST_IID:-}" ]]; then + local mr_iid=$(git rev-list --format=%B --max-count=1 HEAD | + grep -F "${CI_PROJECT_URL}" | grep -F "/merge_requests" | head -n 1) + # mr_iid contains a string like this if it matched: + # Part-of: + if [[ -n "${mr_iid}" ]]; then + mr_iid=$(echo "${mr_iid}" | + sed -E 's,^.*merge_requests/([0-9]+)>.*$,\1,') + CI_MERGE_REQUEST_IID="${mr_iid}" + CI_MERGE_REQUEST_PROJECT_ID=${CI_PROJECT_ID} + fi + fi + set -x +} + +# Posts a comment to the current merge request. +cmd_post_mr_comment() { + { set +x; } 2>/dev/null + local comment="$1" + if [[ -n "${BOT_TOKEN:-}" && -n "${CI_MERGE_REQUEST_IID:-}" ]]; then + local url="${CI_API_V4_URL}/projects/${CI_MERGE_REQUEST_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}/notes" + curl -X POST -g \ + -H "PRIVATE-TOKEN: ${BOT_TOKEN}" \ + --data-urlencode "body=${comment}" \ + --output /dev/null \ + "${url}" + fi + set -x +} + +# Set up and export the environment variables needed by the child processes. +export_env() { + if [[ "${BUILD_TARGET}" == *mingw32 ]]; then + # Wine needs to know the paths to the mingw dlls. These should be + # separated by ';'. + WINEPATH=$("${CC:-clang}" -print-search-dirs --target="${BUILD_TARGET}" \ + | grep -F 'libraries: =' | cut -f 2- -d '=' | tr ':' ';') + # We also need our own libraries in the wine path. + local real_build_dir=$(realpath "${BUILD_DIR}") + # Some library .dll dependencies are installed in /bin: + export WINEPATH="${WINEPATH};${real_build_dir};${real_build_dir}/third_party/brotli;/usr/${BUILD_TARGET}/bin" + + local prefix="${BUILD_DIR}/wineprefix" + mkdir -p "${prefix}" + export WINEPREFIX=$(realpath "${prefix}") + fi + # Sanitizers need these variables to print and properly format the stack + # traces: + LLVM_SYMBOLIZER=$("${CC:-clang}" -print-prog-name=llvm-symbolizer || true) + if [[ -n "${LLVM_SYMBOLIZER}" ]]; then + export ASAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + export MSAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + export UBSAN_SYMBOLIZER_PATH="${LLVM_SYMBOLIZER}" + fi +} + +cmake_configure() { + export_env + + if [[ "${STACK_SIZE:-0}" == 1 ]]; then + # Dump the stack size of each function in the .stack_sizes section for + # analysis. + CMAKE_C_FLAGS+=" -fstack-size-section" + CMAKE_CXX_FLAGS+=" -fstack-size-section" + fi + + local args=( + -B"${BUILD_DIR}" -H"${MYDIR}" + -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" + -G Ninja + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" + -DCMAKE_TOOLCHAIN_FILE="${CMAKE_TOOLCHAIN_FILE}" + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" + -DCMAKE_MODULE_LINKER_FLAGS="${CMAKE_MODULE_LINKER_FLAGS}" + -DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" + -DJPEGXL_VERSION="${JPEGXL_VERSION}" + -DSANITIZER="${SANITIZER}" + # These are not enabled by default in cmake. + -DJPEGXL_ENABLE_VIEWERS=ON + -DJPEGXL_ENABLE_PLUGINS=ON + -DJPEGXL_ENABLE_DEVTOOLS=ON + ) + if [[ "${BUILD_TARGET}" != *mingw32 ]]; then + args+=( + -DJPEGXL_WARNINGS_AS_ERRORS=ON + ) + fi + if [[ -n "${BUILD_TARGET}" ]]; then + local system_name="Linux" + if [[ "${BUILD_TARGET}" == *mingw32 ]]; then + # When cross-compiling with mingw the target must be set to Windows and + # run programs with wine. + system_name="Windows" + args+=( + -DCMAKE_CROSSCOMPILING_EMULATOR="${WINE_BIN}" + # Normally CMake automatically defines MINGW=1 when building with the + # mingw compiler (x86_64-w64-mingw32-gcc) but we are normally compiling + # with clang. + -DMINGW=1 + ) + fi + # EMSCRIPTEN toolchain sets the right values itself + if [[ "${BUILD_TARGET}" != wasm* ]]; then + # If set, BUILD_TARGET must be the target triplet such as + # x86_64-unknown-linux-gnu. + args+=( + -DCMAKE_C_COMPILER_TARGET="${BUILD_TARGET}" + -DCMAKE_CXX_COMPILER_TARGET="${BUILD_TARGET}" + # Only the first element of the target triplet. + -DCMAKE_SYSTEM_PROCESSOR="${BUILD_TARGET%%-*}" + -DCMAKE_SYSTEM_NAME="${system_name}" + ) + else + # sjpeg confuses WASM SIMD with SSE. + args+=( + -DSJPEG_ENABLE_SIMD=OFF + ) + fi + args+=( + # These are needed to make googletest work when cross-compiling. + -DCMAKE_CROSSCOMPILING=1 + -DHAVE_STD_REGEX=0 + -DHAVE_POSIX_REGEX=0 + -DHAVE_GNU_POSIX_REGEX=0 + -DHAVE_STEADY_CLOCK=0 + -DHAVE_THREAD_SAFETY_ATTRIBUTES=0 + ) + if [[ -z "${CMAKE_FIND_ROOT_PATH}" ]]; then + # find_package() will look in this prefix for libraries. + CMAKE_FIND_ROOT_PATH="/usr/${BUILD_TARGET}" + fi + if [[ -z "${CMAKE_PREFIX_PATH}" ]]; then + CMAKE_PREFIX_PATH="/usr/${BUILD_TARGET}" + fi + # Use pkg-config for the target. If there's no pkg-config available for the + # target we can set the PKG_CONFIG_PATH to the appropriate path in most + # linux distributions. + local pkg_config=$(which "${BUILD_TARGET}-pkg-config" || true) + if [[ -z "${pkg_config}" ]]; then + pkg_config=$(which pkg-config) + export PKG_CONFIG_LIBDIR="/usr/${BUILD_TARGET}/lib/pkgconfig" + fi + if [[ -n "${pkg_config}" ]]; then + args+=(-DPKG_CONFIG_EXECUTABLE="${pkg_config}") + fi + fi + if [[ -n "${CMAKE_CROSSCOMPILING_EMULATOR}" ]]; then + args+=( + -DCMAKE_CROSSCOMPILING_EMULATOR="${CMAKE_CROSSCOMPILING_EMULATOR}" + ) + fi + if [[ -n "${CMAKE_FIND_ROOT_PATH}" ]]; then + args+=( + -DCMAKE_FIND_ROOT_PATH="${CMAKE_FIND_ROOT_PATH}" + ) + fi + if [[ -n "${CMAKE_PREFIX_PATH}" ]]; then + args+=( + -DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" + ) + fi + if [[ -n "${CMAKE_C_COMPILER_LAUNCHER}" ]]; then + args+=( + -DCMAKE_C_COMPILER_LAUNCHER="${CMAKE_C_COMPILER_LAUNCHER}" + ) + fi + if [[ -n "${CMAKE_CXX_COMPILER_LAUNCHER}" ]]; then + args+=( + -DCMAKE_CXX_COMPILER_LAUNCHER="${CMAKE_CXX_COMPILER_LAUNCHER}" + ) + fi + if [[ -n "${CMAKE_MAKE_PROGRAM}" ]]; then + args+=( + -DCMAKE_MAKE_PROGRAM="${CMAKE_MAKE_PROGRAM}" + ) + fi + cmake "${args[@]}" "$@" +} + +cmake_build_and_test() { + # gtest_discover_tests() runs the test binaries to discover the list of tests + # at build time, which fails under qemu. + ASAN_OPTIONS=detect_leaks=0 cmake --build "${BUILD_DIR}" -- all doc + # Pack test binaries if requested. + if [[ "${PACK_TEST:-}" == "1" ]]; then + (cd "${BUILD_DIR}" + ${FIND_BIN} -name '*.cmake' -a '!' -path '*CMakeFiles*' + ${FIND_BIN} -type d -name tests -a '!' -path '*CMakeFiles*' + ) | tar -C "${BUILD_DIR}" -cf "${BUILD_DIR}/tests.tar.xz" -T - \ + --use-compress-program="xz --threads=$(nproc --all || echo 1) -6" + du -h "${BUILD_DIR}/tests.tar.xz" + # Pack coverage data if also available. + touch "${BUILD_DIR}/gcno.sentinel" + (cd "${BUILD_DIR}"; echo gcno.sentinel; ${FIND_BIN} -name '*gcno') | \ + tar -C "${BUILD_DIR}" -cvf "${BUILD_DIR}/gcno.tar.xz" -T - \ + --use-compress-program="xz --threads=$(nproc --all || echo 1) -6" + fi + + if [[ "${SKIP_TEST}" -ne "1" ]]; then + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + ctest -j $(nproc --all || echo 1) --output-on-failure) + fi +} + +# Configure the build to strip unused functions. This considerably reduces the +# output size, specially for tests which only use a small part of the whole +# library. +strip_dead_code() { + # Emscripten does tree shaking without any extra flags. + if [[ "${CMAKE_TOOLCHAIN_FILE##*/}" == "Emscripten.cmake" ]]; then + return 0 + fi + # -ffunction-sections, -fdata-sections and -Wl,--gc-sections effectively + # discard all unreachable code, reducing the code size. For this to work, we + # need to also pass --no-export-dynamic to prevent it from exporting all the + # internal symbols (like functions) making them all reachable and thus not a + # candidate for removal. + CMAKE_CXX_FLAGS+=" -ffunction-sections -fdata-sections" + CMAKE_C_FLAGS+=" -ffunction-sections -fdata-sections" + if [[ "${OS}" == "Darwin" ]]; then + CMAKE_EXE_LINKER_FLAGS+=" -dead_strip" + CMAKE_SHARED_LINKER_FLAGS+=" -dead_strip" + else + CMAKE_EXE_LINKER_FLAGS+=" -Wl,--gc-sections -Wl,--no-export-dynamic" + CMAKE_SHARED_LINKER_FLAGS+=" -Wl,--gc-sections -Wl,--no-export-dynamic" + fi +} + +### Externally visible commands + +cmd_debug() { + CMAKE_BUILD_TYPE="Debug" + cmake_configure "$@" + cmake_build_and_test +} + +cmd_release() { + CMAKE_BUILD_TYPE="Release" + strip_dead_code + cmake_configure "$@" + cmake_build_and_test +} + +cmd_opt() { + CMAKE_BUILD_TYPE="RelWithDebInfo" + CMAKE_CXX_FLAGS+=" -DJXL_DEBUG_WARNING -DJXL_DEBUG_ON_ERROR" + cmake_configure "$@" + cmake_build_and_test +} + +cmd_coverage() { + # -O0 prohibits stack space reuse -> causes stack-overflow on dozens of tests. + TEST_STACK_LIMIT="none" + + cmd_release -DJPEGXL_ENABLE_COVERAGE=ON "$@" + + if [[ "${SKIP_TEST}" -ne "1" ]]; then + # If we didn't run the test we also don't print a coverage report. + cmd_coverage_report + fi +} + +cmd_coverage_report() { + LLVM_COV=$("${CC:-clang}" -print-prog-name=llvm-cov) + local real_build_dir=$(realpath "${BUILD_DIR}") + local gcovr_args=( + -r "${real_build_dir}" + --gcov-executable "${LLVM_COV} gcov" + # Only print coverage information for the jxl and fuif directories. The rest + # is not part of the code under test. + --filter '.*jxl/.*' + --exclude '.*_test.cc' + --object-directory "${real_build_dir}" + ) + + ( + cd "${real_build_dir}" + gcovr "${gcovr_args[@]}" --html --html-details \ + --output="${real_build_dir}/coverage.html" + gcovr "${gcovr_args[@]}" --print-summary | + tee "${real_build_dir}/coverage.txt" + gcovr "${gcovr_args[@]}" --xml --output="${real_build_dir}/coverage.xml" + ) +} + +cmd_test() { + export_env + # Unpack tests if needed. + if [[ -e "${BUILD_DIR}/tests.tar.xz" && ! -d "${BUILD_DIR}/tests" ]]; then + tar -C "${BUILD_DIR}" -Jxvf "${BUILD_DIR}/tests.tar.xz" + fi + if [[ -e "${BUILD_DIR}/gcno.tar.xz" && ! -d "${BUILD_DIR}/gcno.sentinel" ]]; then + tar -C "${BUILD_DIR}" -Jxvf "${BUILD_DIR}/gcno.tar.xz" + fi + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + ctest -j $(nproc --all || echo 1) --output-on-failure "$@") +} + +cmd_gbench() { + export_env + (cd "${BUILD_DIR}" + export UBSAN_OPTIONS=print_stacktrace=1 + lib/jxl_gbench \ + --benchmark_counters_tabular=true \ + --benchmark_out_format=json \ + --benchmark_out=gbench.json "$@" + ) +} + +cmd_asan() { + SANITIZER="asan" + CMAKE_C_FLAGS+=" -DJXL_ENABLE_ASSERT=1 -g -DADDRESS_SANITIZER \ + -fsanitize=address ${UBSAN_FLAGS[@]}" + CMAKE_CXX_FLAGS+=" -DJXL_ENABLE_ASSERT=1 -g -DADDRESS_SANITIZER \ + -fsanitize=address ${UBSAN_FLAGS[@]}" + strip_dead_code + cmake_configure "$@" -DJPEGXL_ENABLE_TCMALLOC=OFF + cmake_build_and_test +} + +cmd_tsan() { + SANITIZER="tsan" + local tsan_args=( + -DJXL_ENABLE_ASSERT=1 + -g + -DTHREAD_SANITIZER + ${UBSAN_FLAGS[@]} + -fsanitize=thread + ) + CMAKE_C_FLAGS+=" ${tsan_args[@]}" + CMAKE_CXX_FLAGS+=" ${tsan_args[@]}" + + CMAKE_BUILD_TYPE="RelWithDebInfo" + cmake_configure "$@" -DJPEGXL_ENABLE_TCMALLOC=OFF + cmake_build_and_test +} + +cmd_msan() { + SANITIZER="msan" + detect_clang_version + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + if [[ ! -d "${msan_prefix}" || -e "${msan_prefix}/lib/libc++abi.a" ]]; then + # Install msan libraries for this version if needed or if an older version + # with libc++abi was installed. + cmd_msan_install + fi + + local msan_c_flags=( + -fsanitize=memory + -fno-omit-frame-pointer + -fsanitize-memory-track-origins + + -DJXL_ENABLE_ASSERT=1 + -g + -DMEMORY_SANITIZER + + # Force gtest to not use the cxxbai. + -DGTEST_HAS_CXXABI_H_=0 + ) + local msan_cxx_flags=( + "${msan_c_flags[@]}" + + # Some C++ sources don't use the std at all, so the -stdlib=libc++ is unused + # in those cases. Ignore the warning. + -Wno-unused-command-line-argument + -stdlib=libc++ + + # We include the libc++ from the msan directory instead, so we don't want + # the std includes. + -nostdinc++ + -cxx-isystem"${msan_prefix}/include/c++/v1" + ) + + local msan_linker_flags=( + -L"${msan_prefix}"/lib + -Wl,-rpath -Wl,"${msan_prefix}"/lib/ + ) + + CMAKE_C_FLAGS+=" ${msan_c_flags[@]} ${UBSAN_FLAGS[@]}" + CMAKE_CXX_FLAGS+=" ${msan_cxx_flags[@]} ${UBSAN_FLAGS[@]}" + CMAKE_EXE_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + CMAKE_MODULE_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + CMAKE_SHARED_LINKER_FLAGS+=" ${msan_linker_flags[@]}" + strip_dead_code + cmake_configure "$@" \ + -DCMAKE_CROSSCOMPILING=1 -DRUN_HAVE_STD_REGEX=0 -DRUN_HAVE_POSIX_REGEX=0 \ + -DJPEGXL_ENABLE_TCMALLOC=OFF + cmake_build_and_test +} + +# Install libc++ libraries compiled with msan in the msan_prefix for the current +# compiler version. +cmd_msan_install() { + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + # Detect the llvm to install: + export CC="${CC:-clang}" + export CXX="${CXX:-clang++}" + detect_clang_version + local llvm_tag="llvmorg-${CLANG_VERSION}.0.0" + case "${CLANG_VERSION}" in + "6.0") + llvm_tag="llvmorg-6.0.1" + ;; + "7") + llvm_tag="llvmorg-7.0.1" + ;; + esac + local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz" + curl -L --show-error -o "${llvm_targz}" \ + "https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz" + tar -C "${tmpdir}" -zxf "${llvm_targz}" + local llvm_root="${tmpdir}/llvm-project-${llvm_tag}" + + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + rm -rf "${msan_prefix}" + + declare -A CMAKE_EXTRAS + CMAKE_EXTRAS[libcxx]="\ + -DLIBCXX_CXX_ABI=libstdc++ \ + -DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON" + + for project in libcxx; do + local proj_build="${tmpdir}/build-${project}" + local proj_dir="${llvm_root}/${project}" + mkdir -p "${proj_build}" + cmake -B"${proj_build}" -H"${proj_dir}" \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_USE_SANITIZER=Memory \ + -DLLVM_PATH="${llvm_root}/llvm" \ + -DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \ + head -n1)" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \ + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \ + -DCMAKE_SHARED_LINKER_FLAGS="${CMAKE_SHARED_LINKER_FLAGS}" \ + -DCMAKE_INSTALL_PREFIX="${msan_prefix}" \ + ${CMAKE_EXTRAS[${project}]} + cmake --build "${proj_build}" + ninja -C "${proj_build}" install + done +} + +cmd_fast_benchmark() { + local small_corpus_tar="${BENCHMARK_CORPORA}/jyrki-full.tar" + mkdir -p "${BENCHMARK_CORPORA}" + curl --show-error -o "${small_corpus_tar}" -z "${small_corpus_tar}" \ + "https://storage.googleapis.com/artifacts.jpegxl.appspot.com/corpora/jyrki-full.tar" + + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + tar -xf "${small_corpus_tar}" -C "${tmpdir}" + + run_benchmark "${tmpdir}" 1048576 +} + +cmd_benchmark() { + local nikon_corpus_tar="${BENCHMARK_CORPORA}/nikon-subset.tar" + mkdir -p "${BENCHMARK_CORPORA}" + curl --show-error -o "${nikon_corpus_tar}" -z "${nikon_corpus_tar}" \ + "https://storage.googleapis.com/artifacts.jpegxl.appspot.com/corpora/nikon-subset.tar" + + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + tar -xvf "${nikon_corpus_tar}" -C "${tmpdir}" + + local sem_id="jpegxl_benchmark-$$" + local nprocs=$(nproc --all || echo 1) + images=() + local filename + while IFS= read -r filename; do + # This removes the './' + filename="${filename:2}" + local mode + if [[ "${filename:0:4}" == "srgb" ]]; then + mode="RGB_D65_SRG_Rel_SRG" + elif [[ "${filename:0:5}" == "adobe" ]]; then + mode="RGB_D65_Ado_Rel_Ado" + else + echo "Unknown image colorspace: ${filename}" >&2 + exit 1 + fi + png_filename="${filename%.ppm}.png" + png_filename=$(echo "${png_filename}" | tr '/' '_') + sem --bg --id "${sem_id}" -j"${nprocs}" -- \ + "${BUILD_DIR}/tools/decode_and_encode" \ + "${tmpdir}/${filename}" "${mode}" "${tmpdir}/${png_filename}" + images+=( "${png_filename}" ) + done < <(cd "${tmpdir}"; ${FIND_BIN} . -name '*.ppm' -type f) + sem --id "${sem_id}" --wait + + # We need about 10 GiB per thread on these images. + run_benchmark "${tmpdir}" 10485760 +} + +get_mem_available() { + if [[ "${OS}" == "Darwin" ]]; then + echo $(vm_stat | grep -F 'Pages free:' | awk '{print $3 * 4}') + else + echo $(grep -F MemAvailable: /proc/meminfo | awk '{print $2}') + fi +} + +run_benchmark() { + local src_img_dir="$1" + local mem_per_thread="${2:-10485760}" + + local output_dir="${BUILD_DIR}/benchmark_results" + mkdir -p "${output_dir}" + + # The memory available at the beginning of the benchmark run in kB. The number + # of threads depends on the available memory, and the passed memory per + # thread. We also add a 2 GiB of constant memory. + local mem_available="$(get_mem_available)" + # Check that we actually have a MemAvailable value. + [[ -n "${mem_available}" ]] + local num_threads=$(( (${mem_available} - 1048576) / ${mem_per_thread} )) + if [[ ${num_threads} -le 0 ]]; then + num_threads=1 + fi + + local benchmark_args=( + --input "${src_img_dir}/*.png" + --codec=jpeg:yuv420:q85,webp:q80,jxl:fast:d1,jxl:fast:d1:downsampling=8,jxl:fast:d4,jxl:fast:d4:downsampling=8,jxl:m:cheetah:nl,jxl:cheetah:m,jxl:m:cheetah:P6,jxl:m:falcon:q80 + --output_dir "${output_dir}" + --noprofiler --show_progress + --num_threads="${num_threads}" + ) + if [[ "${STORE_IMAGES}" == "1" ]]; then + benchmark_args+=(--save_decompressed --save_compressed) + fi + ( + [[ "${TEST_STACK_LIMIT}" == "none" ]] || ulimit -s "${TEST_STACK_LIMIT}" + "${BUILD_DIR}/tools/benchmark_xl" "${benchmark_args[@]}" | \ + tee "${output_dir}/results.txt" + + # Check error code for benckmark_xl command. This will exit if not. + return ${PIPESTATUS[0]} + ) + + if [[ -n "${CI_BUILD_NAME:-}" ]]; then + { set +x; } 2>/dev/null + local message="Results for ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} (job ${CI_JOB_URL:-}): + +$(cat "${output_dir}/results.txt") +" + cmd_post_mr_comment "${message}" + set -x + fi +} + +# Helper function to wait for the CPU temperature to cool down on ARM. +wait_for_temp() { + { set +x; } 2>/dev/null + local temp_limit=${1:-37000} + if [[ -z "${THERMAL_FILE:-}" ]]; then + echo "Must define the THERMAL_FILE with the thermal_zoneX/temp file" \ + "to read the temperature from. This is normally set in the runner." >&2 + exit 1 + fi + local org_temp=$(cat "${THERMAL_FILE}") + if [[ "${org_temp}" -ge "${temp_limit}" ]]; then + echo -n "Waiting for temp to get down from ${org_temp}... " + fi + local temp="${org_temp}" + while [[ "${temp}" -ge "${temp_limit}" ]]; do + sleep 1 + temp=$(cat "${THERMAL_FILE}") + done + if [[ "${org_temp}" -ge "${temp_limit}" ]]; then + echo "Done, temp=${temp}" + fi + set -x +} + +# Helper function to set the cpuset restriction of the current process. +cmd_cpuset() { + [[ "${SKIP_CPUSET:-}" != "1" ]] || return 0 + local newset="$1" + local mycpuset=$(cat /proc/self/cpuset) + mycpuset="/dev/cpuset${mycpuset}" + # Check that the directory exists: + [[ -d "${mycpuset}" ]] + if [[ -e "${mycpuset}/cpuset.cpus" ]]; then + echo "${newset}" >"${mycpuset}/cpuset.cpus" + else + echo "${newset}" >"${mycpuset}/cpus" + fi +} + +# Return the encoding/decoding speed from the Stats output. +_speed_from_output() { + local speed="$1" + local unit="${2:-MP/s}" + if [[ "${speed}" == *"${unit}"* ]]; then + speed="${speed%% ${unit}*}" + speed="${speed##* }" + echo "${speed}" + fi +} + + +# Run benchmarks on ARM for the big and little CPUs. +cmd_arm_benchmark() { + # Flags used for cjxl encoder with .png inputs + local jxl_png_benchmarks=( + # Lossy options: + "--epf=0 --distance=1.0 --speed=cheetah" + "--epf=2 --distance=1.0 --speed=cheetah" + "--epf=0 --distance=8.0 --speed=cheetah" + "--epf=1 --distance=8.0 --speed=cheetah" + "--epf=2 --distance=8.0 --speed=cheetah" + "--epf=3 --distance=8.0 --speed=cheetah" + "--modular -Q 90" + "--modular -Q 50" + # Lossless options: + "--modular" + "--modular -E 0 -I 0" + "--modular -P 5" + "--modular --responsive=1" + # Near-lossless options: + "--epf=0 --distance=0.3 --speed=fast" + "--modular -N 3 -I 0" + "--modular -Q 97" + ) + + # Flags used for cjxl encoder with .jpg inputs. These should do lossless + # JPEG recompression (of pixels or full jpeg). + local jxl_jpeg_benchmarks=( + "--num_reps=3" + ) + + local images=( + "third_party/testdata/imagecompression.info/flower_foveon.png" + ) + + local jpg_images=( + "third_party/testdata/imagecompression.info/flower_foveon.png.im_q85_420.jpg" + ) + + if [[ "${SKIP_CPUSET:-}" == "1" ]]; then + # Use a single cpu config in this case. + local cpu_confs=("?") + else + # Otherwise the CPU config comes from the environment: + local cpu_confs=( + "${RUNNER_CPU_LITTLE}" + "${RUNNER_CPU_BIG}" + # The CPU description is something like 3-7, so these configurations only + # take the first CPU of the group. + "${RUNNER_CPU_LITTLE%%-*}" + "${RUNNER_CPU_BIG%%-*}" + ) + # Check that RUNNER_CPU_ALL is defined. In the SKIP_CPUSET=1 case this will + # be ignored but still evaluated when calling cmd_cpuset. + [[ -n "${RUNNER_CPU_ALL}" ]] + fi + + local jpg_dirname="third_party/corpora/jpeg" + mkdir -p "${jpg_dirname}" + local jpg_qualities=( 50 80 95 ) + for src_img in "${images[@]}"; do + for q in "${jpg_qualities[@]}"; do + local jpeg_name="${jpg_dirname}/"$(basename "${src_img}" .png)"-q${q}.jpg" + convert -sampling-factor 1x1 -quality "${q}" \ + "${src_img}" "${jpeg_name}" + jpg_images+=("${jpeg_name}") + done + done + + local output_dir="${BUILD_DIR}/benchmark_results" + mkdir -p "${output_dir}" + local runs_file="${output_dir}/runs.txt" + + if [[ ! -e "${runs_file}" ]]; then + echo -e "binary\tflags\tsrc_img\tsrc size\tsrc pixels\tcpuset\tenc size (B)\tenc speed (MP/s)\tdec speed (MP/s)\tJPG dec speed (MP/s)\tJPG dec speed (MB/s)" | + tee -a "${runs_file}" + fi + + mkdir -p "${BUILD_DIR}/arm_benchmark" + local flags + local src_img + for src_img in "${jpg_images[@]}" "${images[@]}"; do + local src_img_hash=$(sha1sum "${src_img}" | cut -f 1 -d ' ') + local enc_binaries=("${BUILD_DIR}/tools/cjxl") + local src_ext="${src_img##*.}" + for enc_binary in "${enc_binaries[@]}"; do + local enc_binary_base=$(basename "${enc_binary}") + + # Select the list of flags to use for the current encoder/image pair. + local img_benchmarks + if [[ "${src_ext}" == "jpg" ]]; then + img_benchmarks=("${jxl_jpeg_benchmarks[@]}") + else + img_benchmarks=("${jxl_png_benchmarks[@]}") + fi + + for flags in "${img_benchmarks[@]}"; do + # Encoding step. + local enc_file_hash="${enc_binary_base} || $flags || ${src_img} || ${src_img_hash}" + enc_file_hash=$(echo "${enc_file_hash}" | sha1sum | cut -f 1 -d ' ') + local enc_file="${BUILD_DIR}/arm_benchmark/${enc_file_hash}.jxl" + + for cpu_conf in "${cpu_confs[@]}"; do + cmd_cpuset "${cpu_conf}" + # nproc returns the number of active CPUs, which is given by the cpuset + # mask. + local num_threads="$(nproc)" + + echo "Encoding with: ${enc_binary_base} img=${src_img} cpus=${cpu_conf} enc_flags=${flags}" + local enc_output + if [[ "${flags}" == *"modular"* ]]; then + # We don't benchmark encoding speed in this case. + if [[ ! -f "${enc_file}" ]]; then + cmd_cpuset "${RUNNER_CPU_ALL:-}" + "${enc_binary}" ${flags} "${src_img}" "${enc_file}.tmp" + mv "${enc_file}.tmp" "${enc_file}" + cmd_cpuset "${cpu_conf}" + fi + enc_output=" ?? MP/s" + else + wait_for_temp + enc_output=$("${enc_binary}" ${flags} "${src_img}" "${enc_file}.tmp" \ + 2>&1 | tee /dev/stderr | grep -F "MP/s [") + mv "${enc_file}.tmp" "${enc_file}" + fi + local enc_speed=$(_speed_from_output "${enc_output}") + local enc_size=$(stat -c "%s" "${enc_file}") + + echo "Decoding with: img=${src_img} cpus=${cpu_conf} enc_flags=${flags}" + + local dec_output + wait_for_temp + dec_output=$("${BUILD_DIR}/tools/djxl" "${enc_file}" \ + --num_reps=5 --num_threads="${num_threads}" 2>&1 | tee /dev/stderr | + grep -E "M[BP]/s \[") + local img_size=$(echo "${dec_output}" | cut -f 1 -d ',') + local img_size_x=$(echo "${img_size}" | cut -f 1 -d ' ') + local img_size_y=$(echo "${img_size}" | cut -f 3 -d ' ') + local img_size_px=$(( ${img_size_x} * ${img_size_y} )) + local dec_speed=$(_speed_from_output "${dec_output}") + + # For JPEG lossless recompression modes (where the original is a JPEG) + # decode to JPG as well. + local jpeg_dec_mps_speed="" + local jpeg_dec_mbs_speed="" + if [[ "${src_ext}" == "jpg" ]]; then + wait_for_temp + local dec_file="${BUILD_DIR}/arm_benchmark/${enc_file_hash}.jpg" + dec_output=$("${BUILD_DIR}/tools/djxl" "${enc_file}" \ + "${dec_file}" --num_reps=5 --num_threads="${num_threads}" 2>&1 | \ + tee /dev/stderr | grep -E "M[BP]/s \[") + local jpeg_dec_mps_speed=$(_speed_from_output "${dec_output}") + local jpeg_dec_mbs_speed=$(_speed_from_output "${dec_output}" MB/s) + if ! cmp --quiet "${src_img}" "${dec_file}"; then + # Add a start at the end to signal that the files are different. + jpeg_dec_mbs_speed+="*" + fi + fi + + # Record entry in a tab-separated file. + local src_img_base=$(basename "${src_img}") + echo -e "${enc_binary_base}\t${flags}\t${src_img_base}\t${img_size}\t${img_size_px}\t${cpu_conf}\t${enc_size}\t${enc_speed}\t${dec_speed}\t${jpeg_dec_mps_speed}\t${jpeg_dec_mbs_speed}" | + tee -a "${runs_file}" + done + done + done + done + cmd_cpuset "${RUNNER_CPU_ALL:-}" + cat "${runs_file}" + + if [[ -n "${CI_BUILD_NAME:-}" ]]; then + load_mr_vars_from_commit + { set +x; } 2>/dev/null + local message="Results for ${CI_BUILD_NAME} @ ${CI_COMMIT_SHORT_SHA} (job ${CI_JOB_URL:-}): + +\`\`\` +$(column -t -s " " "${runs_file}") +\`\`\` +" + cmd_post_mr_comment "${message}" + set -x + fi +} + +# Generate a corpus and run the fuzzer on that corpus. +cmd_fuzz() { + local corpus_dir=$(realpath "${BUILD_DIR}/fuzzer_corpus") + local fuzzer_crash_dir=$(realpath "${BUILD_DIR}/fuzzer_crash") + mkdir -p "${corpus_dir}" "${fuzzer_crash_dir}" + # Generate step. + "${BUILD_DIR}/tools/fuzzer_corpus" "${corpus_dir}" + # Run step: + local nprocs=$(nproc --all || echo 1) + ( + cd "${BUILD_DIR}" + "tools/djxl_fuzzer" "${fuzzer_crash_dir}" "${corpus_dir}" \ + -max_total_time="${FUZZER_MAX_TIME}" -jobs=${nprocs} \ + -artifact_prefix="${fuzzer_crash_dir}/" + ) +} + +# Runs the linter (clang-format) on the pending CLs. +cmd_lint() { + merge_request_commits + { set +x; } 2>/dev/null + local versions=(${1:-6.0 7 8 9}) + local clang_format_bins=("${versions[@]/#/clang-format-}" clang-format) + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + + local ret=0 + local build_patch="${tmpdir}/build_cleaner.patch" + if ! "${MYDIR}/tools/build_cleaner.py" >"${build_patch}"; then + ret=1 + echo "build_cleaner.py findings:" >&2 + "${COLORDIFF_BIN}" <"${build_patch}" + echo "Run \`tools/build_cleaner.py --update\` to apply them" >&2 + fi + + local installed=() + local clang_patch + local clang_format + for clang_format in "${clang_format_bins[@]}"; do + if ! which "${clang_format}" >/dev/null; then + continue + fi + installed+=("${clang_format}") + local tmppatch="${tmpdir}/${clang_format}.patch" + # We include in this linter all the changes including the uncommited changes + # to avoid printing changes already applied. + set -x + git -C "${MYDIR}" "${clang_format}" --binary "${clang_format}" \ + --style=file --diff "${MR_ANCESTOR_SHA}" -- >"${tmppatch}" + { set +x; } 2>/dev/null + + if grep -E '^--- ' "${tmppatch}">/dev/null; then + if [[ -n "${LINT_OUTPUT:-}" ]]; then + cp "${tmppatch}" "${LINT_OUTPUT}" + fi + clang_patch="${tmppatch}" + else + echo "clang-format check OK" >&2 + return ${ret} + fi + done + + if [[ ${#installed[@]} -eq 0 ]]; then + echo "You must install clang-format for \"git clang-format\"" >&2 + exit 1 + fi + + # clang-format is installed but found problems. + echo "clang-format findings:" >&2 + "${COLORDIFF_BIN}" < "${clang_patch}" + + echo "clang-format found issues in your patches from ${MR_ANCESTOR_SHA}" \ + "to the current patch. Run \`./ci.sh lint | patch -p1\` from the base" \ + "directory to apply them." >&2 + exit 1 +} + +# Runs clang-tidy on the pending CLs. If the "all" argument is passed it runs +# clang-tidy over all the source files instead. +cmd_tidy() { + local what="${1:-}" + + if [[ -z "${CLANG_TIDY_BIN}" ]]; then + echo "ERROR: You must install clang-tidy-7 or newer to use ci.sh tidy" >&2 + exit 1 + fi + + local git_args=() + if [[ "${what}" == "all" ]]; then + git_args=(ls-files) + shift + else + merge_request_commits + git_args=( + diff-tree --no-commit-id --name-only -r "${MR_ANCESTOR_SHA}" + "${MR_HEAD_SHA}" + ) + fi + + # Clang-tidy needs the compilation database generated by cmake. + if [[ ! -e "${BUILD_DIR}/compile_commands.json" ]]; then + # Generate the build options in debug mode, since we need the debug asserts + # enabled for the clang-tidy analyzer to use them. + CMAKE_BUILD_TYPE="Debug" + cmake_configure + # Build the autogen targets to generate the .h files from the .ui files. + local autogen_targets=( + $(ninja -C "${BUILD_DIR}" -t targets | grep -F _autogen: | + cut -f 1 -d :) + ) + if [[ ${#autogen_targets[@]} != 0 ]]; then + ninja -C "${BUILD_DIR}" "${autogen_targets[@]}" + fi + fi + + cd "${MYDIR}" + local nprocs=$(nproc --all || echo 1) + local ret=0 + if ! parallel -j"${nprocs}" --keep-order -- \ + "${CLANG_TIDY_BIN}" -p "${BUILD_DIR}" -format-style=file -quiet "$@" {} \ + < <(git "${git_args[@]}" | grep -E '(\.cc|\.cpp)$') \ + >"${BUILD_DIR}/clang-tidy.txt"; then + ret=1 + fi + { set +x; } 2>/dev/null + echo "Findings statistics:" >&2 + grep -E ' \[[A-Za-z\.,\-]+\]' -o "${BUILD_DIR}/clang-tidy.txt" | sort \ + | uniq -c >&2 + + if [[ $ret -ne 0 ]]; then + cat >&2 </dev/null + local debsdir="${BUILD_DIR}/debs" + local f + while IFS='' read -r -d '' f; do + echo "=====================================================================" + echo "Package $f:" + dpkg --info $f + dpkg --contents $f + done < <(find "${BUILD_DIR}/debs" -maxdepth 1 -mindepth 1 -type f \ + -name '*.deb' -print0) +} + +build_debian_pkg() { + local srcdir="$1" + local srcpkg="$2" + + local debsdir="${BUILD_DIR}/debs" + local builddir="${debsdir}/${srcpkg}" + + # debuild doesn't have an easy way to build out of tree, so we make a copy + # of with all symlinks on the first level. + mkdir -p "${builddir}" + for f in $(find "${srcdir}" -mindepth 1 -maxdepth 1 -printf '%P\n'); do + if [[ ! -L "${builddir}/$f" ]]; then + rm -f "${builddir}/$f" + ln -s "${srcdir}/$f" "${builddir}/$f" + fi + done + ( + cd "${builddir}" + debuild -b -uc -us + ) +} + +cmd_debian_build() { + local srcpkg="${1:-}" + + case "${srcpkg}" in + jpeg-xl) + build_debian_pkg "${MYDIR}" "jpeg-xl" + ;; + highway) + build_debian_pkg "${MYDIR}/third_party/highway" "highway" + ;; + *) + echo "ERROR: Must pass a valid source package name to build." >&2 + ;; + esac +} + +main() { + local cmd="${1:-}" + if [[ -z "${cmd}" ]]; then + cat >&2 < Build the given source package. + debian_stats Print stats about the built packages. + +You can pass some optional environment variables as well: + - BUILD_DIR: The output build directory (by default "$$repo/build") + - BUILD_TARGET: The target triplet used when cross-compiling. + - CMAKE_FLAGS: Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS. + - CMAKE_PREFIX_PATH: Installation prefixes to be searched by the find_package. + - ENABLE_WASM_SIMD=1: enable experimental SIMD in WASM build (only). + - FUZZER_MAX_TIME: "fuzz" command fuzzer running timeout in seconds. + - LINT_OUTPUT: Path to the output patch from the "lint" command. + - SKIP_CPUSET=1: Skip modifying the cpuset in the arm_benchmark. + - SKIP_TEST=1: Skip the test stage. + - STORE_IMAGES=0: Makes the benchmark discard the computed images. + - TEST_STACK_LIMIT: Stack size limit (ulimit -s) during tests, in KiB. + - STACK_SIZE=1: Generate binaries with the .stack_sizes sections. + +These optional environment variables are forwarded to the cmake call as +parameters: + - CMAKE_BUILD_TYPE + - CMAKE_C_FLAGS + - CMAKE_CXX_FLAGS + - CMAKE_C_COMPILER_LAUNCHER + - CMAKE_CXX_COMPILER_LAUNCHER + - CMAKE_CROSSCOMPILING_EMULATOR + - CMAKE_FIND_ROOT_PATH + - CMAKE_EXE_LINKER_FLAGS + - CMAKE_MAKE_PROGRAM + - CMAKE_MODULE_LINKER_FLAGS + - CMAKE_SHARED_LINKER_FLAGS + - CMAKE_TOOLCHAIN_FILE + +Example: + BUILD_DIR=/tmp/build $0 opt +EOF + exit 1 + fi + + cmd="cmd_${cmd}" + shift + set -x + "${cmd}" "$@" +} + +main "$@" diff --git a/third_party/jpeg-xl/debian/changelog b/third_party/jpeg-xl/debian/changelog new file mode 100644 index 000000000000..0b472e9c3a7b --- /dev/null +++ b/third_party/jpeg-xl/debian/changelog @@ -0,0 +1,86 @@ +jpeg-xl (0.3.7) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.7. + * Fix a rounding issue in 8-bit decoding. + + -- Sami Boukortt Mon, 29 Mar 2021 12:14:20 +0200 + +jpeg-xl (0.3.6) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.6. + * Fix a bug that could result in the generation of invalid codestreams as + well as failure to decode valid streams. + + -- Sami Boukortt Thu, 25 Mar 2021 17:40:58 +0100 + +jpeg-xl (0.3.5) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.5. + * Memory usage improvements. + * New encode-time options for faster decoding at the cost of quality. + * Faster decoding to 8-bit output with the C API. + * GIMP plugin: avoid the sRGB conversion dialog for sRGB images, do not show + a console window on Windows. + * Various bug fixes. + * Man pages for cjxl and djxl. + + -- Sami Boukortt Tue, 23 Mar 2021 15:20:44 +0100 + +jpeg-xl (0.3.4) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.4. + * Improved box parsing. + * Improved metadata handling. + * Performance and memory usage improvements. + + -- Sami Boukortt Tue, 16 Mar 2021 12:13:59 +0100 + +jpeg-xl (0.3.3) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.3. + * Performance improvements for small images. + * Add a (flag-protected) non-high-precision mode with better speed. + * Significantly speed up the PQ EOTF. + * Allow optional HDR tone mapping in djxl (--tone_map, --display_nits). + * Change the behavior of djxl -j to make it consistent with cjxl (#153). + * Improve image quality. + * Improve EXIF handling. + + -- Sami Boukortt Fri, 5 Mar 2021 19:15:26 +0100 + +jpeg-xl (0.3.2) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.2. + * Fix embedded ICC encoding regression #149. + + -- Alex Deymo Fri, 12 Feb 2021 21:00:12 +0100 + +jpeg-xl (0.3.1) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3.1. + + -- Alex Deymo Tue, 09 Feb 2021 09:48:43 +0100 + +jpeg-xl (0.3) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.3. + + -- Alex Deymo Wed, 27 Jan 2021 22:36:32 +0100 + +jpeg-xl (0.2) UNRELEASED; urgency=medium + + * Bump JPEG XL version to 0.2. + + -- Alex Deymo Wed, 23 Nov 2020 20:42:10 +0100 + +jpeg-xl (0.1) UNRELEASED; urgency=medium + + * JPEG XL format release candidate. + + -- Alex Deymo Fri, 13 Nov 2020 17:42:24 +0100 + +jpeg-xl (0.0.2-1) UNRELEASED; urgency=medium + + * Initial debian package. + + -- Alex Deymo Tue, 27 Oct 2020 15:27:59 +0100 diff --git a/third_party/jpeg-xl/debian/compat b/third_party/jpeg-xl/debian/compat new file mode 100644 index 000000000000..f599e28b8ab0 --- /dev/null +++ b/third_party/jpeg-xl/debian/compat @@ -0,0 +1 @@ +10 diff --git a/third_party/jpeg-xl/debian/control b/third_party/jpeg-xl/debian/control new file mode 100644 index 000000000000..f4cde67674f6 --- /dev/null +++ b/third_party/jpeg-xl/debian/control @@ -0,0 +1,61 @@ +Source: jpeg-xl +Maintainer: JPEG XL Maintainers +Section: misc +Priority: optional +Standards-Version: 3.9.8 +Build-Depends: cmake, + debhelper (>= 9), + libbrotli-dev, + libgif-dev, + libgmock-dev, + libgoogle-perftools-dev, + libgtest-dev, + libhwy-dev, + libjpeg-dev, + libopenexr-dev, + libpng-dev, + libwebp-dev, + pkg-config, +Homepage: https://gitlab.com/wg1/jpeg-xl +Rules-Requires-Root: no + +Package: jxl +Architecture: any +Section: utils +Depends: ${misc:Depends}, ${shlibs:Depends} +Description: JPEG XL Image Coding System - "JXL" (command line utility) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs the command line utilities. + +Package: libjxl-dev +Architecture: any +Section: libdevel +Depends: libjxl (= ${binary:Version}), ${misc:Depends} +Description: JPEG XL Image Coding System - "JXL" (development files) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs development files. + +Package: libjxl +Architecture: any +Multi-Arch: same +Section: libs +Depends: ${shlibs:Depends}, ${misc:Depends} +Pre-Depends: ${misc:Pre-Depends} +Description: JPEG XL Image Coding System - "JXL" (shared libraries) + The JPEG XL Image Coding System (ISO/IEC 18181) is a lossy and + lossless image compression format. It has a rich feature set and is + particularly optimized for responsive web environments, so that + content renders well on a wide range of devices. Moreover, it includes + several features that help transition from the legacy JPEG format. + . + This package installs shared libraries. diff --git a/third_party/jpeg-xl/debian/copyright b/third_party/jpeg-xl/debian/copyright new file mode 100644 index 000000000000..c01ec81e6071 --- /dev/null +++ b/third_party/jpeg-xl/debian/copyright @@ -0,0 +1,236 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ +Upstream-Name: jpeg-xl + +Files: * +Copyright: 2020 the JPEG XL Project +License: Apache-2.0 + +Files: third_party/sjpeg/* +Copyright: 2017 Google, Inc +License: Apache-2.0 + +Files: third_party/lodepng/* +Copyright: 2005-2018 Lode Vandevenne +License: Zlib License + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + . + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + . + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + . + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + . + 3. This notice may not be removed or altered from any source + distribution. + +Files: third_party/skcms/* +Copyright: 2018 Google Inc. +License: BSD-3-clause + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + . + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of Google Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + . + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Files: third_party/testdata/imagecompression.info/* +Copyright: their respective owners. +License: License without any prohibitive copyright restrictions. + See https://imagecompression.info/test_images/ for details. + . + These Images are available without any prohibitive copyright restrictions. + . + These images are (c) there respective owners. You are granted full + redistribution and publication rights on these images provided: + . + 1. The origin of the pictures must not be misrepresented; you must not claim + that you took the original pictures. If you use, publish or redistribute them, + an acknowledgment would be appreciated but is not required. + 2. Altered versions must be plainly marked as such, and must not be + misinterpreted as being the originals. + 3. No payment is required for distribution of this material, it must be + available freely under the conditions stated here. That is, it is prohibited to + sell the material. + 4. This notice may not be removed or altered from any distribution. + +Files: third_party/testdata/pngsuite/* +Copyright: Willem van Schaik, 1996, 2011 +License: PngSuite License + See http://www.schaik.com/pngsuite/ for details. + . + Permission to use, copy, modify and distribute these images for any + purpose and without fee is hereby granted. + +Files: third_party/testdata/raw.pixls/* +Copyright: their respective owners listed in https://raw.pixls.us/ +License: CC0-1.0 + +Files: third_party/testdata/raw.pixls/* +Copyright: their respective owners listed in https://www.wesaturate.com/ +License: CC0-1.0 + +Files: third_party/testdata/wide-gamut-tests/ +Copyright: github.com/codelogic/wide-gamut-tests authors. +License: Apache-2.0 + +License: Apache-2.0 + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + . + http://www.apache.org/licenses/LICENSE-2.0 + . + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + . + On Debian systems, the complete text of the Apache License, Version 2 + can be found in "/usr/share/common-licenses/Apache-2.0". + +License: CC0 + Creative Commons Zero v1.0 Universal + . + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL + SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN ATTORNEY-CLIENT + RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" + BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE OF THIS + DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS + LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE + INFORMATION OR WORKS PROVIDED HEREUNDER. + . + Statement of Purpose + . + The laws of most jurisdictions throughout the world automatically confer + exclusive Copyright and Related Rights (defined below) upon the creator and + subsequent owner(s) (each and all, an "owner") of an original work of + authorship and/or a database (each, a "Work"). + . + Certain owners wish to permanently relinquish those rights to a Work for the + purpose of contributing to a commons of creative, cultural and scientific + works ("Commons") that the public can reliably and without fear of later + claims of infringement build upon, modify, incorporate in other works, reuse + and redistribute as freely as possible in any form whatsoever and for any + purposes, including without limitation commercial purposes. These owners may + contribute to the Commons to promote the ideal of a free culture and the + further production of creative, cultural and scientific works, or to gain + reputation or greater distribution for their Work in part through the use + and efforts of others. + . + For these and/or other purposes and motivations, and without any expectation + of additional consideration or compensation, the person associating CC0 with + a Work (the "Affirmer"), to the extent that he or she is an owner of + Copyright and Related Rights in the Work, voluntarily elects to apply CC0 to + the Work and publicly distribute the Work under its terms, with knowledge of + his or her Copyright and Related Rights in the Work and the meaning and + intended legal effect of CC0 on those rights. + . + 1. Copyright and Related Rights. A Work made available under CC0 may be + protected by copyright and related or neighboring rights ("Copyright and + Related Rights"). Copyright and Related Rights include, but are not limited + to, the following: + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); + iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation thereof, + including any amended or successor version of such directive); and + vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national implementations + thereof. + . + 2. Waiver. To the greatest extent permitted by, but not in contravention of, + applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and + unconditionally waives, abandons, and surrenders all of Affirmer's Copyright + and Related Rights and associated claims and causes of action, whether now + known or unknown (including existing as well as future claims and causes of + action), in the Work (i) in all territories worldwide, (ii) for the maximum + duration provided by applicable law or treaty (including future time + extensions), (iii) in any current or future medium and for any number of + copies, and (iv) for any purpose whatsoever, including without limitation + commercial, advertising or promotional purposes (the "Waiver"). Affirmer + makes the Waiver for the benefit of each member of the public at large and + to the detriment of Affirmer's heirs and successors, fully intending that + such Waiver shall not be subject to revocation, rescission, cancellation, + termination, or any other legal or equitable action to disrupt the quiet + enjoyment of the Work by the public as contemplated by Affirmer's express + Statement of Purpose. + . + 3. Public License Fallback. Should any part of the Waiver for any reason be + judged legally invalid or ineffective under applicable law, then the Waiver + shall be preserved to the maximum extent permitted taking into account + Affirmer's express Statement of Purpose. In addition, to the extent the + Waiver is so judged Affirmer hereby grants to each affected person a + royalty-free, non transferable, non sublicensable, non exclusive, + irrevocable and unconditional license to exercise Affirmer's Copyright and + Related Rights in the Work (i) in all territories worldwide, (ii) for the + maximum duration provided by applicable law or treaty (including future time + extensions), (iii) in any current or future medium and for any number of + copies, and (iv) for any purpose whatsoever, including without limitation + commercial, advertising or promotional purposes (the "License"). The License + shall be deemed effective as of the date CC0 was applied by Affirmer to the + Work. Should any part of the License for any reason be judged legally + invalid or ineffective under applicable law, such partial invalidity or + ineffectiveness shall not invalidate the remainder of the License, and in + such case Affirmer hereby affirms that he or she will not (i) exercise any + of his or her remaining Copyright and Related Rights in the Work or (ii) + assert any associated claims and causes of action with respect to the Work, + in either case contrary to Affirmer's express Statement of Purpose. + . + 4. Limitations and Disclaimers. + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, statutory or + otherwise, including without limitation warranties of title, + merchantability, fitness for a particular purpose, non infringement, or the + absence of latent or other defects, accuracy, or the present or absence of + errors, whether or not discoverable, all to the greatest extent permissible + under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without limitation + any person's Copyright and Related Rights in the Work. Further, Affirmer + disclaims responsibility for obtaining any necessary consents, permissions + or other rights required for any use of the Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to this + CC0 or use of the Work. + . + For more information, please see: + http://creativecommons.org/publicdomain/zero/1.0/> + diff --git a/third_party/jpeg-xl/debian/jxl.install b/third_party/jpeg-xl/debian/jxl.install new file mode 100644 index 000000000000..8ffdce84e93a --- /dev/null +++ b/third_party/jpeg-xl/debian/jxl.install @@ -0,0 +1 @@ +debian/tmp/usr/bin/* diff --git a/third_party/jpeg-xl/debian/libjxl-dev.install b/third_party/jpeg-xl/debian/libjxl-dev.install new file mode 100644 index 000000000000..26dbec0e7115 --- /dev/null +++ b/third_party/jpeg-xl/debian/libjxl-dev.install @@ -0,0 +1,4 @@ +debian/tmp/usr/include/jxl/*.h +debian/tmp/usr/lib/*/*.a +debian/tmp/usr/lib/*/*.so +debian/tmp/usr/lib/*/pkgconfig/*.pc diff --git a/third_party/jpeg-xl/debian/libjxl.install b/third_party/jpeg-xl/debian/libjxl.install new file mode 100644 index 000000000000..f1f1b1b125a7 --- /dev/null +++ b/third_party/jpeg-xl/debian/libjxl.install @@ -0,0 +1 @@ +debian/tmp/usr/lib/*/libjxl*.so.* diff --git a/third_party/jpeg-xl/debian/rules b/third_party/jpeg-xl/debian/rules new file mode 100644 index 000000000000..9ce8e88e5dcd --- /dev/null +++ b/third_party/jpeg-xl/debian/rules @@ -0,0 +1,12 @@ +#!/usr/bin/make -f +%: + dh $@ --buildsystem=cmake + +override_dh_auto_configure: + # TODO(deymo): Remove the DCMAKE_BUILD_TYPE once builds without NDEBUG + # are as useful as Release builds. + dh_auto_configure -- \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DJPEGXL_FORCE_SYSTEM_GTEST=ON \ + -DJPEGXL_FORCE_SYSTEM_BROTLI=ON \ + -DJPEGXL_FORCE_SYSTEM_HWY=ON diff --git a/third_party/jpeg-xl/debian/source/format b/third_party/jpeg-xl/debian/source/format new file mode 100644 index 000000000000..163aaf8d82b6 --- /dev/null +++ b/third_party/jpeg-xl/debian/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/third_party/jpeg-xl/deps.sh b/third_party/jpeg-xl/deps.sh new file mode 100644 index 000000000000..c28979a52799 --- /dev/null +++ b/third_party/jpeg-xl/deps.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file downloads the dependencies needed to build JPEG XL into third_party. +# These dependencies are normally pulled by gtest. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + +# Git revisions we use for the given submodules. Update these whenever you +# update a git submodule. +THIRD_PARTY_HIGHWAY="ca1a57c342cd815053abfcffa29b44eaead4f20b" +THIRD_PARTY_LODEPNG="48e5364ef48ec2408f44c727657ac1b6703185f8" +THIRD_PARTY_SKCMS="64374756e03700d649f897dbd98c95e78c30c7da" +THIRD_PARTY_SJPEG="868ab558fad70fcbe8863ba4e85179eeb81cc840" + +# Download the target revision from GitHub. +download_github() { + local path="$1" + local project="$2" + + local varname="${path^^}" + varname="${varname/\//_}" + local sha + eval "sha=\${${varname}}" + + local down_dir="${MYDIR}/downloads" + local local_fn="${down_dir}/${sha}.tar.gz" + if [[ -e "${local_fn}" && -d "${MYDIR}/${path}" ]]; then + echo "${path} already up to date." >&2 + return 0 + fi + + local url + local strip_components=0 + if [[ "${project:0:4}" == "http" ]]; then + # "project" is a googlesource.com base url. + url="${project}${sha}.tar.gz" + else + # GitHub files have a top-level directory + strip_components=1 + url="https://github.com/${project}/tarball/${sha}" + fi + + echo "Downloading ${path} version ${sha}..." >&2 + mkdir -p "${down_dir}" + curl -L --show-error -o "${local_fn}.tmp" "${url}" + mkdir -p "${MYDIR}/${path}" + tar -zxf "${local_fn}.tmp" -C "${MYDIR}/${path}" \ + --strip-components="${strip_components}" + mv "${local_fn}.tmp" "${local_fn}" +} + + +main() { + if git -C "${MYDIR}" rev-parse; then + cat >&2 <&2 <&2 + done +} + +build_target() { + local target="$1" + + local dockerfile="${MYDIR}/Dockerfile.${target}" + # JPEG XL builder images are stored in the gcr.io/jpegxl project. + local tag="gcr.io/jpegxl/${target}" + + echo "Building ${target}" + if ! sudo docker build --no-cache -t "${tag}" -f "${dockerfile}" "${MYDIR}" \ + >"${target}.log" 2>&1; then + echo "${target} failed. See ${target}.log" >&2 + else + echo "Done, to upload image run:" >&2 + echo " sudo docker push ${tag}" + if [[ "${JPEGXL_PUSH:-}" == "1" ]]; then + echo "sudo docker push ${tag}" >&2 + sudo docker push "${tag}" + # The RepoDigest is only created after it is pushed. + local fulltag=$(sudo docker inspect --format="{{.RepoDigests}}" "${tag}") + fulltag="${fulltag#[}" + fulltag="${fulltag%]}" + echo "Updating .gitlab-ci.yml to ${fulltag}" >&2 + sed -E "s;${tag}@sha256:[0-9a-f]+;${fulltag};" \ + -i "${MYDIR}/../.gitlab-ci.yml" + fi + fi +} + +main() { + cd "${MYDIR}" + local target="${1:-}" + + load_targets + if [[ -z "${target}" ]]; then + usage $0 + exit 1 + fi + + if [[ "${target}" == "all" ]]; then + for target in "${TARGETS[@]}"; do + build_target "${target}" + done + else + for target in "$@"; do + build_target "${target}" + done + fi +} + +main "$@" diff --git a/third_party/jpeg-xl/docker/scripts/99_norecommends b/third_party/jpeg-xl/docker/scripts/99_norecommends new file mode 100644 index 000000000000..96d672811d87 --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/99_norecommends @@ -0,0 +1 @@ +APT::Install-Recommends "false"; diff --git a/third_party/jpeg-xl/docker/scripts/binutils_align_fix.patch b/third_party/jpeg-xl/docker/scripts/binutils_align_fix.patch new file mode 100644 index 000000000000..6066252db896 --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/binutils_align_fix.patch @@ -0,0 +1,28 @@ +Description: fix lack of alignment in relocations (crashes on mingw) +See https://sourceware.org/git/?p=binutils-gdb.git;a=patch;h=73af69e74974eaa155eec89867e3ccc77ab39f6d +From: Marc +Date: Fri, 9 Nov 2018 11:13:50 +0000 +Subject: [PATCH] Allow for compilers that do not produce aligned .rdat + sections in PE format files. + +--- a/upstream/ld/scripttempl/pe.sc 2020-05-12 18:45:12.000000000 +0200 ++++ b/upstream/ld/scripttempl/pe.sc 2020-05-12 18:47:12.000000000 +0200 +@@ -143,6 +143,7 @@ + .rdata ${RELOCATING+BLOCK(__section_alignment__)} : + { + ${R_RDATA} ++ . = ALIGN(4); + ${RELOCATING+__rt_psrelocs_start = .;} + ${RELOCATING+KEEP(*(.rdata_runtime_pseudo_reloc))} + ${RELOCATING+__rt_psrelocs_end = .;} +--- a/upstream/ld/scripttempl/pep.sc 2020-05-12 18:45:19.000000000 +0200 ++++ b/upstream/ld/scripttempl/pep.sc 2020-05-12 18:47:18.000000000 +0200 +@@ -143,6 +143,7 @@ + .rdata ${RELOCATING+BLOCK(__section_alignment__)} : + { + ${R_RDATA} ++ . = ALIGN(4); + ${RELOCATING+__rt_psrelocs_start = .;} + ${RELOCATING+KEEP(*(.rdata_runtime_pseudo_reloc))} + ${RELOCATING+__rt_psrelocs_end = .;} + diff --git a/third_party/jpeg-xl/docker/scripts/emsdk_install.sh b/third_party/jpeg-xl/docker/scripts/emsdk_install.sh new file mode 100644 index 000000000000..e0dac04e2366 --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/emsdk_install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +EMSDK_URL="https://github.com/emscripten-core/emsdk/archive/master.tar.gz" +EMSDK_DIR="/opt/emsdk" + +EMSDK_RELEASE="2.0.4" + +set -eu -x + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +main() { + local workdir=$(mktemp -d --suffix=emsdk) + CLEANUP_FILES+=("${workdir}") + + local emsdktar="${workdir}/emsdk.tar.gz" + curl --output "${emsdktar}" "${EMSDK_URL}" --location + mkdir -p "${EMSDK_DIR}" + tar -zxf "${emsdktar}" -C "${EMSDK_DIR}" --strip-components=1 + + cd "${EMSDK_DIR}" + ./emsdk install --shallow "${EMSDK_RELEASE}" + ./emsdk activate --embedded "${EMSDK_RELEASE}" +} + +main "$@" diff --git a/third_party/jpeg-xl/docker/scripts/jpegxl_builder.sh b/third_party/jpeg-xl/docker/scripts/jpegxl_builder.sh new file mode 100644 index 000000000000..6ebffdb22714 --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/jpegxl_builder.sh @@ -0,0 +1,515 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Main entry point for all the Dockerfile for jpegxl-builder. This centralized +# file helps sharing code and configuration between Dockerfiles. + +set -eux + +MYDIR=$(dirname $(realpath "$0")) + +# libjpeg-turbo. +JPEG_TURBO_RELEASE="2.0.4" +JPEG_TURBO_URL="https://github.com/libjpeg-turbo/libjpeg-turbo/archive/${JPEG_TURBO_RELEASE}.tar.gz" +JPEG_TURBO_SHA256="7777c3c19762940cff42b3ba4d7cd5c52d1671b39a79532050c85efb99079064" + +# zlib (dependency of libpng) +ZLIB_RELEASE="1.2.11" +ZLIB_URL="https://www.zlib.net/zlib-${ZLIB_RELEASE}.tar.gz" +ZLIB_SHA256="c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1" +# The name in the .pc and the .dll generated don't match in zlib for Windows +# because they use different .dll names in Windows. We avoid that by defining +# UNIX=1. We also install all the .dll files to ${prefix}/lib instead of the +# default ${prefix}/bin. +ZLIB_FLAGS='-DUNIX=1 -DINSTALL_PKGCONFIG_DIR=/${CMAKE_INSTALL_PREFIX}/lib/pkgconfig -DINSTALL_BIN_DIR=/${CMAKE_INSTALL_PREFIX}/lib' + +# libpng +LIBPNG_RELEASE="1.6.37" +LIBPNG_URL="https://github.com/glennrp/libpng/archive/v${LIBPNG_RELEASE}.tar.gz" +LIBPNG_SHA256="ca74a0dace179a8422187671aee97dd3892b53e168627145271cad5b5ac81307" + +# giflib +GIFLIB_RELEASE="5.2.1" +GIFLIB_URL="https://netcologne.dl.sourceforge.net/project/giflib/giflib-${GIFLIB_RELEASE}.tar.gz" +GIFLIB_SHA256="31da5562f44c5f15d63340a09a4fd62b48c45620cd302f77a6d9acf0077879bd" + +# A patch needed to compile GIFLIB in mingw. +GIFLIB_PATCH_URL="https://github.com/msys2/MINGW-packages/raw/3afde38fcee7b3ba2cafd97d76cca8f06934504f/mingw-w64-giflib/001-mingw-build.patch" +GIFLIB_PATCH_SHA256="2b2262ddea87fc07be82e10aeb39eb699239f883c899aa18a16e4d4e40af8ec8" + +# webp +WEBP_RELEASE="1.0.2" +WEBP_URL="https://codeload.github.com/webmproject/libwebp/tar.gz/v${WEBP_RELEASE}" +WEBP_SHA256="347cf85ddc3497832b5fa9eee62164a37b249c83adae0ba583093e039bf4881f" + +# Google benchmark +BENCHMARK_RELEASE="1.5.2" +BENCHMARK_URL="https://github.com/google/benchmark/archive/v${BENCHMARK_RELEASE}.tar.gz" +BENCHMARK_SHA256="dccbdab796baa1043f04982147e67bb6e118fe610da2c65f88912d73987e700c" +BENCHMARK_FLAGS="-DGOOGLETEST_PATH=${MYDIR}/../../third_party/googletest" +# attribute(format(__MINGW_PRINTF_FORMAT, ...)) doesn't work in our +# environment, so we disable the warning. +BENCHMARK_FLAGS="-DCMAKE_BUILD_TYPE=Release -DBENCHMARK_ENABLE_TESTING=OFF \ + -DCMAKE_CXX_FLAGS=-Wno-ignored-attributes \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON" + +# V8 +V8_VERSION="8.7.230" + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +# List of Ubuntu arch names supported by the builder (such as "i386"). +LIST_ARCHS=( + amd64 + i386 + arm64 + armhf +) + +# List of target triplets supported by the builder. +LIST_TARGETS=( + x86_64-linux-gnu + i686-linux-gnu + arm-linux-gnueabihf + aarch64-linux-gnu +) +LIST_MINGW_TARGETS=( + i686-w64-mingw32 + x86_64-w64-mingw32 +) +LIST_WASM_TARGETS=( + wasm32 +) + +# Setup the apt repositories and supported architectures. +setup_apt() { + apt-get update -y + apt-get install -y curl gnupg ca-certificates + + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 1E9377A2BA9EF27F + + # node sources. + cat >/etc/apt/sources.list.d/nodesource.list <>"${newlist}" + fi + + main_list=$(echo "${main_list[@]}" | tr ' ' ,) + grep -v -E '^#' "${bkplist}" | + sed -E "s;^deb (http[^ ]+) (.*)\$;deb [arch=${main_list}] \\1 \\2\ndeb-src [arch=${main_list}] \\1 \\2;" \ + >>"${newlist}" + mv "${newlist}" /etc/apt/sources.list +} + +install_pkgs() { + packages=( + # Native compilers (minimum for SIMD is clang-7) + clang-7 clang-format-7 clang-tidy-7 + + # TODO: Consider adding clang-8 to every builder: + # clang-8 clang-format-8 clang-tidy-8 + + # For cross-compiling to Windows with mingw. + mingw-w64 + wine64 + wine-binfmt + + # Native tools. + bsdmainutils + cmake + extra-cmake-modules + git + llvm + nasm + ninja-build + parallel + pkg-config + + # These are used by the ./ci.sh lint in the native builder. + clang-format-7 + clang-format-8 + + # For coverage builds + gcovr + + # For compiling giflib documentation. + xmlto + + # Common libraries. + libstdc++-8-dev + + # We don't use tcmalloc on archs other than amd64. This installs + # libgoogle-perftools4:amd64. + google-perftools + + # NodeJS for running WASM tests + nodejs + + # To generate API documentation. + doxygen + + # Freezes version that builds (passes tests). Newer version + # (2.30-21ubuntu1~18.04.4) claims to fix "On Intel Skylake + # (-march=native) generated avx512 instruction can be wrong", + # but newly added tests does not pass. Perhaps the problem is + # that mingw package is not updated. + binutils-source=2.30-15ubuntu1 + ) + + # Install packages that are arch-dependent. + local ubarch + for ubarch in "${LIST_ARCHS[@]}"; do + packages+=( + # Library dependencies. These normally depend on the target architecture + # we are compiling for and can't usually be installed for multiple + # architectures at the same time. + libgif7:"${ubarch}" + libjpeg-dev:"${ubarch}" + libpng-dev:"${ubarch}" + libqt5x11extras5-dev:"${ubarch}" + + libstdc++-8-dev:"${ubarch}" + qtbase5-dev:"${ubarch}" + + # For OpenEXR: + libilmbase12:"${ubarch}" + libopenexr22:"${ubarch}" + + # TCMalloc dependency + libunwind-dev:"${ubarch}" + + # Cross-compiling tools per arch. + libc6-dev-"${ubarch}"-cross + libstdc++-8-dev-"${ubarch}"-cross + ) + done + + local target + for target in "${LIST_TARGETS[@]}"; do + # Per target cross-compiling tools. + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + packages+=( + binutils-"${target}" + gcc-"${target}" + ) + fi + done + + # Install all the manual packages via "apt install" for the main arch. These + # will be installed for other archs via manual download and unpack. + apt install -y "${packages[@]}" "${UNPACK_PKGS[@]}" +} + +# binutils <2.32 need a patch. +install_binutils() { + local workdir=$(mktemp -d --suffix=_install) + CLEANUP_FILES+=("${workdir}") + pushd "${workdir}" + apt source binutils-mingw-w64 + apt -y build-dep binutils-mingw-w64 + cd binutils-mingw-w64-8ubuntu1 + cp "${MYDIR}/binutils_align_fix.patch" debian/patches + echo binutils_align_fix.patch >> debian/patches/series + dpkg-buildpackage -b + cd .. + dpkg -i *deb + popd +} + +# Install a library from the source code for multiple targets. +# Usage: install_from_source [] +install_from_source() { + local package="$1" + shift + + local url + eval "url=\${${package}_URL}" + local sha256 + eval "sha256=\${${package}_SHA256}" + # Optional package flags + local pkgflags + eval "pkgflags=\${${package}_FLAGS:-}" + + local workdir=$(mktemp -d --suffix=_install) + CLEANUP_FILES+=("${workdir}") + + local tarfile="${workdir}"/$(basename "${url}") + curl -L --output "${tarfile}" "${url}" + if ! echo "${sha256} ${tarfile}" | sha256sum -c --status -; then + echo "SHA256 mismatch for ${url}: expected ${sha256} but found:" + sha256sum "${tarfile}" + exit 1 + fi + + local target + for target in "$@"; do + echo "Installing ${package} for target ${target} from ${url}" + + local srcdir="${workdir}/source-${target}" + mkdir -p "${srcdir}" + tar -zxf "${tarfile}" -C "${srcdir}" --strip-components=1 + + local prefix="/usr" + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + prefix="/usr/${target}" + fi + + # Apply patches to buildfiles. + if [[ "${package}" == "GIFLIB" && "${target}" == *mingw32 ]]; then + # GIFLIB Makefile has several problems so we need to fix them here. We are + # using a patch from MSYS2 that already fixes the compilation for mingw. + local make_patch="${srcdir}/libgif.patch" + curl -L "${GIFLIB_PATCH_URL}" -o "${make_patch}" + echo "${GIFLIB_PATCH_SHA256} ${make_patch}" | sha256sum -c --status - + patch "${srcdir}/Makefile" < "${make_patch}" + elif [[ "${package}" == "LIBPNG" && "${target}" == wasm* ]]; then + # Cut the dependency to libm; there is pull request to fix it, so this + # might not be needed in the future. + sed -i 's/APPLE/EMSCRIPTEN/g' "${srcdir}/CMakeLists.txt" + fi + + local cmake_args=() + local export_args=("CC=clang-7" "CXX=clang++-7") + local cmake="cmake" + local make="make" + local system_name="Linux" + if [[ "${target}" == *mingw32 ]]; then + system_name="Windows" + # When compiling with clang, CMake doesn't detect that we are using mingw. + cmake_args+=( + -DMINGW=1 + # Googletest needs this when cross-compiling to windows + -DCMAKE_CROSSCOMPILING=1 + -DHAVE_STD_REGEX=0 + -DHAVE_POSIX_REGEX=0 + -DHAVE_GNU_POSIX_REGEX=0 + ) + local windres=$(which ${target}-windres || true) + if [[ -n "${windres}" ]]; then + cmake_args+=(-DCMAKE_RC_COMPILER="${windres}") + fi + fi + if [[ "${target}" == wasm* ]]; then + system_name="WASM" + cmake="emcmake cmake" + make="emmake make" + export_args=() + cmake_args+=( + -DCMAKE_FIND_ROOT_PATH="${prefix}" + -DCMAKE_PREFIX_PATH="${prefix}" + ) + # Static and shared library link to the same file -> race condition. + nproc=1 + else + nproc=`nproc --all` + fi + cmake_args+=(-DCMAKE_SYSTEM_NAME="${system_name}") + + if [[ "${target}" != "x86_64-linux-gnu" ]]; then + # Cross-compiling. + cmake_args+=( + -DCMAKE_C_COMPILER_TARGET="${target}" + -DCMAKE_CXX_COMPILER_TARGET="${target}" + -DCMAKE_SYSTEM_PROCESSOR="${target%%-*}" + ) + fi + + if [[ -e "${srcdir}/CMakeLists.txt" ]]; then + # Most pacakges use cmake for building which is easier to configure for + # cross-compiling. + if [[ "${package}" == "JPEG_TURBO" && "${target}" == wasm* ]]; then + # JT erroneously detects WASM CPU as i386 and tries to use asm. + # Wasm/Emscripten support for dynamic linking is incomplete; disable + # to avoid CMake warning. + cmake_args+=(-DWITH_SIMD=0 -DENABLE_SHARED=OFF) + fi + ( + cd "${srcdir}" + export ${export_args[@]} + ${cmake} \ + -DCMAKE_INSTALL_PREFIX="${prefix}" \ + "${cmake_args[@]}" ${pkgflags} + ${make} -j${nproc} + ${make} install + ) + elif [[ "${package}" == "GIFLIB" ]]; then + # GIFLIB doesn't yet have a cmake build system. There is a pull + # request in giflib for adding CMakeLists.txt so this might not be + # needed in the future. + ( + cd "${srcdir}" + local giflib_make_flags=( + CFLAGS="-O2 --target=${target} -std=gnu99" + PREFIX="${prefix}" + ) + if [[ "${target}" != wasm* ]]; then + giflib_make_flags+=(CC=clang-7) + fi + # giflib make dependencies are not properly set up so parallel building + # doesn't work for everything. + ${make} -j${nproc} libgif.a "${giflib_make_flags[@]}" + ${make} -j${nproc} all "${giflib_make_flags[@]}" + ${make} install "${giflib_make_flags[@]}" + ) + else + echo "Don't know how to install ${package}" + exit 1 + fi + + done +} + +# Packages that are manually unpacked for each architecture. +UNPACK_PKGS=( + libgif-dev + libclang-common-7-dev + + # For OpenEXR: + libilmbase-dev + libopenexr-dev + + # TCMalloc + libgoogle-perftools-dev + libtcmalloc-minimal4 + libgoogle-perftools4 +) + +# Main script entry point. +main() { + cd "${MYDIR}" + + # Configure the repositories with the sources for multi-arch cross + # compilation. + setup_apt + apt-get update -y + apt-get dist-upgrade -y + + install_pkgs + install_binutils + apt clean + + # Manually extract packages for the target arch that can't install it directly + # at the same time as the native ones. + local ubarch + for ubarch in "${LIST_ARCHS[@]}"; do + if [[ "${ubarch}" != "amd64" ]]; then + local pkg + for pkg in "${UNPACK_PKGS[@]}"; do + apt download "${pkg}":"${ubarch}" + dpkg -x "${pkg}"_*_"${ubarch}".deb / + done + fi + done + # TODO: Add clang from the llvm repos. This is problematic since we are + # installing libclang-common-7-dev:"${ubarch}" from the ubuntu ports repos + # which is not available in the llvm repos so it might have a different + # version than the ubuntu ones. + + # Remove the win32 libgcc version. The gcc-mingw-w64-x86-64 (and i686) + # packages install two libgcc versions: + # /usr/lib/gcc/x86_64-w64-mingw32/7.3-posix + # /usr/lib/gcc/x86_64-w64-mingw32/7.3-win32 + # (exact libgcc version number depends on the package version). + # + # Clang will pick the best libgcc, sorting by version, but it doesn't + # seem to be a way to specify one or the other one, except by passing + # -nostdlib and setting all the include paths from the command line. + # To check which one is being used you can run: + # clang++-7 --target=x86_64-w64-mingw32 -v -print-libgcc-file-name + # We need to use the "posix" versions for thread support, so here we + # just remove the other one. + local target + for target in "${LIST_MINGW_TARGETS[@]}"; do + update-alternatives --set "${target}-gcc" $(which "${target}-gcc-posix") + local gcc_win32_path=$("${target}-cpp-win32" -print-libgcc-file-name) + rm -rf $(dirname "${gcc_win32_path}") + done + + # TODO: Add msan for the target when cross-compiling. This only installs it + # for amd64. + ./msan_install.sh + + # Build and install qemu user-linux targets. + ./qemu_install.sh + + # Install emscripten SDK. + ./emsdk_install.sh + + # Setup environment for building WASM libraries from sources. + source /opt/emsdk/emsdk_env.sh + + # Install some dependency libraries manually for the different targets. + + install_from_source JPEG_TURBO "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source ZLIB "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source LIBPNG "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + install_from_source GIFLIB "${LIST_MINGW_TARGETS[@]}" "${LIST_WASM_TARGETS[@]}" + # webp in Ubuntu is relatively old so we install it from source for everybody. + install_from_source WEBP "${LIST_TARGETS[@]}" "${LIST_MINGW_TARGETS[@]}" + + install_from_source BENCHMARK "${LIST_TARGETS[@]}" "${LIST_MINGW_TARGETS[@]}" + + # Install v8. v8 has better WASM SIMD support than NodeJS 14. + # First we need the installer to install v8. + npm install jsvu -g + # install specific version; + HOME=/opt jsvu --os=linux64 "v8@${V8_VERSION}" + ln -s "/opt/.jsvu/v8-${V8_VERSION}" "/opt/.jsvu/v8" + + # Cleanup. + find /var/lib/apt/lists/ -mindepth 1 -delete +} + +main "$@" diff --git a/third_party/jpeg-xl/docker/scripts/msan_install.sh b/third_party/jpeg-xl/docker/scripts/msan_install.sh new file mode 100644 index 000000000000..10c9eb65c20c --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/msan_install.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eu + +MYDIR=$(dirname $(realpath "$0")) + +# Convenience flag to pass both CMAKE_C_FLAGS and CMAKE_CXX_FLAGS +CMAKE_FLAGS=${CMAKE_FLAGS:-} +CMAKE_C_FLAGS=${CMAKE_C_FLAGS:-${CMAKE_FLAGS}} +CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS:-${CMAKE_FLAGS}} +CMAKE_EXE_LINKER_FLAGS=${CMAKE_EXE_LINKER_FLAGS:-} + +CLANG_VERSION="${CLANG_VERSION:-}" +# Detect the clang version suffix and store it in CLANG_VERSION. For example, +# "6.0" for clang 6 or "7" for clang 7. +detect_clang_version() { + if [[ -n "${CLANG_VERSION}" ]]; then + return 0 + fi + local clang_version=$("${CC:-clang}" --version | head -n1) + local llvm_tag + case "${clang_version}" in + "clang version 6."*) + CLANG_VERSION="6.0" + ;; + "clang version 7."*) + CLANG_VERSION="7" + ;; + "clang version 8."*) + CLANG_VERSION="8" + ;; + "clang version 9."*) + CLANG_VERSION="9" + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac +} + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +# Install libc++ libraries compiled with msan in the msan_prefix for the current +# compiler version. +cmd_msan_install() { + local tmpdir=$(mktemp -d) + CLEANUP_FILES+=("${tmpdir}") + # Detect the llvm to install: + export CC="${CC:-clang}" + export CXX="${CXX:-clang++}" + detect_clang_version + local llvm_tag + case "${CLANG_VERSION}" in + "6.0") + llvm_tag="llvmorg-6.0.1" + ;; + "7") + llvm_tag="llvmorg-7.0.1" + ;; + "8") + llvm_tag="llvmorg-8.0.0" + ;; + *) + echo "Unknown clang version: ${clang_version}" >&2 + return 1 + esac + local llvm_targz="${tmpdir}/${llvm_tag}.tar.gz" + curl -L --show-error -o "${llvm_targz}" \ + "https://github.com/llvm/llvm-project/archive/${llvm_tag}.tar.gz" + tar -C "${tmpdir}" -zxf "${llvm_targz}" + local llvm_root="${tmpdir}/llvm-project-${llvm_tag}" + + local msan_prefix="${HOME}/.msan/${CLANG_VERSION}" + rm -rf "${msan_prefix}" + + declare -A CMAKE_EXTRAS + CMAKE_EXTRAS[libcxx]="\ + -DLIBCXX_CXX_ABI=libstdc++ \ + -DLIBCXX_INSTALL_EXPERIMENTAL_LIBRARY=ON" + + for project in libcxx; do + local proj_build="${tmpdir}/build-${project}" + local proj_dir="${llvm_root}/${project}" + mkdir -p "${proj_build}" + cmake -B"${proj_build}" -H"${proj_dir}" \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_USE_SANITIZER=Memory \ + -DLLVM_PATH="${llvm_root}/llvm" \ + -DLLVM_CONFIG_PATH="$(which llvm-config llvm-config-7 llvm-config-6.0 | \ + head -n1)" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS}" \ + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \ + -DCMAKE_INSTALL_PREFIX="${msan_prefix}" \ + ${CMAKE_EXTRAS[${project}]} + cmake --build "${proj_build}" + ninja -C "${proj_build}" install + done +} + +main() { + set -x + for version in 6.0 7 8; do + if ! which "clang-${version}" >/dev/null; then + echo "Skipping msan install for clang version ${version}" + continue + fi + ( + trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + export CLANG_VERSION=${version} + export CC=clang-${version} + export CXX=clang++-${version} + cmd_msan_install + ) & + done + wait +} + +main "$@" diff --git a/third_party/jpeg-xl/docker/scripts/qemu_install.sh b/third_party/jpeg-xl/docker/scripts/qemu_install.sh new file mode 100644 index 000000000000..fdd99951fe77 --- /dev/null +++ b/third_party/jpeg-xl/docker/scripts/qemu_install.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +QEMU_RELEASE="4.1.0" +QEMU_URL="https://download.qemu.org/qemu-${QEMU_RELEASE}.tar.xz" +QEMU_ARCHS=( + aarch64 + arm + i386 + # TODO: Consider adding these: + # aarch64_be + # mips64el + # mips64 + # mips + # ppc64 + # ppc +) + +# Ubuntu packages not installed that are needed to build qemu. +QEMU_BUILD_DEPS=( + libglib2.0-dev + libpixman-1-dev + flex + bison +) + +set -eu -x + +# Temporary files cleanup hooks. +CLEANUP_FILES=() +cleanup() { + if [[ ${#CLEANUP_FILES[@]} -ne 0 ]]; then + rm -fr "${CLEANUP_FILES[@]}" + fi +} +trap "{ set +x; } 2>/dev/null; cleanup" INT TERM EXIT + +main() { + local workdir=$(mktemp -d --suffix=qemu) + CLEANUP_FILES+=("${workdir}") + + apt install -y "${QEMU_BUILD_DEPS[@]}" + + local qemutar="${workdir}/qemu.tar.gz" + curl --output "${qemutar}" "${QEMU_URL}" + tar -Jxf "${qemutar}" -C "${workdir}" + local srcdir="${workdir}/qemu-${QEMU_RELEASE}" + + local builddir="${workdir}/build" + local prefixdir="${workdir}/prefix" + mkdir -p "${builddir}" + + # List of targets to build. + local targets="" + local make_targets=() + local target + for target in "${QEMU_ARCHS[@]}"; do + targets="${targets} ${target}-linux-user" + # Build just the linux-user targets. + make_targets+=("${target}-linux-user/all") + done + + cd "${builddir}" + "${srcdir}/configure" \ + --prefix="${prefixdir}" \ + --static --disable-system --enable-linux-user \ + --target-list="${targets}" + + make -j $(nproc --all || echo 1) "${make_targets[@]}" + + # Manually install these into the non-standard location. This script runs as + # root anyway. + for target in "${QEMU_ARCHS[@]}"; do + cp "${target}-linux-user/qemu-${target}" "/usr/bin/qemu-${target}-static" + done + + apt autoremove -y --purge "${QEMU_BUILD_DEPS[@]}" +} + +main "$@" diff --git a/third_party/jpeg-xl/examples/CMakeLists.txt b/third_party/jpeg-xl/examples/CMakeLists.txt new file mode 100644 index 000000000000..22c1a632753a --- /dev/null +++ b/third_party/jpeg-xl/examples/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(decode_oneshot decode_oneshot.cc) +target_link_libraries(decode_oneshot jxl_dec jxl_threads) +add_executable(encode_oneshot encode_oneshot.cc) +target_link_libraries(encode_oneshot jxl jxl_threads) + +add_executable(jxlinfo jxlinfo.c) +target_link_libraries(jxlinfo jxl) + +if(NOT ${SANITIZER} STREQUAL "none") + # Linking a C test binary with the C++ JPEG XL implementation when using + # address sanitizer is not well supported by clang 9, so force using clang++ + # for linking this test if a sanitizer is used. + set_target_properties(jxlinfo PROPERTIES LINKER_LANGUAGE CXX) +endif() # SANITIZER != "none" diff --git a/third_party/jpeg-xl/examples/decode_oneshot.cc b/third_party/jpeg-xl/examples/decode_oneshot.cc new file mode 100644 index 000000000000..24488eacb6fa --- /dev/null +++ b/third_party/jpeg-xl/examples/decode_oneshot.cc @@ -0,0 +1,252 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This C++ example decodes a JPEG XL image in one shot (all input bytes +// available at once). The example outputs the pixels and color information to a +// floating point image and an ICC profile on disk. + +#include +#include +#include +#include + +#include + +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "jxl/thread_parallel_runner_cxx.h" + +/** Decodes JPEG XL image to floating point pixels and ICC Profile. Pixel are + * stored as floating point, as interleaved RGBA (4 floating point values per + * pixel), line per line from top to bottom. Pixel values have nominal range + * 0..1 but may go beyond this range for HDR or wide gamut. The ICC profile + * describes the color format of the pixel data. + */ +bool DecodeJpegXlOneShot(const uint8_t* jxl, size_t size, + std::vector* pixels, size_t* xsize, + size_t* ysize, std::vector* icc_profile) { + // Multi-threaded parallel runner. + auto runner = JxlThreadParallelRunnerMake( + nullptr, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + + auto dec = JxlDecoderMake(nullptr); + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + return false; + } + + if (JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec.get(), + JxlThreadParallelRunner, + runner.get())) { + fprintf(stderr, "JxlDecoderSetParallelRunner failed\n"); + return false; + } + + JxlBasicInfo info; + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + + JxlDecoderSetInput(dec.get(), jxl, size); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + return false; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + fprintf(stderr, "Error, already provided all input\n"); + return false; + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec.get(), &info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + return false; + } + *xsize = info.xsize; + *ysize = info.ysize; + } else if (status == JXL_DEC_COLOR_ENCODING) { + // Get the ICC color profile of the pixel data + size_t icc_size; + if (JXL_DEC_SUCCESS != + JxlDecoderGetICCProfileSize( + dec.get(), &format, JXL_COLOR_PROFILE_TARGET_DATA, &icc_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + return false; + } + icc_profile->resize(icc_size); + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile( + dec.get(), &format, + JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile->data(), icc_profile->size())) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + return false; + } + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec.get(), &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderImageOutBufferSize failed\n"); + return false; + } + if (buffer_size != *xsize * *ysize * 16) { + fprintf(stderr, "Invalid out buffer size %zu %zu\n", buffer_size, + *xsize * *ysize * 16); + return false; + } + pixels->resize(*xsize * *ysize * 4); + void* pixels_buffer = (void*)pixels->data(); + size_t pixels_buffer_size = pixels->size() * sizeof(float); + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec.get(), &format, + pixels_buffer, + pixels_buffer_size)) { + fprintf(stderr, "JxlDecoderSetImageOutBuffer failed\n"); + return false; + } + } else if (status == JXL_DEC_FULL_IMAGE) { + // Nothing to do. Do not yet return. If the image is an animation, more + // full frames may be decoded. This example only keeps the last one. + } else if (status == JXL_DEC_SUCCESS) { + // All decoding successfully finished. + // It's not required to call JxlDecoderReleaseInput(dec.get()) here since + // the decoder will be destroyed. + return true; + } else { + fprintf(stderr, "Unknown decoder status\n"); + return false; + } + } +} + +/** Writes to .pfm file (Portable FloatMap). Gimp, tev viewer and ImageMagick + * support viewing this format. + * The input pixels are given as 32-bit floating point with 4-channel RGBA. + * The alpha channel will not be written since .pfm does not support it. + */ +bool WritePFM(const char* filename, const float* pixels, size_t xsize, + size_t ysize) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + uint32_t endian_test = 1; + uint8_t little_endian[4]; + memcpy(little_endian, &endian_test, 4); + + fprintf(file, "PF\n%d %d\n%s\n", (int)xsize, (int)ysize, + little_endian[0] ? "-1.0" : "1.0"); + for (int y = ysize - 1; y >= 0; y--) { + for (size_t x = 0; x < xsize; x++) { + for (size_t c = 0; c < 3; c++) { + const float* f = &pixels[(y * xsize + x) * 4 + c]; + fwrite(f, 4, 1, file); + } + } + } + if (fclose(file) != 0) { + return false; + } + return true; +} + +bool LoadFile(const char* filename, std::vector* out) { + FILE* file = fopen(filename, "rb"); + if (!file) { + return false; + } + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + out->resize(size); + size_t readsize = fread(out->data(), 1, size, file); + if (fclose(file) != 0) { + return false; + } + + return readsize == static_cast(size); +} + +bool WriteFile(const char* filename, const uint8_t* data, size_t size) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing", filename); + return false; + } + fwrite(data, 1, size, file); + if (fclose(file) != 0) { + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + if (argc != 4) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " jxl = input JPEG XL image filename\n" + " pfm = output Portable FloatMap image filename\n" + " icc = output ICC color profile filename\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* jxl_filename = argv[1]; + const char* pfm_filename = argv[2]; + const char* icc_filename = argv[3]; + + std::vector jxl; + if (!LoadFile(jxl_filename, &jxl)) { + fprintf(stderr, "couldn't load %s\n", jxl_filename); + return 1; + } + + std::vector pixels; + std::vector icc_profile; + size_t xsize = 0, ysize = 0; + if (!DecodeJpegXlOneShot(jxl.data(), jxl.size(), &pixels, &xsize, &ysize, + &icc_profile)) { + fprintf(stderr, "Error while decoding the jxl file\n"); + return 1; + } + if (!WritePFM(pfm_filename, pixels.data(), xsize, ysize)) { + fprintf(stderr, "Error while writing the PFM image file\n"); + return 1; + } + if (!WriteFile(icc_filename, icc_profile.data(), icc_profile.size())) { + fprintf(stderr, "Error while writing the ICC profile file\n"); + return 1; + } + printf("Successfully wrote %s and %s\n", pfm_filename, icc_filename); + return 0; +} diff --git a/third_party/jpeg-xl/examples/encode_oneshot.cc b/third_party/jpeg-xl/examples/encode_oneshot.cc new file mode 100644 index 000000000000..f0baea56ea0a --- /dev/null +++ b/third_party/jpeg-xl/examples/encode_oneshot.cc @@ -0,0 +1,282 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This example encodes a file containing a floating point image to another +// file containing JPEG XL image with a single frame. + +#include +#include + +#include +#include +#include + +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "jxl/thread_parallel_runner_cxx.h" + +/** + * Reads from .pfm file (Portable FloatMap) + * + * @param filename name of the file to read + * @param pixels vector to fill with loaded pixels as 32-bit floating point with + * 3-channel RGB + * @param xsize set to width of loaded image + * @param ysize set to height of loaded image + */ +bool ReadPFM(const char* filename, std::vector* pixels, uint32_t* xsize, + uint32_t* ysize) { + FILE* file = fopen(filename, "rb"); + if (!file) { + fprintf(stderr, "Could not open %s for reading.\n", filename); + return false; + } + uint32_t endian_test = 1; + uint8_t little_endian[4]; + memcpy(little_endian, &endian_test, 4); + + if (fseek(file, 0, SEEK_END) != 0) { + fclose(file); + return false; + } + + long size = ftell(file); + // Avoid invalid file or directory. + if (size >= LONG_MAX || size < 0) { + fclose(file); + return false; + } + + if (fseek(file, 0, SEEK_SET) != 0) { + fclose(file); + return false; + } + + std::vector data; + data.resize(size); + + size_t readsize = fread(data.data(), 1, size, file); + if ((long)readsize != size) { + return false; + } + if (fclose(file) != 0) { + return false; + } + + std::stringstream datastream; + std::string datastream_content(data.data(), data.size()); + datastream.str(datastream_content); + + std::string pf_token; + getline(datastream, pf_token, '\n'); + if (pf_token != "PF") { + fprintf(stderr, + "%s doesn't seem to be a 3 channel Portable FloatMap file (missing " + "'PF\\n' " + "bytes).\n", + filename); + return false; + } + + std::string xsize_token; + getline(datastream, xsize_token, ' '); + *xsize = std::stoi(xsize_token); + + std::string ysize_token; + getline(datastream, ysize_token, '\n'); + *ysize = std::stoi(ysize_token); + + std::string endianness_token; + getline(datastream, endianness_token, '\n'); + bool input_little_endian; + if (endianness_token == "1.0") { + input_little_endian = false; + } else if (endianness_token == "-1.0") { + input_little_endian = true; + } else { + fprintf(stderr, + "%s doesn't seem to be a Portable FloatMap file (endianness token " + "isn't '1.0' or '-1.0').\n", + filename); + return false; + } + + size_t offset = pf_token.size() + 1 + xsize_token.size() + 1 + + ysize_token.size() + 1 + endianness_token.size() + 1; + + if (data.size() != *ysize * *xsize * 3 * 4 + offset) { + fprintf(stderr, + "%s doesn't seem to be a Portable FloatMap file (pixel data bytes " + "are %d, but expected %d * %d * 3 * 4 + %d (%d).\n", + filename, (int)data.size(), (int)*ysize, (int)*xsize, (int)offset, + (int)(*ysize * *xsize * 3 * 4 + offset)); + return false; + } + + if (!!little_endian[0] != input_little_endian) { + fprintf(stderr, + "%s has a different endianness than we do, conversion is not " + "supported.\n", + filename); + return false; + } + + pixels->resize(*ysize * *xsize * 3); + + for (int y = *ysize - 1; y >= 0; y--) { + for (int x = 0; x < (int)*xsize; x++) { + for (int c = 0; c < 3; c++) { + memcpy(pixels->data() + (y * *xsize + x) * 3 + c, data.data() + offset, + sizeof(float)); + offset += sizeof(float); + } + } + } + + return true; +} + +/** + * Compresses the provided pixels. + * + * @param pixels input pixels + * @param xsize width of the input image + * @param ysize height of the input image + * @param compressed will be populated with the compressed bytes + */ +bool EncodeJxlOneshot(const std::vector& pixels, const uint32_t xsize, + const uint32_t ysize, std::vector* compressed) { + auto enc = JxlEncoderMake(/*memory_manager=*/nullptr); + auto runner = JxlThreadParallelRunnerMake( + /*memory_manager=*/nullptr, + JxlThreadParallelRunnerDefaultNumWorkerThreads()); + if (JXL_ENC_SUCCESS != JxlEncoderSetParallelRunner(enc.get(), + JxlThreadParallelRunner, + runner.get())) { + fprintf(stderr, "JxlEncoderSetParallelRunner failed\n"); + return false; + } + + JxlPixelFormat pixel_format = {3, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + + JxlBasicInfo basic_info = {}; + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.bits_per_sample = 32; + basic_info.exponent_bits_per_sample = 8; + basic_info.alpha_exponent_bits = 0; + basic_info.alpha_bits = 0; + basic_info.uses_original_profile = JXL_FALSE; + if (JXL_ENC_SUCCESS != JxlEncoderSetBasicInfo(enc.get(), &basic_info)) { + fprintf(stderr, "JxlEncoderSetBasicInfo failed\n"); + return false; + } + + JxlColorEncoding color_encoding = {}; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + if (JXL_ENC_SUCCESS != + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)) { + fprintf(stderr, "JxlEncoderSetColorEncoding failed\n"); + return false; + } + + if (JXL_ENC_SUCCESS != + JxlEncoderAddImageFrame(JxlEncoderOptionsCreate(enc.get(), nullptr), + &pixel_format, (void*)pixels.data(), + sizeof(float) * pixels.size())) { + fprintf(stderr, "JxlEncoderAddImageFrame failed\n"); + return false; + } + + compressed->resize(64); + uint8_t* next_out = compressed->data(); + size_t avail_out = compressed->size() - (next_out - compressed->data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed->data(); + compressed->resize(compressed->size() * 2); + next_out = compressed->data() + offset; + avail_out = compressed->size() - offset; + } + } + compressed->resize(next_out - compressed->data()); + if (JXL_ENC_SUCCESS != process_result) { + fprintf(stderr, "JxlEncoderProcessOutput failed\n"); + return false; + } + + return true; +} + +/** + * Writes bytes to file. + */ +bool WriteFile(const std::vector& bytes, const char* filename) { + FILE* file = fopen(filename, "wb"); + if (!file) { + fprintf(stderr, "Could not open %s for writing\n", filename); + return false; + } + if (fwrite(bytes.data(), sizeof(uint8_t), bytes.size(), file) != + bytes.size()) { + fprintf(stderr, "Could not write bytes to %s\n", filename); + return false; + } + if (fclose(file) != 0) { + fprintf(stderr, "Could not close %s\n", filename); + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " pfm = input Portable FloatMap image filename\n" + " jxl = output JPEG XL image filename\n" + "Output files will be overwritten.\n", + argv[0]); + return 1; + } + + const char* pfm_filename = argv[1]; + const char* jxl_filename = argv[2]; + + std::vector pixels; + uint32_t xsize; + uint32_t ysize; + if (!ReadPFM(pfm_filename, &pixels, &xsize, &ysize)) { + fprintf(stderr, "Couldn't load %s\n", pfm_filename); + return 2; + } + + std::vector compressed; + if (!EncodeJxlOneshot(pixels, xsize, ysize, &compressed)) { + fprintf(stderr, "Couldn't encode jxl\n"); + return 3; + } + + if (!WriteFile(compressed, jxl_filename)) { + fprintf(stderr, "Couldn't write jxl file\n"); + return 4; + } + + return 0; +} diff --git a/third_party/jpeg-xl/examples/jxlinfo.c b/third_party/jpeg-xl/examples/jxlinfo.c new file mode 100644 index 000000000000..06a144f30f8b --- /dev/null +++ b/third_party/jpeg-xl/examples/jxlinfo.c @@ -0,0 +1,247 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This example prints information from the main codestream header. + +#include +#include +#include + +#include "jxl/decode.h" + +int PrintBasicInfo(FILE* file) { + uint8_t* data = NULL; + size_t data_size = 0; + // In how large chunks to read from the file and try decoding the basic info. + const size_t chunk_size = 64; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + if (!dec) { + fprintf(stderr, "JxlDecoderCreate failed\n"); + return 0; + } + + JxlDecoderSetKeepOrientation(dec, 1); + + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec, + JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + JxlDecoderDestroy(dec); + return 0; + } + + JxlBasicInfo info; + int seen_basic_info = 0; + + for (;;) { + // The firs time, this will output JXL_DEC_NEED_MORE_INPUT because no + // input is set yet, this is ok since the input is set when handling this + // event. + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + // The firstt time there is nothing to release and it returns 0, but that + // is ok. + size_t remaining = JxlDecoderReleaseInput(dec); + // move any remaining bytes to the front if necessary + if (remaining != 0) { + memmove(data, data + data_size - remaining, remaining); + } + // resize the buffer to append one more chunk of data + // TODO(lode): avoid unnecessary reallocations + data = (uint8_t*)realloc(data, remaining + chunk_size); + // append bytes read from the file behind the remaining bytes + size_t read_size = fread(data + remaining, 1, chunk_size, file); + data_size = remaining + read_size; + JxlDecoderSetInput(dec, data, data_size); + } else if (status == JXL_DEC_SUCCESS) { + // Finished all processing. + break; + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec, &info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + break; + } + + seen_basic_info = 1; + + printf("xsize: %u\n", info.xsize); + printf("ysize: %u\n", info.ysize); + printf("have_container: %d\n", info.have_container); + printf("uses_original_profile: %d\n", info.uses_original_profile); + printf("bits_per_sample: %d\n", info.bits_per_sample); + printf("exponent_bits_per_sample: %d\n", info.exponent_bits_per_sample); + printf("intensity_target: %f\n", info.intensity_target); + printf("min_nits: %f\n", info.min_nits); + printf("relative_to_max_display: %d\n", info.relative_to_max_display); + printf("linear_below: %f\n", info.linear_below); + printf("have_preview: %d\n", info.have_preview); + if (info.have_preview) { + printf("preview xsize: %u\n", info.preview.xsize); + printf("preview ysize: %u\n", info.preview.ysize); + } + printf("have_animation: %d\n", info.have_animation); + if (info.have_animation) { + printf("ticks per second (numerator / denominator): %u / %u\n", + info.animation.tps_numerator, info.animation.tps_denominator); + printf("num_loops: %u\n", info.animation.num_loops); + printf("have_timecodes: %d\n", info.animation.have_timecodes); + } + printf("orientation: %d\n", info.orientation); + printf("num_extra_channels: %d\n", info.num_extra_channels); + printf("alpha_bits: %d\n", info.alpha_bits); + printf("alpha_exponent_bits: %d\n", info.alpha_exponent_bits); + printf("alpha_premultiplied: %d\n", info.alpha_premultiplied); + + for (uint32_t i = 0; i < info.num_extra_channels; i++) { + JxlExtraChannelInfo extra; + if (JXL_DEC_SUCCESS != JxlDecoderGetExtraChannelInfo(dec, i, &extra)) { + fprintf(stderr, "JxlDecoderGetExtraChannelInfo failed\n"); + break; + } + printf("extra channel: %u info:\n", i); + printf(" type: %d\n", extra.type); + printf(" bits_per_sample: %u\n", extra.bits_per_sample); + printf(" exponent_bits_per_sample: %u\n", + extra.exponent_bits_per_sample); + printf(" dim_shift: %u\n", extra.dim_shift); + printf(" name_length: %u\n", extra.name_length); + if (extra.name_length) { + char* name = malloc(extra.name_length + 1); + if (JXL_DEC_SUCCESS != JxlDecoderGetExtraChannelName( + dec, i, name, extra.name_length + 1)) { + fprintf(stderr, "JxlDecoderGetExtraChannelName failed\n"); + free(name); + break; + } + free(name); + printf(" name: %s\n", name); + } + printf(" alpha_associated: %d\n", extra.alpha_associated); + printf(" spot_color: %f %f %f %f\n", extra.spot_color[0], + extra.spot_color[1], extra.spot_color[2], extra.spot_color[3]); + printf(" cfa_channel: %u\n", extra.cfa_channel); + } + } else if (status == JXL_DEC_COLOR_ENCODING) { + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + JxlColorProfileTarget targets[2] = {JXL_COLOR_PROFILE_TARGET_ORIGINAL, + JXL_COLOR_PROFILE_TARGET_DATA}; + for (size_t i = 0; i < 2; i++) { + JxlColorProfileTarget target = targets[i]; + if (info.uses_original_profile) { + if (target != JXL_COLOR_PROFILE_TARGET_ORIGINAL) continue; + printf("color profile:\n"); + } else { + printf(target == JXL_COLOR_PROFILE_TARGET_ORIGINAL + ? "original color profile:\n" + : "data color profile:\n"); + } + + JxlColorEncoding color_encoding; + if (JXL_DEC_SUCCESS == JxlDecoderGetColorAsEncodedProfile( + dec, &format, target, &color_encoding)) { + printf(" format: JPEG XL encoded color profile\n"); + printf(" color_space: %d\n", color_encoding.color_space); + printf(" white_point: %d\n", color_encoding.white_point); + printf(" white_point XY: %f %f\n", color_encoding.white_point_xy[0], + color_encoding.white_point_xy[1]); + if (color_encoding.color_space == JXL_COLOR_SPACE_RGB || + color_encoding.color_space == JXL_COLOR_SPACE_UNKNOWN) { + printf(" primaries: %d\n", color_encoding.primaries); + printf(" red primaries XY: %f %f\n", + color_encoding.primaries_red_xy[0], + color_encoding.primaries_red_xy[1]); + printf(" green primaries XY: %f %f\n", + color_encoding.primaries_green_xy[0], + color_encoding.primaries_green_xy[1]); + printf(" blue primaries XY: %f %f\n", + color_encoding.primaries_blue_xy[0], + color_encoding.primaries_blue_xy[1]); + } + printf(" transfer_function: %d\n", color_encoding.transfer_function); + if (color_encoding.transfer_function == JXL_TRANSFER_FUNCTION_GAMMA) { + printf(" transfer_function gamma: %f\n", color_encoding.gamma); + } + printf(" rendering_intent: %d\n", color_encoding.rendering_intent); + } else { + // The profile is not in JPEG XL encoded form, get as ICC profile + // instead. + printf(" format: ICC profile\n"); + size_t profile_size; + if (JXL_DEC_SUCCESS != JxlDecoderGetICCProfileSize( + dec, &format, target, &profile_size)) { + fprintf(stderr, "JxlDecoderGetICCProfileSize failed\n"); + continue; + } + printf(" ICC profile size: %zu\n", profile_size); + if (profile_size < 132) { + fprintf(stderr, "ICC profile too small\n"); + continue; + } + uint8_t* profile = (uint8_t*)malloc(profile_size); + if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile(dec, &format, + target, profile, + profile_size)) { + fprintf(stderr, "JxlDecoderGetColorAsICCProfile failed\n"); + free(profile); + continue; + } + printf(" CMM type: \"%.4s\"\n", profile + 4); + printf(" color space: \"%.4s\"\n", profile + 16); + printf(" rendering intent: %d\n", (int)profile[67]); + free(profile); + } + } + // This is the last expected event, no need to read the rest of the file. + } else { + fprintf(stderr, "Unexpected decoder status\n"); + break; + } + } + + JxlDecoderDestroy(dec); + free(data); + + return seen_basic_info; +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + fprintf(stderr, + "Usage: %s \n" + "Where:\n" + " jxl = input JPEG XL image filename\n", + argv[0]); + return 1; + } + + const char* jxl_filename = argv[1]; + + FILE* file = fopen(jxl_filename, "rb"); + if (!file) { + fprintf(stderr, "Failed to read file %s\n", jxl_filename); + return 1; + } + + if (!PrintBasicInfo(file)) { + fprintf(stderr, "Couldn't print basic info\n"); + return 1; + } + + return 0; +} diff --git a/third_party/jpeg-xl/js-wasm-wrapper.sh b/third_party/jpeg-xl/js-wasm-wrapper.sh new file mode 100644 index 000000000000..fd1ee9843945 --- /dev/null +++ b/third_party/jpeg-xl/js-wasm-wrapper.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Continuous integration helper module. This module is meant to be called from +# the .gitlab-ci.yml file during the continuous integration build, as well as +# from the command line for developers. + +# This wrapper is used to enable WASM SIMD when running tests. +# Unfortunately, it is impossible to pass the option directly via the +# CMAKE_CROSSCOMPILING_EMULATOR variable. + +# Fallback to default v8 binary, if override is not set. +V8="${V8:-$(which v8)}" +SCRIPT="$1" +shift +"${V8}" --experimental-wasm-simd "${SCRIPT}" -- "$@" diff --git a/third_party/jpeg-xl/lib/CMakeLists.txt b/third_party/jpeg-xl/lib/CMakeLists.txt new file mode 100644 index 000000000000..02f2f2e3a651 --- /dev/null +++ b/third_party/jpeg-xl/lib/CMakeLists.txt @@ -0,0 +1,156 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(JPEGXL_MAJOR_VERSION 0) +set(JPEGXL_MINOR_VERSION 3) +set(JPEGXL_PATCH_VERSION 7) +set(JPEGXL_LIBRARY_VERSION + "${JPEGXL_MAJOR_VERSION}.${JPEGXL_MINOR_VERSION}.${JPEGXL_PATCH_VERSION}") + +# This the library API compatibility version. +set(JPEGXL_LIBRARY_SOVERSION "${JPEGXL_MAJOR_VERSION}") + + +# List of warning and feature flags for our library and tests. +if (MSVC) +set(JPEGXL_INTERNAL_FLAGS + # TODO(janwas): add flags +) +else () +set(JPEGXL_INTERNAL_FLAGS + # F_FLAGS + -fmerge-all-constants + -fno-builtin-fwrite + -fno-builtin-fread + + # WARN_FLAGS + -Wall + -Wextra + -Wc++11-compat + -Warray-bounds + -Wformat-security + -Wimplicit-fallthrough + -Wno-register # Needed by public headers in lcms + -Wno-unused-function + -Wno-unused-parameter + -Wnon-virtual-dtor + -Woverloaded-virtual + -Wvla +) + +# Warning flags supported by clang. +if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND JPEGXL_INTERNAL_FLAGS + -Wc++2a-extensions + -Wdeprecated-increment-bool + # TODO(deymo): Add -Wextra-semi once we update third_party/highway. + # -Wextra-semi + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wliteral-conversion + -Wno-c++98-compat + -Wno-unused-command-line-argument + -Wprivate-header + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + -Wunreachable-code + -Wunused-comparison + ) +endif() # Clang + +if (WIN32) + list(APPEND JPEGXL_INTERNAL_FLAGS + -Wno-c++98-compat-pedantic + -Wno-cast-align + -Wno-double-promotion + -Wno-float-equal + -Wno-format-nonliteral + -Wno-global-constructors + -Wno-language-extension-token + -Wno-missing-prototypes + -Wno-shadow + -Wno-shadow-field-in-constructor + -Wno-sign-conversion + -Wno-unused-member-function + -Wno-unused-template + -Wno-used-but-marked-unused + -Wno-zero-as-null-pointer-constant + ) +else() # WIN32 + list(APPEND JPEGXL_INTERNAL_FLAGS + -fsized-deallocation + -fno-exceptions + + # Language flags + -fmath-errno + ) + + if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND JPEGXL_INTERNAL_FLAGS + -fnew-alignment=8 + -fno-cxx-exceptions + -fno-slp-vectorize + -fno-vectorize + + -disable-free + -disable-llvm-verifier + ) + endif() # Clang +endif() # WIN32 + +# Internal flags for coverage builds: +if(JPEGXL_ENABLE_COVERAGE) +set(JPEGXL_COVERAGE_FLAGS + -g -O0 -fprofile-arcs -ftest-coverage -DJXL_DISABLE_SLOW_TESTS + -DJXL_ENABLE_ASSERT=0 -DJXL_ENABLE_CHECK=0 +) +endif() # JPEGXL_ENABLE_COVERAGE +endif() #!MSVC + +# The jxl library definition. +include(jxl.cmake) + +# Other libraries outside the core jxl library. +include(jxl_extras.cmake) +include(jxl_threads.cmake) + +# Install all the library headers from the source and the generated ones. There +# is no distinction on which libraries use which header since it is expected +# that all developer libraries are available together at build time. +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/jxl + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/include/jxl + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + +if(BUILD_TESTING) +# Unittests +cmake_policy(SET CMP0057 NEW) # https://gitlab.kitware.com/cmake/cmake/issues/18198 +include(GoogleTest) + +# Tests for the jxl library. +include(jxl_tests.cmake) + +# Google benchmark for the jxl library +include(jxl_benchmark.cmake) + +# Profiler for libjxl +include(jxl_profiler.cmake) + +endif() # BUILD_TESTING diff --git a/third_party/jpeg-xl/lib/extras/README.md b/third_party/jpeg-xl/lib/extras/README.md new file mode 100644 index 000000000000..06a9b5ea0782 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/README.md @@ -0,0 +1,5 @@ +## JPEG XL "extras" + +The files in this directory do not form part of the library or codec and are +only used by tests or specific internal tools that have access to the internals +of the library. diff --git a/third_party/jpeg-xl/lib/extras/codec.cc b/third_party/jpeg-xl/lib/extras/codec.cc new file mode 100644 index 000000000000..72c928af9113 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec.cc @@ -0,0 +1,235 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec.h" + +#include "lib/jxl/base/file_io.h" +#if JPEGXL_ENABLE_APNG +#include "lib/extras/codec_apng.h" +#endif +#if JPEGXL_ENABLE_EXR +#include "lib/extras/codec_exr.h" +#endif +#if JPEGXL_ENABLE_GIF +#include "lib/extras/codec_gif.h" +#endif +#include "lib/extras/codec_jpg.h" +#include "lib/extras/codec_pgx.h" +#include "lib/extras/codec_png.h" +#include "lib/extras/codec_pnm.h" +#include "lib/extras/codec_psd.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { +namespace { + +// Any valid encoding is larger (ensures codecs can read the first few bytes) +constexpr size_t kMinBytes = 9; + +} // namespace + +std::string ExtensionFromCodec(Codec codec, const bool is_gray, + const size_t bits_per_sample) { + switch (codec) { + case Codec::kJPG: + return ".jpg"; + case Codec::kPGX: + return ".pgx"; + case Codec::kPNG: + return ".png"; + case Codec::kPNM: + if (is_gray) return ".pgm"; + return (bits_per_sample == 32) ? ".pfm" : ".ppm"; + case Codec::kGIF: + return ".gif"; + case Codec::kEXR: + return ".exr"; + case Codec::kPSD: + return ".psd"; + case Codec::kUnknown: + return std::string(); + } + JXL_UNREACHABLE; + return std::string(); +} + +Codec CodecFromExtension(const std::string& extension, + size_t* JXL_RESTRICT bits_per_sample) { + if (extension == ".png") return Codec::kPNG; + + if (extension == ".jpg") return Codec::kJPG; + if (extension == ".jpeg") return Codec::kJPG; + + if (extension == ".pgx") return Codec::kPGX; + + if (extension == ".pbm") { + *bits_per_sample = 1; + return Codec::kPNM; + } + if (extension == ".pgm") return Codec::kPNM; + if (extension == ".ppm") return Codec::kPNM; + if (extension == ".pfm") { + *bits_per_sample = 32; + return Codec::kPNM; + } + + if (extension == ".gif") return Codec::kGIF; + + if (extension == ".exr") return Codec::kEXR; + + if (extension == ".psd") return Codec::kPSD; + + return Codec::kUnknown; +} + +Status SetFromBytes(const Span bytes, CodecInOut* io, + ThreadPool* pool, Codec* orig_codec) { + if (bytes.size() < kMinBytes) return JXL_FAILURE("Too few bytes"); + + io->metadata.m.bit_depth.bits_per_sample = 0; // (For is-set check below) + + Codec codec; + if (DecodeImagePNG(bytes, pool, io)) { + codec = Codec::kPNG; + } +#if JPEGXL_ENABLE_APNG + else if (DecodeImageAPNG(bytes, pool, io)) { + codec = Codec::kPNG; + } +#endif + else if (DecodeImagePGX(bytes, pool, io)) { + codec = Codec::kPGX; + } else if (DecodeImagePNM(bytes, pool, io)) { + codec = Codec::kPNM; + } +#if JPEGXL_ENABLE_GIF + else if (DecodeImageGIF(bytes, pool, io)) { + codec = Codec::kGIF; + } +#endif + else if (DecodeImageJPG(bytes, pool, io)) { + codec = Codec::kJPG; + } + else if (DecodeImagePSD(bytes, pool, io)) { + codec = Codec::kPSD; + } +#if JPEGXL_ENABLE_EXR + else if (DecodeImageEXR(bytes, pool, io)) { + codec = Codec::kEXR; + } +#endif + else { + return JXL_FAILURE("Codecs failed to decode"); + } + if (orig_codec) *orig_codec = codec; + + io->CheckMetadata(); + return true; +} + +Status SetFromFile(const std::string& pathname, CodecInOut* io, + ThreadPool* pool, Codec* orig_codec) { + PaddedBytes encoded; + JXL_RETURN_IF_ERROR(ReadFile(pathname, &encoded)); + JXL_RETURN_IF_ERROR( + SetFromBytes(Span(encoded), io, pool, orig_codec)); + return true; +} + +Status Encode(const CodecInOut& io, const Codec codec, + const ColorEncoding& c_desired, size_t bits_per_sample, + PaddedBytes* bytes, ThreadPool* pool) { + JXL_CHECK(!io.Main().c_current().ICC().empty()); + JXL_CHECK(!c_desired.ICC().empty()); + io.CheckMetadata(); + if (io.Main().IsJPEG() && codec != Codec::kJPG) { + return JXL_FAILURE( + "Output format has to be JPEG for losslessly recompressed JPEG " + "reconstruction"); + } + + switch (codec) { + case Codec::kPNG: + return EncodeImagePNG(&io, c_desired, bits_per_sample, pool, bytes); + case Codec::kJPG: +#if JPEGXL_ENABLE_JPEG + return EncodeImageJPG( + &io, io.use_sjpeg ? JpegEncoder::kSJpeg : JpegEncoder::kLibJpeg, + io.jpeg_quality, YCbCrChromaSubsampling(), pool, bytes, + io.Main().IsJPEG() ? DecodeTarget::kQuantizedCoeffs + : DecodeTarget::kPixels); +#else + return JXL_FAILURE("JPEG XL was built without JPEG support"); +#endif + case Codec::kPNM: + return EncodeImagePNM(&io, c_desired, bits_per_sample, pool, bytes); + case Codec::kPGX: + return EncodeImagePGX(&io, c_desired, bits_per_sample, pool, bytes); + case Codec::kGIF: + return JXL_FAILURE("Encoding to GIF is not implemented"); + case Codec::kPSD: + return EncodeImagePSD(&io, c_desired, bits_per_sample, pool, bytes); + case Codec::kEXR: +#if JPEGXL_ENABLE_EXR + return EncodeImageEXR(&io, c_desired, pool, bytes); +#else + return JXL_FAILURE("JPEG XL was built without OpenEXR support"); +#endif + case Codec::kUnknown: + return JXL_FAILURE("Cannot encode using Codec::kUnknown"); + } + + return JXL_FAILURE("Invalid codec"); +} + +Status EncodeToFile(const CodecInOut& io, const ColorEncoding& c_desired, + size_t bits_per_sample, const std::string& pathname, + ThreadPool* pool) { + const std::string extension = Extension(pathname); + const Codec codec = CodecFromExtension(extension, &bits_per_sample); + + // Warn about incorrect usage of PBM/PGM/PGX/PPM - only the latter supports + // color, but CodecFromExtension lumps them all together. + if (codec == Codec::kPNM && extension != ".pfm") { + if (!io.Main().IsGray() && extension != ".ppm") { + JXL_WARNING("For color images, the filename should end with .ppm.\n"); + } else if (io.Main().IsGray() && extension == ".ppm") { + JXL_WARNING( + "For grayscale images, the filename should not end with .ppm.\n"); + } + if (bits_per_sample > 16) { + JXL_WARNING("PPM only supports up to 16 bits per sample"); + bits_per_sample = 16; + } + } else if (codec == Codec::kPGX && !io.Main().IsGray()) { + JXL_WARNING("Storing color image to PGX - use .ppm extension instead.\n"); + } + if (bits_per_sample > 16 && codec == Codec::kPNG) { + JXL_WARNING("PNG only supports up to 16 bits per sample"); + bits_per_sample = 16; + } + + PaddedBytes encoded; + return Encode(io, codec, c_desired, bits_per_sample, &encoded, pool) && + WriteFile(encoded, pathname); +} + +Status EncodeToFile(const CodecInOut& io, const std::string& pathname, + ThreadPool* pool) { + // TODO(lode): need to take the floating_point_sample field into account + return EncodeToFile(io, io.metadata.m.color_encoding, + io.metadata.m.bit_depth.bits_per_sample, pathname, pool); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec.h b/third_party/jpeg-xl/lib/extras/codec.h new file mode 100644 index 000000000000..cdbf4ed43f3a --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec.h @@ -0,0 +1,94 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_H_ +#define LIB_EXTRAS_CODEC_H_ + +// Facade for image encoders/decoders (PNG, PNM, ...). + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/field_encodings.h" // MakeBit + +namespace jxl { + +// Codecs supported by CodecInOut::Encode. +enum class Codec : uint32_t { + kUnknown, // for CodecFromExtension + kPNG, + kPNM, + kPGX, + kJPG, + kGIF, + kEXR, + kPSD +}; + +static inline constexpr uint64_t EnumBits(Codec /*unused*/) { + // Return only fully-supported codecs (kGIF is decode-only). + return MakeBit(Codec::kPNM) | MakeBit(Codec::kPNG) +#if JPEGXL_ENABLE_JPEG + | MakeBit(Codec::kJPG) +#endif +#if JPEGXL_ENABLE_EXR + | MakeBit(Codec::kEXR) +#endif + | MakeBit(Codec::kPSD); +} + +// Lower case ASCII including dot, e.g. ".png". +std::string ExtensionFromCodec(Codec codec, bool is_gray, + size_t bits_per_sample); + +// If and only if extension is ".pfm", *bits_per_sample is updated to 32 so +// that Encode() would encode to PFM instead of PPM. +Codec CodecFromExtension(const std::string& extension, + size_t* JXL_RESTRICT bits_per_sample); + +// Decodes "bytes" and sets io->metadata.m. +// dec_hints may specify the "color_space" (otherwise, defaults to sRGB). +Status SetFromBytes(const Span bytes, CodecInOut* io, + ThreadPool* pool = nullptr, Codec* orig_codec = nullptr); + +// Reads from file and calls SetFromBytes. +Status SetFromFile(const std::string& pathname, CodecInOut* io, + ThreadPool* pool = nullptr, Codec* orig_codec = nullptr); + +// Replaces "bytes" with an encoding of pixels transformed from c_current +// color space to c_desired. +Status Encode(const CodecInOut& io, Codec codec, const ColorEncoding& c_desired, + size_t bits_per_sample, PaddedBytes* bytes, + ThreadPool* pool = nullptr); + +// Deduces codec, calls Encode and writes to file. +Status EncodeToFile(const CodecInOut& io, const ColorEncoding& c_desired, + size_t bits_per_sample, const std::string& pathname, + ThreadPool* pool = nullptr); +// Same, but defaults to metadata.original color_encoding and bits_per_sample. +Status EncodeToFile(const CodecInOut& io, const std::string& pathname, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_apng.cc b/third_party/jpeg-xl/lib/extras/codec_apng.cc new file mode 100644 index 000000000000..c6d666065c6d --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_apng.cc @@ -0,0 +1,427 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_apng.h" + +// Parts of this code are taken from apngdis, which has the following license: +/* APNG Disassembler 2.8 + * + * Deconstructs APNG files into individual frames. + * + * http://apngdis.sourceforge.net + * + * Copyright (c) 2010-2015 Max Stepin + * maxst at users.sourceforge.net + * + * zlib license + * ------------ + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + * + */ + +#include +#include + +#if defined(_WIN32) || defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#define WIN32_LEAN_AND_MEAN +#include +#endif + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" +#include "png.h" /* original (unpatched) libpng is ok */ + +namespace jxl { + +namespace { +#define notabc(c) ((c) < 65 || (c) > 122 || ((c) > 90 && (c) < 97)) + +#define id_IHDR 0x52444849 +#define id_acTL 0x4C546361 +#define id_fcTL 0x4C546366 +#define id_IDAT 0x54414449 +#define id_fdAT 0x54416466 +#define id_IEND 0x444E4549 + +struct CHUNK { + unsigned char* p; + unsigned int size; +}; +struct APNGFrame { + unsigned char *p, **rows; + unsigned int w, h, delay_num, delay_den; +}; + +const unsigned long cMaxPNGSize = 1000000UL; +const size_t kMaxPNGChunkSize = 100000000; // 100 MB + +void info_fn(png_structp png_ptr, png_infop info_ptr) { + png_set_expand(png_ptr); + png_set_strip_16(png_ptr); + png_set_gray_to_rgb(png_ptr); + png_set_palette_to_rgb(png_ptr); + png_set_add_alpha(png_ptr, 0xff, PNG_FILLER_AFTER); + (void)png_set_interlace_handling(png_ptr); + png_read_update_info(png_ptr, info_ptr); +} + +void row_fn(png_structp png_ptr, png_bytep new_row, png_uint_32 row_num, + int pass) { + APNGFrame* frame = (APNGFrame*)png_get_progressive_ptr(png_ptr); + png_progressive_combine_row(png_ptr, frame->rows[row_num], new_row); +} + +inline unsigned int read_chunk(FILE* f, CHUNK* pChunk) { + unsigned char len[4]; + pChunk->size = 0; + pChunk->p = 0; + if (fread(&len, 4, 1, f) == 1) { + const auto size = png_get_uint_32(len); + // Check first, to avoid overflow. + if (size > kMaxPNGChunkSize) { + JXL_WARNING("APNG chunk size is too big"); + return 0; + } + pChunk->size = size + 12; + pChunk->p = new unsigned char[pChunk->size]; + memcpy(pChunk->p, len, 4); + if (fread(pChunk->p + 4, pChunk->size - 4, 1, f) == 1) + return *(unsigned int*)(pChunk->p + 4); + } + return 0; +} + +int processing_start(png_structp& png_ptr, png_infop& info_ptr, void* frame_ptr, + bool hasInfo, CHUNK& chunkIHDR, + std::vector& chunksInfo) { + unsigned char header[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + + png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL); + info_ptr = png_create_info_struct(png_ptr); + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_read_struct(&png_ptr, &info_ptr, 0); + return 1; + } + + png_set_crc_action(png_ptr, PNG_CRC_QUIET_USE, PNG_CRC_QUIET_USE); + png_set_progressive_read_fn(png_ptr, frame_ptr, info_fn, row_fn, NULL); + + png_process_data(png_ptr, info_ptr, header, 8); + png_process_data(png_ptr, info_ptr, chunkIHDR.p, chunkIHDR.size); + + if (hasInfo) { + for (unsigned int i = 0; i < chunksInfo.size(); i++) { + png_process_data(png_ptr, info_ptr, chunksInfo[i].p, chunksInfo[i].size); + } + } + return 0; +} + +int processing_data(png_structp png_ptr, png_infop info_ptr, unsigned char* p, + unsigned int size) { + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_read_struct(&png_ptr, &info_ptr, 0); + return 1; + } + + png_process_data(png_ptr, info_ptr, p, size); + return 0; +} + +int processing_finish(png_structp png_ptr, png_infop info_ptr) { + unsigned char footer[12] = {0, 0, 0, 0, 73, 69, 78, 68, 174, 66, 96, 130}; + + if (!png_ptr || !info_ptr) return 1; + + if (setjmp(png_jmpbuf(png_ptr))) { + png_destroy_read_struct(&png_ptr, &info_ptr, 0); + return 1; + } + + png_process_data(png_ptr, info_ptr, footer, 12); + png_destroy_read_struct(&png_ptr, &info_ptr, 0); + + return 0; +} + +#if defined(_WIN32) || defined(_WIN64) +FILE* fmemopen(void* buf, size_t size, const char* mode) { + char temp[999]; + if (!GetTempPath(sizeof(temp), temp)) return nullptr; + + char pathname[999]; + if (!GetTempFileName(temp, "jpegxl", 0, pathname)) return nullptr; + + FILE* f = fopen(pathname, "wb"); + if (f == nullptr) return nullptr; + fwrite(buf, 1, size, f); + JXL_CHECK(fclose(f) == 0); + + return fopen(pathname, mode); +} + +#endif + +} // namespace + +Status DecodeImageAPNG(Span bytes, ThreadPool* pool, + CodecInOut* io) { + FILE* f; + unsigned int id, i, j, w, h, w0, h0, x0, y0; + unsigned int delay_num, delay_den, dop, bop, rowbytes, imagesize; + unsigned char sig[8]; + png_structp png_ptr; + png_infop info_ptr; + CHUNK chunk; + CHUNK chunkIHDR; + std::vector chunksInfo; + bool isAnimated = false; + bool skipFirst = false; + bool hasInfo = false; + APNGFrame frameRaw = {}; + + if (!(f = fmemopen((void*)bytes.data(), bytes.size(), "rb"))) { + return JXL_FAILURE("Failed to fmemopen"); + } + // Not an aPNG => not an error + unsigned char png_signature[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + if (fread(sig, 1, 8, f) != 8 || memcmp(sig, png_signature, 8) != 0) { + fclose(f); + return false; + } + id = read_chunk(f, &chunkIHDR); + + io->frames.clear(); + io->dec_pixels = 0; + io->metadata.m.SetUintSamples(8); + io->metadata.m.SetAlphaBits(8); + io->metadata.m.color_encoding = + ColorEncoding::SRGB(); // todo: get data from png metadata + (void)io->dec_hints.Foreach( + [](const std::string& key, const std::string& /*value*/) { + JXL_WARNING("APNG decoder ignoring %s hint", key.c_str()); + return true; + }); + + bool errorstate = true; + if (id == id_IHDR && chunkIHDR.size == 25) { + w0 = w = png_get_uint_32(chunkIHDR.p + 8); + h0 = h = png_get_uint_32(chunkIHDR.p + 12); + + if (w > cMaxPNGSize || h > cMaxPNGSize) { + fclose(f); + return false; + } + + x0 = 0; + y0 = 0; + delay_num = 1; + delay_den = 10; + dop = 0; + bop = 0; + rowbytes = w * 4; + imagesize = h * rowbytes; + + frameRaw.p = new unsigned char[imagesize]; + frameRaw.rows = new png_bytep[h * sizeof(png_bytep)]; + for (j = 0; j < h; j++) frameRaw.rows[j] = frameRaw.p + j * rowbytes; + + if (!processing_start(png_ptr, info_ptr, (void*)&frameRaw, hasInfo, + chunkIHDR, chunksInfo)) { + bool last_base_was_none = true; + while (!feof(f)) { + id = read_chunk(f, &chunk); + if (!id) break; + JXL_ASSERT(chunk.p != nullptr); + + if (id == id_acTL && !hasInfo && !isAnimated) { + isAnimated = true; + skipFirst = true; + io->metadata.m.have_animation = true; + io->metadata.m.animation.tps_numerator = 1000; + } else if (id == id_IEND || + (id == id_fcTL && (!hasInfo || isAnimated))) { + if (hasInfo) { + if (!processing_finish(png_ptr, info_ptr)) { + ImageBundle bundle(&io->metadata.m); + bundle.duration = delay_num * 1000 / delay_den; + bundle.origin.x0 = x0; + bundle.origin.y0 = y0; + // TODO(veluca): this could in principle be implemented. + if (last_base_was_none && + (x0 != 0 || y0 != 0 || w0 != w || h0 != h || bop != 0)) { + return JXL_FAILURE( + "APNG with dispose-to-0 is not supported for non-full or " + "blended frames"); + } + switch (dop) { + case 0: + bundle.use_for_next_frame = true; + last_base_was_none = false; + break; + case 2: + bundle.use_for_next_frame = false; + break; + default: + bundle.use_for_next_frame = false; + last_base_was_none = true; + } + bundle.blend = bop != 0; + io->dec_pixels += w0 * h0; + + Image3F sub_frame(w0, h0); + ImageF sub_frame_alpha(w0, h0); + for (size_t y = 0; y < h0; ++y) { + float* const JXL_RESTRICT row_r = sub_frame.PlaneRow(0, y); + float* const JXL_RESTRICT row_g = sub_frame.PlaneRow(1, y); + float* const JXL_RESTRICT row_b = sub_frame.PlaneRow(2, y); + float* const JXL_RESTRICT row_alpha = sub_frame_alpha.Row(y); + uint8_t* const f = frameRaw.rows[y]; + for (size_t x = 0; x < w0; ++x) { + if (f[4 * x + 3] == 0) { + row_alpha[x] = 0; + row_r[x] = 0; + row_g[x] = 0; + row_b[x] = 0; + continue; + } + row_r[x] = f[4 * x + 0] * (1.f / 255); + row_g[x] = f[4 * x + 1] * (1.f / 255); + row_b[x] = f[4 * x + 2] * (1.f / 255); + row_alpha[x] = f[4 * x + 3] * (1.f / 255); + } + } + bundle.SetFromImage(std::move(sub_frame), ColorEncoding::SRGB()); + bundle.SetAlpha(std::move(sub_frame_alpha), + /*alpha_is_premultiplied=*/false); + io->frames.push_back(std::move(bundle)); + } else { + delete[] chunk.p; + break; + } + } + + if (id == id_IEND) { + errorstate = false; + break; + } + // At this point the old frame is done. Let's start a new one. + w0 = png_get_uint_32(chunk.p + 12); + h0 = png_get_uint_32(chunk.p + 16); + x0 = png_get_uint_32(chunk.p + 20); + y0 = png_get_uint_32(chunk.p + 24); + delay_num = png_get_uint_16(chunk.p + 28); + delay_den = png_get_uint_16(chunk.p + 30); + dop = chunk.p[32]; + bop = chunk.p[33]; + + if (w0 > cMaxPNGSize || h0 > cMaxPNGSize || x0 > cMaxPNGSize || + y0 > cMaxPNGSize || x0 + w0 > w || y0 + h0 > h || dop > 2 || + bop > 1) { + delete[] chunk.p; + break; + } + + if (hasInfo) { + memcpy(chunkIHDR.p + 8, chunk.p + 12, 8); + if (processing_start(png_ptr, info_ptr, (void*)&frameRaw, hasInfo, + chunkIHDR, chunksInfo)) { + delete[] chunk.p; + break; + } + } else + skipFirst = false; + + if (io->frames.size() == (skipFirst ? 1 : 0)) { + bop = 0; + if (dop == 2) dop = 1; + } + } else if (id == id_IDAT) { + hasInfo = true; + if (processing_data(png_ptr, info_ptr, chunk.p, chunk.size)) { + delete[] chunk.p; + break; + } + } else if (id == id_fdAT && isAnimated) { + png_save_uint_32(chunk.p + 4, chunk.size - 16); + memcpy(chunk.p + 8, "IDAT", 4); + if (processing_data(png_ptr, info_ptr, chunk.p + 4, chunk.size - 4)) { + delete[] chunk.p; + break; + } + } else if (notabc(chunk.p[4]) || notabc(chunk.p[5]) || + notabc(chunk.p[6]) || notabc(chunk.p[7])) { + delete[] chunk.p; + break; + } else if (!hasInfo) { + if (processing_data(png_ptr, info_ptr, chunk.p, chunk.size)) { + delete[] chunk.p; + break; + } + chunksInfo.push_back(chunk); + continue; + } + delete[] chunk.p; + } + } + delete[] frameRaw.rows; + delete[] frameRaw.p; + } + + for (i = 0; i < chunksInfo.size(); i++) delete[] chunksInfo[i].p; + + chunksInfo.clear(); + delete[] chunkIHDR.p; + + fclose(f); + + if (errorstate) return false; + SetIntensityTarget(io); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_apng.h b/third_party/jpeg-xl/lib/extras/codec_apng.h new file mode 100644 index 000000000000..d3a7bbc8e998 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_apng.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_APNG_H_ +#define LIB_EXTRAS_CODEC_APNG_H_ + +// Decodes APNG images in memory. + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints are ignored. +Status DecodeImageAPNG(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_APNG_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_exr.cc b/third_party/jpeg-xl/lib/extras/codec_exr.cc new file mode 100644 index 000000000000..7700d5c9fd87 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_exr.cc @@ -0,0 +1,359 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_exr.h" + +#include +#include +#include +#include + +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" + +namespace jxl { + +namespace { + +namespace OpenEXR = OPENEXR_IMF_NAMESPACE; +namespace Imath = IMATH_NAMESPACE; + +// OpenEXR::Int64 is deprecated in favor of using uint64_t directly, but using +// uint64_t as recommended causes build failures with previous OpenEXR versions +// on macOS, where the definition for OpenEXR::Int64 was actually not equivalent +// to uint64_t. This alternative should work in all cases. +using ExrInt64 = decltype(std::declval().tellg()); + +constexpr int kExrBitsPerSample = 16; +constexpr int kExrAlphaBits = 16; + +float GetIntensityTarget(const CodecInOut& io, + const OpenEXR::Header& exr_header) { + if (OpenEXR::hasWhiteLuminance(exr_header)) { + const float exr_luminance = OpenEXR::whiteLuminance(exr_header); + if (io.target_nits != 0) { + JXL_WARNING( + "overriding OpenEXR whiteLuminance of %g with user-specified value " + "of %g", + exr_luminance, io.target_nits); + return io.target_nits; + } + return exr_luminance; + } + if (io.target_nits != 0) { + return io.target_nits; + } + JXL_WARNING( + "no OpenEXR whiteLuminance tag found and no intensity_target specified, " + "defaulting to %g", + kDefaultIntensityTarget); + return kDefaultIntensityTarget; +} + +size_t GetNumThreads(ThreadPool* pool) { + size_t exr_num_threads = 1; + RunOnPool( + pool, 0, 1, + [&](size_t num_threads) { + exr_num_threads = num_threads; + return true; + }, + [&](const int /* task */, const int /*thread*/) {}, + "DecodeImageEXRThreads"); + return exr_num_threads; +} + +class InMemoryIStream : public OpenEXR::IStream { + public: + // The data pointed to by `bytes` must outlive the InMemoryIStream. + explicit InMemoryIStream(const Span bytes) + : IStream(/*fileName=*/""), bytes_(bytes) {} + + bool isMemoryMapped() const override { return true; } + char* readMemoryMapped(const int n) override { + JXL_ASSERT(pos_ + n <= bytes_.size()); + char* const result = + const_cast(reinterpret_cast(bytes_.data() + pos_)); + pos_ += n; + return result; + } + bool read(char c[], const int n) override { + std::copy_n(readMemoryMapped(n), n, c); + return pos_ < bytes_.size(); + } + + ExrInt64 tellg() override { return pos_; } + void seekg(const ExrInt64 pos) override { + JXL_ASSERT(pos + 1 <= bytes_.size()); + pos_ = pos; + } + + private: + const Span bytes_; + size_t pos_ = 0; +}; + +class InMemoryOStream : public OpenEXR::OStream { + public: + // `bytes` must outlive the InMemoryOStream. + explicit InMemoryOStream(PaddedBytes* const bytes) + : OStream(/*fileName=*/""), bytes_(*bytes) {} + + void write(const char c[], const int n) override { + if (bytes_.size() < pos_ + n) { + bytes_.resize(pos_ + n); + } + std::copy_n(c, n, bytes_.begin() + pos_); + pos_ += n; + } + + ExrInt64 tellp() override { return pos_; } + void seekp(const ExrInt64 pos) override { + if (bytes_.size() + 1 < pos) { + bytes_.resize(pos - 1); + } + pos_ = pos; + } + + private: + PaddedBytes& bytes_; + size_t pos_ = 0; +}; + +} // namespace + +Status DecodeImageEXR(Span bytes, ThreadPool* pool, + CodecInOut* io) { + // Get the number of threads we should be using for OpenEXR. + // OpenEXR creates its own set of threads, independent from ours. `pool` is + // only used for converting from a buffer of OpenEXR::Rgba to Image3F. + // TODO(sboukortt): look into changing that with OpenEXR 2.3 which allows + // custom thread pools according to its changelog. + OpenEXR::setGlobalThreadCount(GetNumThreads(pool)); + + InMemoryIStream is(bytes); + +#ifdef __EXCEPTIONS + std::unique_ptr input_ptr; + try { + input_ptr.reset(new OpenEXR::RgbaInputFile(is)); + } catch (...) { + return JXL_FAILURE("OpenEXR failed to parse input"); + } + OpenEXR::RgbaInputFile& input = *input_ptr; +#else + OpenEXR::RgbaInputFile input(is); +#endif + + if ((input.channels() & OpenEXR::RgbaChannels::WRITE_RGB) != + OpenEXR::RgbaChannels::WRITE_RGB) { + return JXL_FAILURE("only RGB OpenEXR files are supported"); + } + const bool has_alpha = (input.channels() & OpenEXR::RgbaChannels::WRITE_A) == + OpenEXR::RgbaChannels::WRITE_A; + + const float intensity_target = GetIntensityTarget(*io, input.header()); + + auto image_size = input.displayWindow().size(); + // Size is computed as max - min, but both bounds are inclusive. + ++image_size.x; + ++image_size.y; + Image3F image(image_size.x, image_size.y); + ZeroFillImage(&image); + ImageF alpha; + if (has_alpha) { + alpha = ImageF(image_size.x, image_size.y); + FillImage(1.f, &alpha); + } + + const int row_size = input.dataWindow().size().x + 1; + // Number of rows to read at a time. + // https://www.openexr.com/documentation/ReadingAndWritingImageFiles.pdf + // recommends reading the whole file at once. + const int y_chunk_size = input.displayWindow().size().y + 1; + std::vector input_rows(row_size * y_chunk_size); + for (int start_y = + std::max(input.dataWindow().min.y, input.displayWindow().min.y); + start_y <= + std::min(input.dataWindow().max.y, input.displayWindow().max.y); + start_y += y_chunk_size) { + // Inclusive. + const int end_y = std::min( + start_y + y_chunk_size - 1, + std::min(input.dataWindow().max.y, input.displayWindow().max.y)); + input.setFrameBuffer( + input_rows.data() - input.dataWindow().min.x - start_y * row_size, + /*xStride=*/1, /*yStride=*/row_size); + input.readPixels(start_y, end_y); + RunOnPool( + pool, start_y, end_y + 1, ThreadPool::SkipInit(), + [&](const int exr_y, const int /*thread*/) { + const int image_y = exr_y - input.displayWindow().min.y; + const OpenEXR::Rgba* const JXL_RESTRICT input_row = + &input_rows[(exr_y - start_y) * row_size]; + float* const JXL_RESTRICT rows[] = { + image.PlaneRow(0, image_y), + image.PlaneRow(1, image_y), + image.PlaneRow(2, image_y), + }; + float* const JXL_RESTRICT alpha_row = + has_alpha ? alpha.Row(image_y) : nullptr; + for (int exr_x = std::max(input.dataWindow().min.x, + input.displayWindow().min.x); + exr_x <= + std::min(input.dataWindow().max.x, input.displayWindow().max.x); + ++exr_x) { + const int image_x = exr_x - input.displayWindow().min.x; + const OpenEXR::Rgba& pixel = + input_row[exr_x - input.dataWindow().min.x]; + rows[0][image_x] = pixel.r; + rows[1][image_x] = pixel.g; + rows[2][image_x] = pixel.b; + if (has_alpha) { + alpha_row[image_x] = pixel.a; + } + } + }, + "DecodeImageEXR"); + } + + ColorEncoding color_encoding; + color_encoding.tf.SetTransferFunction(TransferFunction::kLinear); + color_encoding.SetColorSpace(ColorSpace::kRGB); + PrimariesCIExy primaries = ColorEncoding::SRGB().GetPrimaries(); + CIExy white_point = ColorEncoding::SRGB().GetWhitePoint(); + if (OpenEXR::hasChromaticities(input.header())) { + const auto& chromaticities = OpenEXR::chromaticities(input.header()); + primaries.r.x = chromaticities.red.x; + primaries.r.y = chromaticities.red.y; + primaries.g.x = chromaticities.green.x; + primaries.g.y = chromaticities.green.y; + primaries.b.x = chromaticities.blue.x; + primaries.b.y = chromaticities.blue.y; + white_point.x = chromaticities.white.x; + white_point.y = chromaticities.white.y; + } + JXL_RETURN_IF_ERROR(color_encoding.SetPrimaries(primaries)); + JXL_RETURN_IF_ERROR(color_encoding.SetWhitePoint(white_point)); + JXL_RETURN_IF_ERROR(color_encoding.CreateICC()); + + io->metadata.m.bit_depth.bits_per_sample = kExrBitsPerSample; + // EXR uses binary16 or binary32 floating point format. + io->metadata.m.bit_depth.exponent_bits_per_sample = + kExrBitsPerSample == 16 ? 5 : 8; + io->metadata.m.bit_depth.floating_point_sample = true; + io->SetFromImage(std::move(image), color_encoding); + io->metadata.m.color_encoding = color_encoding; + io->metadata.m.SetIntensityTarget(intensity_target); + if (has_alpha) { + io->metadata.m.SetAlphaBits(kExrAlphaBits, /*alpha_is_premultiplied=*/true); + io->Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/true); + } + return true; +} + +Status EncodeImageEXR(const CodecInOut* io, const ColorEncoding& c_desired, + ThreadPool* pool, PaddedBytes* bytes) { + // As in `DecodeImageEXR`, `pool` is only used for pixel conversion, not for + // actual OpenEXR I/O. + OpenEXR::setGlobalThreadCount(GetNumThreads(pool)); + + ColorEncoding c_linear = c_desired; + c_linear.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_RETURN_IF_ERROR(c_linear.CreateICC()); + ImageMetadata metadata = io->metadata.m; + ImageBundle store(&metadata); + const ImageBundle* linear; + JXL_RETURN_IF_ERROR( + TransformIfNeeded(io->Main(), c_linear, pool, &store, &linear)); + + const bool has_alpha = io->Main().HasAlpha(); + const bool alpha_is_premultiplied = io->Main().AlphaIsPremultiplied(); + + OpenEXR::Header header(io->xsize(), io->ysize()); + const PrimariesCIExy& primaries = c_linear.HasPrimaries() + ? c_linear.GetPrimaries() + : ColorEncoding::SRGB().GetPrimaries(); + OpenEXR::Chromaticities chromaticities; + chromaticities.red = Imath::V2f(primaries.r.x, primaries.r.y); + chromaticities.green = Imath::V2f(primaries.g.x, primaries.g.y); + chromaticities.blue = Imath::V2f(primaries.b.x, primaries.b.y); + chromaticities.white = + Imath::V2f(c_linear.GetWhitePoint().x, c_linear.GetWhitePoint().y); + OpenEXR::addChromaticities(header, chromaticities); + OpenEXR::addWhiteLuminance(header, io->metadata.m.IntensityTarget()); + + // Ensure that the destructor of RgbaOutputFile has run before we look at the + // size of `bytes`. + { + InMemoryOStream os(bytes); + OpenEXR::RgbaOutputFile output( + os, header, has_alpha ? OpenEXR::WRITE_RGBA : OpenEXR::WRITE_RGB); + // How many rows to write at once. Again, the OpenEXR documentation + // recommends writing the whole image in one call. + const int y_chunk_size = io->ysize(); + std::vector output_rows(io->xsize() * y_chunk_size); + + for (size_t start_y = 0; start_y < io->ysize(); start_y += y_chunk_size) { + // Inclusive. + const size_t end_y = + std::min(start_y + y_chunk_size - 1, io->ysize() - 1); + output.setFrameBuffer(output_rows.data() - start_y * io->xsize(), + /*xStride=*/1, /*yStride=*/io->xsize()); + RunOnPool( + pool, start_y, end_y + 1, ThreadPool::SkipInit(), + [&](const int y, const int /*thread*/) { + const float* const JXL_RESTRICT input_rows[] = { + linear->color().ConstPlaneRow(0, y), + linear->color().ConstPlaneRow(1, y), + linear->color().ConstPlaneRow(2, y), + }; + OpenEXR::Rgba* const JXL_RESTRICT row_data = + &output_rows[(y - start_y) * io->xsize()]; + if (has_alpha) { + const float* const JXL_RESTRICT alpha_row = + io->Main().alpha().ConstRow(y); + if (alpha_is_premultiplied) { + for (size_t x = 0; x < io->xsize(); ++x) { + row_data[x] = + OpenEXR::Rgba(input_rows[0][x], input_rows[1][x], + input_rows[2][x], alpha_row[x]); + } + } else { + for (size_t x = 0; x < io->xsize(); ++x) { + row_data[x] = OpenEXR::Rgba(alpha_row[x] * input_rows[0][x], + alpha_row[x] * input_rows[1][x], + alpha_row[x] * input_rows[2][x], + alpha_row[x]); + } + } + } else { + for (size_t x = 0; x < io->xsize(); ++x) { + row_data[x] = OpenEXR::Rgba(input_rows[0][x], input_rows[1][x], + input_rows[2][x], 1.f); + } + } + }, + "EncodeImageEXR"); + output.writePixels(/*numScanLines=*/end_y - start_y + 1); + } + } + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_exr.h b/third_party/jpeg-xl/lib/extras/codec_exr.h new file mode 100644 index 000000000000..5a28ccc5dca4 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_exr.h @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_EXR_H_ +#define LIB_EXTRAS_CODEC_EXR_H_ + +// Encodes OpenEXR images in memory. + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints are ignored. +Status DecodeImageEXR(Span bytes, ThreadPool* pool, + CodecInOut* io); + +// Transforms from io->c_current to `c_desired` (with the transfer function set +// to linear as that is the OpenEXR convention) and encodes into `bytes`. +Status EncodeImageEXR(const CodecInOut* io, const ColorEncoding& c_desired, + ThreadPool* pool, PaddedBytes* bytes); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_EXR_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_gif.cc b/third_party/jpeg-xl/lib/extras/codec_gif.cc new file mode 100644 index 000000000000..98bd16f10f19 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_gif.cc @@ -0,0 +1,362 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_gif.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/luminance.h" + +#ifdef MEMORY_SANITIZER +#include "sanitizer/msan_interface.h" +#endif + +namespace jxl { + +namespace { + +struct ReadState { + Span bytes; +}; + +struct DGifCloser { + void operator()(GifFileType* const ptr) const { DGifCloseFile(ptr, nullptr); } +}; +using GifUniquePtr = std::unique_ptr; + +// Gif does not support partial transparency, so this considers anything non-0 +// as opaque. +bool AllOpaque(const ImageF& alpha) { + for (size_t y = 0; y < alpha.ysize(); ++y) { + const float* const JXL_RESTRICT row = alpha.ConstRow(y); + for (size_t x = 0; x < alpha.xsize(); ++x) { + if (row[x] == 0.f) { + return false; + } + } + } + return true; +} + +} // namespace + +Status DecodeImageGIF(Span bytes, ThreadPool* pool, + CodecInOut* io) { + int error = GIF_OK; + ReadState state = {bytes}; + const auto ReadFromSpan = [](GifFileType* const gif, GifByteType* const bytes, + int n) { + ReadState* const state = reinterpret_cast(gif->UserData); + // giflib API requires the input size `n` to be signed int. + if (static_cast(n) > state->bytes.size()) { + n = state->bytes.size(); + } + memcpy(bytes, state->bytes.data(), n); + state->bytes.remove_prefix(n); + return n; + }; + GifUniquePtr gif(DGifOpen(&state, ReadFromSpan, &error)); + if (gif == nullptr) { + if (error == D_GIF_ERR_NOT_GIF_FILE) { + // Not an error. + return false; + } else { + return JXL_FAILURE("Failed to read GIF: %s", GifErrorString(error)); + } + } + error = DGifSlurp(gif.get()); + if (error != GIF_OK) { + return JXL_FAILURE("Failed to read GIF: %s", GifErrorString(gif->Error)); + } + +#ifdef MEMORY_SANITIZER + __msan_unpoison(gif.get(), sizeof(*gif)); + if (gif->SColorMap) { + __msan_unpoison(gif->SColorMap, sizeof(*gif->SColorMap)); + __msan_unpoison(gif->SColorMap->Colors, sizeof(*gif->SColorMap->Colors) * + gif->SColorMap->ColorCount); + } + __msan_unpoison(gif->SavedImages, + sizeof(*gif->SavedImages) * gif->ImageCount); +#endif + + const SizeConstraints* constraints = &io->constraints; + + JXL_RETURN_IF_ERROR( + VerifyDimensions(constraints, gif->SWidth, gif->SHeight)); + uint64_t total_pixel_count = + static_cast(gif->SWidth) * gif->SHeight; + for (int i = 0; i < gif->ImageCount; ++i) { + const SavedImage& image = gif->SavedImages[i]; + uint32_t w = image.ImageDesc.Width; + uint32_t h = image.ImageDesc.Height; + JXL_RETURN_IF_ERROR(VerifyDimensions(constraints, w, h)); + uint64_t pixel_count = static_cast(w) * h; + if (total_pixel_count + pixel_count < total_pixel_count) { + return JXL_FAILURE("Image too big"); + } + total_pixel_count += pixel_count; + if (total_pixel_count > constraints->dec_max_pixels) { + return JXL_FAILURE("Image too big"); + } + } + + if (!gif->SColorMap) { + for (int i = 0; i < gif->ImageCount; ++i) { + if (!gif->SavedImages[i].ImageDesc.ColorMap) { + return JXL_FAILURE("Missing GIF color map"); + } + } + } + + if (gif->ImageCount > 1) { + io->metadata.m.have_animation = true; + // Delays in GIF are specified in 100ths of a second. + io->metadata.m.animation.tps_numerator = 100; + } + + io->frames.clear(); + io->frames.reserve(gif->ImageCount); + io->dec_pixels = 0; + + io->metadata.m.SetUintSamples(8); + io->metadata.m.color_encoding = ColorEncoding::SRGB(); + io->metadata.m.SetAlphaBits(0); + (void)io->dec_hints.Foreach( + [](const std::string& key, const std::string& /*value*/) { + JXL_WARNING("GIF decoder ignoring %s hint", key.c_str()); + return true; + }); + + Image3F canvas(gif->SWidth, gif->SHeight); + io->SetSize(gif->SWidth, gif->SHeight); + ImageF alpha(gif->SWidth, gif->SHeight); + GifColorType background_color; + if (gif->SColorMap == nullptr) { + background_color = {0, 0, 0}; + } else { + if (gif->SBackGroundColor >= gif->SColorMap->ColorCount) { + return JXL_FAILURE("GIF specifies out-of-bounds background color"); + } + background_color = gif->SColorMap->Colors[gif->SBackGroundColor]; + } + FillPlane(background_color.Red, &canvas.Plane(0)); + FillPlane(background_color.Green, &canvas.Plane(1)); + FillPlane(background_color.Blue, &canvas.Plane(2)); + ZeroFillImage(&alpha); + + Rect previous_rect_if_restore_to_background; + + bool has_alpha = false; + bool replace = true; + bool last_base_was_none = true; + for (int i = 0; i < gif->ImageCount; ++i) { + const SavedImage& image = gif->SavedImages[i]; +#ifdef MEMORY_SANITIZER + __msan_unpoison(image.RasterBits, sizeof(*image.RasterBits) * + image.ImageDesc.Width * + image.ImageDesc.Height); +#endif + const Rect image_rect(image.ImageDesc.Left, image.ImageDesc.Top, + image.ImageDesc.Width, image.ImageDesc.Height); + io->dec_pixels += image_rect.xsize() * image_rect.ysize(); + Rect total_rect; + if (previous_rect_if_restore_to_background.xsize() != 0 || + previous_rect_if_restore_to_background.ysize() != 0) { + const size_t xbegin = std::min( + image_rect.x0(), previous_rect_if_restore_to_background.x0()); + const size_t ybegin = std::min( + image_rect.y0(), previous_rect_if_restore_to_background.y0()); + const size_t xend = + std::max(image_rect.x0() + image_rect.xsize(), + previous_rect_if_restore_to_background.x0() + + previous_rect_if_restore_to_background.xsize()); + const size_t yend = + std::max(image_rect.y0() + image_rect.ysize(), + previous_rect_if_restore_to_background.y0() + + previous_rect_if_restore_to_background.ysize()); + total_rect = Rect(xbegin, ybegin, xend - xbegin, yend - ybegin); + previous_rect_if_restore_to_background = Rect(); + replace = true; + } else { + total_rect = image_rect; + replace = false; + } + if (!image_rect.IsInside(canvas)) { + return JXL_FAILURE("GIF frame extends outside of the canvas"); + } + const ColorMapObject* const color_map = + image.ImageDesc.ColorMap ? image.ImageDesc.ColorMap : gif->SColorMap; + JXL_CHECK(color_map); +#ifdef MEMORY_SANITIZER + __msan_unpoison(color_map, sizeof(*color_map)); + __msan_unpoison(color_map->Colors, + sizeof(*color_map->Colors) * color_map->ColorCount); +#endif + GraphicsControlBlock gcb; + DGifSavedExtensionToGCB(gif.get(), i, &gcb); +#ifdef MEMORY_SANITIZER + __msan_unpoison(&gcb, sizeof(gcb)); +#endif + + ImageBundle bundle(&io->metadata.m); + if (io->metadata.m.have_animation) { + bundle.duration = gcb.DelayTime; + bundle.origin.x0 = total_rect.x0(); + bundle.origin.y0 = total_rect.y0(); + if (last_base_was_none) { + replace = true; + } + bundle.blend = !replace; + // TODO(veluca): this could in principle be implemented. + if (last_base_was_none && + (total_rect.x0() != 0 || total_rect.y0() != 0 || + total_rect.xsize() != canvas.xsize() || + total_rect.ysize() != canvas.ysize() || !replace)) { + return JXL_FAILURE( + "GIF with dispose-to-0 is not supported for non-full or " + "blended frames"); + } + switch (gcb.DisposalMode) { + case DISPOSE_DO_NOT: + case DISPOSE_BACKGROUND: + bundle.use_for_next_frame = true; + last_base_was_none = false; + break; + case DISPOSE_PREVIOUS: + bundle.use_for_next_frame = false; + break; + default: + bundle.use_for_next_frame = false; + last_base_was_none = true; + } + } + Image3F frame = CopyImage(canvas); + ImageF frame_alpha = CopyImage(alpha); + for (size_t y = 0, byte_index = 0; y < image_rect.ysize(); ++y) { + float* const JXL_RESTRICT row_r = image_rect.Row(&frame.Plane(0), y); + float* const JXL_RESTRICT row_g = image_rect.Row(&frame.Plane(1), y); + float* const JXL_RESTRICT row_b = image_rect.Row(&frame.Plane(2), y); + float* const JXL_RESTRICT row_alpha = image_rect.Row(&frame_alpha, y); + for (size_t x = 0; x < image_rect.xsize(); ++x, ++byte_index) { + const GifByteType byte = image.RasterBits[byte_index]; + if (byte >= color_map->ColorCount) { + return JXL_FAILURE("GIF color is out of bounds"); + } + if (byte == gcb.TransparentColor) continue; + GifColorType color = color_map->Colors[byte]; + row_alpha[x] = 1.f; + row_r[x] = (1.f / 255) * color.Red; + row_g[x] = (1.f / 255) * color.Green; + row_b[x] = (1.f / 255) * color.Blue; + } + } + Image3F sub_frame(total_rect.xsize(), total_rect.ysize()); + ImageF sub_frame_alpha(total_rect.xsize(), total_rect.ysize()); + bool blend_alpha = false; + if (replace) { + CopyImageTo(total_rect, frame, &sub_frame); + CopyImageTo(total_rect, frame_alpha, &sub_frame_alpha); + } else { + for (size_t y = 0, byte_index = 0; y < image_rect.ysize(); ++y) { + float* const JXL_RESTRICT row_r = sub_frame.PlaneRow(0, y); + float* const JXL_RESTRICT row_g = sub_frame.PlaneRow(1, y); + float* const JXL_RESTRICT row_b = sub_frame.PlaneRow(2, y); + float* const JXL_RESTRICT row_alpha = sub_frame_alpha.Row(y); + for (size_t x = 0; x < image_rect.xsize(); ++x, ++byte_index) { + const GifByteType byte = image.RasterBits[byte_index]; + if (byte > color_map->ColorCount) { + return JXL_FAILURE("GIF color is out of bounds"); + } + if (byte == gcb.TransparentColor) { + row_alpha[x] = 0; + row_r[x] = 0; + row_g[x] = 0; + row_b[x] = 0; + blend_alpha = + true; // need to use alpha channel if BlendMode blend is used + continue; + } + GifColorType color = color_map->Colors[byte]; + row_alpha[x] = 1.f; + row_r[x] = (1.f / 255) * color.Red; + row_g[x] = (1.f / 255) * color.Green; + row_b[x] = (1.f / 255) * color.Blue; + } + } + } + bundle.SetFromImage(std::move(sub_frame), ColorEncoding::SRGB()); + if (has_alpha || !AllOpaque(frame_alpha) || blend_alpha) { + if (!has_alpha) { + has_alpha = true; + io->metadata.m.SetAlphaBits(8); + for (ImageBundle& previous_frame : io->frames) { + ImageF previous_alpha(previous_frame.xsize(), previous_frame.ysize()); + FillImage(1.f, &previous_alpha); + previous_frame.SetAlpha(std::move(previous_alpha), + /*alpha_is_premultiplied=*/false); + } + } + bundle.SetAlpha(std::move(sub_frame_alpha), + /*alpha_is_premultiplied=*/false); + } + io->frames.push_back(std::move(bundle)); + switch (gcb.DisposalMode) { + case DISPOSE_DO_NOT: + canvas = std::move(frame); + alpha = std::move(frame_alpha); + break; + + case DISPOSE_BACKGROUND: + FillPlane((1.f / 255) * background_color.Red, &canvas.Plane(0), + image_rect); + FillPlane((1.f / 255) * background_color.Green, &canvas.Plane(1), + image_rect); + FillPlane((1.f / 255) * background_color.Blue, &canvas.Plane(2), + image_rect); + FillPlane(0.f, &alpha, image_rect); + previous_rect_if_restore_to_background = image_rect; + break; + + case DISPOSE_PREVIOUS: + break; + + case DISPOSAL_UNSPECIFIED: + default: + FillPlane((1.f / 255) * background_color.Red, &canvas.Plane(0)); + FillPlane((1.f / 255) * background_color.Green, + &canvas.Plane(1)); + FillPlane((1.f / 255) * background_color.Blue, &canvas.Plane(2)); + ZeroFillImage(&alpha); + } + } + + SetIntensityTarget(io); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_gif.h b/third_party/jpeg-xl/lib/extras/codec_gif.h new file mode 100644 index 000000000000..af9799512bdb --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_gif.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_GIF_H_ +#define LIB_EXTRAS_CODEC_GIF_H_ + +// Decodes GIF images in memory. + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints are ignored. +Status DecodeImageGIF(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_GIF_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_jpg.cc b/third_party/jpeg-xl/lib/extras/codec_jpg.cc new file mode 100644 index 000000000000..eefc3d0dbc69 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_jpg.cc @@ -0,0 +1,538 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_jpg.h" + +#include +#include + +#if JPEGXL_ENABLE_JPEG +// After stddef/stdio +#include +#include +#include +#endif // JPEGXL_ENABLE_JPEG + +#include +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/time.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" +#include "lib/jxl/luminance.h" +#if JPEGXL_ENABLE_SJPEG +#include "sjpeg.h" +#endif + +#ifdef MEMORY_SANITIZER +#include "sanitizer/msan_interface.h" +#endif + +namespace jxl { + +#if JPEGXL_ENABLE_JPEG +namespace { + +constexpr float kJPEGSampleMultiplier = MAXJSAMPLE; +constexpr unsigned char kICCSignature[12] = { + 0x49, 0x43, 0x43, 0x5F, 0x50, 0x52, 0x4F, 0x46, 0x49, 0x4C, 0x45, 0x00}; +constexpr int kICCMarker = JPEG_APP0 + 2; +constexpr size_t kMaxBytesInMarker = 65533; + +constexpr unsigned char kExifSignature[6] = {0x45, 0x78, 0x69, + 0x66, 0x00, 0x00}; +constexpr int kExifMarker = JPEG_APP0 + 1; + +constexpr float kJPEGSampleMin = 0; +constexpr float kJPEGSampleMax = MAXJSAMPLE; + +bool MarkerIsICC(const jpeg_saved_marker_ptr marker) { + return marker->marker == kICCMarker && + marker->data_length >= sizeof kICCSignature + 2 && + std::equal(std::begin(kICCSignature), std::end(kICCSignature), + marker->data); +} +bool MarkerIsExif(const jpeg_saved_marker_ptr marker) { + return marker->marker == kExifMarker && + marker->data_length >= sizeof kExifSignature + 2 && + std::equal(std::begin(kExifSignature), std::end(kExifSignature), + marker->data); +} + +Status ReadICCProfile(jpeg_decompress_struct* const cinfo, + PaddedBytes* const icc) { + constexpr size_t kICCSignatureSize = sizeof kICCSignature; + // ICC signature + uint8_t index + uint8_t max_index. + constexpr size_t kICCHeadSize = kICCSignatureSize + 2; + // Markers are 1-indexed, and we keep them that way in this vector to get a + // convenient 0 at the front for when we compute the offsets later. + std::vector marker_lengths; + int num_markers = 0; + int seen_markers_count = 0; + bool has_num_markers = false; + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { +#ifdef MEMORY_SANITIZER + // marker is initialized by libjpeg, which we are not instrumenting with + // msan. + __msan_unpoison(marker, sizeof(*marker)); + __msan_unpoison(marker->data, marker->data_length); +#endif + if (!MarkerIsICC(marker)) continue; + + const int current_marker = marker->data[kICCSignatureSize]; + if (current_marker == 0) { + return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); + } + const int current_num_markers = marker->data[kICCSignatureSize + 1]; + if (current_marker > current_num_markers) { + return JXL_FAILURE("inconsistent JPEG ICC marker numbering"); + } + if (has_num_markers) { + if (current_num_markers != num_markers) { + return JXL_FAILURE("inconsistent numbers of JPEG ICC markers"); + } + } else { + num_markers = current_num_markers; + has_num_markers = true; + marker_lengths.resize(num_markers + 1); + } + + size_t marker_length = marker->data_length - kICCHeadSize; + + if (marker_length == 0) { + // NB: if we allow empty chunks, then the next check is incorrect. + return JXL_FAILURE("Empty ICC chunk"); + } + + if (marker_lengths[current_marker] != 0) { + return JXL_FAILURE("duplicate JPEG ICC marker number"); + } + marker_lengths[current_marker] = marker_length; + seen_markers_count++; + } + + if (marker_lengths.empty()) { + // Not an error. + return false; + } + + if (seen_markers_count != num_markers) { + JXL_DASSERT(has_num_markers); + return JXL_FAILURE("Incomplete set of ICC chunks"); + } + + std::vector offsets = std::move(marker_lengths); + std::partial_sum(offsets.begin(), offsets.end(), offsets.begin()); + icc->resize(offsets.back()); + + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { + if (!MarkerIsICC(marker)) continue; + const uint8_t* first = marker->data + kICCHeadSize; + uint8_t current_marker = marker->data[kICCSignatureSize]; + size_t offset = offsets[current_marker - 1]; + size_t marker_length = offsets[current_marker] - offset; + std::copy_n(first, marker_length, icc->data() + offset); + } + + return true; +} + +void ReadExif(jpeg_decompress_struct* const cinfo, PaddedBytes* const exif) { + constexpr size_t kExifSignatureSize = sizeof kExifSignature; + for (jpeg_saved_marker_ptr marker = cinfo->marker_list; marker != nullptr; + marker = marker->next) { +#ifdef MEMORY_SANITIZER + // marker is initialized by libjpeg, which we are not instrumenting with + // msan. + __msan_unpoison(marker, sizeof(*marker)); + __msan_unpoison(marker->data, marker->data_length); +#endif + if (!MarkerIsExif(marker)) continue; + size_t marker_length = marker->data_length - kExifSignatureSize; + exif->resize(marker_length); + std::copy_n(marker->data + kExifSignatureSize, marker_length, exif->data()); + return; + } +} + +// TODO (jon): take orientation into account when writing jpeg output +// TODO (jon): write Exif blob also in sjpeg encoding +// TODO (jon): overwrite orientation in Exif blob to avoid double orientation + +void WriteICCProfile(jpeg_compress_struct* const cinfo, + const PaddedBytes& icc) { + constexpr size_t kMaxIccBytesInMarker = + kMaxBytesInMarker - sizeof kICCSignature - 2; + const int num_markers = + static_cast(DivCeil(icc.size(), kMaxIccBytesInMarker)); + size_t begin = 0; + for (int current_marker = 0; current_marker < num_markers; ++current_marker) { + const size_t length = std::min(kMaxIccBytesInMarker, icc.size() - begin); + jpeg_write_m_header( + cinfo, kICCMarker, + static_cast(length + sizeof kICCSignature + 2)); + for (const unsigned char c : kICCSignature) { + jpeg_write_m_byte(cinfo, c); + } + jpeg_write_m_byte(cinfo, current_marker + 1); + jpeg_write_m_byte(cinfo, num_markers); + for (size_t i = 0; i < length; ++i) { + jpeg_write_m_byte(cinfo, icc[begin]); + ++begin; + } + } +} +void WriteExif(jpeg_compress_struct* const cinfo, const PaddedBytes& exif) { + if (exif.size() < 4) return; + jpeg_write_m_header( + cinfo, kExifMarker, + static_cast(exif.size() - 4 + sizeof kExifSignature)); + for (const unsigned char c : kExifSignature) { + jpeg_write_m_byte(cinfo, c); + } + for (size_t i = 4; i < exif.size(); ++i) { + jpeg_write_m_byte(cinfo, exif[i]); + } +} + +Status SetChromaSubsampling(const YCbCrChromaSubsampling& chroma_subsampling, + jpeg_compress_struct* const cinfo) { + for (size_t i = 0; i < 3; i++) { + cinfo->comp_info[i].h_samp_factor = + 1 << (chroma_subsampling.MaxHShift() - + chroma_subsampling.HShift(i < 2 ? i ^ 1 : i)); + cinfo->comp_info[i].v_samp_factor = + 1 << (chroma_subsampling.MaxVShift() - + chroma_subsampling.VShift(i < 2 ? i ^ 1 : i)); + } + return true; +} + +void MyErrorExit(j_common_ptr cinfo) { + jmp_buf* env = static_cast(cinfo->client_data); + (*cinfo->err->output_message)(cinfo); + jpeg_destroy_decompress(reinterpret_cast(cinfo)); + longjmp(*env, 1); +} + +void MyOutputMessage(j_common_ptr cinfo) { +#if JXL_DEBUG_WARNING == 1 + char buf[JMSG_LENGTH_MAX]; + (*cinfo->err->format_message)(cinfo, buf); + JXL_WARNING("%s", buf); +#endif +} + +} // namespace +#endif // JPEGXL_ENABLE_JPEG + +Status DecodeImageJPG(const Span bytes, ThreadPool* pool, + CodecInOut* io, double* const elapsed_deinterleave) { + if (elapsed_deinterleave != nullptr) *elapsed_deinterleave = 0; + // Don't do anything for non-JPEG files (no need to report an error) + if (!IsJPG(bytes)) return false; + const DecodeTarget target = io->dec_target; + + // Use brunsli JPEG decoder to read quantized coefficients. + if (target == DecodeTarget::kQuantizedCoeffs) { + return jxl::jpeg::DecodeImageJPG(bytes, io); + } + +#if JPEGXL_ENABLE_JPEG + // TODO(veluca): use JPEGData also for pixels? + + // We need to declare all the non-trivial destructor local variables before + // the call to setjmp(). + ColorEncoding color_encoding; + PaddedBytes icc; + Image3F image; + std::unique_ptr row; + ImageBundle bundle(&io->metadata.m); + + const auto try_catch_block = [&]() -> bool { + jpeg_decompress_struct cinfo; +#ifdef MEMORY_SANITIZER + // cinfo is initialized by libjpeg, which we are not instrumenting with + // msan, therefore we need to initialize cinfo here. + memset(&cinfo, 0, sizeof(cinfo)); +#endif + // Setup error handling in jpeg library so we can deal with broken jpegs in + // the fuzzer. + jpeg_error_mgr jerr; + jmp_buf env; + cinfo.err = jpeg_std_error(&jerr); + jerr.error_exit = &MyErrorExit; + jerr.output_message = &MyOutputMessage; + if (setjmp(env)) { + return false; + } + cinfo.client_data = static_cast(&env); + + jpeg_create_decompress(&cinfo); + jpeg_mem_src(&cinfo, reinterpret_cast(bytes.data()), + bytes.size()); + jpeg_save_markers(&cinfo, kICCMarker, 0xFFFF); + jpeg_save_markers(&cinfo, kExifMarker, 0xFFFF); + jpeg_read_header(&cinfo, TRUE); + if (!VerifyDimensions(&io->constraints, cinfo.image_width, + cinfo.image_height)) { + jpeg_abort_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + return JXL_FAILURE("image too big"); + } + if (ReadICCProfile(&cinfo, &icc)) { + if (!color_encoding.SetICC(std::move(icc))) { + jpeg_abort_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + return JXL_FAILURE("read an invalid ICC profile"); + } + } else { + color_encoding = ColorEncoding::SRGB(cinfo.output_components == 1); + } + ReadExif(&cinfo, &io->blobs.exif); + io->metadata.m.SetUintSamples(BITS_IN_JSAMPLE); + io->metadata.m.color_encoding = color_encoding; + int nbcomp = cinfo.num_components; + if (nbcomp != 1 && nbcomp != 3) { + jpeg_abort_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + return JXL_FAILURE("unsupported number of components (%d) in JPEG", + cinfo.output_components); + } + (void)io->dec_hints.Foreach( + [](const std::string& key, const std::string& /*value*/) { + JXL_WARNING("JPEG decoder ignoring %s hint", key.c_str()); + return true; + }); + + jpeg_start_decompress(&cinfo); + JXL_ASSERT(cinfo.output_components == nbcomp); + image = Image3F(cinfo.image_width, cinfo.image_height); + row.reset(new JSAMPLE[cinfo.output_components * cinfo.image_width]); + for (size_t y = 0; y < image.ysize(); ++y) { + JSAMPROW rows[] = {row.get()}; + jpeg_read_scanlines(&cinfo, rows, 1); +#ifdef MEMORY_SANITIZER + __msan_unpoison(row.get(), sizeof(JSAMPLE) * cinfo.output_components * + cinfo.image_width); +#endif + auto start = Now(); + float* const JXL_RESTRICT output_row[] = { + image.PlaneRow(0, y), image.PlaneRow(1, y), image.PlaneRow(2, y)}; + if (cinfo.output_components == 1) { + for (size_t x = 0; x < image.xsize(); ++x) { + output_row[0][x] = output_row[1][x] = output_row[2][x] = + row[x] * (1.f / kJPEGSampleMultiplier); + } + } else { // 3 components + for (size_t x = 0; x < image.xsize(); ++x) { + for (size_t c = 0; c < 3; ++c) { + output_row[c][x] = row[3 * x + c] * (1.f / kJPEGSampleMultiplier); + } + } + } + auto end = Now(); + if (elapsed_deinterleave != nullptr) { + *elapsed_deinterleave += end - start; + } + } + io->SetFromImage(std::move(image), color_encoding); + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + io->dec_pixels = io->xsize() * io->ysize(); + return true; + }; + + return try_catch_block(); +#else // JPEGXL_ENABLE_JPEG + return JXL_FAILURE("JPEG decoding not enabled at build time."); +#endif // JPEGXL_ENABLE_JPEG +} + +#if JPEGXL_ENABLE_JPEG +Status EncodeWithLibJpeg(const ImageBundle* ib, const CodecInOut* io, + size_t quality, + const YCbCrChromaSubsampling& chroma_subsampling, + PaddedBytes* bytes) { + jpeg_compress_struct cinfo; +#ifdef MEMORY_SANITIZER + // cinfo is initialized by libjpeg, which we are not instrumenting with + // msan. + __msan_unpoison(&cinfo, sizeof(cinfo)); +#endif + jpeg_error_mgr jerr; + cinfo.err = jpeg_std_error(&jerr); + jpeg_create_compress(&cinfo); + unsigned char* buffer = nullptr; + unsigned long size = 0; + jpeg_mem_dest(&cinfo, &buffer, &size); + cinfo.image_width = ib->xsize(); + cinfo.image_height = ib->ysize(); + if (ib->IsGray()) { + cinfo.input_components = 1; + cinfo.in_color_space = JCS_GRAYSCALE; + } else { + cinfo.input_components = 3; + cinfo.in_color_space = JCS_RGB; + } + jpeg_set_defaults(&cinfo); + cinfo.optimize_coding = TRUE; + if (cinfo.input_components == 3) { + JXL_RETURN_IF_ERROR(SetChromaSubsampling(chroma_subsampling, &cinfo)); + } + jpeg_set_quality(&cinfo, quality, TRUE); + jpeg_start_compress(&cinfo, TRUE); + if (!ib->IsSRGB()) { + WriteICCProfile(&cinfo, ib->c_current().ICC()); + } + WriteExif(&cinfo, io->blobs.exif); + if (cinfo.input_components > 3 || cinfo.input_components < 0) + return JXL_FAILURE("invalid numbers of components"); + + std::unique_ptr row( + new JSAMPLE[cinfo.input_components * cinfo.image_width]); + for (size_t y = 0; y < ib->ysize(); ++y) { + const float* const JXL_RESTRICT input_row[3] = { + ib->color().ConstPlaneRow(0, y), ib->color().ConstPlaneRow(1, y), + ib->color().ConstPlaneRow(2, y)}; + for (size_t x = 0; x < ib->xsize(); ++x) { + for (size_t c = 0; c < static_cast(cinfo.input_components); ++c) { + JXL_RETURN_IF_ERROR(c < 3); + row[cinfo.input_components * x + c] = static_cast( + std::max(std::min(kJPEGSampleMultiplier * input_row[c][x] + .5f, + kJPEGSampleMax), + kJPEGSampleMin)); + } + } + JSAMPROW rows[] = {row.get()}; + jpeg_write_scanlines(&cinfo, rows, 1); + } + jpeg_finish_compress(&cinfo); + jpeg_destroy_compress(&cinfo); + bytes->resize(size); +#ifdef MEMORY_SANITIZER + // Compressed image data is initialized by libjpeg, which we are not + // instrumenting with msan. + __msan_unpoison(buffer, size); +#endif + std::copy_n(buffer, size, bytes->data()); + std::free(buffer); + return true; +} + +Status EncodeWithSJpeg(const ImageBundle* ib, size_t quality, + const YCbCrChromaSubsampling& chroma_subsampling, + PaddedBytes* bytes) { +#if !JPEGXL_ENABLE_SJPEG + return JXL_FAILURE("JPEG XL was built without sjpeg support"); +#else + sjpeg::EncoderParam param(quality); + if (!ib->IsSRGB()) { + param.iccp.assign(ib->metadata()->color_encoding.ICC().begin(), + ib->metadata()->color_encoding.ICC().end()); + } + if (chroma_subsampling.Is444()) { + param.yuv_mode = SJPEG_YUV_444; + } else if (chroma_subsampling.Is420()) { + param.yuv_mode = SJPEG_YUV_SHARP; + } else { + return JXL_FAILURE("sjpeg does not support this chroma subsampling mode"); + } + std::vector rgb; + rgb.reserve(ib->xsize() * ib->ysize() * 3); + for (size_t y = 0; y < ib->ysize(); ++y) { + const float* const rows[] = { + ib->color().ConstPlaneRow(0, y), + ib->color().ConstPlaneRow(1, y), + ib->color().ConstPlaneRow(2, y), + }; + for (size_t x = 0; x < ib->xsize(); ++x) { + for (const float* const row : rows) { + rgb.push_back(static_cast( + std::max(0.f, std::min(255.f, roundf(255.f * row[x]))))); + } + } + } + std::string output; + JXL_RETURN_IF_ERROR(sjpeg::Encode(rgb.data(), ib->xsize(), ib->ysize(), + ib->xsize() * 3, param, &output)); + bytes->assign( + reinterpret_cast(output.data()), + reinterpret_cast(output.data() + output.size())); + return true; +#endif +} +#endif // JPEGXL_ENABLE_JPEG + +Status EncodeImageJPG(const CodecInOut* io, JpegEncoder encoder, size_t quality, + YCbCrChromaSubsampling chroma_subsampling, + ThreadPool* pool, PaddedBytes* bytes, + const DecodeTarget target) { + if (io->Main().HasAlpha()) { + return JXL_FAILURE("alpha is not supported"); + } + if (quality > 100) { + return JXL_FAILURE("please specify a 0-100 JPEG quality"); + } + + if (target == DecodeTarget::kQuantizedCoeffs) { + auto write = [&bytes](const uint8_t* buf, size_t len) { + bytes->append(buf, buf + len); + return len; + }; + return jpeg::WriteJpeg(*io->Main().jpeg_data, write); + } + +#if JPEGXL_ENABLE_JPEG + const ImageBundle* ib; + ImageMetadata metadata = io->metadata.m; + ImageBundle ib_store(&metadata); + JXL_RETURN_IF_ERROR(TransformIfNeeded( + io->Main(), io->metadata.m.color_encoding, pool, &ib_store, &ib)); + + switch (encoder) { + case JpegEncoder::kLibJpeg: + JXL_RETURN_IF_ERROR( + EncodeWithLibJpeg(ib, io, quality, chroma_subsampling, bytes)); + break; + case JpegEncoder::kSJpeg: + JXL_RETURN_IF_ERROR( + EncodeWithSJpeg(ib, quality, chroma_subsampling, bytes)); + break; + default: + return JXL_FAILURE("tried to use an unknown JPEG encoder"); + } + + return true; +#else // JPEGXL_ENABLE_JPEG + return JXL_FAILURE("JPEG pixel encoding not enabled at build time"); +#endif // JPEGXL_ENABLE_JPEG +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_jpg.h b/third_party/jpeg-xl/lib/extras/codec_jpg.h new file mode 100644 index 000000000000..b08f071d2a9f --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_jpg.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_JPG_H_ +#define LIB_EXTRAS_CODEC_JPG_H_ + +// Encodes JPG pixels and metadata in memory. + +#include + +#include "lib/extras/codec.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +enum class JpegEncoder { + kLibJpeg, + kSJpeg, +}; + +static inline bool IsJPG(const Span bytes) { + if (bytes.size() < 2) return false; + if (bytes[0] != 0xFF || bytes[1] != 0xD8) return false; + return true; +} + +// Decodes `bytes` into `io`. io->dec_hints are ignored. +// `elapsed_deinterleave`, if non-null, will be set to the time (in seconds) +// that it took to deinterleave the raw JSAMPLEs to planar floats. +Status DecodeImageJPG(Span bytes, ThreadPool* pool, + CodecInOut* io, double* elapsed_deinterleave = nullptr); + +// Encodes into `bytes`. +Status EncodeImageJPG(const CodecInOut* io, JpegEncoder encoder, size_t quality, + YCbCrChromaSubsampling chroma_subsampling, + ThreadPool* pool, PaddedBytes* bytes, + DecodeTarget target = DecodeTarget::kPixels); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_JPG_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_pgx.cc b/third_party/jpeg-xl/lib/extras/codec_pgx.cc new file mode 100644 index 000000000000..dcd913b532d2 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_pgx.cc @@ -0,0 +1,366 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_pgx.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/fields.h" // AllDefault +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { +namespace { + +struct HeaderPGX { + // NOTE: PGX is always grayscale + size_t xsize; + size_t ysize; + size_t bits_per_sample; + bool big_endian; + bool is_signed; +}; + +class Parser { + public: + explicit Parser(const Span bytes) + : pos_(bytes.data()), end_(pos_ + bytes.size()) {} + + // Sets "pos" to the first non-header byte/pixel on success. + Status ParseHeader(HeaderPGX* header, const uint8_t** pos) { + // codec.cc ensures we have at least two bytes => no range check here. + if (pos_[0] != 'P' || pos_[1] != 'G') return false; + pos_ += 2; + return ParseHeaderPGX(header, pos); + } + + // Exposed for testing + Status ParseUnsigned(size_t* number) { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before number"); + if (!IsDigit(*pos_)) return JXL_FAILURE("PGX: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + private: + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + static bool IsLineBreak(const uint8_t c) { return c == '\r' || c == '\n'; } + static bool IsWhitespace(const uint8_t c) { + return IsLineBreak(c) || c == '\t' || c == ' '; + } + + Status SkipSpace() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before space"); + const uint8_t c = *pos_; + if (c != ' ') return JXL_FAILURE("PGX: expected space"); + ++pos_; + return true; + } + + Status SkipLineBreak() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before line break"); + // Line break can be either "\n" (0a) or "\r\n" (0d 0a). + if (*pos_ == '\n') { + pos_++; + return true; + } else if (*pos_ == '\r' && pos_ + 1 != end_ && *(pos_ + 1) == '\n') { + pos_ += 2; + return true; + } + return JXL_FAILURE("PGX: expected line break"); + } + + Status SkipSingleWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PGX: reached end before whitespace"); + if (!IsWhitespace(*pos_)) return JXL_FAILURE("PGX: expected whitespace"); + ++pos_; + return true; + } + + Status ParseHeaderPGX(HeaderPGX* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipSpace()); + if (pos_ + 2 > end_) return JXL_FAILURE("PGX: header too small"); + if (*pos_ == 'M' && *(pos_ + 1) == 'L') { + header->big_endian = true; + } else if (*pos_ == 'L' && *(pos_ + 1) == 'M') { + header->big_endian = false; + } else { + return JXL_FAILURE("PGX: invalid endianness"); + } + pos_ += 2; + JXL_RETURN_IF_ERROR(SkipSpace()); + if (pos_ == end_) return JXL_FAILURE("PGX: header too small"); + if (*pos_ == '+') { + header->is_signed = false; + } else if (*pos_ == '-') { + header->is_signed = true; + } else { + return JXL_FAILURE("PGX: invalid signedness"); + } + pos_++; + // Skip optional space + if (pos_ < end_ && *pos_ == ' ') pos_++; + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->bits_per_sample)); + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + // 0xa, or 0xd 0xa. + JXL_RETURN_IF_ERROR(SkipLineBreak()); + + if (header->bits_per_sample > 16) { + return JXL_FAILURE("PGX: >16 bits not yet supported"); + } + // TODO(lode): support signed integers. This may require changing the way + // external_image works. + if (header->is_signed) { + return JXL_FAILURE("PGX: signed not yet supported"); + } + + size_t numpixels = header->xsize * header->ysize; + size_t bytes_per_pixel = header->bits_per_sample <= 8 + ? 1 + : header->bits_per_sample <= 16 ? 2 : 4; + if (pos_ + numpixels * bytes_per_pixel > end_) { + return JXL_FAILURE("PGX: data too small"); + } + + *pos = pos_; + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +constexpr size_t kMaxHeaderSize = 200; + +Status EncodeHeader(const ImageBundle& ib, const size_t bits_per_sample, + char* header, int* JXL_RESTRICT chars_written) { + if (ib.HasAlpha()) return JXL_FAILURE("PGX: can't store alpha"); + if (!ib.IsGray()) return JXL_FAILURE("PGX: must be grayscale"); + // TODO(lode): verify other bit depths: for other bit depths such as 1 or 4 + // bits, have a test case to verify it works correctly. For bits > 16, we may + // need to change the way external_image works. + if (bits_per_sample != 8 && bits_per_sample != 16) { + return JXL_FAILURE("PGX: bits other than 8 or 16 not yet supported"); + } + + // Use ML (Big Endian), LM may not be well supported by all decoders. + snprintf(header, kMaxHeaderSize, "PG ML + %zu %zu %zu\n%n", bits_per_sample, + ib.xsize(), ib.ysize(), chars_written); + return true; +} + +Status ApplyHints(CodecInOut* io) { + bool got_color_space = false; + + JXL_RETURN_IF_ERROR(io->dec_hints.Foreach( + [io, &got_color_space](const std::string& key, + const std::string& value) -> Status { + ColorEncoding* c_original = &io->metadata.m.color_encoding; + if (key == "color_space") { + if (!ParseDescription(value, c_original) || + !c_original->CreateICC()) { + return JXL_FAILURE("PGX: Failed to apply color_space"); + } + + if (!io->metadata.m.color_encoding.IsGray()) { + return JXL_FAILURE("PGX: color_space hint must be grayscale"); + } + + got_color_space = true; + } else if (key == "icc_pathname") { + PaddedBytes icc; + JXL_RETURN_IF_ERROR(ReadFile(value, &icc)); + JXL_RETURN_IF_ERROR(c_original->SetICC(std::move(icc))); + got_color_space = true; + } else { + JXL_WARNING("PGX decoder ignoring %s hint", key.c_str()); + } + return true; + })); + + if (!got_color_space) { + JXL_WARNING("PGX: no color_space/icc_pathname given, assuming sRGB"); + JXL_RETURN_IF_ERROR( + io->metadata.m.color_encoding.SetSRGB(ColorSpace::kGray)); + } + + return true; +} + +template +void ExpectNear(T a, T b, T precision) { + JXL_CHECK(std::abs(a - b) <= precision); +} + +Span MakeSpan(const char* str) { + return Span(reinterpret_cast(str), + strlen(str)); +} + +} // namespace + +Status DecodeImagePGX(const Span bytes, ThreadPool* pool, + CodecInOut* io) { + Parser parser(bytes); + HeaderPGX header = {}; + const uint8_t* pos; + if (!parser.ParseHeader(&header, &pos)) return false; + JXL_RETURN_IF_ERROR( + VerifyDimensions(&io->constraints, header.xsize, header.ysize)); + if (header.bits_per_sample == 0 || header.bits_per_sample > 32) { + return JXL_FAILURE("PGX: bits_per_sample invalid"); + } + + JXL_RETURN_IF_ERROR(ApplyHints(io)); + io->metadata.m.SetUintSamples(header.bits_per_sample); + io->metadata.m.SetAlphaBits(0); + io->dec_pixels = header.xsize * header.ysize; + io->SetSize(header.xsize, header.ysize); + io->frames.clear(); + io->frames.reserve(1); + ImageBundle ib(&io->metadata.m); + + const bool has_alpha = false; + const bool flipped_y = false; + const Span span(pos, bytes.data() + bytes.size() - pos); + JXL_RETURN_IF_ERROR(ConvertFromExternal( + span, header.xsize, header.ysize, io->metadata.m.color_encoding, + has_alpha, + /*alpha_is_premultiplied=*/false, + io->metadata.m.bit_depth.bits_per_sample, + header.big_endian ? JXL_BIG_ENDIAN : JXL_LITTLE_ENDIAN, flipped_y, pool, + &ib)); + io->frames.push_back(std::move(ib)); + SetIntensityTarget(io); + return true; +} + +Status EncodeImagePGX(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes) { + if (!Bundle::AllDefault(io->metadata.m)) { + JXL_WARNING("PGX encoder ignoring metadata - use a different codec"); + } + if (!c_desired.IsSRGB()) { + JXL_WARNING( + "PGX encoder cannot store custom ICC profile; decoder\n" + "will need hint key=color_space to get the same values"); + } + + ImageBundle ib = io->Main().Copy(); + + ImageMetadata metadata = io->metadata.m; + ImageBundle store(&metadata); + const ImageBundle* transformed; + JXL_RETURN_IF_ERROR( + TransformIfNeeded(ib, c_desired, pool, &store, &transformed)); + PaddedBytes pixels(ib.xsize() * ib.ysize() * + (bits_per_sample / kBitsPerByte)); + size_t stride = ib.xsize() * (bits_per_sample / kBitsPerByte); + JXL_RETURN_IF_ERROR(ConvertToExternal( + *transformed, bits_per_sample, + /*float_out=*/false, + /*num_channels=*/1, JXL_BIG_ENDIAN, stride, pool, pixels.data(), + pixels.size(), metadata.GetOrientation())); + + char header[kMaxHeaderSize]; + int header_size = 0; + JXL_RETURN_IF_ERROR(EncodeHeader(ib, bits_per_sample, header, &header_size)); + + bytes->resize(static_cast(header_size) + pixels.size()); + memcpy(bytes->data(), header, static_cast(header_size)); + memcpy(bytes->data() + header_size, pixels.data(), pixels.size()); + + return true; +} + +void TestCodecPGX() { + { + std::string pgx = "PG ML + 8 2 3\npixels"; + + CodecInOut io; + ThreadPool* pool = nullptr; + + Status ok = DecodeImagePGX(MakeSpan(pgx.c_str()), pool, &io); + JXL_CHECK(ok == true); + + ScaleImage(255.f, io.Main().color()); + + JXL_CHECK(!io.metadata.m.bit_depth.floating_point_sample); + JXL_CHECK(io.metadata.m.bit_depth.bits_per_sample == 8); + JXL_CHECK(io.metadata.m.color_encoding.IsGray()); + JXL_CHECK(io.xsize() == 2); + JXL_CHECK(io.ysize() == 3); + float eps = 1e-5; + ExpectNear('p', io.Main().color()->Plane(0).Row(0)[0], eps); + ExpectNear('i', io.Main().color()->Plane(0).Row(0)[1], eps); + ExpectNear('x', io.Main().color()->Plane(0).Row(1)[0], eps); + ExpectNear('e', io.Main().color()->Plane(0).Row(1)[1], eps); + ExpectNear('l', io.Main().color()->Plane(0).Row(2)[0], eps); + ExpectNear('s', io.Main().color()->Plane(0).Row(2)[1], eps); + } + + { + std::string pgx = "PG ML + 16 2 3\np_i_x_e_l_s_"; + + CodecInOut io; + ThreadPool* pool = nullptr; + + Status ok = DecodeImagePGX(MakeSpan(pgx.c_str()), pool, &io); + JXL_CHECK(ok == true); + + ScaleImage(255.f, io.Main().color()); + + JXL_CHECK(!io.metadata.m.bit_depth.floating_point_sample); + JXL_CHECK(io.metadata.m.bit_depth.bits_per_sample == 16); + JXL_CHECK(io.metadata.m.color_encoding.IsGray()); + JXL_CHECK(io.xsize() == 2); + JXL_CHECK(io.ysize() == 3); + float eps = 1e-7; + const auto& plane = io.Main().color()->Plane(0); + ExpectNear(256.0f * 'p' + '_', plane.Row(0)[0] * 257, eps); + ExpectNear(256.0f * 'i' + '_', plane.Row(0)[1] * 257, eps); + ExpectNear(256.0f * 'x' + '_', plane.Row(1)[0] * 257, eps); + ExpectNear(256.0f * 'e' + '_', plane.Row(1)[1] * 257, eps); + ExpectNear(256.0f * 'l' + '_', plane.Row(2)[0] * 257, eps); + ExpectNear(256.0f * 's' + '_', plane.Row(2)[1] * 257, eps); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_pgx.h b/third_party/jpeg-xl/lib/extras/codec_pgx.h new file mode 100644 index 000000000000..328214b47c86 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_pgx.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_PGX_H_ +#define LIB_EXTRAS_CODEC_PGX_H_ + +// Encodes/decodes PGX pixels in memory. + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints may specify "color_space", which +// defaults to sRGB. +Status DecodeImagePGX(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +// Transforms from io->c_current to `c_desired` and encodes into `bytes`. +Status EncodeImagePGX(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes); + +void TestCodecPGX(); +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_PGX_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_png.cc b/third_party/jpeg-xl/lib/extras/codec_png.cc new file mode 100644 index 000000000000..ed5dd964ac49 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_png.cc @@ -0,0 +1,858 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_png.h" + +#include +#include +#include +#include + +// Lodepng library: +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { +namespace { + +#define JXL_PNG_VERBOSE 0 + +// Retrieves XMP and EXIF/IPTC from itext and text. +class BlobsReaderPNG { + public: + static Status Decode(const LodePNGInfo& info, Blobs* blobs) { + for (unsigned idx_itext = 0; idx_itext < info.itext_num; ++idx_itext) { + // We trust these are properly null-terminated by LodePNG. + const char* key = info.itext_keys[idx_itext]; + const char* value = info.itext_strings[idx_itext]; + if (strstr(key, "XML:com.adobe.xmp")) { + blobs->xmp.resize(strlen(value)); // safe, see above + memcpy(blobs->xmp.data(), value, blobs->xmp.size()); + } + } + + for (unsigned idx_text = 0; idx_text < info.text_num; ++idx_text) { + // We trust these are properly null-terminated by LodePNG. + const char* key = info.text_keys[idx_text]; + const char* value = info.text_strings[idx_text]; + std::string type; + PaddedBytes bytes; + + // Handle text chunks annotated with key "Raw profile type ####", with + // #### a type, which may contain metadata. + const char* kKey = "Raw profile type "; + if (strncmp(key, kKey, strlen(kKey)) != 0) continue; + + if (!MaybeDecodeBase16(key, value, &type, &bytes)) { + JXL_WARNING("Couldn't parse 'Raw format type' text chunk"); + continue; + } + if (type == "exif") { + if (!blobs->exif.empty()) { + JXL_WARNING("overwriting EXIF (%zu bytes) with base16 (%zu bytes)", + blobs->exif.size(), bytes.size()); + } + blobs->exif = std::move(bytes); + } else if (type == "iptc") { + // TODO (jon): Deal with IPTC in some way + } else if (type == "8bim") { + // TODO (jon): Deal with 8bim in some way + } else if (type == "xmp") { + if (!blobs->xmp.empty()) { + JXL_WARNING("overwriting XMP (%zu bytes) with base16 (%zu bytes)", + blobs->xmp.size(), bytes.size()); + } + blobs->xmp = std::move(bytes); + } else { + JXL_WARNING( + "Unknown type in 'Raw format type' text chunk: %s: %zu bytes", + type.c_str(), bytes.size()); + } + } + + return true; + } + + private: + // Returns false if invalid. + static JXL_INLINE Status DecodeNibble(const char c, + uint32_t* JXL_RESTRICT nibble) { + if ('a' <= c && c <= 'f') { + *nibble = 10 + c - 'a'; + } else if ('0' <= c && c <= '9') { + *nibble = c - '0'; + } else { + *nibble = 0; + return JXL_FAILURE("Invalid metadata nibble"); + } + JXL_ASSERT(*nibble < 16); + return true; + } + + // Parses a PNG text chunk with key of the form "Raw profile type ####", with + // #### a type. + // Returns whether it could successfully parse the content. + // We trust key and encoded are null-terminated because they come from + // LodePNG. + static Status MaybeDecodeBase16(const char* key, const char* encoded, + std::string* type, PaddedBytes* bytes) { + const char* encoded_end = encoded + strlen(encoded); + + const char* kKey = "Raw profile type "; + if (strncmp(key, kKey, strlen(kKey)) != 0) return false; + *type = key + strlen(kKey); + const size_t kMaxTypeLen = 20; + if (type->length() > kMaxTypeLen) return false; // Type too long + + // Header: freeform string and number of bytes + unsigned long bytes_to_decode; + int header_len; + std::vector description((encoded_end - encoded) + 1); + const int fields = sscanf(encoded, "\n%[^\n]\n%8lu%n", description.data(), + &bytes_to_decode, &header_len); + if (fields != 2) return false; // Failed to decode metadata header + JXL_ASSERT(bytes->empty()); + bytes->reserve(bytes_to_decode); + + // Encoding: base16 with newline after 72 chars. + const char* pos = encoded + header_len; + for (size_t i = 0; i < bytes_to_decode; ++i) { + if (i % 36 == 0) { + if (pos + 1 >= encoded_end) return false; // Truncated base16 1 + if (*pos != '\n') return false; // Expected newline + ++pos; + } + + if (pos + 2 >= encoded_end) return false; // Truncated base16 2; + uint32_t nibble0, nibble1; + JXL_RETURN_IF_ERROR(DecodeNibble(pos[0], &nibble0)); + JXL_RETURN_IF_ERROR(DecodeNibble(pos[1], &nibble1)); + bytes->push_back(static_cast((nibble0 << 4) + nibble1)); + pos += 2; + } + if (pos + 1 != encoded_end) return false; // Too many encoded bytes + if (pos[0] != '\n') return false; // Incorrect metadata terminator + return true; + } +}; + +// Stores XMP and EXIF/IPTC into itext and text. +class BlobsWriterPNG { + public: + static Status Encode(const Blobs& blobs, LodePNGInfo* JXL_RESTRICT info) { + if (!blobs.exif.empty()) { + JXL_RETURN_IF_ERROR(EncodeBase16("exif", blobs.exif, info)); + } + if (!blobs.iptc.empty()) { + JXL_RETURN_IF_ERROR(EncodeBase16("iptc", blobs.iptc, info)); + } + + if (!blobs.xmp.empty()) { + JXL_RETURN_IF_ERROR(EncodeBase16("xmp", blobs.xmp, info)); + + // Below is the official way, but it does not seem to work in ImageMagick. + // Exiv2 and exiftool are OK with either way of encoding XMP. + if (/* DISABLES CODE */ (0)) { + const char* key = "XML:com.adobe.xmp"; + const std::string text(reinterpret_cast(blobs.xmp.data()), + blobs.xmp.size()); + if (lodepng_add_itext(info, key, "", "", text.c_str()) != 0) { + return JXL_FAILURE("Failed to add itext"); + } + } + } + + return true; + } + + private: + static JXL_INLINE char EncodeNibble(const uint8_t nibble) { + JXL_ASSERT(nibble < 16); + return (nibble < 10) ? '0' + nibble : 'a' + nibble - 10; + } + + static Status EncodeBase16(const std::string& type, const PaddedBytes& bytes, + LodePNGInfo* JXL_RESTRICT info) { + // Encoding: base16 with newline after 72 chars. + const size_t base16_size = + 2 * bytes.size() + DivCeil(bytes.size(), size_t(36)) + 1; + std::string base16; + base16.reserve(base16_size); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i % 36 == 0) base16.push_back('\n'); + base16.push_back(EncodeNibble(bytes[i] >> 4)); + base16.push_back(EncodeNibble(bytes[i] & 0x0F)); + } + base16.push_back('\n'); + JXL_ASSERT(base16.length() == base16_size); + + char key[30]; + snprintf(key, sizeof(key), "Raw profile type %s", type.c_str()); + + char header[30]; + snprintf(header, sizeof(header), "\n%s\n%8zu", type.c_str(), bytes.size()); + + const std::string& encoded = std::string(header) + base16; + if (lodepng_add_text(info, key, encoded.c_str()) != 0) { + return JXL_FAILURE("Failed to add text"); + } + + return true; + } +}; + +// Retrieves ColorEncoding from PNG chunks. +class ColorEncodingReaderPNG { + public: + // Fills original->color_encoding or returns false. + Status operator()(const Span bytes, const bool is_gray, + CodecInOut* io) { + ColorEncoding* c_original = &io->metadata.m.color_encoding; + JXL_RETURN_IF_ERROR(Decode(bytes, &io->blobs)); + + const ColorSpace color_space = + is_gray ? ColorSpace::kGray : ColorSpace::kRGB; + + if (have_pq_) { + c_original->SetColorSpace(color_space); + c_original->white_point = WhitePoint::kD65; + c_original->primaries = Primaries::k2100; + c_original->tf.SetTransferFunction(TransferFunction::kPQ); + c_original->rendering_intent = RenderingIntent::kRelative; + if (c_original->CreateICC()) return true; + JXL_WARNING("Failed to synthesize BT.2100 PQ"); + // Else: try the actual ICC profile. + } + + // ICC overrides anything else if present. + if (c_original->SetICC(std::move(icc_))) { + if (have_srgb_) { + JXL_WARNING("Invalid PNG with both sRGB and ICC; ignoring sRGB"); + } + if (is_gray != c_original->IsGray()) { + return JXL_FAILURE("Mismatch between ICC and PNG header grayscale"); + } + return true; // it's fine to ignore gAMA/cHRM. + } + + // PNG requires that sRGB override gAMA/cHRM. + if (have_srgb_) { + return c_original->SetSRGB(color_space, rendering_intent_); + } + + // Try to create a custom profile: + + c_original->SetColorSpace(color_space); + + // Attempt to set whitepoint and primaries if there is a cHRM chunk, or else + // use default sRGB (the PNG then is device-dependent). + // In case of grayscale, do not attempt to set the primaries and ignore the + // ones the PNG image has (but still set the white point). + if (!have_chrm_ || !c_original->SetWhitePoint(white_point_) || + (!is_gray && !c_original->SetPrimaries(primaries_))) { +#if JXL_PNG_VERBOSE >= 1 + JXL_WARNING("No (valid) cHRM, assuming sRGB"); +#endif + c_original->white_point = WhitePoint::kD65; + c_original->primaries = Primaries::kSRGB; + } + + if (!have_gama_ || !c_original->tf.SetGamma(gamma_)) { +#if JXL_PNG_VERBOSE >= 1 + JXL_WARNING("No (valid) gAMA nor sRGB, assuming sRGB"); +#endif + c_original->tf.SetTransferFunction(TransferFunction::kSRGB); + } + + c_original->rendering_intent = RenderingIntent::kRelative; + if (c_original->CreateICC()) return true; + + JXL_WARNING( + "DATA LOSS: unable to create an ICC profile for PNG gAMA/cHRM.\n" + "Image pixels will be interpreted as sRGB. Please add an ICC \n" + "profile to the input image"); + return c_original->SetSRGB(color_space); + } + + // Whether the image has any color profile information (ICC chunk, sRGB + // chunk, cHRM chunk, and so on), or has no color information chunks at all. + bool HaveColorProfile() const { + return have_pq_ || have_srgb_ || have_gama_ || have_chrm_ || have_icc_; + } + + private: + Status DecodeICC(const unsigned char* const payload, + const size_t payload_size) { + if (payload_size == 0) return JXL_FAILURE("Empty ICC payload"); + const unsigned char* pos = payload; + const unsigned char* end = payload + payload_size; + + // Profile name + if (*pos == '\0') return JXL_FAILURE("Expected ICC name"); + for (size_t i = 0;; ++i) { + if (i == 80) return JXL_FAILURE("ICC profile name too long"); + if (pos == end) return JXL_FAILURE("Not enough bytes for ICC name"); + if (*pos++ == '\0') break; + } + + // Special case for BT.2100 PQ (https://w3c.github.io/png-hdr-pq/) - try to + // synthesize the profile because table-based curves are less accurate. + // strcmp is safe because we already verified the string is 0-terminated. + if (!strcmp(reinterpret_cast(payload), "ITUR_2100_PQ_FULL")) { + have_pq_ = true; + } + + // Skip over compression method (only one is allowed) + if (pos == end) return JXL_FAILURE("Not enough bytes for ICC method"); + if (*pos++ != 0) return JXL_FAILURE("Unsupported ICC method"); + + // Decompress + unsigned char* icc_buf = nullptr; + size_t icc_size = 0; + LodePNGDecompressSettings settings; + lodepng_decompress_settings_init(&settings); + const unsigned err = lodepng_zlib_decompress( + &icc_buf, &icc_size, pos, payload_size - (pos - payload), &settings); + if (err == 0) { + icc_.resize(icc_size); + memcpy(icc_.data(), icc_buf, icc_size); + } + free(icc_buf); + have_icc_ = true; + return true; + } + + // Returns floating-point value from the PNG encoding (times 10^5). + static double F64FromU32(const uint32_t x) { + return static_cast(x) * 1E-5; + } + + Status DecodeSRGB(const unsigned char* payload, const size_t payload_size) { + if (payload_size != 1) return JXL_FAILURE("Wrong sRGB size"); + // (PNG uses the same values as ICC.) + rendering_intent_ = static_cast(payload[0]); + have_srgb_ = true; + return true; + } + + Status DecodeGAMA(const unsigned char* payload, const size_t payload_size) { + if (payload_size != 4) return JXL_FAILURE("Wrong gAMA size"); + gamma_ = F64FromU32(LoadBE32(payload)); + have_gama_ = true; + return true; + } + + Status DecodeCHRM(const unsigned char* payload, const size_t payload_size) { + if (payload_size != 32) return JXL_FAILURE("Wrong cHRM size"); + white_point_.x = F64FromU32(LoadBE32(payload + 0)); + white_point_.y = F64FromU32(LoadBE32(payload + 4)); + primaries_.r.x = F64FromU32(LoadBE32(payload + 8)); + primaries_.r.y = F64FromU32(LoadBE32(payload + 12)); + primaries_.g.x = F64FromU32(LoadBE32(payload + 16)); + primaries_.g.y = F64FromU32(LoadBE32(payload + 20)); + primaries_.b.x = F64FromU32(LoadBE32(payload + 24)); + primaries_.b.y = F64FromU32(LoadBE32(payload + 28)); + have_chrm_ = true; + return true; + } + + Status DecodeEXIF(const unsigned char* payload, const size_t payload_size, + Blobs* blobs) { + // If we already have EXIF, keep the larger one. + if (blobs->exif.size() > payload_size) return true; + blobs->exif.resize(payload_size); + memcpy(blobs->exif.data(), payload, payload_size); + return true; + } + + Status Decode(const Span bytes, Blobs* blobs) { + // Look for colorimetry and text chunks in the PNG image. The PNG chunks + // begin after the PNG magic header of 8 bytes. + const unsigned char* chunk = bytes.data() + 8; + const unsigned char* end = bytes.data() + bytes.size(); + for (;;) { + // chunk points to the first field of a PNG chunk. The chunk has + // respectively 4 bytes of length, 4 bytes type, length bytes of data, + // 4 bytes CRC. + if (chunk + 4 >= end) { + break; // Regular end reached. + } + + char type_char[5]; + if (chunk + 8 >= end) { + JXL_NOTIFY_ERROR("PNG: malformed chunk"); + break; + } + lodepng_chunk_type(type_char, chunk); + std::string type = type_char; + + if (type == "acTL" || type == "fcTL" || type == "fdAT") { + // this is an APNG file, without proper handling we would just return + // the first frame, so for now codec_apng handles animation until the + // animation chunk handling is added here + return false; + } + if (type == "eXIf" || type == "iCCP" || type == "sRGB" || + type == "gAMA" || type == "cHRM") { + const unsigned char* payload = lodepng_chunk_data_const(chunk); + const size_t payload_size = lodepng_chunk_length(chunk); + // The entire chunk needs also 4 bytes of CRC after the payload. + if (payload + payload_size + 4 >= end) { + JXL_NOTIFY_ERROR("PNG: truncated chunk"); + break; + } + if (lodepng_chunk_check_crc(chunk) != 0) { + JXL_NOTIFY_ERROR("CRC mismatch in unknown PNG chunk"); + chunk = lodepng_chunk_next_const(chunk, end); + continue; + } + + if (type == "eXIf") { + JXL_RETURN_IF_ERROR(DecodeEXIF(payload, payload_size, blobs)); + } else if (type == "iCCP") { + JXL_RETURN_IF_ERROR(DecodeICC(payload, payload_size)); + } else if (type == "sRGB") { + JXL_RETURN_IF_ERROR(DecodeSRGB(payload, payload_size)); + } else if (type == "gAMA") { + JXL_RETURN_IF_ERROR(DecodeGAMA(payload, payload_size)); + } else if (type == "cHRM") { + JXL_RETURN_IF_ERROR(DecodeCHRM(payload, payload_size)); + } + } + + chunk = lodepng_chunk_next_const(chunk, end); + } + return true; + } + + PaddedBytes icc_; + + bool have_pq_ = false; + bool have_srgb_ = false; + bool have_gama_ = false; + bool have_chrm_ = false; + bool have_icc_ = false; + + // Only valid if have_srgb_: + RenderingIntent rendering_intent_; + + // Only valid if have_gama_: + double gamma_; + + // Only valid if have_chrm_: + CIExy white_point_; + PrimariesCIExy primaries_; +}; + +Status ApplyHints(const bool is_gray, CodecInOut* io) { + bool got_color_space = false; + + JXL_RETURN_IF_ERROR(io->dec_hints.Foreach( + [is_gray, io, &got_color_space](const std::string& key, + const std::string& value) -> Status { + ColorEncoding* c_original = &io->metadata.m.color_encoding; + if (key == "color_space") { + if (!ParseDescription(value, c_original) || + !c_original->CreateICC()) { + return JXL_FAILURE("PNG: Failed to apply color_space"); + } + + if (is_gray != io->metadata.m.color_encoding.IsGray()) { + return JXL_FAILURE( + "PNG: mismatch between file and color_space hint"); + } + + got_color_space = true; + } else if (key == "icc_pathname") { + PaddedBytes icc; + JXL_RETURN_IF_ERROR(ReadFile(value, &icc)); + JXL_RETURN_IF_ERROR(c_original->SetICC(std::move(icc))); + got_color_space = true; + } else { + JXL_WARNING("PNG decoder ignoring %s hint", key.c_str()); + } + return true; + })); + + if (!got_color_space) { + JXL_WARNING("PNG: no color_space/icc_pathname given, assuming sRGB"); + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetSRGB( + is_gray ? ColorSpace::kGray : ColorSpace::kRGB)); + } + + return true; +} + +// Stores ColorEncoding into PNG chunks. +class ColorEncodingWriterPNG { + public: + static Status Encode(const ColorEncoding& c, LodePNGInfo* JXL_RESTRICT info) { + // Prefer to only write sRGB - smaller. + if (c.IsSRGB()) { + JXL_RETURN_IF_ERROR(AddSRGB(c, info)); + // PNG recommends not including both sRGB and iCCP, so skip the latter. + } else { + JXL_ASSERT(!c.ICC().empty()); + JXL_RETURN_IF_ERROR(AddICC(c.ICC(), info)); + } + + // gAMA and cHRM are always allowed but will be overridden by sRGB/iCCP. + JXL_RETURN_IF_ERROR(MaybeAddGAMA(c, info)); + JXL_RETURN_IF_ERROR(MaybeAddCHRM(c, info)); + return true; + } + + private: + static Status AddChunk(const char* type, const PaddedBytes& payload, + LodePNGInfo* JXL_RESTRICT info) { + // Ignore original location/order of chunks; place them in the first group. + if (lodepng_chunk_create(&info->unknown_chunks_data[0], + &info->unknown_chunks_size[0], payload.size(), + type, payload.data()) != 0) { + return JXL_FAILURE("Failed to add chunk"); + } + return true; + } + + static Status AddICC(const PaddedBytes& icc, LodePNGInfo* JXL_RESTRICT info) { + LodePNGCompressSettings settings; + lodepng_compress_settings_init(&settings); + unsigned char* out = nullptr; + size_t out_size = 0; + if (lodepng_zlib_compress(&out, &out_size, icc.data(), icc.size(), + &settings) != 0) { + return JXL_FAILURE("Failed to compress ICC"); + } + + PaddedBytes payload; + payload.resize(3 + out_size); + // TODO(janwas): use special name if PQ + payload[0] = '1'; // profile name + payload[1] = '\0'; + payload[2] = 0; // compression method (zlib) + memcpy(&payload[3], out, out_size); + free(out); + + return AddChunk("iCCP", payload, info); + } + + static Status AddSRGB(const ColorEncoding& c, + LodePNGInfo* JXL_RESTRICT info) { + PaddedBytes payload; + payload.push_back(static_cast(c.rendering_intent)); + return AddChunk("sRGB", payload, info); + } + + // Returns PNG encoding of floating-point value (times 10^5). + static uint32_t U32FromF64(const double x) { + return static_cast(roundf(x * 1E5)); + } + + static Status MaybeAddGAMA(const ColorEncoding& c, + LodePNGInfo* JXL_RESTRICT info) { + if (!c.tf.IsGamma()) return true; + const double gamma = c.tf.GetGamma(); + + PaddedBytes payload(4); + StoreBE32(U32FromF64(gamma), payload.data()); + return AddChunk("gAMA", payload, info); + } + + static Status MaybeAddCHRM(const ColorEncoding& c, + LodePNGInfo* JXL_RESTRICT info) { + // TODO(lode): remove this, PNG can also have cHRM for P3, sRGB, ... + if (c.white_point != WhitePoint::kCustom && + c.primaries != Primaries::kCustom) { + return true; + } + + const CIExy white_point = c.GetWhitePoint(); + // A PNG image stores both whitepoint and primaries in the cHRM chunk, but + // for grayscale images we don't have primaries. It does not matter what + // values are stored in the PNG though (all colors are a multiple of the + // whitepoint), so choose default ones. See + // http://www.libpng.org/pub/png/spec/1.2/PNG-Chunks.html section 4.2.2.1. + const PrimariesCIExy primaries = + c.IsGray() ? ColorEncoding().GetPrimaries() : c.GetPrimaries(); + + PaddedBytes payload(32); + StoreBE32(U32FromF64(white_point.x), &payload[0]); + StoreBE32(U32FromF64(white_point.y), &payload[4]); + StoreBE32(U32FromF64(primaries.r.x), &payload[8]); + StoreBE32(U32FromF64(primaries.r.y), &payload[12]); + StoreBE32(U32FromF64(primaries.g.x), &payload[16]); + StoreBE32(U32FromF64(primaries.g.y), &payload[20]); + StoreBE32(U32FromF64(primaries.b.x), &payload[24]); + StoreBE32(U32FromF64(primaries.b.y), &payload[28]); + return AddChunk("cHRM", payload, info); + } +}; + +// RAII - ensures state is freed even if returning early. +struct PNGState { + PNGState() { lodepng_state_init(&s); } + ~PNGState() { lodepng_state_cleanup(&s); } + + LodePNGState s; +}; + +Status CheckGray(const LodePNGColorMode& mode, bool has_icc, bool* is_gray) { + switch (mode.colortype) { + case LCT_GREY: + case LCT_GREY_ALPHA: + *is_gray = true; + return true; + + case LCT_RGB: + case LCT_RGBA: + *is_gray = false; + return true; + + case LCT_PALETTE: { + if (has_icc) { + // If an ICC profile is present, the PNG specification requires + // palette to be intepreted as RGB colored, not grayscale, so we must + // output color in that case and unfortunately can't optimize it to + // gray if the palette only has gray entries. + *is_gray = false; + return true; + } else { + *is_gray = true; + for (size_t i = 0; i < mode.palettesize; i++) { + if (mode.palette[i * 4] != mode.palette[i * 4 + 1] || + mode.palette[i * 4] != mode.palette[i * 4 + 2]) { + *is_gray = false; + break; + } + } + return true; + } + } + + default: + *is_gray = false; + return JXL_FAILURE("Unexpected PNG color type"); + } +} + +Status CheckAlpha(const LodePNGColorMode& mode, bool* has_alpha) { + if (mode.key_defined) { + // Color key marks a single color as transparent. + *has_alpha = true; + return true; + } + + switch (mode.colortype) { + case LCT_GREY: + case LCT_RGB: + *has_alpha = false; + return true; + + case LCT_GREY_ALPHA: + case LCT_RGBA: + *has_alpha = true; + return true; + + case LCT_PALETTE: { + *has_alpha = false; + for (size_t i = 0; i < mode.palettesize; i++) { + // PNG palettes are always 8-bit. + if (mode.palette[i * 4 + 3] != 255) { + *has_alpha = true; + break; + } + } + return true; + } + + default: + *has_alpha = false; + return JXL_FAILURE("Unexpected PNG color type"); + } +} + +LodePNGColorType MakeType(const bool is_gray, const bool has_alpha) { + if (is_gray) { + return has_alpha ? LCT_GREY_ALPHA : LCT_GREY; + } + return has_alpha ? LCT_RGBA : LCT_RGB; +} + +// Inspects first chunk of the given type and updates state with the information +// when the chunk is relevant and present in the file. +Status InspectChunkType(const Span bytes, + const std::string& type, LodePNGState* state) { + const unsigned char* chunk = lodepng_chunk_find_const( + bytes.data(), bytes.data() + bytes.size(), type.c_str()); + if (chunk && lodepng_inspect_chunk(state, chunk - bytes.data(), bytes.data(), + bytes.size()) != 0) { + return JXL_FAILURE("Invalid chunk \"%s\" in PNG image", type.c_str()); + } + return true; +} + +} // namespace + +Status DecodeImagePNG(const Span bytes, ThreadPool* pool, + CodecInOut* io) { + unsigned w, h; + PNGState state; + if (lodepng_inspect(&w, &h, &state.s, bytes.data(), bytes.size()) != 0) { + return false; // not an error - just wrong format + } + JXL_RETURN_IF_ERROR(VerifyDimensions(&io->constraints, w, h)); + io->SetSize(w, h); + // Palette RGB values + if (!InspectChunkType(bytes, "PLTE", &state.s)) { + return false; + } + // Transparent color key, or palette transparency + if (!InspectChunkType(bytes, "tRNS", &state.s)) { + return false; + } + // ICC profile + if (!InspectChunkType(bytes, "iCCP", &state.s)) { + return false; + } + const LodePNGColorMode& color_mode = state.s.info_png.color; + bool has_icc = state.s.info_png.iccp_defined; + + bool is_gray, has_alpha; + JXL_RETURN_IF_ERROR(CheckGray(color_mode, has_icc, &is_gray)); + JXL_RETURN_IF_ERROR(CheckAlpha(color_mode, &has_alpha)); + // We want LodePNG to promote 1/2/4 bit pixels to 8. + size_t bits_per_sample = std::max(color_mode.bitdepth, 8u); + if (bits_per_sample != 8 && bits_per_sample != 16) { + return JXL_FAILURE("Unexpected PNG bit depth"); + } + io->metadata.m.SetUintSamples(static_cast(bits_per_sample)); + io->metadata.m.SetAlphaBits( + has_alpha ? io->metadata.m.bit_depth.bits_per_sample : 0); + + // Always decode to 8/16-bit RGB/RGBA, not LCT_PALETTE. + state.s.info_raw.bitdepth = static_cast(bits_per_sample); + state.s.info_raw.colortype = MakeType(is_gray, has_alpha); + unsigned char* out = nullptr; + const unsigned err = + lodepng_decode(&out, &w, &h, &state.s, bytes.data(), bytes.size()); + // Automatically call free(out) on return. + std::unique_ptr out_ptr{out, free}; + if (err != 0) { + return JXL_FAILURE("PNG decode failed: %s", lodepng_error_text(err)); + } + + if (!BlobsReaderPNG::Decode(state.s.info_png, &io->blobs)) { + JXL_WARNING("PNG metadata may be incomplete"); + } + ColorEncodingReaderPNG reader; + JXL_RETURN_IF_ERROR(reader(bytes, is_gray, io)); +#if JXL_PNG_VERBOSE >= 1 + printf("PNG read %s\n", Description(io->metadata.m.color_encoding).c_str()); +#endif + + const size_t num_channels = (is_gray ? 1 : 3) + has_alpha; + const size_t out_size = w * h * num_channels * bits_per_sample / kBitsPerByte; + + const JxlEndianness endianness = JXL_BIG_ENDIAN; // PNG requirement + const Span span(out, out_size); + const bool ok = + ConvertFromExternal(span, w, h, io->metadata.m.color_encoding, has_alpha, + /*alpha_is_premultiplied=*/false, + io->metadata.m.bit_depth.bits_per_sample, endianness, + /*flipped_y=*/false, pool, &io->Main()); + JXL_RETURN_IF_ERROR(ok); + io->dec_pixels = w * h; + io->metadata.m.bit_depth.bits_per_sample = io->Main().DetectRealBitdepth(); + SetIntensityTarget(io); + if (!reader.HaveColorProfile()) { + JXL_RETURN_IF_ERROR(ApplyHints(is_gray, io)); + } else { + (void)io->dec_hints.Foreach( + [](const std::string& key, const std::string& /*value*/) { + JXL_WARNING("PNG decoder ignoring %s hint", key.c_str()); + return true; + }); + } + return true; +} + +Status EncodeImagePNG(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes) { + if (bits_per_sample > 8) { + bits_per_sample = 16; + } else if (bits_per_sample < 8) { + // PNG can also do 4, 2, and 1 bits per sample, but it isn't implemented + bits_per_sample = 8; + } + ImageBundle ib = io->Main().Copy(); + const size_t alpha_bits = ib.HasAlpha() ? bits_per_sample : 0; + ImageMetadata metadata = io->metadata.m; + ImageBundle store(&metadata); + const ImageBundle* transformed; + JXL_RETURN_IF_ERROR( + TransformIfNeeded(ib, c_desired, pool, &store, &transformed)); + size_t stride = ib.oriented_xsize() * + DivCeil(c_desired.Channels() * bits_per_sample + alpha_bits, + kBitsPerByte); + PaddedBytes raw_bytes(stride * ib.oriented_ysize()); + JXL_RETURN_IF_ERROR(ConvertToExternal( + *transformed, bits_per_sample, /*float_out=*/false, + c_desired.Channels() + (ib.HasAlpha() ? 1 : 0), JXL_BIG_ENDIAN, stride, + pool, raw_bytes.data(), raw_bytes.size(), metadata.GetOrientation())); + + PNGState state; + // For maximum compatibility, still store 8-bit even if pixels are all zero. + state.s.encoder.auto_convert = 0; + + LodePNGInfo* info = &state.s.info_png; + info->color.bitdepth = bits_per_sample; + info->color.colortype = MakeType(ib.IsGray(), ib.HasAlpha()); + state.s.info_raw = info->color; + + JXL_RETURN_IF_ERROR(ColorEncodingWriterPNG::Encode(c_desired, info)); + JXL_RETURN_IF_ERROR(BlobsWriterPNG::Encode(io->blobs, info)); + + unsigned char* out = nullptr; + size_t out_size = 0; + const unsigned err = + lodepng_encode(&out, &out_size, raw_bytes.data(), ib.oriented_xsize(), + ib.oriented_ysize(), &state.s); + // Automatically call free(out) on return. + std::unique_ptr out_ptr{out, free}; + if (err != 0) { + return JXL_FAILURE("Failed to encode PNG: %s", lodepng_error_text(err)); + } + bytes->resize(out_size); + memcpy(bytes->data(), out, out_size); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_png.h b/third_party/jpeg-xl/lib/extras/codec_png.h new file mode 100644 index 000000000000..c86bc74eb2f8 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_png.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_PNG_H_ +#define LIB_EXTRAS_CODEC_PNG_H_ + +// Encodes/decodes PNG pixels and metadata in memory. + +#include +#include + +// TODO(janwas): workaround for incorrect Win64 codegen (cause unknown) +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints are ignored. +Status DecodeImagePNG(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +// Transforms from io->c_current to `c_desired` and encodes into `bytes`. +Status EncodeImagePNG(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_PNG_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_pnm.cc b/third_party/jpeg-xl/lib/extras/codec_pnm.cc new file mode 100644 index 000000000000..7f82630a523f --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_pnm.cc @@ -0,0 +1,461 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_pnm.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/fields.h" // AllDefault +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/luminance.h" + +namespace jxl { +namespace { + +struct HeaderPNM { + size_t xsize; + size_t ysize; + bool is_bit; // PBM + bool is_gray; // PGM + size_t bits_per_sample; + bool floating_point; + bool big_endian; +}; + +class Parser { + public: + explicit Parser(const Span bytes) + : pos_(bytes.data()), end_(pos_ + bytes.size()) {} + + // Sets "pos" to the first non-header byte/pixel on success. + Status ParseHeader(HeaderPNM* header, const uint8_t** pos) { + // codec.cc ensures we have at least two bytes => no range check here. + if (pos_[0] != 'P') return false; + const uint8_t type = pos_[1]; + pos_ += 2; + + header->is_bit = false; + + switch (type) { + case '4': + header->is_bit = true; + header->is_gray = true; + header->bits_per_sample = 1; + return ParseHeaderPNM(header, pos); + + case '5': + header->is_gray = true; + return ParseHeaderPNM(header, pos); + + case '6': + header->is_gray = false; + return ParseHeaderPNM(header, pos); + + case 'F': + header->is_gray = false; + return ParseHeaderPFM(header, pos); + + case 'f': + header->is_gray = true; + return ParseHeaderPFM(header, pos); + } + return false; + } + + // Exposed for testing + Status ParseUnsigned(size_t* number) { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before number"); + if (!IsDigit(*pos_)) return JXL_FAILURE("PNM: expected unsigned number"); + + *number = 0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + return true; + } + + Status ParseSigned(double* number) { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before signed"); + + if (*pos_ != '-' && *pos_ != '+' && !IsDigit(*pos_)) { + return JXL_FAILURE("PNM: expected signed number"); + } + + // Skip sign + const bool is_neg = *pos_ == '-'; + if (is_neg || *pos_ == '+') { + ++pos_; + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before digits"); + } + + // Leading digits + *number = 0.0; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number *= 10; + *number += *pos_ - '0'; + ++pos_; + } + + // Decimal places? + if (pos_ < end_ && *pos_ == '.') { + ++pos_; + double place = 0.1; + while (pos_ < end_ && *pos_ >= '0' && *pos_ <= '9') { + *number += (*pos_ - '0') * place; + place *= 0.1; + ++pos_; + } + } + + if (is_neg) *number = -*number; + return true; + } + + private: + static bool IsDigit(const uint8_t c) { return '0' <= c && c <= '9'; } + static bool IsLineBreak(const uint8_t c) { return c == '\r' || c == '\n'; } + static bool IsWhitespace(const uint8_t c) { + return IsLineBreak(c) || c == '\t' || c == ' '; + } + + Status SkipBlank() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before blank"); + const uint8_t c = *pos_; + if (c != ' ' && c != '\n') return JXL_FAILURE("PNM: expected blank"); + ++pos_; + return true; + } + + Status SkipSingleWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_)) return JXL_FAILURE("PNM: expected whitespace"); + ++pos_; + return true; + } + + Status SkipWhitespace() { + if (pos_ == end_) return JXL_FAILURE("PNM: reached end before whitespace"); + if (!IsWhitespace(*pos_) && *pos_ != '#') { + return JXL_FAILURE("PNM: expected whitespace/comment"); + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + + // Comment(s) + while (pos_ != end_ && *pos_ == '#') { + while (pos_ != end_ && !IsLineBreak(*pos_)) { + ++pos_; + } + // Newline(s) + while (pos_ != end_ && IsLineBreak(*pos_)) pos_++; + } + + while (pos_ < end_ && IsWhitespace(*pos_)) { + ++pos_; + } + return true; + } + + Status ParseHeaderPNM(HeaderPNM* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + + JXL_RETURN_IF_ERROR(SkipWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + + if (!header->is_bit) { + JXL_RETURN_IF_ERROR(SkipWhitespace()); + size_t max_val; + JXL_RETURN_IF_ERROR(ParseUnsigned(&max_val)); + if (max_val == 0 || max_val >= 65536) { + return JXL_FAILURE("PNM: bad MaxVal"); + } + header->bits_per_sample = CeilLog2Nonzero(max_val); + } + header->floating_point = false; + header->big_endian = true; + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + + *pos = pos_; + return true; + } + + Status ParseHeaderPFM(HeaderPNM* header, const uint8_t** pos) { + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->xsize)); + + JXL_RETURN_IF_ERROR(SkipBlank()); + JXL_RETURN_IF_ERROR(ParseUnsigned(&header->ysize)); + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + // The scale has no meaning as multiplier, only its sign is used to + // indicate endianness. All software expects nominal range 0..1. + double scale; + JXL_RETURN_IF_ERROR(ParseSigned(&scale)); + header->big_endian = scale >= 0.0; + header->bits_per_sample = 32; + header->floating_point = true; + + JXL_RETURN_IF_ERROR(SkipSingleWhitespace()); + + *pos = pos_; + return true; + } + + const uint8_t* pos_; + const uint8_t* const end_; +}; + +constexpr size_t kMaxHeaderSize = 200; + +Status EncodeHeader(const ImageBundle& ib, const size_t bits_per_sample, + const bool little_endian, char* header, + int* JXL_RESTRICT chars_written) { + if (ib.HasAlpha()) return JXL_FAILURE("PNM: can't store alpha"); + + if (bits_per_sample == 32) { // PFM + const char type = ib.IsGray() ? 'f' : 'F'; + const double scale = little_endian ? -1.0 : 1.0; + snprintf(header, kMaxHeaderSize, "P%c\n%zu %zu\n%.1f\n%n", type, + ib.oriented_xsize(), ib.oriented_ysize(), scale, chars_written); + } else if (bits_per_sample == 1) { // PBM + if (!ib.IsGray()) { + return JXL_FAILURE("Cannot encode color as PBM"); + } + snprintf(header, kMaxHeaderSize, "P4\n%zu %zu\n%n", ib.oriented_xsize(), + ib.oriented_ysize(), chars_written); + } else { // PGM/PPM + const uint32_t max_val = (1U << bits_per_sample) - 1; + if (max_val >= 65536) return JXL_FAILURE("PNM cannot have > 16 bits"); + const char type = ib.IsGray() ? '5' : '6'; + snprintf(header, kMaxHeaderSize, "P%c\n%zu %zu\n%u\n%n", type, + ib.oriented_xsize(), ib.oriented_ysize(), max_val, chars_written); + } + return true; +} + +Status ApplyHints(const bool is_gray, CodecInOut* io) { + bool got_color_space = false; + + JXL_RETURN_IF_ERROR(io->dec_hints.Foreach( + [is_gray, io, &got_color_space](const std::string& key, + const std::string& value) -> Status { + ColorEncoding* c_original = &io->metadata.m.color_encoding; + if (key == "color_space") { + if (!ParseDescription(value, c_original) || + !c_original->CreateICC()) { + return JXL_FAILURE("PNM: Failed to apply color_space"); + } + + if (is_gray != io->metadata.m.color_encoding.IsGray()) { + return JXL_FAILURE( + "PNM: mismatch between file and color_space hint"); + } + + got_color_space = true; + } else if (key == "icc_pathname") { + PaddedBytes icc; + JXL_RETURN_IF_ERROR(ReadFile(value, &icc)); + JXL_RETURN_IF_ERROR(c_original->SetICC(std::move(icc))); + got_color_space = true; + } else { + JXL_WARNING("PNM decoder ignoring %s hint", key.c_str()); + } + return true; + })); + + if (!got_color_space) { + JXL_WARNING("PNM: no color_space/icc_pathname given, assuming sRGB"); + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetSRGB( + is_gray ? ColorSpace::kGray : ColorSpace::kRGB)); + } + + return true; +} + +Span MakeSpan(const char* str) { + return Span(reinterpret_cast(str), + strlen(str)); +} + +// Flip the image vertically for loading/saving PFM files which have the +// scanlines inverted. +void VerticallyFlipImage(Image3F* const image) { + for (int c = 0; c < 3; c++) { + for (size_t y = 0; y < image->ysize() / 2; y++) { + float* first_row = image->PlaneRow(c, y); + float* other_row = image->PlaneRow(c, image->ysize() - y - 1); + for (size_t x = 0; x < image->xsize(); ++x) { + float tmp = first_row[x]; + first_row[x] = other_row[x]; + other_row[x] = tmp; + } + } + } +} + +} // namespace + +Status DecodeImagePNM(const Span bytes, ThreadPool* pool, + CodecInOut* io) { + Parser parser(bytes); + HeaderPNM header = {}; + const uint8_t* pos = nullptr; + if (!parser.ParseHeader(&header, &pos)) return false; + JXL_RETURN_IF_ERROR( + VerifyDimensions(&io->constraints, header.xsize, header.ysize)); + + if (header.bits_per_sample == 0 || header.bits_per_sample > 32) { + return JXL_FAILURE("PNM: bits_per_sample invalid"); + } + + JXL_RETURN_IF_ERROR(ApplyHints(header.is_gray, io)); + if (header.floating_point) { + io->metadata.m.SetFloat32Samples(); + } else { + io->metadata.m.SetUintSamples(header.bits_per_sample); + } + io->metadata.m.SetAlphaBits(0); + io->dec_pixels = header.xsize * header.ysize; + + const bool flipped_y = header.bits_per_sample == 32; // PFMs are flipped + const Span span(pos, bytes.data() + bytes.size() - pos); + JXL_RETURN_IF_ERROR(ConvertFromExternal( + span, header.xsize, header.ysize, io->metadata.m.color_encoding, + /*has_alpha=*/false, /*alpha_is_premultiplied=*/false, + io->metadata.m.bit_depth.bits_per_sample, + header.big_endian ? JXL_BIG_ENDIAN : JXL_LITTLE_ENDIAN, flipped_y, pool, + &io->Main())); + if (!header.floating_point) { + io->metadata.m.bit_depth.bits_per_sample = io->Main().DetectRealBitdepth(); + } + io->SetSize(header.xsize, header.ysize); + SetIntensityTarget(io); + return true; +} + +Status EncodeImagePNM(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes) { + const bool floating_point = bits_per_sample > 16; + // Choose native for PFM; PGM/PPM require big-endian (N/A for PBM) + const JxlEndianness endianness = + floating_point ? JXL_NATIVE_ENDIAN : JXL_BIG_ENDIAN; + + ImageMetadata metadata_copy = io->metadata.m; + // AllDefault sets all_default, which can cause a race condition. + if (!Bundle::AllDefault(metadata_copy)) { + JXL_WARNING("PNM encoder ignoring metadata - use a different codec"); + } + if (!c_desired.IsSRGB()) { + JXL_WARNING( + "PNM encoder cannot store custom ICC profile; decoder\n" + "will need hint key=color_space to get the same values"); + } + + ImageBundle ib = io->Main().Copy(); + // In case of PFM the image must be flipped upside down since that format + // is designed that way. + const ImageBundle* to_color_transform = &ib; + ImageBundle flipped; + if (floating_point) { + flipped = ib.Copy(); + VerticallyFlipImage(flipped.color()); + to_color_transform = &flipped; + } + ImageMetadata metadata = io->metadata.m; + ImageBundle store(&metadata); + const ImageBundle* transformed; + JXL_RETURN_IF_ERROR(TransformIfNeeded(*to_color_transform, c_desired, pool, + &store, &transformed)); + size_t stride = ib.oriented_xsize() * + (c_desired.Channels() * bits_per_sample) / kBitsPerByte; + PaddedBytes pixels(stride * ib.oriented_ysize()); + JXL_RETURN_IF_ERROR(ConvertToExternal( + *transformed, bits_per_sample, floating_point, c_desired.Channels(), + endianness, stride, pool, pixels.data(), pixels.size(), + metadata.GetOrientation())); + + char header[kMaxHeaderSize]; + int header_size = 0; + bool is_little_endian = endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + JXL_RETURN_IF_ERROR(EncodeHeader(*transformed, bits_per_sample, + is_little_endian, header, &header_size)); + + bytes->resize(static_cast(header_size) + pixels.size()); + memcpy(bytes->data(), header, static_cast(header_size)); + memcpy(bytes->data() + header_size, pixels.data(), pixels.size()); + + return true; +} + +void TestCodecPNM() { + size_t u = 77777; // Initialized to wrong value. + double d = 77.77; +// Failing to parse invalid strings results in a crash if `JXL_CRASH_ON_ERROR` +// is defined and hence the tests fail. Therefore we only run these tests if +// `JXL_CRASH_ON_ERROR` is not defined. +#ifndef JXL_CRASH_ON_ERROR + JXL_CHECK(false == Parser(MakeSpan("")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("+")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("-")).ParseUnsigned(&u)); + JXL_CHECK(false == Parser(MakeSpan("A")).ParseUnsigned(&u)); + + JXL_CHECK(false == Parser(MakeSpan("")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("+")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("-")).ParseSigned(&d)); + JXL_CHECK(false == Parser(MakeSpan("A")).ParseSigned(&d)); +#endif + JXL_CHECK(true == Parser(MakeSpan("1")).ParseUnsigned(&u)); + JXL_CHECK(u == 1); + + JXL_CHECK(true == Parser(MakeSpan("32")).ParseUnsigned(&u)); + JXL_CHECK(u == 32); + + JXL_CHECK(true == Parser(MakeSpan("1")).ParseSigned(&d)); + JXL_CHECK(d == 1.0); + JXL_CHECK(true == Parser(MakeSpan("+2")).ParseSigned(&d)); + JXL_CHECK(d == 2.0); + JXL_CHECK(true == Parser(MakeSpan("-3")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - -3.0) < 1E-15); + JXL_CHECK(true == Parser(MakeSpan("3.141592")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - 3.141592) < 1E-15); + JXL_CHECK(true == Parser(MakeSpan("-3.141592")).ParseSigned(&d)); + JXL_CHECK(std::abs(d - -3.141592) < 1E-15); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_pnm.h b/third_party/jpeg-xl/lib/extras/codec_pnm.h new file mode 100644 index 000000000000..750083d70083 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_pnm.h @@ -0,0 +1,49 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_PNM_H_ +#define LIB_EXTRAS_CODEC_PNM_H_ + +// Encodes/decodes PBM/PGM/PPM/PFM pixels in memory. + +#include +#include + +// TODO(janwas): workaround for incorrect Win64 codegen (cause unknown) +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +// Decodes `bytes` into `io`. io->dec_hints may specify "color_space", which +// defaults to sRGB. +Status DecodeImagePNM(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +// Transforms from io->c_current to `c_desired` and encodes into `bytes`. +Status EncodeImagePNM(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes); + +void TestCodecPNM(); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_PNM_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_psd.cc b/third_party/jpeg-xl/lib/extras/codec_psd.cc new file mode 100644 index 000000000000..65be508d667c --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_psd.cc @@ -0,0 +1,587 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec_psd.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" // AllDefault +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/luminance.h" + +namespace jxl { +namespace { + +uint64_t get_be_int(int bytes, const uint8_t*& pos, const uint8_t* maxpos) { + uint64_t r = 0; + if (pos + bytes <= maxpos) { + if (bytes == 1) { + r = *pos; + } else if (bytes == 2) { + r = LoadBE16(pos); + } else if (bytes == 4) { + r = LoadBE32(pos); + } else if (bytes == 8) { + r = LoadBE64(pos); + } + } + pos += bytes; + return r; +} + +// Copies up to n bytes, without reading from maxpos (the STL-style end). +void safe_copy(const uint8_t* JXL_RESTRICT pos, + const uint8_t* JXL_RESTRICT maxpos, char* JXL_RESTRICT out, + size_t n) { + for (size_t i = 0; i < n; ++i) { + if (pos + i >= maxpos) return; + out[i] = pos[i]; + } +} + +// maxpos is the STL-style end! The valid range is up to [pos, maxpos). +int safe_strncmp(const uint8_t* pos, const uint8_t* maxpos, const char* s2, + size_t n) { + if (pos + n > maxpos) return 1; + return strncmp((const char*)pos, s2, n); +} +constexpr int PSD_VERBOSITY = 1; + +Status decode_layer(const uint8_t*& pos, const uint8_t* maxpos, + ImageBundle& layer, std::vector chans, + std::vector invert, int w, int h, int version, + int colormodel, bool is_layer, int depth) { + int compression_method = 2; + int nb_channels = chans.size(); + JXL_DEBUG_V(PSD_VERBOSITY, + "Trying to decode layer with dimensions %ix%i and %i channels", w, + h, nb_channels); + if (w <= 0 || h <= 0) return JXL_FAILURE("PSD: empty layer"); + for (int c = 0; c < nb_channels; c++) { + // skip nop byte padding + while (pos < maxpos && *pos == 128) pos++; + JXL_DEBUG_V(PSD_VERBOSITY, "Channel %i (pos %zu)", c, (size_t)pos); + // Merged image stores all channels together (same compression method) + // Layers store channel per channel + if (is_layer || c == 0) { + compression_method = get_be_int(2, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, "compression method: %i", compression_method); + if (compression_method > 1 || compression_method < 0) { + return JXL_FAILURE("PSD: can't handle compression method %i", + compression_method); + } + } + + if (!is_layer && c < colormodel) { + // skip to the extra channels + if (compression_method == 0) { + pos += w * h * (depth >> 3) * colormodel; + c = colormodel - 1; + continue; + } + size_t skip_amount = 0; + for (int i = 0; i < nb_channels; i++) { + if (i < colormodel) { + for (int y = 0; y < h; y++) { + skip_amount += get_be_int(2 * version, pos, maxpos); + } + } else { + pos += h * 2 * version; + } + } + pos += skip_amount; + c = colormodel - 1; + continue; + } + if (is_layer || c == 0) { + // skip the line-counts, we don't need them + if (compression_method == 1) { + pos += h * (is_layer ? 1 : nb_channels) * 2 * + version; // PSB uses 4 bytes per rowsize instead of 2 + } + } + int c_id = chans[c]; + if (c_id < 0) continue; // skip + if (static_cast(c_id) >= 3 + layer.extra_channels().size()) + return JXL_FAILURE("PSD: can't handle channel id %i", c_id); + ImageF& ch = (c_id < 3 ? layer.color()->Plane(c_id) + : layer.extra_channels()[c_id - 3]); + + for (int y = 0; y < h; y++) { + if (pos > maxpos) return JXL_FAILURE("PSD: premature end of input"); + float* const JXL_RESTRICT row = ch.Row(y); + if (compression_method == 0) { + // uncompressed is easy + if (depth == 8) { + for (int x = 0; x < w; x++) { + row[x] = get_be_int(1, pos, maxpos) * (1.f / 255.f); + } + } else if (depth == 16) { + for (int x = 0; x < w; x++) { + row[x] = get_be_int(2, pos, maxpos) * (1.f / 65535.f); + } + } else if (depth == 32) { + for (int x = 0; x < w; x++) { + uint32_t f = get_be_int(4, pos, maxpos); + memcpy(&row[x], &f, 4); + } + } + } else { + // RLE is not that hard + if (depth != 8) + return JXL_FAILURE("PSD: did not expect RLE with depth>1"); + for (int x = 0; x < w;) { + if (pos >= maxpos) return JXL_FAILURE("PSD: out of bounds"); + int8_t rle = *pos++; + if (rle <= 0) { + if (rle == -128) continue; // nop + int count = 1 - rle; + float v = get_be_int(1, pos, maxpos) * (1.f / 255.f); + while (count && x < w) { + row[x] = v; + count--; + x++; + } + if (count) return JXL_FAILURE("PSD: row overflow"); + } else { + int count = 1 + rle; + while (count && x < w) { + row[x] = get_be_int(1, pos, maxpos) * (1.f / 255.f); + count--; + x++; + } + if (count) return JXL_FAILURE("PSD: row overflow"); + } + } + } + if (invert[c]) { + // sometimes 0 means full ink + for (int x = 0; x < w; x++) { + row[x] = 1.f - row[x]; + } + } + } + JXL_DEBUG_V(PSD_VERBOSITY, "Channel %i read.", c); + } + + return true; +} + +} // namespace + +Status DecodeImagePSD(const Span bytes, ThreadPool* pool, + CodecInOut* io) { + const uint8_t* pos = bytes.data(); + const uint8_t* maxpos = bytes.data() + bytes.size(); + if (safe_strncmp(pos, maxpos, "8BPS", 4)) return false; // not a PSD file + JXL_DEBUG_V(PSD_VERBOSITY, "trying psd decode"); + pos += 4; + int version = get_be_int(2, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, "Version=%i", version); + if (version < 1 || version > 2) + return JXL_FAILURE("PSD: unknown format version"); + // PSD = version 1, PSB = version 2 + pos += 6; + int nb_channels = get_be_int(2, pos, maxpos); + size_t ysize = get_be_int(4, pos, maxpos); + size_t xsize = get_be_int(4, pos, maxpos); + const SizeConstraints* constraints = &io->constraints; + JXL_RETURN_IF_ERROR(VerifyDimensions(constraints, xsize, ysize)); + uint64_t total_pixel_count = static_cast(xsize) * ysize; + int bitdepth = get_be_int(2, pos, maxpos); + if (bitdepth != 8 && bitdepth != 16 && bitdepth != 32) { + return JXL_FAILURE("PSD: bit depth %i invalid or not supported", bitdepth); + } + if (bitdepth == 32) { + io->metadata.m.SetFloat32Samples(); + } else { + io->metadata.m.SetUintSamples(bitdepth); + } + int colormodel = get_be_int(2, pos, maxpos); + // 1 = Grayscale, 3 = RGB, 4 = CMYK + if (colormodel != 1 && colormodel != 3 && colormodel != 4) + return JXL_FAILURE("PSD: unsupported color model"); + + int real_nb_channels = colormodel; + std::vector> spotcolor; + + if (get_be_int(4, pos, maxpos)) + return JXL_FAILURE("PSD: Unsupported color mode section"); + + bool hasmergeddata = true; + bool have_alpha = false; + bool merged_has_alpha = false; + size_t metalength = get_be_int(4, pos, maxpos); + const uint8_t* metaoffset = pos; + while (pos < metaoffset + metalength) { + char header[5] = "????"; + safe_copy(pos, maxpos, header, 4); + if (memcmp(header, "8BIM", 4) != 0) { + return JXL_FAILURE("PSD: Unexpected image resource header: %s", header); + } + pos += 4; + int id = get_be_int(2, pos, maxpos); + int namelength = get_be_int(1, pos, maxpos); + pos += namelength; + if (!(namelength & 1)) pos++; // padding to even length + size_t blocklength = get_be_int(4, pos, maxpos); + // JXL_DEBUG_V(PSD_VERBOSITY, "block id: %i | block length: %zu",id, + // blocklength); + if (pos > maxpos) return JXL_FAILURE("PSD: Unexpected end of file"); + if (id == 1039) { // ICC profile + size_t delta = maxpos - pos; + if (delta < blocklength) { + return JXL_FAILURE("PSD: Invalid block length"); + } + PaddedBytes icc; + icc.resize(blocklength); + memcpy(icc.data(), pos, blocklength); + if (!io->metadata.m.color_encoding.SetICC(std::move(icc))) { + return JXL_FAILURE("PSD: Invalid color profile"); + } + } else if (id == 1057) { // compatibility mode or not? + if (get_be_int(4, pos, maxpos) != 1) { + return JXL_FAILURE("PSD: expected version=1 in id=1057 resource block"); + } + hasmergeddata = get_be_int(1, pos, maxpos); + pos++; + blocklength -= 6; // already skipped these bytes + } else if (id == 1077) { // spot colors + int version = get_be_int(4, pos, maxpos); + if (version != 1) { + return JXL_FAILURE( + "PSD: expected DisplayInfo version 1, got version %i", version); + } + int spotcolorcount = nb_channels - colormodel; + JXL_DEBUG_V(PSD_VERBOSITY, "Reading %i spot colors. %zu", spotcolorcount, + blocklength); + for (int k = 0; k < spotcolorcount; k++) { + int colorspace = get_be_int(2, pos, maxpos); + if ((colormodel == 3 && colorspace != 0) || + (colormodel == 4 && colorspace != 2)) { + return JXL_FAILURE( + "PSD: cannot handle spot colors in different color spaces than " + "image itself"); + } + if (colorspace == 2) JXL_WARNING("PSD: K ignored in CMYK spot color"); + std::vector color; + color.push_back(get_be_int(2, pos, maxpos) / 65535.f); // R or C + color.push_back(get_be_int(2, pos, maxpos) / 65535.f); // G or M + color.push_back(get_be_int(2, pos, maxpos) / 65535.f); // B or Y + color.push_back(get_be_int(2, pos, maxpos) / 65535.f); // ignored or K + color.push_back(get_be_int(2, pos, maxpos) / + 100.f); // solidity (alpha, basically) + int kind = get_be_int(1, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, "Kind=%i", kind); + color.push_back(kind); + spotcolor.push_back(color); + if (kind == 2) { + JXL_DEBUG_V(PSD_VERBOSITY, "Actual spot color"); + } else if (kind == 1) { + JXL_DEBUG_V(PSD_VERBOSITY, "Mask (alpha) channel"); + } else if (kind == 0) { + JXL_DEBUG_V(PSD_VERBOSITY, "Selection (alpha) channel"); + } else { + return JXL_FAILURE("PSD: Unknown extra channel type"); + } + } + if (blocklength & 1) pos++; + blocklength = 0; + } + pos += blocklength; + if (blocklength & 1) pos++; // padding again + } + + size_t layerlength = get_be_int(4 * version, pos, maxpos); + const uint8_t* after_layers_pos = pos + layerlength; + if (after_layers_pos < pos) return JXL_FAILURE("PSD: invalid layer length"); + if (layerlength) { + pos += 4 * version; // don't care about layerinfolength + JXL_DEBUG_V(PSD_VERBOSITY, "Layer section length: %zu", layerlength); + int layercount = static_cast(get_be_int(2, pos, maxpos)); + JXL_DEBUG_V(PSD_VERBOSITY, "Layer count: %i", layercount); + io->frames.clear(); + + if (layercount == 0) { + if (get_be_int(2, pos, maxpos) != 0) { + return JXL_FAILURE( + "PSD: Expected zero padding before additional layer info"); + } + while (pos < after_layers_pos) { + if (safe_strncmp(pos, maxpos, "8BIM", 4) && + safe_strncmp(pos, maxpos, "8B64", 4)) + return JXL_FAILURE("PSD: Unexpected layer info signature"); + pos += 4; + const uint8_t* tpos = pos; + pos += 4; + size_t blocklength = get_be_int(4 * version, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, "Length=%zu", blocklength); + if (blocklength > 0) { + if (pos >= maxpos) return JXL_FAILURE("PSD: Unexpected end of file"); + size_t delta = maxpos - pos; + if (delta < blocklength) { + return JXL_FAILURE("PSD: Invalid block length"); + } + } + if (!safe_strncmp(tpos, maxpos, "Layr", 4) || + !safe_strncmp(tpos, maxpos, "Lr16", 4) || + !safe_strncmp(tpos, maxpos, "Lr32", 4)) { + layercount = static_cast(get_be_int(2, pos, maxpos)); + if (layercount < 0) { + return JXL_FAILURE("PSD: Invalid layer count"); + } + JXL_DEBUG_V(PSD_VERBOSITY, "Real layer count: %i", layercount); + break; + } + if (!safe_strncmp(tpos, maxpos, "Mtrn", 4) || + !safe_strncmp(tpos, maxpos, "Mt16", 4) || + !safe_strncmp(tpos, maxpos, "Mt32", 4)) { + JXL_DEBUG_V(PSD_VERBOSITY, "Merged layer has transparency channel"); + if (nb_channels > real_nb_channels) { + real_nb_channels++; + have_alpha = true; + merged_has_alpha = true; + } + } + pos += blocklength; + } + } else if (layercount < 0) { + // negative layer count indicates merged has alpha and it is to be shown + if (nb_channels > real_nb_channels) { + real_nb_channels++; + have_alpha = true; + merged_has_alpha = true; + } + layercount = -layercount; + } else { + // multiple layers implies there is alpha + real_nb_channels++; + have_alpha = true; + } + + ExtraChannelInfo info; + info.bit_depth.bits_per_sample = bitdepth; + info.dim_shift = 0; + + if (colormodel == 4) { // cmyk + info.type = ExtraChannel::kBlack; + io->metadata.m.extra_channel_info.push_back(info); + } + if (have_alpha) { + JXL_DEBUG_V(PSD_VERBOSITY, "Have alpha"); + info.type = ExtraChannel::kAlpha; + info.alpha_associated = + false; // true? PSD is not consistent with this, need to check + io->metadata.m.extra_channel_info.push_back(info); + } + if (merged_has_alpha && !spotcolor.empty() && spotcolor[0][5] == 1) { + // first alpha channel + spotcolor.erase(spotcolor.begin()); + } + for (size_t i = 0; i < spotcolor.size(); i++) { + real_nb_channels++; + if (spotcolor[i][5] == 2) { + info.type = ExtraChannel::kSpotColor; + info.spot_color[0] = spotcolor[i][0]; + info.spot_color[1] = spotcolor[i][1]; + info.spot_color[2] = spotcolor[i][2]; + info.spot_color[3] = spotcolor[i][4]; + } else if (spotcolor[i][5] == 1) { + info.type = ExtraChannel::kAlpha; + } else if (spotcolor[i][5] == 0) { + info.type = ExtraChannel::kSelectionMask; + } else + return JXL_FAILURE("PSD: unhandled extra channel"); + io->metadata.m.extra_channel_info.push_back(info); + } + std::vector> layer_chan_id; + std::vector layer_offsets(layercount + 1, 0); + std::vector is_real_layer(layercount, false); + for (int l = 0; l < layercount; l++) { + ImageBundle layer(&io->metadata.m); + layer.duration = 0; + layer.blend = (l > 0); + + layer.use_for_next_frame = (l + 1 < layercount); + layer.origin.y0 = get_be_int(4, pos, maxpos); + layer.origin.x0 = get_be_int(4, pos, maxpos); + size_t height = get_be_int(4, pos, maxpos) - layer.origin.y0; + size_t width = get_be_int(4, pos, maxpos) - layer.origin.x0; + JXL_DEBUG_V(PSD_VERBOSITY, "Layer %i: %zu x %zu at origin (%i, %i)", l, + width, height, layer.origin.x0, layer.origin.y0); + int nb_chs = get_be_int(2, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, " channels: %i", nb_chs); + std::vector chan_ids; + layer_offsets[l + 1] = layer_offsets[l]; + for (int lc = 0; lc < nb_chs; lc++) { + int id = get_be_int(2, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, " id=%i", id); + if (id == 65535) { + chan_ids.push_back(colormodel); // alpha + } else if (id == 65534) { + chan_ids.push_back(-1); // layer mask, ignored + } else { + chan_ids.push_back(id); // color channel + } + layer_offsets[l + 1] += get_be_int(4 * version, pos, maxpos); + } + layer_chan_id.push_back(chan_ids); + if (safe_strncmp(pos, maxpos, "8BIM", 4)) + return JXL_FAILURE("PSD: Layer %i: Unexpected signature (not 8BIM)", l); + pos += 4; + if (safe_strncmp(pos, maxpos, "norm", 4)) { + return JXL_FAILURE( + "PSD: Layer %i: Cannot handle non-default blend mode", l); + } + pos += 4; + int opacity = get_be_int(1, pos, maxpos); + if (opacity < 100) { + JXL_WARNING( + "PSD: ignoring opacity of semi-transparent layer %i (opacity=%i)", + l, opacity); + } + pos++; // clipping + int flags = get_be_int(1, pos, maxpos); + pos++; + bool invisible = (flags & 2); + if (invisible) { + if (l + 1 < layercount) { + layer.blend = false; + layer.use_for_next_frame = false; + } else { + // TODO: instead add dummy last frame? + JXL_WARNING("PSD: invisible top layer was made visible"); + } + } + size_t extradata = get_be_int(4, pos, maxpos); + JXL_DEBUG_V(PSD_VERBOSITY, " extradata: %zu bytes", extradata); + const uint8_t* after_extra = pos + extradata; + // TODO: deal with non-empty layer masks + pos += get_be_int(4, pos, maxpos); // skip layer mask data + pos += get_be_int(4, pos, maxpos); // skip layer blend range data + size_t namelength = get_be_int(1, pos, maxpos); + size_t delta = maxpos - pos; + if (delta < namelength) return JXL_FAILURE("PSD: Invalid block length"); + char lname[256] = {}; + memcpy(lname, pos, namelength); + lname[namelength] = 0; + JXL_DEBUG_V(PSD_VERBOSITY, " name: %s", lname); + pos = after_extra; + if (width == 0 || height == 0) { + JXL_DEBUG_V(PSD_VERBOSITY, + " NOT A REAL LAYER"); // probably layer group + continue; + } + is_real_layer[l] = true; + JXL_RETURN_IF_ERROR(VerifyDimensions(constraints, width, height)); + uint64_t pixel_count = static_cast(width) * height; + if (!SafeAdd(total_pixel_count, pixel_count, total_pixel_count)) { + return JXL_FAILURE("Image too big"); + } + if (total_pixel_count > constraints->dec_max_pixels) { + return JXL_FAILURE("Image too big"); + } + Image3F rgb(width, height); + layer.SetFromImage(std::move(rgb), io->metadata.m.color_encoding); + std::vector ec; + for (const auto& ec_meta : layer.metadata()->extra_channel_info) { + ImageF extra(width, height); + if (ec_meta.type == ExtraChannel::kAlpha) { + FillPlane(1.0f, &extra, Rect(extra)); // opaque + } else { + ZeroFillPlane(&extra, Rect(extra)); // zeroes + } + ec.push_back(std::move(extra)); + } + if (!ec.empty()) layer.SetExtraChannels(std::move(ec)); + layer.name = lname; + io->dec_pixels += layer.xsize() * layer.ysize(); + io->frames.push_back(std::move(layer)); + } + + std::vector invert(real_nb_channels, false); + int il = 0; + const uint8_t* bpos = pos; + for (int l = 0; l < layercount; l++) { + if (!is_real_layer[l]) continue; + pos = bpos + layer_offsets[l]; + if (pos < bpos) return JXL_FAILURE("PSD: invalid layer offset"); + JXL_DEBUG_V(PSD_VERBOSITY, "At position %i (%zu)", + (int)(pos - bytes.data()), (size_t)pos); + ImageBundle& layer = io->frames[il++]; + JXL_RETURN_IF_ERROR(decode_layer(pos, maxpos, layer, layer_chan_id[l], + invert, layer.xsize(), layer.ysize(), + version, colormodel, true, bitdepth)); + } + } else + return JXL_FAILURE("PSD: no layer data found"); + + if (!hasmergeddata && !spotcolor.empty()) { + return JXL_FAILURE("PSD: extra channel data declared but not found"); + } + + if (io->frames.empty()) return JXL_FAILURE("PSD: no layers"); + + if (!spotcolor.empty()) { + // PSD only has spot colors / extra alpha/mask data in the merged image + // We don't redundantly store the merged image, so we put it in the first + // layer (the next layers will kAdd zeroes to it) + pos = after_layers_pos; + ImageBundle& layer = io->frames[0]; + std::vector chan_id(real_nb_channels); + std::iota(chan_id.begin(), chan_id.end(), 0); + std::vector invert(real_nb_channels, false); + if (!merged_has_alpha) { + chan_id.erase(chan_id.begin() + colormodel); + invert.erase(invert.begin() + colormodel); + } else + colormodel++; + for (size_t i = colormodel; i < invert.size(); i++) { + if (spotcolor[i - colormodel][5] == 2) invert[i] = true; + if (spotcolor[i - colormodel][5] == 0) invert[i] = true; + } + JXL_RETURN_IF_ERROR(decode_layer(pos, maxpos, layer, chan_id, invert, + layer.xsize(), layer.ysize(), version, + colormodel, false, bitdepth)); + } + io->SetSize(xsize, ysize); + + SetIntensityTarget(io); + + return true; +} + +Status EncodeImagePSD(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes) { + return JXL_FAILURE("PSD encoding not yet implemented"); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/codec_psd.h b/third_party/jpeg-xl/lib/extras/codec_psd.h new file mode 100644 index 000000000000..d98a8e9241c4 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_psd.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_CODEC_PSD_H_ +#define LIB_EXTRAS_CODEC_PSD_H_ + +// Decodes Photoshop PSD/PSB, preserving the layers + +#include +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +// Decodes `bytes` into `io`. +Status DecodeImagePSD(const Span bytes, ThreadPool* pool, + CodecInOut* io); + +// Not implemented yet +Status EncodeImagePSD(const CodecInOut* io, const ColorEncoding& c_desired, + size_t bits_per_sample, ThreadPool* pool, + PaddedBytes* bytes); + +} // namespace jxl + +#endif // LIB_EXTRAS_CODEC_PSD_H_ diff --git a/third_party/jpeg-xl/lib/extras/codec_test.cc b/third_party/jpeg-xl/lib/extras/codec_test.cc new file mode 100644 index 000000000000..11593c3a818a --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/codec_test.cc @@ -0,0 +1,384 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/codec.h" + +#include +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec_pgx.h" +#include "lib/extras/codec_pnm.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +CodecInOut CreateTestImage(const size_t xsize, const size_t ysize, + const bool is_gray, const bool add_alpha, + const size_t bits_per_sample, + const ColorEncoding& c_native) { + Image3F image(xsize, ysize); + std::mt19937_64 rng(129); + std::uniform_real_distribution dist(0.0f, 1.0f); + if (is_gray) { + for (size_t y = 0; y < ysize; ++y) { + float* JXL_RESTRICT row0 = image.PlaneRow(0, y); + float* JXL_RESTRICT row1 = image.PlaneRow(1, y); + float* JXL_RESTRICT row2 = image.PlaneRow(2, y); + for (size_t x = 0; x < xsize; ++x) { + row0[x] = row1[x] = row2[x] = dist(rng); + } + } + } else { + RandomFillImage(&image, 1.0f); + } + CodecInOut io; + + if (bits_per_sample == 32) { + io.metadata.m.SetFloat32Samples(); + } else { + io.metadata.m.SetUintSamples(bits_per_sample); + } + io.metadata.m.color_encoding = c_native; + io.SetFromImage(std::move(image), c_native); + if (add_alpha) { + ImageF alpha(xsize, ysize); + RandomFillImage(&alpha, 1.f); + io.metadata.m.SetAlphaBits(bits_per_sample <= 8 ? 8 : 16); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + } + return io; +} + +// Ensures reading a newly written file leads to the same image pixels. +void TestRoundTrip(Codec codec, const size_t xsize, const size_t ysize, + const bool is_gray, const bool add_alpha, + const size_t bits_per_sample, ThreadPool* pool) { + // JPEG encoding is not lossless. + if (codec == Codec::kJPG) return; + if (codec == Codec::kPNM && add_alpha) return; + // Our EXR codec always uses 16-bit premultiplied alpha, does not support + // grayscale, and somehow does not have sufficient precision for this test. + if (codec == Codec::kEXR) return; + printf("Codec %s bps:%zu gr:%d al:%d\n", + ExtensionFromCodec(codec, is_gray, bits_per_sample).c_str(), + bits_per_sample, is_gray, add_alpha); + + ColorEncoding c_native; + c_native.SetColorSpace(is_gray ? ColorSpace::kGray : ColorSpace::kRGB); + // Note: this must not be wider than c_external, otherwise gamut clipping + // will cause large round-trip errors. + c_native.primaries = Primaries::kP3; + c_native.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_CHECK(c_native.CreateICC()); + + // Generally store same color space to reduce round trip errors.. + ColorEncoding c_external = c_native; + // .. unless we have enough precision for some transforms. + if (bits_per_sample >= 16) { + c_external.white_point = WhitePoint::kE; + c_external.primaries = Primaries::k2100; + c_external.tf.SetTransferFunction(TransferFunction::kSRGB); + } + JXL_CHECK(c_external.CreateICC()); + + const CodecInOut io = CreateTestImage(xsize, ysize, is_gray, add_alpha, + bits_per_sample, c_native); + const ImageBundle& ib1 = io.Main(); + + PaddedBytes encoded; + JXL_CHECK(Encode(io, codec, c_external, bits_per_sample, &encoded, pool)); + + CodecInOut io2; + io2.target_nits = io.metadata.m.IntensityTarget(); + // Only for PNM because PNG will warn about ignoring them. + if (codec == Codec::kPNM) { + io2.dec_hints.Add("color_space", Description(c_external)); + } + JXL_CHECK(SetFromBytes(Span(encoded), &io2, pool)); + ImageBundle& ib2 = io2.Main(); + + EXPECT_EQ(Description(c_external), + Description(io2.metadata.m.color_encoding)); + + // See c_external above - for low bits_per_sample the encoded space is + // already the same. + if (bits_per_sample < 16) { + EXPECT_EQ(Description(ib1.c_current()), Description(ib2.c_current())); + } + + if (add_alpha) { + EXPECT_TRUE(SamePixels(ib1.alpha(), *ib2.alpha())); + } + + JXL_CHECK(ib2.TransformTo(ib1.c_current(), pool)); + + double max_l1, max_rel; + // Round-trip tolerances must be higher than in external_image_test because + // codecs do not support unbounded ranges. +#if JPEGXL_ENABLE_SKCMS + if (bits_per_sample <= 12) { + max_l1 = 0.5; + max_rel = 6E-3; + } else { + max_l1 = 1E-3; + max_rel = 5E-4; + } +#else // JPEGXL_ENABLE_SKCMS + if (bits_per_sample <= 12) { + max_l1 = 0.5; + max_rel = 6E-3; + } else if (bits_per_sample == 16) { + max_l1 = 3E-3; + max_rel = 1E-4; + } else { +#ifdef __ARM_ARCH + // pow() implementation in arm is a bit less precise than in x86 and + // therefore we need a bigger error margin in this case. + max_l1 = 1E-7; + max_rel = 1E-4; +#else + max_l1 = 1E-7; + max_rel = 1E-5; +#endif + } +#endif // JPEGXL_ENABLE_SKCMS + + VerifyRelativeError(ib1.color(), *ib2.color(), max_l1, max_rel); +} + +#if 0 +TEST(CodecTest, TestRoundTrip) { + ThreadPoolInternal pool(12); + + const size_t xsize = 7; + const size_t ysize = 4; + + for (Codec codec : Values()) { + for (int bits_per_sample : {8, 10, 12, 16, 32}) { + for (bool is_gray : {false, true}) { + for (bool add_alpha : {false, true}) { + TestRoundTrip(codec, xsize, ysize, is_gray, add_alpha, + static_cast(bits_per_sample), &pool); + } + } + } + } +} +#endif + +CodecInOut DecodeRoundtrip(const std::string& pathname, Codec expected_codec, + ThreadPool* pool, + const DecoderHints& dec_hints = DecoderHints()) { + CodecInOut io; + io.dec_hints = dec_hints; + const PaddedBytes orig = ReadTestData(pathname); + JXL_CHECK(SetFromBytes(Span(orig), &io, pool)); + const ImageBundle& ib1 = io.Main(); + + // Encode/Decode again to make sure Encode carries through all metadata. + PaddedBytes encoded; + JXL_CHECK(Encode(io, expected_codec, io.metadata.m.color_encoding, + io.metadata.m.bit_depth.bits_per_sample, &encoded, pool)); + + CodecInOut io2; + io2.dec_hints = dec_hints; + JXL_CHECK(SetFromBytes(Span(encoded), &io2, pool)); + const ImageBundle& ib2 = io2.Main(); + EXPECT_EQ(Description(ib1.metadata()->color_encoding), + Description(ib2.metadata()->color_encoding)); + EXPECT_EQ(Description(ib1.c_current()), Description(ib2.c_current())); + + size_t bits_per_sample = io2.metadata.m.bit_depth.bits_per_sample; + + // "Same" pixels? + double max_l1 = bits_per_sample <= 12 ? 1.3 : 2E-3; + double max_rel = bits_per_sample <= 12 ? 6E-3 : 1E-4; + if (ib1.metadata()->color_encoding.IsGray()) { + max_rel *= 2.0; + } else if (ib1.metadata()->color_encoding.primaries != Primaries::kSRGB) { + // Need more tolerance for large gamuts (anything but sRGB) + max_l1 *= 1.5; + max_rel *= 3.0; + } + VerifyRelativeError(ib1.color(), ib2.color(), max_l1, max_rel); + + // Simulate the encoder removing profile and decoder restoring it. + if (!ib2.metadata()->color_encoding.WantICC()) { + io2.metadata.m.color_encoding.InternalRemoveICC(); + EXPECT_TRUE(io2.metadata.m.color_encoding.CreateICC()); + } + + return io2; +} + +#if 0 +TEST(CodecTest, TestMetadataSRGB) { + ThreadPoolInternal pool(12); + + const char* paths[] = {"raw.pixls/DJI-FC6310-16bit_srgb8_v4_krita.png", + "raw.pixls/Google-Pixel2XL-16bit_srgb8_v4_krita.png", + "raw.pixls/HUAWEI-EVA-L09-16bit_srgb8_dt.png", + "raw.pixls/Nikon-D300-12bit_srgb8_dt.png", + "raw.pixls/Sony-DSC-RX1RM2-14bit_srgb8_v4_krita.png"}; + for (const char* relative_pathname : paths) { + const CodecInOut io = + DecodeRoundtrip(relative_pathname, Codec::kPNG, &pool); + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(Primaries::kSRGB, c_original.primaries); + EXPECT_TRUE(c_original.tf.IsSRGB()); + } +} + +TEST(CodecTest, TestMetadataLinear) { + ThreadPoolInternal pool(12); + + const char* paths[3] = { + "raw.pixls/Google-Pixel2XL-16bit_acescg_g1_v4_krita.png", + "raw.pixls/HUAWEI-EVA-L09-16bit_709_g1_dt.png", + "raw.pixls/Nikon-D300-12bit_2020_g1_dt.png", + }; + const WhitePoint white_points[3] = {WhitePoint::kCustom, WhitePoint::kD65, + WhitePoint::kD65}; + const Primaries primaries[3] = {Primaries::kCustom, Primaries::kSRGB, + Primaries::k2100}; + + for (size_t i = 0; i < 3; ++i) { + const CodecInOut io = DecodeRoundtrip(paths[i], Codec::kPNG, &pool); + EXPECT_EQ(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(white_points[i], c_original.white_point); + EXPECT_EQ(primaries[i], c_original.primaries); + EXPECT_TRUE(c_original.tf.IsLinear()); + } +} + +TEST(CodecTest, TestMetadataICC) { + ThreadPoolInternal pool(12); + + const char* paths[] = { + "raw.pixls/DJI-FC6310-16bit_709_v4_krita.png", + "raw.pixls/Sony-DSC-RX1RM2-14bit_709_v4_krita.png", + }; + for (const char* relative_pathname : paths) { + const CodecInOut io = + DecodeRoundtrip(relative_pathname, Codec::kPNG, &pool); + EXPECT_GE(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_LE(14, io.metadata.m.bit_depth.bits_per_sample); + + EXPECT_EQ(64, io.xsize()); + EXPECT_EQ(64, io.ysize()); + EXPECT_FALSE(io.metadata.m.HasAlpha()); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(RenderingIntent::kPerceptual, c_original.rendering_intent); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(Primaries::kSRGB, c_original.primaries); + EXPECT_EQ(TransferFunction::k709, c_original.tf.GetTransferFunction()); + } +} + +TEST(CodecTest, TestPNGSuite) { + ThreadPoolInternal pool(12); + + // Ensure we can load PNG with text, japanese UTF-8, compressed text. + (void)DecodeRoundtrip("pngsuite/ct1n0g04.png", Codec::kPNG, &pool); + (void)DecodeRoundtrip("pngsuite/ctjn0g04.png", Codec::kPNG, &pool); + (void)DecodeRoundtrip("pngsuite/ctzn0g04.png", Codec::kPNG, &pool); + + // Extract gAMA + const CodecInOut b1 = + DecodeRoundtrip("pngsuite/g10n3p04.png", Codec::kPNG, &pool); + EXPECT_TRUE(b1.metadata.color_encoding.tf.IsLinear()); + + // Extract cHRM + const CodecInOut b_p = + DecodeRoundtrip("pngsuite/ccwn2c08.png", Codec::kPNG, &pool); + EXPECT_EQ(Primaries::kSRGB, b_p.metadata.color_encoding.primaries); + EXPECT_EQ(WhitePoint::kD65, b_p.metadata.color_encoding.white_point); + + // Extract EXIF from (new-style) dedicated chunk + const CodecInOut b_exif = + DecodeRoundtrip("pngsuite/exif2c08.png", Codec::kPNG, &pool); + EXPECT_EQ(978, b_exif.blobs.exif.size()); +} +#endif + +void VerifyWideGamutMetadata(const std::string& relative_pathname, + const Primaries primaries, ThreadPool* pool) { + const CodecInOut io = DecodeRoundtrip(relative_pathname, Codec::kPNG, pool); + + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + const ColorEncoding& c_original = io.metadata.m.color_encoding; + EXPECT_FALSE(c_original.ICC().empty()); + EXPECT_EQ(RenderingIntent::kAbsolute, c_original.rendering_intent); + EXPECT_EQ(ColorSpace::kRGB, c_original.GetColorSpace()); + EXPECT_EQ(WhitePoint::kD65, c_original.white_point); + EXPECT_EQ(primaries, c_original.primaries); +} + +TEST(CodecTest, TestWideGamut) { + ThreadPoolInternal pool(12); + // VerifyWideGamutMetadata("wide-gamut-tests/P3-sRGB-color-bars.png", + // Primaries::kP3, &pool); + VerifyWideGamutMetadata("wide-gamut-tests/P3-sRGB-color-ring.png", + Primaries::kP3, &pool); + // VerifyWideGamutMetadata("wide-gamut-tests/R2020-sRGB-color-bars.png", + // Primaries::k2100, &pool); + // VerifyWideGamutMetadata("wide-gamut-tests/R2020-sRGB-color-ring.png", + // Primaries::k2100, &pool); +} + +TEST(CodecTest, TestPNM) { TestCodecPNM(); } +TEST(CodecTest, TestPGX) { TestCodecPGX(); } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/extras/tone_mapping.cc b/third_party/jpeg-xl/lib/extras/tone_mapping.cc new file mode 100644 index 000000000000..8fcc58597d5a --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/tone_mapping.cc @@ -0,0 +1,169 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/extras/tone_mapping.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/extras/tone_mapping.cc" +#include +#include + +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +Status ToneMapFrame(const std::pair display_nits, + ImageBundle* const ib, ThreadPool* const pool) { + // Perform tone mapping as described in Report ITU-R BT.2390-8, section 5.4 + // (pp. 23-25). + // https://www.itu.int/pub/R-REP-BT.2390-8-2020 + + HWY_FULL(float) df; + using V = decltype(Zero(df)); + + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_RETURN_IF_ERROR(linear_rec2020.CreateICC()); + JXL_RETURN_IF_ERROR(ib->TransformTo(linear_rec2020, pool)); + + const auto eotf_inv = [&df](const V luminance) -> V { + return TF_PQ().EncodedFromDisplay(df, luminance * Set(df, 1. / 10000)); + }; + + const V pq_mastering_min = + eotf_inv(Set(df, ib->metadata()->tone_mapping.min_nits)); + const V pq_mastering_max = + eotf_inv(Set(df, ib->metadata()->tone_mapping.intensity_target)); + const V pq_mastering_range = pq_mastering_max - pq_mastering_min; + const V inv_pq_mastering_range = + Set(df, 1) / (pq_mastering_max - pq_mastering_min); + const V min_lum = (eotf_inv(Set(df, display_nits.first)) - pq_mastering_min) * + inv_pq_mastering_range; + const V max_lum = + (eotf_inv(Set(df, display_nits.second)) - pq_mastering_min) * + inv_pq_mastering_range; + const V ks = MulAdd(Set(df, 1.5f), max_lum, Set(df, -0.5f)); + const V b = min_lum; + + const V inv_one_minus_ks = Set(df, 1) / Max(Set(df, 1e-6f), Set(df, 1) - ks); + const auto T = [ks, inv_one_minus_ks](const V a) { + return (a - ks) * inv_one_minus_ks; + }; + const auto P = [&T, &df, ks, max_lum](const V b) { + const V t_b = T(b); + const V t_b_2 = t_b * t_b; + const V t_b_3 = t_b_2 * t_b; + return MulAdd( + MulAdd(Set(df, 2), t_b_3, MulAdd(Set(df, -3), t_b_2, Set(df, 1))), ks, + MulAdd(t_b_3 + MulAdd(Set(df, -2), t_b_2, t_b), Set(df, 1) - ks, + MulAdd(Set(df, -2), t_b_3, Set(df, 3) * t_b_2) * max_lum)); + }; + + const V inv_max_display_nits = Set(df, 1 / display_nits.second); + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ib->ysize(), ThreadPool::SkipInit(), + [&](const int y, const int thread) { + float* const JXL_RESTRICT row_r = ib->color()->PlaneRow(0, y); + float* const JXL_RESTRICT row_g = ib->color()->PlaneRow(1, y); + float* const JXL_RESTRICT row_b = ib->color()->PlaneRow(2, y); + for (size_t x = 0; x < ib->xsize(); x += Lanes(df)) { + V red = Load(df, row_r + x); + V green = Load(df, row_g + x); + V blue = Load(df, row_b + x); + const V luminance = Set(df, ib->metadata()->IntensityTarget()) * + (MulAdd(Set(df, 0.2627f), red, + MulAdd(Set(df, 0.6780f), green, + Set(df, 0.0593f) * blue))); + const V normalized_pq = + Min(Set(df, 1.f), (eotf_inv(luminance) - pq_mastering_min) * + inv_pq_mastering_range); + const V e2 = + IfThenElse(normalized_pq < ks, normalized_pq, P(normalized_pq)); + const V one_minus_e2 = Set(df, 1) - e2; + const V one_minus_e2_2 = one_minus_e2 * one_minus_e2; + const V one_minus_e2_4 = one_minus_e2_2 * one_minus_e2_2; + const V e3 = MulAdd(b, one_minus_e2_4, e2); + const V e4 = MulAdd(e3, pq_mastering_range, pq_mastering_min); + const V new_luminance = + Min(Set(df, display_nits.second), + ZeroIfNegative(Set(df, 10000) * + TF_PQ().DisplayFromEncoded(df, e4))); + + const V ratio = new_luminance / luminance; + const V multiplier = ratio * + Set(df, ib->metadata()->IntensityTarget()) * + inv_max_display_nits; + + red *= multiplier; + green *= multiplier; + blue *= multiplier; + + const V gray = new_luminance * inv_max_display_nits; + + // Desaturate out-of-gamut pixels. + V gray_mix = Zero(df); + for (const V val : {red, green, blue}) { + const V inv_val_minus_gray = Set(df, 1) / (val - gray); + const V bound1 = val * inv_val_minus_gray; + const V bound2 = bound1 - inv_val_minus_gray; + const V min_bound = Min(bound1, bound2); + const V max_bound = Max(bound1, bound2); + gray_mix = Clamp(gray_mix, min_bound, max_bound); + } + gray_mix = Clamp(gray_mix, Zero(df), Set(df, 1)); + for (V* const val : {&red, &green, &blue}) { + *val = IfThenElse(luminance < Set(df, 1e-6), gray, + MulAdd(gray_mix, gray - *val, *val)); + } + + Store(red, df, row_r + x); + Store(green, df, row_g + x); + Store(blue, df, row_b + x); + } + }, + "ToneMap")); + + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +namespace { +HWY_EXPORT(ToneMapFrame); +} + +Status ToneMapTo(const std::pair display_nits, + CodecInOut* const io, ThreadPool* const pool) { + const auto tone_map_frame = HWY_DYNAMIC_DISPATCH(ToneMapFrame); + for (ImageBundle& ib : io->frames) { + JXL_RETURN_IF_ERROR(tone_map_frame(display_nits, &ib, pool)); + } + io->metadata.m.SetIntensityTarget(display_nits.second); + return true; +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/extras/tone_mapping.h b/third_party/jpeg-xl/lib/extras/tone_mapping.h new file mode 100644 index 000000000000..886223739063 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/tone_mapping.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_EXTRAS_TONE_MAPPING_H_ +#define LIB_EXTRAS_TONE_MAPPING_H_ + +#include "lib/jxl/codec_in_out.h" + +namespace jxl { + +Status ToneMapTo(std::pair display_nits, CodecInOut* io, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_EXTRAS_TONE_MAPPING_H_ diff --git a/third_party/jpeg-xl/lib/extras/tone_mapping_gbench.cc b/third_party/jpeg-xl/lib/extras/tone_mapping_gbench.cc new file mode 100644 index 000000000000..d73b8fb82a25 --- /dev/null +++ b/third_party/jpeg-xl/lib/extras/tone_mapping_gbench.cc @@ -0,0 +1,54 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "lib/extras/codec.h" +#include "lib/extras/tone_mapping.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +static void BM_ToneMapping(benchmark::State& state) { + CodecInOut image; + const PaddedBytes image_bytes = + ReadTestData("imagecompression.info/flower_foveon.png"); + JXL_CHECK(SetFromBytes(Span(image_bytes), &image)); + + // Convert to linear Rec. 2020 so that `ToneMapTo` doesn't have to and we + // mainly measure the tone mapping itself. + ColorEncoding linear_rec2020; + linear_rec2020.SetColorSpace(ColorSpace::kRGB); + linear_rec2020.primaries = Primaries::k2100; + linear_rec2020.white_point = WhitePoint::kD65; + linear_rec2020.tf.SetTransferFunction(TransferFunction::kLinear); + JXL_CHECK(linear_rec2020.CreateICC()); + JXL_CHECK(image.TransformTo(linear_rec2020)); + + for (auto _ : state) { + state.PauseTiming(); + CodecInOut tone_mapping_input; + tone_mapping_input.SetFromImage(CopyImage(*image.Main().color()), + image.Main().c_current()); + tone_mapping_input.metadata.m.SetIntensityTarget( + image.metadata.m.IntensityTarget()); + state.ResumeTiming(); + + JXL_CHECK(ToneMapTo({0.1, 100}, &tone_mapping_input)); + } + + state.SetItemsProcessed(state.iterations() * image.xsize() * image.ysize()); +} +BENCHMARK(BM_ToneMapping); + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/include/jxl/butteraugli.h b/third_party/jpeg-xl/lib/include/jxl/butteraugli.h new file mode 100644 index 000000000000..2681bb5660cf --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/butteraugli.h @@ -0,0 +1,165 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file butteraugli.h + * @brief Butteraugli API for JPEG XL. + */ + +#ifndef JXL_BUTTERAUGLI_H_ +#define JXL_BUTTERAUGLI_H_ + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" + +/** + * Opaque structure that holds a butteraugli API. + * + * Allocated and initialized with JxlButteraugliApiCreate(). + * Cleaned up and deallocated with JxlButteraugliApiDestroy(). + */ +typedef struct JxlButteraugliApiStruct JxlButteraugliApi; + +/** + * Opaque structure that holds intermediary butteraugli results. + * + * Allocated and initialized with JxlButteraugliCompute(). + * Cleaned up and deallocated with JxlButteraugliResultDestroy(). + */ +typedef struct JxlButteraugliResultStruct JxlButteraugliResult; + +/** + * Deinitializes and frees JxlButteraugliResult instance. + * + * @param result instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlButteraugliResultDestroy(JxlButteraugliResult* result); + +/** + * Creates an instance of JxlButteraugliApi and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized JxlEncoder otherwise + */ +JXL_EXPORT JxlButteraugliApi* JxlButteraugliApiCreate( + const JxlMemoryManager* memory_manager); + +/** + * Set the parallel runner for multithreading. + * + * @param api api instance. + * @param parallel_runner function pointer to runner for multithreading. A + * multithreaded runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + */ +JXL_EXPORT void JxlButteraugliApiSetParallelRunner( + JxlButteraugliApi* api, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Set the hf_asymmetry option for butteraugli. + * + * @param api api instance. + * @param v new hf_asymmetry value. + */ +JXL_EXPORT void JxlButteraugliApiSetHFAsymmetry(JxlButteraugliApi* api, + float v); + +/** + * Set the intensity_target option for butteraugli. + * + * @param api api instance. + * @param v new intensity_target value. + */ +JXL_EXPORT void JxlButteraugliApiSetIntensityTarget(JxlButteraugliApi* api, + float v); + +/** + * Deinitializes and frees JxlButteraugliApi instance. + * + * @param api instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlButteraugliApiDestroy(JxlButteraugliApi* api); + +/** + * Computes intermediary butteraugli result between an original image and a + * distortion. + * + * @param api api instance for this computation. + * @param xsize width of the compared images. + * @param ysize height of the compared images. + * @param pixel_format_orig pixel format for original image. + * @param buffer_orig pixel data for original image. + * @param size_orig size of buffer_orig in bytes. + * @param pixel_format_dist pixel format for distortion. + * @param buffer_dist pixel data for distortion. + * @param size_dist size of buffer_dist in bytes. + * @return @c NULL if the results can not be computed or initialized. + * @return pointer to initialized and computed intermediary result. + */ +JXL_EXPORT JxlButteraugliResult* JxlButteraugliCompute( + const JxlButteraugliApi* api, uint32_t xsize, uint32_t ysize, + const JxlPixelFormat* pixel_format_orig, const void* buffer_orig, + size_t size_orig, const JxlPixelFormat* pixel_format_dist, + const void* buffer_dist, size_t size_dist); + +/** + * Computes butteraugli max distance based on an intermediary butteraugli + * result. + * + * @param result intermediary result instance. + * @return max distance. + */ +JXL_EXPORT float JxlButteraugliResultGetMaxDistance( + const JxlButteraugliResult* result); + +/** + * Computes a butteraugli distance based on an intermediary butteraugli result. + * + * @param result intermediary result instance. + * @param pnorm pnorm to calculate. + * @return distance using the given pnorm. + */ +JXL_EXPORT float JxlButteraugliResultGetDistance( + const JxlButteraugliResult* result, float pnorm); + +/** + * Get a pointer to the distmap in the result. + * + * @param result intermediary result instance. + * @param buffer will be set to the distmap. The distance value for (x,y) will + * be available at buffer + y * row_stride + x. + * @param row_stride will be set to the row stride of the distmap. + */ +JXL_EXPORT void JxlButteraugliResultGetDistmap( + const JxlButteraugliResult* result, const float** buffer, + uint32_t* row_stride); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_BUTTERAUGLI_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/butteraugli_cxx.h b/third_party/jpeg-xl/lib/include/jxl/butteraugli_cxx.h new file mode 100644 index 000000000000..8fa236fa0953 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/butteraugli_cxx.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// @file butteraugli_cxx.h +/// @brief C++ header-only helper for @ref butteraugli.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_BUTTERAUGLI_CXX_H_ +#define JXL_BUTTERAUGLI_CXX_H_ + +#include + +#include "jxl/butteraugli.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/butteraugli.h from C sources." +#endif + +/// Struct to call JxlButteraugliApiDestroy from the JxlButteraugliApiPtr +/// unique_ptr. +struct JxlButteraugliApiDestroyStruct { + /// Calls @ref JxlButteraugliApiDestroy() on the passed api. + void operator()(JxlButteraugliApi* api) { JxlButteraugliApiDestroy(api); } +}; + +/// std::unique_ptr<> type that calls JxlButteraugliApiDestroy() when releasing +/// the pointer. +/// +/// Use this helper type from C++ sources to ensure the api is destroyed and +/// their internal resources released. +typedef std::unique_ptr + JxlButteraugliApiPtr; + +/// Struct to call JxlButteraugliResultDestroy from the JxlButteraugliResultPtr +/// unique_ptr. +struct JxlButteraugliResultDestroyStruct { + /// Calls @ref JxlButteraugliResultDestroy() on the passed result object. + void operator()(JxlButteraugliResult* result) { + JxlButteraugliResultDestroy(result); + } +}; + +/// std::unique_ptr<> type that calls JxlButteraugliResultDestroy() when +/// releasing the pointer. +/// +/// Use this helper type from C++ sources to ensure the result object is +/// destroyed and their internal resources released. +typedef std::unique_ptr + JxlButteraugliResultPtr; + +#endif // JXL_BUTTERAUGLI_CXX_H_ diff --git a/third_party/jpeg-xl/lib/include/jxl/codestream_header.h b/third_party/jpeg-xl/lib/include/jxl/codestream_header.h new file mode 100644 index 000000000000..bf084ee54279 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/codestream_header.h @@ -0,0 +1,320 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file codestream_header.h + * @brief Definitions of structs and enums for the metadata from the JPEG XL + * codestream headers (signature, metadata, preview dimensions, ...), excluding + * color encoding which is in color_encoding.h. + */ + +#ifndef JXL_CODESTREAM_HEADER_H_ +#define JXL_CODESTREAM_HEADER_H_ + +#include +#include + +#include "jxl/color_encoding.h" +#include "jxl/types.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Image orientation metadata. + * Values 1..8 match the EXIF definitions. + * The name indicates the operation to perform to transform from the encoded + * image to the display image. + */ +typedef enum { + JXL_ORIENT_IDENTITY = 1, + JXL_ORIENT_FLIP_HORIZONTAL = 2, + JXL_ORIENT_ROTATE_180 = 3, + JXL_ORIENT_FLIP_VERTICAL = 4, + JXL_ORIENT_TRANSPOSE = 5, + JXL_ORIENT_ROTATE_90_CW = 6, + JXL_ORIENT_ANTI_TRANSPOSE = 7, + JXL_ORIENT_ROTATE_90_CCW = 8, +} JxlOrientation; + +/** Given type of an extra channel. + */ +typedef enum { + JXL_CHANNEL_ALPHA, + JXL_CHANNEL_DEPTH, + JXL_CHANNEL_SPOT_COLOR, + JXL_CHANNEL_SELECTION_MASK, + JXL_CHANNEL_BLACK, + JXL_CHANNEL_CFA, + JXL_CHANNEL_THERMAL, + JXL_CHANNEL_RESERVED0, + JXL_CHANNEL_RESERVED1, + JXL_CHANNEL_RESERVED2, + JXL_CHANNEL_RESERVED3, + JXL_CHANNEL_RESERVED4, + JXL_CHANNEL_RESERVED5, + JXL_CHANNEL_RESERVED6, + JXL_CHANNEL_RESERVED7, + JXL_CHANNEL_UNKNOWN, + JXL_CHANNEL_OPTIONAL +} JxlExtraChannelType; + +/** The codestream preview header */ +typedef struct { + /** Preview width in pixels */ + uint32_t xsize; + + /** Preview height in pixels */ + uint32_t ysize; +} JxlPreviewHeader; + +/** The codestream animation header, optionally present in the beginning of + * the codestream, and if it is it applies to all animation frames, unlike + * JxlFrameHeader which applies to an individual frame. + */ +typedef struct { + /** Numerator of ticks per second of a single animation frame time unit */ + uint32_t tps_numerator; + + /** Denominator of ticks per second of a single animation frame time unit */ + uint32_t tps_denominator; + + /** Amount of animation loops, or 0 to repeat infinitely */ + uint32_t num_loops; + + /** Whether animation time codes are present at animation frames in the + * codestream */ + JXL_BOOL have_timecodes; +} JxlAnimationHeader; + +/** Basic image information. This information is available from the file + * signature and first part of the codestream header. + */ +typedef struct JxlBasicInfo { + /* TODO(lode): need additional fields for (transcoded) JPEG? For reusable + * fields orientation must be read from Exif APP1. For has_icc_profile: must + * look up where ICC profile is guaranteed to be in a JPEG file to be able to + * indicate this. */ + + /* TODO(lode): make struct packed, and/or make this opaque struct with getter + * functions (still separate struct from opaque decoder) */ + + /** Whether the codestream is embedded in the container format. If true, + * metadata information and extensions may be available in addition to the + * codestream. + */ + JXL_BOOL have_container; + + /** Width of the image in pixels, before applying orientation. + */ + uint32_t xsize; + + /** Height of the image in pixels, before applying orientation. + */ + uint32_t ysize; + + /** Original image color channel bit depth. + */ + uint32_t bits_per_sample; + + /** Original image color channel floating point exponent bits, or 0 if they + * are unsigned integer. For example, if the original data is half-precision + * (binary16) floating point, bits_per_sample is 16 and + * exponent_bits_per_sample is 5, and so on for other floating point + * precisions. + */ + uint32_t exponent_bits_per_sample; + + /** Upper bound on the intensity level present in the image in nits. For + * unsigned integer pixel encodings, this is the brightness of the largest + * representable value. The image does not necessarily contain a pixel + * actually this bright. An encoder is allowed to set 255 for SDR images + * without computing a histogram. + */ + float intensity_target; + + /** Lower bound on the intensity level present in the image. This may be + * loose, i.e. lower than the actual darkest pixel. When tone mapping, a + * decoder will map [min_nits, intensity_target] to the display range. + */ + float min_nits; + + /** See the description of @see linear_below. + */ + JXL_BOOL relative_to_max_display; + + /** The tone mapping will leave unchanged (linear mapping) any pixels whose + * brightness is strictly below this. The interpretation depends on + * relative_to_max_display. If true, this is a ratio [0, 1] of the maximum + * display brightness [nits], otherwise an absolute brightness [nits]. + */ + float linear_below; + + /** Whether the data in the codestream is encoded in the original color + * profile that is attached to the codestream metadata header, or is + * encoded in an internally supported absolute color space (which the decoder + * can always convert to linear or non-linear sRGB or to XYB). If the original + * profile is used, the decoder outputs pixel data in the color space matching + * that profile, but doesn't convert it to any other color space. If the + * original profile is not used, the decoder only outputs the data as sRGB + * (linear if outputting to floating point, nonlinear with standard sRGB + * transfer function if outputting to unsigned integers) but will not convert + * it to to the original color profile. The decoder also does not convert to + * the target display color profile, but instead will always indicate which + * color profile the returned pixel data is encoded in when using @see + * JXL_COLOR_PROFILE_TARGET_DATA so that a CMS can be used to convert the + * data. + */ + JXL_BOOL uses_original_profile; + + /** Indicates a preview image exists near the beginning of the codestream. + * The preview itself or its dimensions are not included in the basic info. + */ + JXL_BOOL have_preview; + + /** Indicates animation frames exist in the codestream. The animation + * information is not included in the basic info. + */ + JXL_BOOL have_animation; + + /** Image orientation, value 1-8 matching the values used by JEITA CP-3451C + * (Exif version 2.3). + */ + JxlOrientation orientation; + + /** Number of color channels encoded in the image, this is either 1 for + * grayscale data, or 3 for colored data. This count does not include + * the alpha channel or other extra channels. To check presence of an alpha + * channel, such as in the case of RGBA color, check alpha_bits != 0. + * If and only if this is 1, the JxlColorSpace in the JxlColorEncoding is + * JXL_COLOR_SPACE_GRAY. + */ + uint32_t num_color_channels; + + /** Number of additional image channels. This includes the main alpha channel, + * but can also include additional channels such as depth, additional alpha + * channels, spot colors, and so on. Information about the extra channels + * can be queried with JxlDecoderGetExtraChannelInfo. The main alpha channel, + * if it exists, also has its information available in the alpha_bits, + * alpha_exponent_bits and alpha_premultiplied fields in this JxlBasicInfo. + */ + uint32_t num_extra_channels; + + /** Bit depth of the encoded alpha channel, or 0 if there is no alpha channel. + */ + uint32_t alpha_bits; + + /** Alpha channel floating point exponent bits, or 0 if they are unsigned + * integer. + */ + uint32_t alpha_exponent_bits; + + /** Whether the alpha channel is premultiplied + */ + JXL_BOOL alpha_premultiplied; + + /** Dimensions of encoded preview image, only used if have_preview is + * JXL_TRUE. + */ + JxlPreviewHeader preview; + + /** Animation header with global animation properties for all frames, only + * used if have_animation is JXL_TRUE. + */ + JxlAnimationHeader animation; +} JxlBasicInfo; + +/** Information for a single extra channel. + */ +typedef struct { + /** Given type of an extra channel. + */ + JxlExtraChannelType type; + + /** Total bits per sample for this channel. + */ + uint32_t bits_per_sample; + + /** Floating point exponent bits per channel, or 0 if they are unsigned + * integer. + */ + uint32_t exponent_bits_per_sample; + + /** The exponent the channel is downsampled by on each axis. + * TODO(lode): expand this comment to match the JPEG XL specification, + * specify how to upscale, how to round the size computation, and to which + * extra channels this field applies. + */ + uint32_t dim_shift; + + /** Length of the extra channel name in bytes, or 0 if no name. + * Excludes null termination character. + */ + uint32_t name_length; + + /** Whether alpha channel uses premultiplied alpha. Only applicable if + * type is JXL_CHANNEL_ALPHA. + */ + JXL_BOOL alpha_associated; + + /** Spot color of the current spot channel in linear RGBA. Only applicable if + * type is JXL_CHANNEL_SPOT_COLOR. + */ + float spot_color[4]; + + /** Only applicable if type is JXL_CHANNEL_CFA. + * TODO(lode): add comment about the meaning of this field. + */ + uint32_t cfa_channel; +} JxlExtraChannelInfo; + +/* TODO(lode): add API to get the codestream header extensions. */ +/** Extensions in the codestream header. */ +typedef struct { + /** Extension bits. */ + uint64_t extensions; +} JxlHeaderExtensions; + +/** The header of one displayed frame. */ +typedef struct { + /** How long to wait after rendering in ticks. The duration in seconds of a + * tick is given by tps_numerator and tps_denominator in JxlAnimationHeader. + */ + uint32_t duration; + + /** SMPTE timecode of the current frame in form 0xHHMMSSFF, or 0. The bits are + * interpreted from most-significant to least-significant as hour, minute, + * second, and frame. If timecode is nonzero, it is strictly larger than that + * of a previous frame with nonzero duration. These values are only available + * if have_timecodes in JxlAnimationHeader is JXL_TRUE. + * This value is only used if have_timecodes in JxlAnimationHeader is + * JXL_TRUE. + */ + uint32_t timecode; + + /** Length of the frame name in bytes, or 0 if no name. + * Excludes null termination character. + */ + uint32_t name_length; + + /** Indicates this is the last animation frame. + */ + JXL_BOOL is_last; +} JxlFrameHeader; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_CODESTREAM_HEADER_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/color_encoding.h b/third_party/jpeg-xl/lib/include/jxl/color_encoding.h new file mode 100644 index 000000000000..818463082d37 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/color_encoding.h @@ -0,0 +1,173 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file color_encoding.h + * @brief Color Encoding definitions used by JPEG XL. + * All CIE units are for the standard 1931 2 degree observer. + */ + +#ifndef JXL_COLOR_ENCODING_H_ +#define JXL_COLOR_ENCODING_H_ + +#include + +#include "jxl/types.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Color space of the image data. */ +typedef enum { + /** Tristimulus RGB */ + JXL_COLOR_SPACE_RGB, + /** Luminance based, the primaries in JxlColorEncoding must be ignored. This + * value implies that num_color_channels in JxlBasicInfo is 1, any other value + * implies num_color_channels is 3. */ + JXL_COLOR_SPACE_GRAY, + /** XYB (opsin) color space */ + JXL_COLOR_SPACE_XYB, + /** None of the other table entries describe the color space appropriately */ + JXL_COLOR_SPACE_UNKNOWN, +} JxlColorSpace; + +/** Built-in whitepoints for color encoding. Numeric values match CICP (Rec. + * ITU-T H.273 | ISO/IEC 23091-2:2019(E)). */ +typedef enum { + /** CIE Standard Illuminant D65: 0.3127, 0.3290 */ + JXL_WHITE_POINT_D65 = 1, + /** Custom white point stored in JxlColorEncoding white_point. */ + JXL_WHITE_POINT_CUSTOM = 2, + /** CIE Standard Illuminant E (equal-energy): 1/3, 1/3 */ + JXL_WHITE_POINT_E = 10, + /** DCI-P3 from SMPTE RP 431-2: 0.314, 0.351 */ + JXL_WHITE_POINT_DCI = 11, +} JxlWhitePoint; + +/** Built-in primaries for color encoding. Numeric values match CICP (Rec. ITU-T + * H.273 | ISO/IEC 23091-2:2019(E)). */ +typedef enum { + /** The CIE xy values of the red, green and blue primaries are: 0.639998686, + 0.330010138; 0.300003784, 0.600003357; 0.150002046, 0.059997204 */ + JXL_PRIMARIES_SRGB = 1, + /** Custom white point stored in JxlColorEncoding primaries_red_xy, + primaries_green_xy and primaries_blue_xy. */ + JXL_PRIMARIES_CUSTOM = 2, + /** As specified in Rec. ITU-R BT.2100-1 */ + JXL_PRIMARIES_2100 = 9, + /** As specified in SMPTE RP 431-2 */ + JXL_PRIMARIES_P3 = 11, +} JxlPrimaries; + +/** Built-in transfer functions for color encoding. Numeric values match CICP + * (Rec. ITU-T H.273 | ISO/IEC 23091-2:2019(E)) unless specified otherwise. */ +typedef enum { + /** As specified in SMPTE RP 431-2 */ + JXL_TRANSFER_FUNCTION_709 = 1, + /** None of the other table entries describe the transfer function. */ + JXL_TRANSFER_FUNCTION_UNKNOWN = 2, + /** The gamma exponent is 1 */ + JXL_TRANSFER_FUNCTION_LINEAR = 8, + /** As specified in IEC 61966-2-1 sRGB */ + JXL_TRANSFER_FUNCTION_SRGB = 13, + /** As specified in SMPTE ST 428-1 */ + JXL_TRANSFER_FUNCTION_PQ = 16, + /** As specified in SMPTE ST 428-1 */ + JXL_TRANSFER_FUNCTION_DCI = 17, + /** As specified in Rec. ITU-R BT.2100-1 (HLG) */ + JXL_TRANSFER_FUNCTION_HLG = 18, + /** Transfer function follows power law given by the gamma value in + JxlColorEncoding. Not a CICP value. */ + JXL_TRANSFER_FUNCTION_GAMMA = 65535, +} JxlTransferFunction; + +/** Renderig intent for color encoding, as specified in ISO 15076-1:2010 */ +typedef enum { + /** vendor-specific */ + JXL_RENDERING_INTENT_PERCEPTUAL = 0, + /** media-relative */ + JXL_RENDERING_INTENT_RELATIVE, + /** vendor-specific */ + JXL_RENDERING_INTENT_SATURATION, + /** ICC-absolute */ + JXL_RENDERING_INTENT_ABSOLUTE, +} JxlRenderingIntent; + +/** Color encoding of the image as structured information. + */ +typedef struct { + /** Color space of the image data. + */ + JxlColorSpace color_space; + + /** Built-in white point. If this value is JXL_WHITE_POINT_CUSTOM, must + * use the numerical whitepoint values from white_point_xy. + */ + JxlWhitePoint white_point; + + /** Numerical whitepoint values in CIE xy space. */ + double white_point_xy[2]; + + /** Built-in RGB primaries. If this value is JXL_PRIMARIES_CUSTOM, must + * use the numerical primaries values below. This field and the custom values + * below are unused and must be ignored if the color space is + * JXL_COLOR_SPACE_GRAY or JXL_COLOR_SPACE_XYB. + */ + JxlPrimaries primaries; + + /** Numerical red primary values in CIE xy space. */ + double primaries_red_xy[2]; + + /** Numerical green primary values in CIE xy space. */ + double primaries_green_xy[2]; + + /** Numerical blue primary values in CIE xy space. */ + double primaries_blue_xy[2]; + + /** Transfer function is have_gamma is 0 */ + JxlTransferFunction transfer_function; + + /** Gamma value used when transfer_function is JXL_TRANSFER_FUNCTION_GAMMA + */ + double gamma; + + /** Rendering intent defined for the color profile. */ + JxlRenderingIntent rendering_intent; +} JxlColorEncoding; + +/** Color transform used for the XYB encoding. This affects how the internal + * XYB color format is converted, and is not needed unless XYB color is used. + */ +typedef struct { + /** Inverse opsin matrix. + */ + float opsin_inv_matrix[3][3]; + + /** Opsin bias for opsin matrix. This affects how the internal XYB color + * format is converted, and is not needed unless XYB color is used. + */ + float opsin_biases[3]; + + /** Quantization bias for opsin matrix. This affects how the internal XYB + * color format is converted, and is not needed unless XYB color is used. + */ + float quant_biases[3]; +} JxlInverseOpsinMatrix; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_COLOR_ENCODING_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/decode.h b/third_party/jpeg-xl/lib/include/jxl/decode.h new file mode 100644 index 000000000000..73a33d6f49d2 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/decode.h @@ -0,0 +1,811 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file decode.h + * @brief Decoding API for JPEG XL. + */ + +#ifndef JXL_DECODE_H_ +#define JXL_DECODE_H_ + +#include +#include + +#include "jxl/codestream_header.h" +#include "jxl/color_encoding.h" +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Decoder library version. + * + * @return the decoder library version as an integer: + * MAJOR_VERSION * 1000000 + MINOR_VERSION * 1000 + PATCH_VERSION. For example, + * version 1.2.3 would return 1002003. + */ +JXL_EXPORT uint32_t JxlDecoderVersion(void); + +/** The result of JxlSignatureCheck. + */ +typedef enum { + /** Not enough bytes were passed to determine if a valid signature was found. + */ + JXL_SIG_NOT_ENOUGH_BYTES = 0, + + /** No valid JPEGXL header was found. */ + JXL_SIG_INVALID = 1, + + /** A valid JPEG XL codestream signature was found, that is a JPEG XL image + * without container. + */ + JXL_SIG_CODESTREAM = 2, + + /** A valid container signature was found, that is a JPEG XL image embedded + * in a box format container. + */ + JXL_SIG_CONTAINER = 3, +} JxlSignature; + +/** + * JPEG XL signature identification. + * + * Checks if the passed buffer contains a valid JPEG XL signature. The passed @p + * buf of size + * @p size doesn't need to be a full image, only the beginning of the file. + * + * @return a flag indicating if a JPEG XL signature was found and what type. + * - JXL_SIG_NOT_ENOUGH_BYTES not enough bytes were passed to determine + * if a valid signature is there. + * - JXL_SIG_INVALID: no valid signature found for JPEG XL decoding. + * - JXL_SIG_CODESTREAM a valid JPEG XL codestream signature was found. + * - JXL_SIG_CONTAINER a valid JPEG XL container signature was found. + */ +JXL_EXPORT JxlSignature JxlSignatureCheck(const uint8_t* buf, size_t len); + +/** + * Opaque structure that holds the JPEGXL decoder. + * + * Allocated and initialized with JxlDecoderCreate(). + * Cleaned up and deallocated with JxlDecoderDestroy(). + */ +typedef struct JxlDecoderStruct JxlDecoder; + +/** + * Creates an instance of JxlDecoder and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jpegxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized JxlDecoder otherwise + */ +JXL_EXPORT JxlDecoder* JxlDecoderCreate(const JxlMemoryManager* memory_manager); + +/** + * Re-initializes a JxlDecoder instance, so it can be re-used for decoding + * another image. All state and settings are reset as if the object was + * newly created with JxlDecoderCreate, but the memory manager is kept. + * + * @param dec instance to be re-initialized. + */ +JXL_EXPORT void JxlDecoderReset(JxlDecoder* dec); + +/** + * Deinitializes and frees JxlDecoder instance. + * + * @param dec instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlDecoderDestroy(JxlDecoder* dec); + +/** + * Return value for JxlDecoderProcessInput. + * The values above 0x40 are optional informal events that can be subscribed to, + * they are never returned if they have not been registered with + * JxlDecoderSubscribeEvents. + */ +typedef enum { + /** Function call finished successfully, or decoding is finished and there is + * nothing more to be done. + */ + JXL_DEC_SUCCESS = 0, + + /** An error occured, for example invalid input file or out of memory. + * TODO(lode): add function to get error information from decoder. + */ + JXL_DEC_ERROR = 1, + + /** The decoder needs more input bytes to continue. In the next + * JxlDecoderProcessInput call, next_in and avail_in must point to more + * bytes to continue. If *avail_in is not 0, the new bytes must be appended to + * the *avail_in last previous bytes. + */ + JXL_DEC_NEED_MORE_INPUT = 2, + + /** The decoder is able to decode a preview image and requests setting a + * preview output buffer using JxlDecoderSetPreviewOutBuffer. This occurs if + * JXL_DEC_PREVIEW_IMAGE is requested and it is possible to decode a preview + * image from the codestream and the preview out buffer was not yet set. There + * is maximum one preview image in a codestream. + */ + JXL_DEC_NEED_PREVIEW_OUT_BUFFER = 3, + + /** The decoder is able to decode a DC image and requests setting a DC output + * buffer using JxlDecoderSetDCOutBuffer. This occurs if JXL_DEC_DC_IMAGE is + * requested and it is possible to decode a DC image from the codestream and + * the DC out buffer was not yet set. This event re-occurs for new frames + * if there are multiple animation frames. + */ + JXL_DEC_NEED_DC_OUT_BUFFER = 4, + + /** The decoder requests an output buffer to store the full resolution image, + * which can be set with JxlDecoderSetImageOutBuffer or with + * JxlDecoderSetImageOutCallback. This event re-occurs for new frames if there + * are multiple animation frames and requires setting an output again. + */ + JXL_DEC_NEED_IMAGE_OUT_BUFFER = 5, + + /** Informative event by JxlDecoderProcessInput: JPEG reconstruction buffer is + * too small for reconstructed JPEG codestream to fit. + * JxlDecoderSetJPEGBuffer must be called again to make room for remaining + * bytes. This event may occur multiple times after + * JXL_DEC_JPEG_RECONSTRUCTION + */ + JXL_DEC_JPEG_NEED_MORE_OUTPUT = 6, + + /** Informative event by JxlDecoderProcessInput: basic information such as + * image dimensions and extra channels. This event occurs max once per image. + */ + JXL_DEC_BASIC_INFO = 0x40, + + /** Informative event by JxlDecoderProcessInput: user extensions of the + * codestream header. This event occurs max once per image and always later + * than JXL_DEC_BASIC_INFO and earlier than any pixel data. + */ + JXL_DEC_EXTENSIONS = 0x80, + + /** Informative event by JxlDecoderProcessInput: color encoding or ICC + * profile from the codestream header. This event occurs max once per image + * and always later than JXL_DEC_BASIC_INFO and earlier than any pixel + * data. + */ + JXL_DEC_COLOR_ENCODING = 0x100, + + /** Informative event by JxlDecoderProcessInput: Preview image, a small + * frame, decoded. This event can only happen if the image has a preview + * frame encoded. This event occurs max once for the codestream and always + * later than JXL_DEC_COLOR_ENCODING and before JXL_DEC_FRAME. + * This event is different than JXL_DEC_PREVIEW_HEADER because the latter only + * outputs the dimensions of the preview image. + */ + JXL_DEC_PREVIEW_IMAGE = 0x200, + + /** Informative event by JxlDecoderProcessInput: Beginning of a frame. + * JxlDecoderGetFrameHeader can be used at this point. A note on frames: + * a JPEG XL image can have internal frames that are not intended to be + * displayed (e.g. used for compositing a final frame), but this only returns + * displayed frames. A displayed frame either has an animation duration or is + * the only or last frame in the image. This event occurs max once per + * displayed frame, always later than JXL_DEC_COLOR_ENCODING, and always + * earlier than any pixel data. While JPEG XL supports encoding a single frame + * as the composition of multiple internal sub-frames also called frames, this + * event is not indicated for the internal frames. + */ + JXL_DEC_FRAME = 0x400, + + /** Informative event by JxlDecoderProcessInput: DC image, 8x8 sub-sampled + * frame, decoded. It is not guaranteed that the decoder will always return DC + * separately, but when it does it will do so before outputting the full + * frame. JxlDecoderSetDCOutBuffer must be used after getting the basic + * image information to be able to get the DC pixels, if not this return + * status only indicates we're past this point in the codestream. This event + * occurs max once per frame and always later than JXL_DEC_FRAME_HEADER + * and other header events and earlier than full resolution pixel data. + */ + JXL_DEC_DC_IMAGE = 0x800, + + /** Informative event by JxlDecoderProcessInput: full frame decoded. + * JxlDecoderSetImageOutBuffer must be used after getting the basic image + * information to be able to get the image pixels, if not this return status + * only indicates we're past this point in the codestream. This event occurs + * max once per frame and always later than JXL_DEC_DC_IMAGE. + */ + JXL_DEC_FULL_IMAGE = 0x1000, + + /** Informative event by JxlDecoderProcessInput: JPEG reconstruction data + * decoded. JxlDecoderSetJPEGBuffer may be used to set a JPEG + * reconstruction buffer after getting the JPEG reconstruction data. If a JPEG + * reconstruction buffer is set a byte stream identical to the JPEG codestream + * used to encode the image will be written to the JPEG reconstruction buffer + * instead of pixels to the image out buffer. This event occurs max once per + * image and always before JXL_DEC_FULL_IMAGE. + */ + JXL_DEC_JPEG_RECONSTRUCTION = 0x2000, +} JxlDecoderStatus; + +/** + * Get the default pixel format for this decoder. + * + * Requires that the decoder can produce JxlBasicInfo. + * + * @param dec JxlDecoder to query when creating the recommended pixel format. + * @param format JxlPixelFormat to populate with the recommended settings for + * the data loaded into this decoder. + * @return JXL_DEC_SUCCESS if no error, JXL_DEC_NEED_MORE_INPUT if the + * basic info isn't yet available, and JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderDefaultPixelFormat(const JxlDecoder* dec, JxlPixelFormat* format); + +/** + * Set the parallel runner for multithreading. May only be set before starting + * decoding. + * + * @param dec decoder object + * @param parallel_runner function pointer to runner for multithreading. It may + * be NULL to use the default, single-threaded, runner. A multithreaded + * runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + * @return JXL_DEC_SUCCESS if the runner was set, JXL_DEC_ERROR + * otherwise (the previous runner remains set). + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetParallelRunner(JxlDecoder* dec, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Returns a hint indicating how many more bytes the decoder is expected to + * need to make JxlDecoderGetBasicInfo available after the next + * JxlDecoderProcessInput call. This is a suggested large enough value for + * the *avail_in parameter, but it is not guaranteed to be an upper bound nor + * a lower bound. + * Can be used before the first JxlDecoderProcessInput call, and is correct + * the first time in most cases. If not, JxlDecoderSizeHintBasicInfo can be + * called again to get an updated hint. + * + * @param dec decoder object + * @return the size hint in bytes if the basic info is not yet fully decoded. + * @return 0 when the basic info is already available. + */ +JXL_EXPORT size_t JxlDecoderSizeHintBasicInfo(const JxlDecoder* dec); + +/** Select for which informative events (JXL_DEC_BASIC_INFO, etc...) the + * decoder should return with a status. It is not required to subscribe to any + * events, data can still be requested from the decoder as soon as it available. + * By default, the decoder is subscribed to no events (events_wanted == 0), and + * the decoder will then only return when it cannot continue because it needs + * more input data or more output buffer. This function may only be be called + * before using JxlDecoderProcessInput + * + * @param dec decoder object + * @param events_wanted bitfield of desired events. + * @return JXL_DEC_SUCCESS if no error, JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSubscribeEvents(JxlDecoder* dec, + int events_wanted); + +/** Enables or disables preserving of original orientation. Some images are + * encoded with an orientation tag indicating the image is rotated and/or + * mirrored (here called the original orientation). + * + * *) If keep_orientation is JXL_FALSE (the default): the decoder will perform + * work to undo the transformation. This ensures the decoded pixels will not + * be rotated or mirrored. The decoder will always set the orientation field + * of the JxlBasicInfo to JXL_ORIENT_IDENTITY to match the returned pixel data. + * The decoder may also swap xsize and ysize in the JxlBasicInfo compared to the + * values inside of the codestream, to correctly match the decoded pixel data, + * e.g. when a 90 degree rotation was performed. + * + * *) If this option is JXL_TRUE: then the image is returned as-is, which may be + * rotated or mirrored, and the user must check the orientation field in + * JxlBasicInfo after decoding to correctly interpret the decoded pixel data. + * This may be faster to decode since the decoder doesn't have to apply the + * transformation, but can cause wrong display of the image if the orientation + * tag is not correctly taken into account by the user. + * + * By default, this option is disabled, and the decoder automatically corrects + * the orientation. + * + * @see JxlBasicInfo for the orientation field, and @see JxlOrientation for the + * possible values. + * + * @param dec decoder object + * @param keep_orientation JXL_TRUE to enable, JXL_FALSE to disable. + * @return JXL_DEC_SUCCESS if no error, JXL_DEC_ERROR otherwise. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetKeepOrientation(JxlDecoder* dec, JXL_BOOL keep_orientation); + +/** + * Decodes JPEG XL file using the available bytes. Requires input has been + * set with JxlDecoderSetInput. After JxlDecoderProcessInput, input can + * optionally be released with JxlDecoderReleaseInput and then set again to + * next bytes in the stream. JxlDecoderReleaseInput returns how many bytes are + * not yet processed, before a next call to JxlDecoderProcessInput all + * unprocessed bytes must be provided again (the address need not match, but the + * contents must), and more bytes may be concatenated after the unprocessed + * bytes. + * + * The returned status indicates whether the decoder needs more input bytes, or + * more output buffer for a certain type of output data. No matter what the + * returned status is (other than JXL_DEC_ERROR), new information, such as + * JxlDecoderGetBasicInfo, may have become available after this call. When + * the return value is not JXL_DEC_ERROR or JXL_DEC_SUCCESS, the decoding + * requires more JxlDecoderProcessInput calls to continue. + * + * @param dec decoder object + * @return JXL_DEC_SUCCESS when decoding finished and all events handled. + * @return JXL_DEC_ERROR when decoding failed, e.g. invalid codestream. + * TODO(lode) document the input data mechanism + * @return JXL_DEC_NEED_MORE_INPUT more input data is necessary. + * @return JXL_DEC_BASIC_INFO when basic info such as image dimensions is + * available and this informative event is subscribed to. + * @return JXL_DEC_EXTENSIONS when JPEG XL codestream user extensions are + * available and this informative event is subscribed to. + * @return JXL_DEC_COLOR_ENCODING when color profile information is + * available and this informative event is subscribed to. + * @return JXL_DEC_PREVIEW_IMAGE when preview pixel information is available and + * output in the preview buffer. + * @return JXL_DEC_DC_IMAGE when DC pixel information (8x8 downscaled version + * of the image) is available and output in the DC buffer. + * @return JXL_DEC_FULL_IMAGE when all pixel information at highest detail is + * available and has been output in the pixel buffer. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderProcessInput(JxlDecoder* dec); + +/** + * Sets input data for JxlDecoderProcessInput. The data is owned by the caller + * and may be used by the decoder until JxlDecoderReleaseInput is called or + * the decoder is destroyed or reset so must be kept alive until then. + * @param dec decoder object + * @param data pointer to next bytes to read from + * @param size amount of bytes available starting from data + * @return JXL_DEC_ERROR if input was already set without releasing, + * JXL_DEC_SUCCESS otherwise + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetInput(JxlDecoder* dec, + const uint8_t* data, + size_t size); + +/** + * Releases input which was provided with JxlDecoderSetInput. Between + * JxlDecoderProcessInput and JxlDecoderReleaseInput, the user may not alter + * the data in the buffer. Calling JxlDecoderReleaseInput is required whenever + * any input is already set and new input needs to be added with + * JxlDecoderSetInput, but is not required before JxlDecoderDestroy or + * JxlDecoderReset. Calling JxlDecoderReleaseInput when no input is set is + * not an error and returns 0. + * @param dec decoder object + * @return the amount of bytes the decoder has not yet processed that are + * still remaining in the data set by JxlDecoderSetInput, or 0 if no input is + * set or JxlDecoderReleaseInput was already called. For a next call to + * JxlDecoderProcessInput, the buffer must start with these unprocessed bytes. + * This value doesn't provide information about how many bytes the decoder + * truly processed internally or how large the original JPEG XL codestream or + * file are. + */ +JXL_EXPORT size_t JxlDecoderReleaseInput(JxlDecoder* dec); + +/** + * Outputs the basic image information, such as image dimensions, bit depth and + * all other JxlBasicInfo fields, if available. + * + * @param dec decoder object + * @param info struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return JXL_DEC_SUCCESS if the value is available, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetBasicInfo(const JxlDecoder* dec, + JxlBasicInfo* info); + +/** + * Outputs information for extra channel at the given index. The index must be + * smaller than num_extra_channels in the associated JxlBasicInfo. + * + * @param dec decoder object + * @param index index of the extra channel to query. + * @param info struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return JXL_DEC_SUCCESS if the value is available, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetExtraChannelInfo( + const JxlDecoder* dec, size_t index, JxlExtraChannelInfo* info); + +/** + * Outputs name for extra channel at the given index in UTF-8. The index must be + * smaller than num_extra_channels in the associated JxlBasicInfo. The buffer + * for name must have at least name_length + 1 bytes allocated, gotten from + * the associated JxlExtraChannelInfo. + * + * @param dec decoder object + * @param index index of the extra channel to query. + * @param name buffer to copy the name into + * @param size size of the name buffer in bytes + * @return JXL_DEC_SUCCESS if the value is available, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetExtraChannelName(const JxlDecoder* dec, + size_t index, + char* name, + size_t size); + +/** Defines which color profile to get: the profile from the codestream + * metadata header, which represents the color profile of the original image, + * or the color profile from the pixel data received by the decoder. Both are + * the same if the basic has uses_original_profile set. + */ +typedef enum { + /** Get the color profile of the original image from the metadata.. + */ + JXL_COLOR_PROFILE_TARGET_ORIGINAL = 0, + + /** Get the color profile of the pixel data the decoder outputs. */ + JXL_COLOR_PROFILE_TARGET_DATA = 1, +} JxlColorProfileTarget; + +/** + * Outputs the color profile as JPEG XL encoded structured data, if available. + * This is an alternative to an ICC Profile, which can represent a more limited + * amount of color spaces, but represents them exactly through enum values. + * + * It is often possible to use JxlDecoderGetColorAsICCProfile as an + * alternative anyway. The following scenarios are possible: + * - The JPEG XL image has an attached ICC Profile, in that case, the encoded + * structured data is not available, this function will return an error status + * and you must use JxlDecoderGetColorAsICCProfile instead. + * - The JPEG XL image has an encoded structured color profile, and it + * represents an RGB or grayscale color space. This function will return it. + * You can still use JxlDecoderGetColorAsICCProfile as well as an + * alternative if desired, though depending on which RGB color space is + * represented, the ICC profile may be a close approximation. It is also not + * always feasible to deduce from an ICC profile which named color space it + * exactly represents, if any, as it can represent any arbitrary space. + * - The JPEG XL image has an encoded structured color profile, and it indicates + * an unknown or xyb color space. In that case, + * JxlDecoderGetColorAsICCProfile is not available. + * + * If you wish to render the image using a system that supports ICC profiles, + * use JxlDecoderGetColorAsICCProfile first. If you're looking for a specific + * color space possibly indicated in the JPEG XL image, use + * JxlDecoderGetColorAsEncodedProfile first. + * + * @param dec decoder object + * @param format pixel format to output the data to. Only used for + * JXL_COLOR_PROFILE_TARGET_DATA, may be nullptr otherwise. + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param color_encoding struct to copy the information into, or NULL to only + * check whether the information is available through the return value. + * @return JXL_DEC_SUCCESS if the data is available and returned, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * the encuded structured color profile does not exist in the codestream. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetColorAsEncodedProfile( + const JxlDecoder* dec, const JxlPixelFormat* format, + JxlColorProfileTarget target, JxlColorEncoding* color_encoding); + +/** + * Outputs the size in bytes of the ICC profile returned by + * JxlDecoderGetColorAsICCProfile, if available, or indicates there is none + * available. In most cases, the image will have an ICC profile available, but + * if it does not, JxlDecoderGetColorAsEncodedProfile must be used instead. + * @see JxlDecoderGetColorAsEncodedProfile for more information. The ICC + * profile is either the exact ICC profile attached to the codestream metadata, + * or a close approximation generated from JPEG XL encoded structured data, + * depending of what is encoded in the codestream. + * + * @param dec decoder object + * @param format pixel format to output the data to. Only used for + * JXL_COLOR_PROFILE_TARGET_DATA, may be nullptr otherwise. + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param size variable to output the size into, or NULL to only check the + * return status. + * @return JXL_DEC_SUCCESS if the ICC profile is available, + * JXL_DEC_NEED_MORE_INPUT if the decoder has not yet received enough + * input data to determine whether an ICC profile is available or what its + * size is, JXL_DEC_ERROR in case the ICC profile is not available and + * cannot be generated. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderGetICCProfileSize(const JxlDecoder* dec, const JxlPixelFormat* format, + JxlColorProfileTarget target, size_t* size); + +/** + * Outputs ICC profile if available. The profile is only available if + * JxlDecoderGetICCProfileSize returns success. The output buffer must have + * at least as many bytes as given by JxlDecoderGetICCProfileSize. + * + * @param dec decoder object + * @param format pixel format to output the data to. Only used for + * JXL_COLOR_PROFILE_TARGET_DATA, may be nullptr otherwise. + * @param target whether to get the original color profile from the metadata + * or the color profile of the decoded pixels. + * @param icc_profile buffer to copy the ICC profile into + * @param size size of the icc_profile buffer in bytes + * @return JXL_DEC_SUCCESS if the profile was successfully returned is + * available, JXL_DEC_NEED_MORE_INPUT if not yet available, + * JXL_DEC_ERROR if the profile doesn't exist or the output size is not + * large enough. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetColorAsICCProfile( + const JxlDecoder* dec, const JxlPixelFormat* format, + JxlColorProfileTarget target, uint8_t* icc_profile, size_t size); + +/** + * Returns the minimum size in bytes of the preview image output pixel buffer + * for the given format. This is the buffer for JxlDecoderSetPreviewOutBuffer. + * Requires the preview header information is available in the decoder. + * + * @param dec decoder object + * @param format format of pixels + * @param size output value, buffer size in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * information not available yet. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderPreviewOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets the buffer to write the small resolution preview image + * to. The size of the buffer must be at least as large as given by + * JxlDecoderPreviewOutBufferSize. The buffer follows the format described by + * JxlPixelFormat. The preview image dimensions are given by the + * JxlPreviewHeader. The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of pixels. Object owned by user and its contents are + * copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * size too small. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreviewOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Outputs the information from the frame, such as duration when have_animation. + * This function can be called when JXL_DEC_FRAME occurred for the current + * frame, even when have_animation in the JxlBasicInfo is JXL_FALSE. + * + * @param dec decoder object + * @param header struct to copy the information into, or NULL to only check + * whether the information is available through the return value. + * @return JXL_DEC_SUCCESS if the value is available, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetFrameHeader(const JxlDecoder* dec, + JxlFrameHeader* header); + +/** + * Outputs name for the current frame. The buffer + * for name must have at least name_length + 1 bytes allocated, gotten from + * the associated JxlFrameHeader. + * + * @param dec decoder object + * @param name buffer to copy the name into + * @param size size of the name buffer in bytes, includig zero termination + * character, so this must be at least JxlFrameHeader.name_length + 1. + * @return JXL_DEC_SUCCESS if the value is available, + * JXL_DEC_NEED_MORE_INPUT if not yet available, JXL_DEC_ERROR in case + * of other error conditions. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderGetFrameName(const JxlDecoder* dec, + char* name, size_t size); + +/** + * Returns the minimum size in bytes of the DC image output buffer + * for the given format. This is the buffer for JxlDecoderSetDCOutBuffer. + * Requires the basic image information is available in the decoder. + * + * @param dec decoder object + * @param format format of pixels + * @param size output value, buffer size in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * information not available yet. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderDCOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets the buffer to write the lower resolution (8x8 sub-sampled) DC image + * to. The size of the buffer must be at least as large as given by + * JxlDecoderDCOutBufferSize. The buffer follows the format described by + * JxlPixelFormat. The DC image has dimensions ceil(xsize / 8) * ceil(ysize / + * 8). The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of pixels. Object owned by user and its contents are + * copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * size too small. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetDCOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Returns the minimum size in bytes of the image output pixel buffer for the + * given format. This is the buffer for JxlDecoderSetImageOutBuffer. Requires + * the basic image information is available in the decoder. + * + * @param dec decoder object + * @param format format of the pixels. + * @param size output value, buffer size in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * information not available yet. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderImageOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size); + +/** + * Sets output buffer for reconstructed JPEG codestream. + * + * The data is owned by the caller + * and may be used by the decoder until JxlDecoderReleaseJPEGBuffer is called or + * the decoder is destroyed or reset so must be kept alive until then. + * + * @param dec decoder object + * @param data pointer to next bytes to write to + * @param size amount of bytes available starting from data + * @return JXL_DEC_ERROR if input was already set without releasing, + * JXL_DEC_SUCCESS otherwise + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetJPEGBuffer(JxlDecoder* dec, + uint8_t* data, size_t size); + +/** + * Releases buffer which was provided with JxlDecoderSetJPEGBuffer. + * + * Calling JxlDecoderReleaseJPEGBuffer is required whenever + * a buffer is already set and a new buffer needs to be added with + * JxlDecoderSetJPEGBuffer, but is not required before JxlDecoderDestroy or + * JxlDecoderReset. + * + * Calling JxlDecoderReleaseJPEGBuffer when no input is set is + * not an error and returns 0. + * + * @param dec decoder object + * @return the amount of bytes the decoder has not yet written to of the data + * set by JxlDecoderSetJPEGBuffer, or 0 if no buffer is set or + * JxlDecoderReleaseJPEGBuffer was already called. + */ +JXL_EXPORT size_t JxlDecoderReleaseJPEGBuffer(JxlDecoder* dec); + +/** + * Sets the buffer to write the full resolution image to. This can be set when + * the JXL_DEC_FRAME event occurs, must be set when the + * JXL_DEC_NEED_IMAGE_OUT_BUFFER event occurs, and applies only for the current + * frame. The size of the buffer must be at least as large as given by + * JxlDecoderImageOutBufferSize. The buffer follows the format described by + * JxlPixelFormat. The buffer is owned by the caller. + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user and its contents + * are copied internally. + * @param buffer buffer type to output the pixel data to + * @param size size of buffer in bytes + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * size too small. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderSetImageOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size); + +/** + * Callback function type for JxlDecoderSetImageOutCallback. @see + * JxlDecoderSetImageOutCallback for usage. + * + * The callback bay be called simultaneously by different threads when using a + * threaded parallel runner, on different pixels. + * + * @param opaque optional user data, as given to JxlDecoderSetImageOutCallback. + * @param x horizontal position of leftmost pixel of the pixel data. + * @param y vertical position of the pixel data. + * @param xsize amount of pixels included in the pixel data, horizontally. + * This is not the same as xsize of the full image, it may be smaller. + * @param pixels pixel data as a horizontal stripe, in the format passed to + * JxlDecoderSetImageOutCallback. The memory is not owned by the user, and is + * only valid during the time the callback is running. + */ +typedef void (*JxlImageOutCallback)(void* opaque, size_t x, size_t y, + size_t xsize, const void* pixels); + +/** + * Sets pixel output callback. This is an alternative to + * JxlDecoderSetImageOutBuffer. This can be set when the JXL_DEC_FRAME event + * occurs, must be set when the JXL_DEC_NEED_IMAGE_OUT_BUFFER event occurs, and + * applies only for the current frame. Only one of JxlDecoderSetImageOutBuffer + * or JxlDecoderSetImageOutCallback may be used for the same frame, not both at + * the same time. + * + * The callback will be called multiple times, to receive the image + * data in small chunks. The callback receives a horizontal stripe of pixel + * data, 1 pixel high, xsize pixels wide, called a scanline. The xsize here is + * not the same as the full image width, the scanline may be a partial section, + * and xsize may differ between calls. The user can then process and/or copy the + * partial scanline to an image buffer. The callback bay be called + * simultaneously by different threads when using a threaded parallel runner, on + * different pixels. + * + * If JxlDecoderFlushImage is not used, then each pixel will be visited exactly + * once by the different callback calls, during processing with one or more + * JxlDecoderProcessInput calls. These pixels are decoded to full detail, they + * are not part of a lower resolution or lower quality progressive pass, but the + * final pass. + * + * If JxlDecoderFlushImage is used, then in addition each pixel will be visited + * zero or one times during the blocking JxlDecoderFlushImage call. Pixels + * visited as a result of JxlDecoderFlushImage may represent a lower resolution + * or lower quality intermediate progressive pass of the image. Any visited + * pixel will be of a quality at least as good or better than previous visits of + * this pixel. A pixel may be visited zero times if it cannot be decoded yet + * or if it was already decoded to full precision (this behavior is not + * guaranteed). + * + * @param dec decoder object + * @param format format of the pixels. Object owned by user and its contents + * are copied internally. + * @param callback the callback function receiving partial scanlines of pixel + * data. + * @param opaque optional user data, which will be passed on to the callback, + * may be NULL. + * @return JXL_DEC_SUCCESS on success, JXL_DEC_ERROR on error, such as + * JxlDecoderSetImageOutBuffer already set. + */ +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetImageOutCallback(JxlDecoder* dec, const JxlPixelFormat* format, + JxlImageOutCallback callback, void* opaque); + +/* TODO(lode): add way to output extra channels */ + +/** + * Outputs progressive step towards the decoded image so far when only partial + * input was received. If the flush was successful, the buffer set with + * JxlDecoderSetImageOutBuffer will contain partial image data. + * + * Can be called when JxlDecoderProcessInput returns JXL_DEC_NEED_MORE_INPUT, + * after the JXL_DEC_FRAME event already occured and before the + * JXL_DEC_FULL_IMAGE event occured for a frame. + * + * @param dec decoder object + * @return JXL_DEC_SUCCESS if image data was flushed to the output buffer, or + * JXL_DEC_ERROR when no flush was done, e.g. if not enough image data was + * available yet even for flush, or no output buffer was set yet. An error is + * not fatal, it only indicates no flushed image is available now, regular, + * decoding can still be performed. + */ +JXL_EXPORT JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_DECODE_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/decode_cxx.h b/third_party/jpeg-xl/lib/include/jxl/decode_cxx.h new file mode 100644 index 000000000000..0fba691c07e1 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/decode_cxx.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// @file decode_cxx.h +/// @brief C++ header-only helper for @ref decode.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_DECODE_CXX_H_ +#define JXL_DECODE_CXX_H_ + +#include + +#include "jxl/decode.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/decode.h from C sources." +#endif + +/// Struct to call JxlDecoderDestroy from the JxlDecoderPtr unique_ptr. +struct JxlDecoderDestroyStruct { + /// Calls @ref JxlDecoderDestroy() on the passed decoder. + void operator()(JxlDecoder* decoder) { JxlDecoderDestroy(decoder); } +}; + +/// std::unique_ptr<> type that calls JxlDecoderDestroy() when releasing the +/// decoder. +/// +/// Use this helper type from C++ sources to ensure the decoder is destroyed and +/// their internal resources released. +typedef std::unique_ptr JxlDecoderPtr; + +/// Creates an instance of JxlDecoder into a JxlDecoderPtr and initializes it. +/// +/// This function returns a unique_ptr that will call JxlDecoderDestroy() when +/// releasing the pointer. See @ref JxlDecoderCreate for details on the +/// instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @return a @c NULL JxlDecoderPtr if the instance can not be allocated or +/// initialized +/// @return initialized JxlDecoderPtr instance otherwise. +static inline JxlDecoderPtr JxlDecoderMake( + const JxlMemoryManager* memory_manager) { + return JxlDecoderPtr(JxlDecoderCreate(memory_manager)); +} + +#endif // JXL_DECODE_CXX_H_ diff --git a/third_party/jpeg-xl/lib/include/jxl/encode.h b/third_party/jpeg-xl/lib/include/jxl/encode.h new file mode 100644 index 000000000000..d85cbb96705a --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/encode.h @@ -0,0 +1,380 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file encode.h + * @brief Encoding API for JPEG XL. + */ + +#ifndef JXL_ENCODE_H_ +#define JXL_ENCODE_H_ + +#include "jxl/decode.h" +#include "jxl/jxl_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Encoder library version. + * + * @return the encoder library version as an integer: + * MAJOR_VERSION * 1000000 + MINOR_VERSION * 1000 + PATCH_VERSION. For example, + * version 1.2.3 would return 1002003. + */ +JXL_EXPORT uint32_t JxlEncoderVersion(void); + +/** + * Opaque structure that holds the JPEG XL encoder. + * + * Allocated and initialized with JxlEncoderCreate(). + * Cleaned up and deallocated with JxlEncoderDestroy(). + */ +typedef struct JxlEncoderStruct JxlEncoder; + +/** + * Opaque structure that holds frame specific encoding options for a JPEG XL + * encoder. + * + * Allocated and initialized with JxlEncoderOptionsCreate(). + * Cleaned up and deallocated when the encoder is destroyed with + * JxlEncoderDestroy(). + */ +typedef struct JxlEncoderOptionsStruct JxlEncoderOptions; + +/** + * Return value for multiple encoder functions. + */ +typedef enum { + /** Function call finished successfully, or encoding is finished and there is + * nothing more to be done. + */ + JXL_ENC_SUCCESS = 0, + + /** An error occurred, for example out of memory. + */ + JXL_ENC_ERROR = 1, + + /** The encoder needs more output buffer to continue encoding. + */ + JXL_ENC_NEED_MORE_OUTPUT = 2, + + /** The encoder doesn't (yet) support this. + */ + JXL_ENC_NOT_SUPPORTED = 3, + +} JxlEncoderStatus; + +/** + * Creates an instance of JxlEncoder and initializes it. + * + * @p memory_manager will be used for all the library dynamic allocations made + * from this instance. The parameter may be NULL, in which case the default + * allocator will be used. See jpegxl/memory_manager.h for details. + * + * @param memory_manager custom allocator function. It may be NULL. The memory + * manager will be copied internally. + * @return @c NULL if the instance can not be allocated or initialized + * @return pointer to initialized JxlEncoder otherwise + */ +JXL_EXPORT JxlEncoder* JxlEncoderCreate(const JxlMemoryManager* memory_manager); + +/** + * Re-initializes a JxlEncoder instance, so it can be re-used for encoding + * another image. All state and settings are reset as if the object was + * newly created with JxlEncoderCreate, but the memory manager is kept. + * + * @param enc instance to be re-initialized. + */ +JXL_EXPORT void JxlEncoderReset(JxlEncoder* enc); + +/** + * Deinitializes and frees JxlEncoder instance. + * + * @param enc instance to be cleaned up and deallocated. + */ +JXL_EXPORT void JxlEncoderDestroy(JxlEncoder* enc); + +/** + * Set the parallel runner for multithreading. May only be set before starting + * encoding. + * + * @param enc encoder object. + * @param parallel_runner function pointer to runner for multithreading. It may + * be NULL to use the default, single-threaded, runner. A multithreaded + * runner should be set to reach fast performance. + * @param parallel_runner_opaque opaque pointer for parallel_runner. + * @return JXL_ENC_SUCCESS if the runner was set, JXL_ENC_ERROR + * otherwise (the previous runner remains set). + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderSetParallelRunner(JxlEncoder* enc, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque); + +/** + * Encodes JPEG XL file using the available bytes. @p *avail_out indicates how + * many output bytes are available, and @p *next_out points to the input bytes. + * *avail_out will be decremented by the amount of bytes that have been + * processed by the encoder and *next_out will be incremented by the same + * amount, so *next_out will now point at the amount of *avail_out unprocessed + * bytes. + * + * The returned status indicates whether the encoder needs more output bytes. + * When the return value is not JXL_ENC_ERROR or JXL_ENC_SUCCESS, the encoding + * requires more JxlEncoderProcessOutput calls to continue. + * + * @param enc encoder object. + * @param next_out pointer to next bytes to write to. + * @param avail_out amount of bytes available starting from *next_out. + * @return JXL_ENC_SUCCESS when encoding finished and all events handled. + * @return JXL_ENC_ERROR when encoding failed, e.g. invalid input. + * @return JXL_ENC_NEED_MORE_OUTPUT more output buffer is necessary. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderProcessOutput(JxlEncoder* enc, + uint8_t** next_out, + size_t* avail_out); + +/** + * Sets the buffer to read JPEG encoded bytes from for the next frame to encode. + * + * If the encoder is set to store JPEG reconstruction metadata using @ref + * JxlEncoderStoreJPEGMetadata and a single JPEG frame is added, it will be + * possible to losslessly reconstruct the JPEG codestream. + * + * @param options set of encoder options to use when encoding the frame. + * @param buffer bytes to read JPEG from. Owned by the caller and its contents + * are copied internally. + * @param size size of buffer in bytes. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderAddJPEGFrame( + const JxlEncoderOptions* options, const uint8_t* buffer, size_t size); + +/** + * Sets the buffer to read pixels from for the next image to encode. Must call + * JxlEncoderSetDimensions before JxlEncoderAddImageFrame. + * + * Currently only some pixel formats are supported: + * - JXL_TYPE_UINT8 + * - JXL_TYPE_UINT16 + * - JXL_TYPE_FLOAT, with nominal range 0..1 + * + * The color profile of the pixels depends on the value of uses_original_profile + * in the JxlBasicInfo. If true, the pixels are assumed to be encoded in the + * original profile that is set with JxlEncoderSetColorEncoding or + * JxlEncoderSetICCProfile. If false, the pixels are assumed to be nonlinear + * sRGB for integer data types (JXL_TYPE_UINT8 and JXL_TYPE_UINT16), and linear + * sRGB for floating point data types (JXL_TYPE_FLOAT). + * + * @param options set of encoder options to use when encoding the frame. + * @param pixel_format format for pixels. Object owned by the caller and its + * contents are copied internally. + * @param buffer buffer type to input the pixel data from. Owned by the caller + * and its contents are copied internally. + * @param size size of buffer in bytes. + * @return JXL_ENC_SUCCESS on success, JXL_ENC_ERROR on error + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderAddImageFrame( + const JxlEncoderOptions* options, const JxlPixelFormat* pixel_format, + const void* buffer, size_t size); + +/** + * Declares that this encoder will not encode anything further. + * + * Must be called between JxlEncoderAddImageFrame/JPEGFrame of the last frame + * and the next call to JxlEncoderProcessOutput, or JxlEncoderProcessOutput + * won't output the last frame correctly. + * + * @param enc encoder object. + */ +JXL_EXPORT void JxlEncoderCloseInput(JxlEncoder* enc); + +/** + * Sets the original color encoding of the image encoded by this encoder. This + * is an alternative to JxlEncoderSetICCProfile and only one of these two must + * be used. This one sets the color encoding as a @ref JxlColorEncoding, while + * the other sets it as ICC binary data. + * + * @param enc encoder object. + * @param color color encoding. Object owned by the caller and its contents are + * copied internally. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR or + * JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderSetColorEncoding(JxlEncoder* enc, const JxlColorEncoding* color); + +/** + * Sets the original color encoding of the image encoded by this encoder as an + * ICC color profile. This is an alternative to JxlEncoderSetColorEncoding and + * only one of these two must be used. This one sets the color encoding as ICC + * binary data, while the other defines it as a @ref JxlColorEncoding. + * + * @param enc encoder object. + * @param icc_profile bytes of the original ICC profile + * @param size size of the icc_profile buffer in bytes + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR or + * JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetICCProfile(JxlEncoder* enc, + const uint8_t* icc_profile, + size_t size); + +/** + * Sets the global metadata of the image encoded by this encoder. + * + * @param enc encoder object. + * @param info global image metadata. Object owned by the caller and its + * contents are copied internally. + * @return JXL_ENC_SUCCESS if the operation was successful, + * JXL_ENC_ERROR or JXL_ENC_NOT_SUPPORTED otherwise + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderSetBasicInfo(JxlEncoder* enc, + const JxlBasicInfo* info); + +/** + * Configure the encoder to store JPEG reconstruction metadata in the JPEG XL + * container. + * + * The encoder must be configured to use the JPEG XL container format using @ref + * JxlEncoderUseContainer for this to have any effect. + * + * If this is set to true and a single JPEG frame is added, it will be + * possible to losslessly reconstruct the JPEG codestream. + * + * @param enc encoder object. + * @param store_jpeg_metadata true if the encoder should store JPEG metadata. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderStoreJPEGMetadata(JxlEncoder* enc, JXL_BOOL store_jpeg_metadata); + +/** + * Configure the encoder to use the JPEG XL container format. + * + * Using the JPEG XL container format allows to store metadata such as JPEG + * reconstruction (@ref JxlEncoderStoreJPEGMetadata) or other metadata like + * EXIF; but it adds a few bytes to the encoded file for container headers even + * if there is no extra metadata. + * + * @param enc encoder object. + * @param use_container true if the encoder should output the JPEG XL container + * format. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus JxlEncoderUseContainer(JxlEncoder* enc, + JXL_BOOL use_container); + +/** + * Sets lossless/lossy mode for the provided options. Default is lossy. + * + * @param options set of encoder options to update with the new mode + * @param lossless whether the options should be lossless + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderOptionsSetLossless(JxlEncoderOptions* options, JXL_BOOL lossless); + +/** + * Set the decoding speed tier for the provided options. Minimum is 0 (highest + * quality), and maximum is 4 (lowest quality). Default is 0. + * + * @param options set of encoder options to update with the new decoding speed + * tier. + * @param tier the decoding speed tier to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderOptionsSetDecodingSpeed(JxlEncoderOptions* options, int tier); + +/** + * Sets encoder effort/speed level without affecting decoding speed. Valid + * values are, from faster to slower speed: 3:falcon 4:cheetah 5:hare 6:wombat + * 7:squirrel 8:kitten 9:tortoise Default: squirrel (7). + * + * @param options set of encoder options to update with the new mode. + * @param effort the effort value to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderOptionsSetEffort(JxlEncoderOptions* options, int effort); + +/** + * Sets the distance level for lossy compression: target max butteraugli + * distance, lower = higher quality. Range: 0 .. 15. + * 0.0 = mathematically lossless (however, use JxlEncoderOptionsSetLossless to + * use true lossless). + * 1.0 = visually lossless. + * Recommended range: 0.5 .. 3.0. + * Default value: 1.0. + * If JxlEncoderOptionsSetLossless is used, this value is unused and implied + * to be 0. + * + * @param options set of encoder options to update with the new mode. + * @param distance the distance value to set. + * @return JXL_ENC_SUCCESS if the operation was successful, JXL_ENC_ERROR + * otherwise. + */ +JXL_EXPORT JxlEncoderStatus +JxlEncoderOptionsSetDistance(JxlEncoderOptions* options, float distance); + +/** + * Create a new set of encoder options, with all values initially copied from + * the @p source options, or set to default if @p source is NULL. + * + * The returned pointer is an opaque struct tied to the encoder and it will be + * deallocated by the encoder when JxlEncoderDestroy() is called. For functions + * taking both a @ref JxlEncoder and a @ref JxlEncoderOptions, only + * JxlEncoderOptions created with this function for the same encoder instance + * can be used. + * + * @param enc encoder object. + * @param source source options to copy initial values from, or NULL to get + * defaults initialized to defaults. + * @return the opaque struct pointer identifying a new set of encoder options. + */ +JXL_EXPORT JxlEncoderOptions* JxlEncoderOptionsCreate( + JxlEncoder* enc, const JxlEncoderOptions* source); + +/** + * Sets a color encoding to be sRGB. + * + * @param color_encoding color encoding instance. + * @param is_gray whether the color encoding should be gray scale or color. + */ +JXL_EXPORT void JxlColorEncodingSetToSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray); + +/** + * Sets a color encoding to be linear sRGB. + * + * @param color_encoding color encoding instance. + * @param is_gray whether the color encoding should be gray scale or color. + */ +JXL_EXPORT void JxlColorEncodingSetToLinearSRGB( + JxlColorEncoding* color_encoding, JXL_BOOL is_gray); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_ENCODE_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/encode_cxx.h b/third_party/jpeg-xl/lib/include/jxl/encode_cxx.h new file mode 100644 index 000000000000..a5a963044ff2 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/encode_cxx.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// @file encode_cxx.h +/// @brief C++ header-only helper for @ref encode.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_ENCODE_CXX_H_ +#define JXL_ENCODE_CXX_H_ + +#include + +#include "jxl/encode.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error "This a C++ only header. Use jxl/encode.h from C sources." +#endif + +/// Struct to call JxlEncoderDestroy from the JxlEncoderPtr unique_ptr. +struct JxlEncoderDestroyStruct { + /// Calls @ref JxlEncoderDestroy() on the passed encoder. + void operator()(JxlEncoder* encoder) { JxlEncoderDestroy(encoder); } +}; + +/// std::unique_ptr<> type that calls JxlEncoderDestroy() when releasing the +/// encoder. +/// +/// Use this helper type from C++ sources to ensure the encoder is destroyed and +/// their internal resources released. +typedef std::unique_ptr JxlEncoderPtr; + +/// Creates an instance of JxlEncoder into a JxlEncoderPtr and initializes it. +/// +/// This function returns a unique_ptr that will call JxlEncoderDestroy() when +/// releasing the pointer. See @ref JxlEncoderCreate for details on the +/// instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @return a @c NULL JxlEncoderPtr if the instance can not be allocated or +/// initialized +/// @return initialized JxlEncoderPtr instance otherwise. +static inline JxlEncoderPtr JxlEncoderMake( + const JxlMemoryManager* memory_manager) { + return JxlEncoderPtr(JxlEncoderCreate(memory_manager)); +} + +#endif // JXL_ENCODE_CXX_H_ diff --git a/third_party/jpeg-xl/lib/include/jxl/memory_manager.h b/third_party/jpeg-xl/lib/include/jxl/memory_manager.h new file mode 100644 index 000000000000..6826ffbed913 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/memory_manager.h @@ -0,0 +1,76 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file memory_manager.h + * @brief Abstraction functions used by JPEG XL to allocate memory. + */ + +#ifndef JXL_MEMORY_MANAGER_H_ +#define JXL_MEMORY_MANAGER_H_ + +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * Allocating function for a memory region of a given size. + * + * Allocates a contiguous memory region of size @p size bytes. The returned + * memory may not be aligned to a specific size or initialized at all. + * + * @param opaque custom memory manager handle provided by the caller. + * @param size in bytes of the requested memory region. + * @returns @c 0 if the memory can not be allocated, + * @returns pointer to the memory otherwise. + */ +typedef void* (*jpegxl_alloc_func)(void* opaque, size_t size); + +/** + * Deallocating function pointer type. + * + * This function @b MUST do nothing if @p address is @c 0. + * + * @param opaque custom memory manager handle provided by the caller. + * @param address memory region pointer returned by ::jpegxl_alloc_func, or @c 0 + */ +typedef void (*jpegxl_free_func)(void* opaque, void* address); + +/** + * Memory Manager struct. + * These functions, when provided by the caller, will be used to handle memory + * allocations. + */ +typedef struct JxlMemoryManagerStruct { + /** The opaque pointer that will be passed as the first parameter to all the + * functions in this struct. */ + void* opaque; + + /** Memory allocation function. This can be NULL if and only if also the + * free() member in this class is NULL. All dynamic memory will be allocated + * and freed with these functions if they are not NULL. */ + jpegxl_alloc_func alloc; + /** Free function matching the alloc() member. */ + jpegxl_free_func free; + + /* TODO(deymo): Add cache-aligned alloc/free functions here. */ +} JxlMemoryManager; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_MEMORY_MANAGER_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/parallel_runner.h b/third_party/jpeg-xl/lib/include/jxl/parallel_runner.h new file mode 100644 index 000000000000..0532ea01d342 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/parallel_runner.h @@ -0,0 +1,160 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file parallel_runner.h + */ + +/** API for running data operations in parallel in a multi-threaded environment. + * This module allows the JPEG XL caller to define their own way of creating and + * assigning threads. + * + * The JxlParallelRunner function type defines a parallel data processing + * runner that may be implemented by the caller to allow the library to process + * in multiple threads. The multi-threaded processing in this library only + * requires to run the same function over each number of a range, possibly + * running each call in a different thread. The JPEG XL caller is responsible + * for implementing this logic using the thread APIs available in their system. + * For convenience, a C++ implementation based on std::thread is provided in + * jpegxl/parallel_runner_thread.h (part of the jpegxl_threads library). + * + * Thread pools usually store small numbers of heterogeneous tasks in a queue. + * When tasks are identical or differ only by an integer input parameter, it is + * much faster to store just one function of an integer parameter and call it + * for each value. Conventional vector-of-tasks can be run in parallel using a + * lambda function adapter that simply calls task_funcs[task]. + * + * If no multi-threading is desired, a @c NULL value of JxlParallelRunner + * will use an internal implementation without multi-threading. + */ + +#ifndef JXL_PARALLEL_RUNNER_H_ +#define JXL_PARALLEL_RUNNER_H_ + +#include +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Return code used in the JxlParallel* functions as return value. A value + * of 0 means success and any other value means error. The special value + * JXL_PARALLEL_RET_RUNNER_ERROR can be used by the runner to indicate any + * other error. + */ +typedef int JxlParallelRetCode; + +/** + * General error returned by the JxlParallelRunInit function to indicate + * an error. + */ +#define JXL_PARALLEL_RET_RUNNER_ERROR (-1) + +/** + * Parallel run initialization callback. See JxlParallelRunner for details. + * + * This function MUST be called by the JxlParallelRunner only once, on the + * same thread that called JxlParallelRunner, before any parallel execution. + * The purpose of this call is to provide the maximum number of threads that the + * JxlParallelRunner will use, which can be used by JPEG XL to allocate + * per-thread storage if needed. + * + * @param jpegxl_opaque the @p jpegxl_opaque handle provided to + * JxlParallelRunner() must be passed here. + * @param num_threads the maximum number of threads. This value must be + * positive. + * @returns 0 if the initialization process was successful. + * @returns an error code if there was an error, which should be returned by + * JxlParallelRunner(). + */ +typedef JxlParallelRetCode (*JxlParallelRunInit)(void* jpegxl_opaque, + size_t num_threads); + +/** + * Parallel run data processing callback. See JxlParallelRunner for details. + * + * This function MUST be called once for every number in the range [start_range, + * end_range) (including start_range but not including end_range) passing this + * number as the @p value. Calls for different value may be executed from + * different threads in parallel. + * + * @param jpegxl_opaque the @p jpegxl_opaque handle provided to + * JxlParallelRunner() must be passed here. + * @param value the number in the range [start_range, end_range) of the call. + * @param thread_id the thread number where this function is being called from. + * This must be lower than the @p num_threads value passed to + * JxlParallelRunInit. + */ +typedef void (*JxlParallelRunFunction)(void* jpegxl_opaque, uint32_t value, + size_t thread_id); + +/** + * JxlParallelRunner function type. A parallel runner implementation can be + * provided by a JPEG XL caller to allow running computations in multiple + * threads. This function must call the initialization function @p init in the + * same thread that called it and then call the passed @p func once for every + * number in the range [start_range, end_range) (including start_range but not + * including end_range) possibly from different multiple threads in parallel. + * + * The JxlParallelRunner function does not need to be re-entrant. This means + * that the same JxlParallelRunner function with the same runner_opaque + * provided parameter will not be called from the library from either @p init or + * @p func in the same decoder or encoder instance. However, a single decoding + * or encoding instance may call the provided JxlParallelRunner multiple + * times for different parts of the decoding or encoding process. + * + * @returns 0 if the @p init call succeeded (returned 0) and no other error + * occurred in the runner code. + * @returns JXL_PARALLEL_RET_RUNNER_ERROR if an error occurred in the runner + * code, for example, setting up the threads. + * @return the return value of @p init() if non-zero. + */ +typedef JxlParallelRetCode (*JxlParallelRunner)( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + +/* The following is an example of a JxlParallelRunner that doesn't use any + * multi-threading. Note that this implementation doesn't store any state + * between multiple calls of the ExampleSequentialRunner function, so the + * runner_opaque value is not used. + + JxlParallelRetCode ExampleSequentialRunner(void* runner_opaque, + void* jpegxl_opaque, + JxlParallelRunInit init, + JxlParallelRunFunction func, + uint32_t start_range, + uint32_t end_range) { + // We only use one thread (the currently running thread). + JxlParallelRetCode init_ret = (*init)(jpegxl_opaque, 1); + if (init_ret != 0) return init_ret; + + // In case of other initialization error (for example when initializing the + // threads) one can return JXL_PARALLEL_RET_RUNNER_ERROR. + + for (uint32_t i = start_range; i < end_range; i++) { + // Every call is in the thread number 0. These don't need to be in any + // order. + (*func)(jpegxl_opaque, i, 0); + } + return 0; + } + */ + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_PARALLEL_RUNNER_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner.h b/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner.h new file mode 100644 index 000000000000..e052273e0486 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner.h @@ -0,0 +1,78 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file thread_parallel_runner.h + * @brief implementation using std::thread of a ::JxlParallelRunner. + */ + +/** Implementation of JxlParallelRunner than can be used to enable + * multithreading when using the JPEG XL library. This uses std::thread + * internally and related synchronization functions. The number of threads + * created is fixed at construction time and the threads are re-used for every + * ThreadParallelRunner::Runner call. Only one concurrent + * JxlThreadParallelRunner call per instance is allowed at a time. + * + * This is a scalable, lower-overhead thread pool runner, especially suitable + * for data-parallel computations in the fork-join model, where clients need to + * know when all tasks have completed. + * + * This thread pool can efficiently load-balance millions of tasks using an + * atomic counter, thus avoiding per-task virtual or system calls. With 48 + * hyperthreads and 1M tasks that add to an atomic counter, overall runtime is + * 10-20x higher when using std::async, and ~200x for a queue-based thread + */ + +#ifndef JXL_THREAD_PARALLEL_RUNNER_H_ +#define JXL_THREAD_PARALLEL_RUNNER_H_ + +#include +#include +#include +#include + +#include "jxl/jxl_threads_export.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** Parallel runner internally using std::thread. Use as JxlParallelRunner. + */ +JXL_THREADS_EXPORT JxlParallelRetCode JxlThreadParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + +/** Creates the runner for JxlThreadParallelRunner. Use as the opaque + * runner. + */ +JXL_THREADS_EXPORT void* JxlThreadParallelRunnerCreate( + const JxlMemoryManager* memory_manager, size_t num_worker_threads); + +/** Destroys the runner created by JxlThreadParallelRunnerCreate. + */ +JXL_THREADS_EXPORT void JxlThreadParallelRunnerDestroy(void* runner_opaque); + +/** Returns a default num_worker_threads value for + * JxlThreadParallelRunnerCreate. + */ +JXL_THREADS_EXPORT size_t JxlThreadParallelRunnerDefaultNumWorkerThreads(); + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_THREAD_PARALLEL_RUNNER_H_ */ diff --git a/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner_cxx.h b/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner_cxx.h new file mode 100644 index 000000000000..30ed579521a0 --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/thread_parallel_runner_cxx.h @@ -0,0 +1,68 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// @file thread_parallel_runner_cxx.h +/// @brief C++ header-only helper for @ref thread_parallel_runner.h. +/// +/// There's no binary library associated with the header since this is a header +/// only library. + +#ifndef JXL_THREAD_PARALLEL_RUNNER_CXX_H_ +#define JXL_THREAD_PARALLEL_RUNNER_CXX_H_ + +#include + +#include "jxl/thread_parallel_runner.h" + +#if !(defined(__cplusplus) || defined(c_plusplus)) +#error \ + "This a C++ only header. Use jxl/jxl_thread_parallel_runner.h from C" \ + "sources." +#endif + +/// Struct to call JxlThreadParallelRunnerDestroy from the +/// JxlThreadParallelRunnerPtr unique_ptr. +struct JxlThreadParallelRunnerDestroyStruct { + /// Calls @ref JxlThreadParallelRunnerDestroy() on the passed runner. + void operator()(void* runner) { JxlThreadParallelRunnerDestroy(runner); } +}; + +/// std::unique_ptr<> type that calls JxlThreadParallelRunnerDestroy() when +/// releasing the runner. +/// +/// Use this helper type from C++ sources to ensure the runner is destroyed and +/// their internal resources released. +typedef std::unique_ptr + JxlThreadParallelRunnerPtr; + +/// Creates an instance of JxlThreadParallelRunner into a +/// JxlThreadParallelRunnerPtr and initializes it. +/// +/// This function returns a unique_ptr that will call +/// JxlThreadParallelRunnerDestroy() when releasing the pointer. See @ref +/// JxlThreadParallelRunnerCreate for details on the instance creation. +/// +/// @param memory_manager custom allocator function. It may be NULL. The memory +/// manager will be copied internally. +/// @param num_worker_threads the number of worker threads to create. +/// @return a @c NULL JxlThreadParallelRunnerPtr if the instance can not be +/// allocated or initialized +/// @return initialized JxlThreadParallelRunnerPtr instance otherwise. +static inline JxlThreadParallelRunnerPtr JxlThreadParallelRunnerMake( + const JxlMemoryManager* memory_manager, size_t num_worker_threads) { + return JxlThreadParallelRunnerPtr( + JxlThreadParallelRunnerCreate(memory_manager, num_worker_threads)); +} + +#endif // JXL_THREAD_PARALLEL_RUNNER_CXX_H_ diff --git a/third_party/jpeg-xl/lib/include/jxl/types.h b/third_party/jpeg-xl/lib/include/jxl/types.h new file mode 100644 index 000000000000..797a7d4b3c2e --- /dev/null +++ b/third_party/jpeg-xl/lib/include/jxl/types.h @@ -0,0 +1,125 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file types.h + * @brief Data types for the JPEG XL API, for both encoding and decoding. + */ + +#ifndef JXL_TYPES_H_ +#define JXL_TYPES_H_ + +#include +#include + +#if defined(__cplusplus) || defined(c_plusplus) +extern "C" { +#endif + +/** + * A portable @c bool replacement. + * + * ::JXL_BOOL is a "documentation" type: actually it is @c int, but in API it + * denotes a type, whose only values are ::JXL_TRUE and ::JXL_FALSE. + */ +#define JXL_BOOL int +/** Portable @c true replacement. */ +#define JXL_TRUE 1 +/** Portable @c false replacement. */ +#define JXL_FALSE 0 + +/** Data type for the sample values per channel per pixel. + */ +typedef enum { + /** Use 32-bit single-precision floating point values, with range 0.0-1.0 + * (within gamut, may go outside this range for wide color gamut). Floating + * point output, either JXL_TYPE_FLOAT or JXL_TYPE_FLOAT16, is recommended + * for HDR and wide gamut images when color profile conversion is required. */ + JXL_TYPE_FLOAT = 0, + + /** Use 1-bit packed in uint8_t, first pixel in LSB, padded to uint8_t per + * row. + * TODO(lode): support first in MSB, other padding. + */ + JXL_TYPE_BOOLEAN, + + /** Use type uint8_t. May clip wide color gamut data. + */ + JXL_TYPE_UINT8, + + /** Use type uint16_t. May clip wide color gamut data. + */ + JXL_TYPE_UINT16, + + /** Use type uint32_t. May clip wide color gamut data. + */ + JXL_TYPE_UINT32, + + /** Use 16-bit IEEE 754 half-precision floating point values */ + JXL_TYPE_FLOAT16, +} JxlDataType; + +/** Ordering of multi-byte data. + */ +typedef enum { + /** Use the endianness of the system, either little endian or big endian, + * without forcing either specific endianness. Do not use if pixel data + * should be exported to a well defined format. + */ + JXL_NATIVE_ENDIAN = 0, + /** Force little endian */ + JXL_LITTLE_ENDIAN = 1, + /** Force big endian */ + JXL_BIG_ENDIAN = 2, +} JxlEndianness; + +/** Data type for the sample values per channel per pixel for the output buffer + * for pixels. This is not necessarily the same as the data type encoded in the + * codestream. The channels are interleaved per pixel. The pixels are + * organized row by row, left to right, top to bottom. + * TODO(lode): implement padding / alignment (row stride) + * TODO(lode): support different channel orders if needed (RGB, BGR, ...) + */ +typedef struct { + /** Amount of channels available in a pixel buffer. + * 1: single-channel data, e.g. grayscale + * 2: single-channel + alpha + * 3: trichromatic, e.g. RGB + * 4: trichromatic + alpha + * TODO(lode): this needs finetuning. It is not yet defined how the user + * chooses output color space. CMYK+alpha needs 5 channels. + */ + uint32_t num_channels; + + /** Data type of each channel. + */ + JxlDataType data_type; + + /** Whether multi-byte data types are represented in big endian or little + * endian format. This applies to JXL_TYPE_UINT16, JXL_TYPE_UINT32 + * and JXL_TYPE_FLOAT. + */ + JxlEndianness endianness; + + /** Align scanlines to a multiple of align bytes, or 0 to require no + * alignment at all (which has the same effect as value 1) + */ + size_t align; +} JxlPixelFormat; + +#if defined(__cplusplus) || defined(c_plusplus) +} +#endif + +#endif /* JXL_TYPES_H_ */ diff --git a/third_party/jpeg-xl/lib/jxl.cmake b/third_party/jpeg-xl/lib/jxl.cmake new file mode 100644 index 000000000000..81db26a7bd41 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl.cmake @@ -0,0 +1,543 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lists all source files for the JPEG XL decoder library. These are also used +# by the encoder: the encoder uses both dec and enc ourse files, while the +# decoder uses only dec source files. +# TODO(lode): further prune these files and move to JPEGXL_INTERNAL_SOURCES_ENC: +# only those files that the decoder absolutely needs, and or not +# only for encoding, should be listed here. +set(JPEGXL_INTERNAL_SOURCES_DEC + jxl/ac_context.h + jxl/ac_strategy.cc + jxl/ac_strategy.h + jxl/alpha.cc + jxl/alpha.h + jxl/ans_common.cc + jxl/ans_common.h + jxl/ans_params.h + jxl/aux_out.cc + jxl/aux_out.h + jxl/aux_out_fwd.h + jxl/base/arch_macros.h + jxl/base/bits.h + jxl/base/byte_order.h + jxl/base/cache_aligned.cc + jxl/base/cache_aligned.h + jxl/base/compiler_specific.h + jxl/base/data_parallel.cc + jxl/base/data_parallel.h + jxl/base/descriptive_statistics.cc + jxl/base/descriptive_statistics.h + jxl/base/file_io.h + jxl/base/iaca.h + jxl/base/os_macros.h + jxl/base/override.h + jxl/base/padded_bytes.cc + jxl/base/padded_bytes.h + jxl/base/profiler.h + jxl/base/robust_statistics.h + jxl/base/span.h + jxl/base/status.cc + jxl/base/status.h + jxl/base/thread_pool_internal.h + jxl/base/time.cc + jxl/base/time.h + jxl/blending.cc + jxl/blending.h + jxl/chroma_from_luma.cc + jxl/chroma_from_luma.h + jxl/codec_in_out.h + jxl/coeff_order.cc + jxl/coeff_order.h + jxl/coeff_order_fwd.h + jxl/color_encoding_internal.cc + jxl/color_encoding_internal.h + jxl/color_management.cc + jxl/color_management.h + jxl/common.h + jxl/compressed_dc.cc + jxl/compressed_dc.h + jxl/convolve-inl.h + jxl/convolve.cc + jxl/convolve.h + jxl/dct-inl.h + jxl/dct_block-inl.h + jxl/dct_scales.cc + jxl/dct_scales.h + jxl/dct_util.h + jxl/dec_ans.cc + jxl/dec_ans.h + jxl/dec_bit_reader.h + jxl/dec_cache.h + jxl/dec_context_map.cc + jxl/dec_context_map.h + jxl/dec_external_image.cc + jxl/dec_external_image.h + jxl/dec_frame.cc + jxl/dec_frame.h + jxl/dec_group.cc + jxl/dec_group.h + jxl/dec_group_border.cc + jxl/dec_group_border.h + jxl/dec_huffman.cc + jxl/dec_huffman.h + jxl/dec_modular.cc + jxl/dec_modular.h + jxl/dec_noise.cc + jxl/dec_noise.h + jxl/dec_params.h + jxl/dec_patch_dictionary.cc + jxl/dec_patch_dictionary.h + jxl/dec_reconstruct.cc + jxl/dec_reconstruct.h + jxl/dec_transforms-inl.h + jxl/dec_upsample.cc + jxl/dec_upsample.h + jxl/dec_xyb-inl.h + jxl/dec_xyb.cc + jxl/dec_xyb.h + jxl/decode.cc + jxl/enc_bit_writer.cc + jxl/enc_bit_writer.h + jxl/entropy_coder.cc + jxl/entropy_coder.h + jxl/epf.cc + jxl/epf.h + jxl/fast_math-inl.h + jxl/field_encodings.h + jxl/fields.cc + jxl/fields.h + jxl/filters.cc + jxl/filters.h + jxl/filters_internal.h + jxl/frame_header.cc + jxl/frame_header.h + jxl/gauss_blur.cc + jxl/gauss_blur.h + jxl/headers.cc + jxl/headers.h + jxl/huffman_table.cc + jxl/huffman_table.h + jxl/icc_codec.cc + jxl/icc_codec.h + jxl/icc_codec_common.cc + jxl/icc_codec_common.h + jxl/image.cc + jxl/image.h + jxl/image_bundle.cc + jxl/image_bundle.h + jxl/image_metadata.cc + jxl/image_metadata.h + jxl/image_ops.h + jxl/jpeg/dec_jpeg_data.cc + jxl/jpeg/dec_jpeg_data.h + jxl/jpeg/dec_jpeg_data_writer.cc + jxl/jpeg/dec_jpeg_data_writer.h + jxl/jpeg/dec_jpeg_output_chunk.h + jxl/jpeg/dec_jpeg_serialization_state.h + jxl/jpeg/jpeg_data.cc + jxl/jpeg/jpeg_data.h + jxl/jxl_inspection.h + jxl/lehmer_code.h + jxl/linalg.h + jxl/loop_filter.cc + jxl/loop_filter.h + jxl/luminance.cc + jxl/luminance.h + jxl/memory_manager_internal.cc + jxl/memory_manager_internal.h + jxl/modular/encoding/context_predict.h + jxl/modular/encoding/dec_ma.cc + jxl/modular/encoding/dec_ma.h + jxl/modular/encoding/encoding.cc + jxl/modular/encoding/encoding.h + jxl/modular/encoding/ma_common.h + jxl/modular/modular_image.cc + jxl/modular/modular_image.h + jxl/modular/options.h + jxl/modular/transform/near-lossless.h + jxl/modular/transform/palette.h + jxl/modular/transform/squeeze.h + jxl/modular/transform/subtractgreen.h + jxl/modular/transform/transform.cc + jxl/modular/transform/transform.h + jxl/noise.h + jxl/noise_distributions.h + jxl/opsin_params.cc + jxl/opsin_params.h + jxl/passes_state.cc + jxl/passes_state.h + jxl/patch_dictionary_internal.h + jxl/quant_weights.cc + jxl/quant_weights.h + jxl/quantizer-inl.h + jxl/quantizer.cc + jxl/quantizer.h + jxl/rational_polynomial-inl.h + jxl/splines.cc + jxl/splines.h + jxl/toc.cc + jxl/toc.h + jxl/transfer_functions-inl.h + jxl/transpose-inl.h + jxl/xorshift128plus-inl.h +) + +# List of source files only needed by the encoder or by tools (including +# decoding tools), but not by the decoder library. +set(JPEGXL_INTERNAL_SOURCES_ENC + jxl/butteraugli/butteraugli.cc + jxl/butteraugli/butteraugli.h + jxl/butteraugli_wrapper.cc + jxl/dec_file.cc + jxl/dec_file.h + jxl/enc_ac_strategy.cc + jxl/enc_ac_strategy.h + jxl/enc_adaptive_quantization.cc + jxl/enc_adaptive_quantization.h + jxl/enc_ans.cc + jxl/enc_ans.h + jxl/enc_ans_params.h + jxl/enc_ar_control_field.cc + jxl/enc_ar_control_field.h + jxl/enc_butteraugli_comparator.cc + jxl/enc_butteraugli_comparator.h + jxl/enc_butteraugli_pnorm.cc + jxl/enc_butteraugli_pnorm.h + jxl/enc_cache.cc + jxl/enc_cache.h + jxl/enc_chroma_from_luma.cc + jxl/enc_chroma_from_luma.h + jxl/enc_cluster.cc + jxl/enc_cluster.h + jxl/enc_coeff_order.cc + jxl/enc_coeff_order.h + jxl/enc_color_management.cc + jxl/enc_color_management.h + jxl/enc_comparator.cc + jxl/enc_comparator.h + jxl/enc_context_map.cc + jxl/enc_context_map.h + jxl/enc_detect_dots.cc + jxl/enc_detect_dots.h + jxl/enc_dot_dictionary.cc + jxl/enc_dot_dictionary.h + jxl/enc_entropy_coder.cc + jxl/enc_entropy_coder.h + jxl/enc_external_image.cc + jxl/enc_external_image.h + jxl/enc_fast_heuristics.cc + jxl/enc_file.cc + jxl/enc_file.h + jxl/enc_frame.cc + jxl/enc_frame.h + jxl/enc_gamma_correct.h + jxl/enc_group.cc + jxl/enc_group.h + jxl/enc_heuristics.cc + jxl/enc_heuristics.h + jxl/enc_huffman.cc + jxl/enc_huffman.h + jxl/enc_icc_codec.cc + jxl/enc_icc_codec.h + jxl/enc_image_bundle.cc + jxl/enc_image_bundle.h + jxl/enc_modular.cc + jxl/enc_modular.h + jxl/enc_noise.cc + jxl/enc_noise.h + jxl/enc_params.h + jxl/enc_patch_dictionary.cc + jxl/enc_patch_dictionary.h + jxl/enc_quant_weights.cc + jxl/enc_quant_weights.h + jxl/enc_splines.cc + jxl/enc_splines.h + jxl/enc_toc.cc + jxl/enc_toc.h + jxl/enc_transforms-inl.h + jxl/enc_transforms.cc + jxl/enc_transforms.h + jxl/enc_xyb.cc + jxl/enc_xyb.h + jxl/encode.cc + jxl/encode_internal.h + jxl/gaborish.cc + jxl/gaborish.h + jxl/huffman_tree.cc + jxl/huffman_tree.h + jxl/jpeg/enc_jpeg_data.cc + jxl/jpeg/enc_jpeg_data.h + jxl/jpeg/enc_jpeg_data_reader.cc + jxl/jpeg/enc_jpeg_data_reader.h + jxl/jpeg/enc_jpeg_huffman_decode.cc + jxl/jpeg/enc_jpeg_huffman_decode.h + jxl/linalg.cc + jxl/modular/encoding/enc_encoding.cc + jxl/modular/encoding/enc_encoding.h + jxl/modular/encoding/enc_ma.cc + jxl/modular/encoding/enc_ma.h + jxl/optimize.cc + jxl/optimize.h + jxl/progressive_split.cc + jxl/progressive_split.h +) + +set(JPEGXL_DEC_INTERNAL_LIBS + brotlidec-static + brotlicommon-static + hwy +) + +set(JPEGXL_INTERNAL_LIBS + ${JPEGXL_DEC_INTERNAL_LIBS} + brotlienc-static + Threads::Threads +) + +# strips the -static suffix from all the elements in LIST +function(strip_static OUTPUT_VAR LIB_LIST) + foreach(lib IN LISTS ${LIB_LIST}) + string(REGEX REPLACE "-static$" "" lib "${lib}") + list(APPEND out_list "${lib}") + endforeach() + set(${OUTPUT_VAR} ${out_list} PARENT_SCOPE) +endfunction() + +if (JPEGXL_ENABLE_SKCMS) + list(APPEND JPEGXL_INTERNAL_FLAGS -DJPEGXL_ENABLE_SKCMS=1) + list(APPEND JPEGXL_INTERNAL_LIBS skcms) +else () + list(APPEND JPEGXL_INTERNAL_LIBS lcms2) +endif () + +set(OBJ_COMPILE_DEFINITIONS + JPEGXL_MAJOR_VERSION=${JPEGXL_MAJOR_VERSION} + JPEGXL_MINOR_VERSION=${JPEGXL_MINOR_VERSION} + JPEGXL_PATCH_VERSION=${JPEGXL_PATCH_VERSION} + # Used to determine if we are building the library when defined or just + # including the library when not defined. This is public so libjxl shared + # library gets this define too. + JXL_INTERNAL_LIBRARY_BUILD +) + +# Decoder-only object library +add_library(jxl_dec-obj OBJECT ${JPEGXL_INTERNAL_SOURCES_DEC}) +target_compile_options(jxl_dec-obj PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_dec-obj PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_dec-obj PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_dec-obj PUBLIC + ${PROJECT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + $ + $ +) +target_compile_definitions(jxl_dec-obj PUBLIC + ${OBJ_COMPILE_DEFINITIONS} +) + +# Object library. This is used to hold the set of objects and properties. +add_library(jxl_enc-obj OBJECT ${JPEGXL_INTERNAL_SOURCES_ENC}) +target_compile_options(jxl_enc-obj PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_enc-obj PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_enc-obj PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_enc-obj PUBLIC + ${PROJECT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + $ + $ +) +target_compile_definitions(jxl_enc-obj PUBLIC + ${OBJ_COMPILE_DEFINITIONS} +) + +#TODO(lode): don't depend on CMS for the core library +if (JPEGXL_ENABLE_SKCMS) + target_include_directories(jxl_enc-obj PRIVATE + $ + ) + target_include_directories(jxl_dec-obj PRIVATE + $ + ) +else () + target_include_directories(jxl_enc-obj PRIVATE + $ + ) + target_include_directories(jxl_dec-obj PRIVATE + $ + ) +endif () + +# Headers for exporting/importing public headers +include(GenerateExportHeader) +# TODO(deymo): Add these visibility properties to the static dependencies of +# jxl_{dec,enc}-obj since those are currently compiled with the default +# visibility. +set_target_properties(jxl_dec-obj PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_INTERNAL_LIBRARY_BUILD +) +target_include_directories(jxl_dec-obj PUBLIC + ${CMAKE_CURRENT_BINARY_DIR}/include) + +set_target_properties(jxl_enc-obj PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_INTERNAL_LIBRARY_BUILD +) +generate_export_header(jxl_enc-obj + BASE_NAME JXL + EXPORT_FILE_NAME include/jxl/jxl_export.h) +target_include_directories(jxl_enc-obj PUBLIC + ${CMAKE_CURRENT_BINARY_DIR}/include) + +# Private static library. This exposes all the internal functions and is used +# for tests. +# TODO(lode): this library is missing symbols, more encoder-only code needs to +# be moved to JPEGXL_INTERNAL_SOURCES_ENC before this works +add_library(jxl_dec-static STATIC + $ +) +target_link_libraries(jxl_dec-static + PUBLIC ${JPEGXL_COVERAGE_FLAGS} ${JPEGXL_DEC_INTERNAL_LIBS} hwy) +target_include_directories(jxl_dec-static PUBLIC + "${PROJECT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") + +# Private static library. This exposes all the internal functions and is used +# for tests. +# TODO(lode): once the source files are correctly split so that it is possible +# to do, remove $ here and depend on jxl_dec-static +add_library(jxl-static STATIC + $ + $ +) +target_link_libraries(jxl-static + PUBLIC ${JPEGXL_COVERAGE_FLAGS} ${JPEGXL_INTERNAL_LIBS} hwy) +target_include_directories(jxl-static PUBLIC + "${PROJECT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") + +# JXL_EXPORT is defined to "__declspec(dllimport)" automatically by CMake +# in Windows builds when including headers from the C API and compiling from +# outside the jxl library. This is required when using the shared library, +# however in windows this causes the function to not be found when linking +# against the static library. This define JXL_EXPORT= here forces it to not +# use dllimport in tests and other tools that require the static library. +target_compile_definitions(jxl-static INTERFACE -DJXL_EXPORT=) +target_compile_definitions(jxl_dec-static INTERFACE -DJXL_EXPORT=) + +# TODO(deymo): Move TCMalloc linkage to the tools/ directory since the library +# shouldn't do any allocs anyway. +if(${JPEGXL_ENABLE_TCMALLOC}) + pkg_check_modules(TCMallocMinimal REQUIRED IMPORTED_TARGET + libtcmalloc_minimal) + # tcmalloc 2.8 has concurrency issues that makes it sometimes return nullptr + # for large allocs. See https://github.com/gperftools/gperftools/issues/1204 + # for details. + if(TCMallocMinimal_VERSION VERSION_EQUAL 2.8) + message(FATAL_ERROR + "tcmalloc version 2.8 has a concurrency bug. You have installed " + "version ${TCMallocMinimal_VERSION}, please either downgrade tcmalloc " + "to version 2.7, upgrade to 2.8.1 or newer or pass " + "-DJPEGXL_ENABLE_TCMALLOC=OFF to jpeg-xl cmake line. See the following " + "bug for details:\n" + " https://github.com/gperftools/gperftools/issues/1204\n") + endif() + target_link_libraries(jxl-static PUBLIC PkgConfig::TCMallocMinimal) +endif() # JPEGXL_ENABLE_TCMALLOC + +# Install the static library too, but as jxl.a file without the -static except +# in Windows. +if (NOT WIN32) + set_target_properties(jxl-static PROPERTIES OUTPUT_NAME "jxl") + set_target_properties(jxl_dec-static PROPERTIES OUTPUT_NAME "jxl_dec") +endif() +install(TARGETS jxl-static DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS jxl_dec-static DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +if (((NOT DEFINED "${TARGET_SUPPORTS_SHARED_LIBS}") OR + TARGET_SUPPORTS_SHARED_LIBS) AND NOT JPEGXL_STATIC) + +# Public shared library. +add_library(jxl SHARED + $ + $) +strip_static(JPEGXL_INTERNAL_SHARED_LIBS JPEGXL_INTERNAL_LIBS) +target_link_libraries(jxl PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +target_link_libraries(jxl PRIVATE ${JPEGXL_INTERNAL_SHARED_LIBS}) +# Shared library include path contains only the "include/" paths. +target_include_directories(jxl PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") +set_target_properties(jxl PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + +# Public shared decoder library. +add_library(jxl_dec SHARED $) +strip_static(JPEGXL_DEC_INTERNAL_SHARED_LIBS JPEGXL_DEC_INTERNAL_LIBS) +target_link_libraries(jxl_dec PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +target_link_libraries(jxl_dec PRIVATE ${JPEGXL_DEC_INTERNAL_SHARED_LIBS}) +# Shared library include path contains only the "include/" paths. +target_include_directories(jxl_dec PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") +set_target_properties(jxl_dec PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + +# Add a jxl.version file as a version script to tag symbols with the +# appropriate version number. This script is also used to limit what's exposed +# in the shared library from the static dependencies bundled here. +foreach(target IN ITEMS jxl jxl_dec) + set_target_properties(${target} PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version) + if(APPLE) + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS "-Wl,-exported_symbols_list,${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl_osx.syms") + elseif(WIN32) + # Nothing needed here, we use __declspec(dllexport) (jxl_export.h) + else() + set_property(TARGET ${target} APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/jxl/jxl.version") + endif() # APPLE +endforeach() + +# Only install libjxl shared library. The libjxl_dec is not installed since it +# contains symbols also in libjxl which would conflict if programs try to use +# both. +install(TARGETS jxl + DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# Add a pkg-config file for libjxl. +set(JPEGXL_LIBRARY_REQUIRES + "libhwy libbrotlicommon libbrotlienc libbrotlidec") +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/jxl/libjxl.pc.in" + "libjxl.pc" @ONLY) +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libjxl.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + +else() +add_library(jxl ALIAS jxl-static) +add_library(jxl_dec ALIAS jxl_dec-static) +endif() # TARGET_SUPPORTS_SHARED_LIBS AND NOT JPEGXL_STATIC diff --git a/third_party/jpeg-xl/lib/jxl/ac_context.h b/third_party/jpeg-xl/lib/jxl/ac_context.h new file mode 100644 index 000000000000..a5822d2043c6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_context.h @@ -0,0 +1,158 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_AC_CONTEXT_H_ +#define LIB_JXL_AC_CONTEXT_H_ + +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" + +namespace jxl { + +// Block context used for scanning order, number of non-zeros, AC coefficients. +// Equal to the channel. +constexpr uint32_t kDCTOrderContextStart = 0; + +// The number of predicted nonzeros goes from 0 to 1008. We use +// ceil(log2(predicted+1)) as a context for the number of nonzeros, so from 0 to +// 10, inclusive. +constexpr uint32_t kNonZeroBuckets = 37; + +static const uint16_t kCoeffFreqContext[64] = { + 0xBAD, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, + 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, +}; + +static const uint16_t kCoeffNumNonzeroContext[64] = { + 0xBAD, 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, + 152, 152, 152, 152, 152, 152, 152, 152, 180, 180, 180, 180, 180, + 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, +}; + +// Supremum of ZeroDensityContext(x, y) + 1, when x + y < 64. +constexpr int kZeroDensityContextCount = 458; +// Supremum of ZeroDensityContext(x, y) + 1. +constexpr int kZeroDensityContextLimit = 474; + +/* This function is used for entropy-sources pre-clustering. + * + * Ideally, each combination of |nonzeros_left| and |k| should go to its own + * bucket; but it implies (64 * 63 / 2) == 2016 buckets. If there is other + * dimension (e.g. block context), then number of primary clusters becomes too + * big. + * + * To solve this problem, |nonzeros_left| and |k| values are clustered. It is + * known that their sum is at most 64, consequently, the total number buckets + * is at most A(64) * B(64). + */ +// TODO(user): investigate, why disabling pre-clustering makes entropy code +// less dense. Perhaps we would need to add HQ clustering algorithm that would +// be able to squeeze better by spending more CPU cycles. +static JXL_INLINE size_t ZeroDensityContext(size_t nonzeros_left, size_t k, + size_t covered_blocks, + size_t log2_covered_blocks, + size_t prev) { + JXL_DASSERT((1u << log2_covered_blocks) == covered_blocks); + nonzeros_left = (nonzeros_left + covered_blocks - 1) >> log2_covered_blocks; + k >>= log2_covered_blocks; + JXL_DASSERT(k > 0); + JXL_DASSERT(k < 64); + JXL_DASSERT(nonzeros_left > 0); + // Asserting nonzeros_left + k < 65 here causes crashes in debug mode with + // invalid input, since the (hot) decoding loop does not check this condition. + // As no out-of-bound memory reads are issued even if that condition is + // broken, we check this simpler condition which holds anyway. The decoder + // will still mark a file in which that condition happens as not valid at the + // end of the decoding loop, as `nzeros` will not be `0`. + JXL_DASSERT(nonzeros_left < 64); + return (kCoeffNumNonzeroContext[nonzeros_left] + kCoeffFreqContext[k]) * 2 + + prev; +} + +struct BlockCtxMap { + std::vector dc_thresholds[3]; + std::vector qf_thresholds; + std::vector ctx_map; + size_t num_ctxs, num_dc_ctxs; + + static constexpr uint8_t kDefaultCtxMap[] = { + // Default ctx map clusters all the large transforms together. + 0, 1, 2, 2, 3, 3, 4, 5, 6, 6, 6, 6, 6, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + }; + static_assert(3 * kNumOrders == + sizeof(kDefaultCtxMap) / sizeof *kDefaultCtxMap, + "Update default context map"); + + size_t Context(int dc_idx, uint32_t qf, size_t ord, size_t c) const { + size_t qf_idx = 0; + for (uint32_t t : qf_thresholds) { + if (qf > t) qf_idx++; + } + size_t idx = c < 2 ? c ^ 1 : 2; + idx = idx * kNumOrders + ord; + idx = idx * (qf_thresholds.size() + 1) + qf_idx; + idx = idx * num_dc_ctxs + dc_idx; + return ctx_map[idx]; + } + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + constexpr uint32_t ZeroDensityContextsOffset(uint32_t block_ctx) const { + return num_ctxs * kNonZeroBuckets + kZeroDensityContextCount * block_ctx; + } + + // Context map for AC coefficients consists of 2 blocks: + // |num_ctxs x : context for number of non-zeros in the block + // kNonZeroBuckets| computed from block context and predicted + // value (based top and left values) + // |num_ctxs x : context for AC coefficient symbols, + // kZeroDensityContextCount| computed from block context, + // number of non-zeros left and + // index in scan order + constexpr uint32_t NumACContexts() const { + return num_ctxs * (kNonZeroBuckets + kZeroDensityContextCount); + } + + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + inline uint32_t NonZeroContext(uint32_t non_zeros, uint32_t block_ctx) const { + uint32_t ctx; + if (non_zeros >= 64) non_zeros = 64; + if (non_zeros < 8) { + ctx = non_zeros; + } else { + ctx = 4 + non_zeros / 2; + } + return ctx * num_ctxs + block_ctx; + } + + BlockCtxMap() { + ctx_map.assign(std::begin(kDefaultCtxMap), std::end(kDefaultCtxMap)); + num_ctxs = *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + num_dc_ctxs = 1; + } +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_CONTEXT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy.cc b/third_party/jpeg-xl/lib/jxl/ac_strategy.cc new file mode 100644 index 000000000000..7da56e8e5064 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy.cc @@ -0,0 +1,119 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/ac_strategy.h" + +#include + +#include +#include // iota +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +// Tries to generalize zig-zag order to non-square blocks. Surprisingly, in +// square block frequency along the (i + j == const) diagonals is roughly the +// same. For historical reasons, consecutive diagonals are traversed +// in alternating directions - so called "zig-zag" (or "snake") order. +AcStrategy::CoeffOrderAndLut::CoeffOrderAndLut() { + for (size_t s = 0; s < AcStrategy::kNumValidStrategies; s++) { + const AcStrategy acs = AcStrategy::FromRawStrategy(s); + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + JXL_ASSERT((AcStrategy::CoeffOrderAndLut::kOffset[s + 1] - + AcStrategy::CoeffOrderAndLut::kOffset[s]) == cx * cy); + coeff_order_t* JXL_RESTRICT order_start = + order + AcStrategy::CoeffOrderAndLut::kOffset[s] * kDCTBlockSize; + coeff_order_t* JXL_RESTRICT lut_start = + lut + AcStrategy::CoeffOrderAndLut::kOffset[s] * kDCTBlockSize; + + // CoefficientLayout ensures cx >= cy. + // We compute the zigzag order for a cx x cx block, then discard all the + // lines that are not multiple of the ratio between cx and cy. + size_t xs = cx / cy; + size_t xsm = xs - 1; + size_t xss = CeilLog2Nonzero(xs); + // First half of the block + size_t cur = cx * cy; + for (size_t i = 0; i < cx * kBlockDim; i++) { + for (size_t j = 0; j <= i; j++) { + size_t x = j; + size_t y = i - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = 0; + if (x < cx && y < cy) { + val = y * cx + x; + } else { + val = cur++; + } + lut_start[y * cx * kBlockDim + x] = val; + order_start[val] = y * cx * kBlockDim + x; + } + } + // Second half + for (size_t ip = cx * kBlockDim - 1; ip > 0; ip--) { + size_t i = ip - 1; + for (size_t j = 0; j <= i; j++) { + size_t x = cx * kBlockDim - 1 - (i - j); + size_t y = cx * kBlockDim - 1 - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = cur++; + lut_start[y * cx * kBlockDim + x] = val; + order_start[val] = y * cx * kBlockDim + x; + } + } + } +} + +const AcStrategy::CoeffOrderAndLut* AcStrategy::CoeffOrder() { + static AcStrategy::CoeffOrderAndLut* order = + new AcStrategy::CoeffOrderAndLut(); + return order; +} + +// These definitions are needed before C++17. +constexpr size_t AcStrategy::kMaxCoeffBlocks; +constexpr size_t AcStrategy::kMaxBlockDim; +constexpr size_t AcStrategy::kMaxCoeffArea; +constexpr size_t AcStrategy::CoeffOrderAndLut::kOffset[]; + +AcStrategyImage::AcStrategyImage(size_t xsize, size_t ysize) + : layers_(xsize, ysize) { + row_ = layers_.Row(0); + stride_ = layers_.PixelsPerRow(); +} + +size_t AcStrategyImage::CountBlocks(AcStrategy::Type type) const { + size_t ret = 0; + for (size_t y = 0; y < layers_.ysize(); y++) { + const uint8_t* JXL_RESTRICT row = layers_.ConstRow(y); + for (size_t x = 0; x < layers_.xsize(); x++) { + if (row[x] == ((static_cast(type) << 1) | 1)) ret++; + } + } + return ret; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy.h b/third_party/jpeg-xl/lib/jxl/ac_strategy.h new file mode 100644 index 000000000000..eb89d2daf7cb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy.h @@ -0,0 +1,296 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_AC_STRATEGY_H_ +#define LIB_JXL_AC_STRATEGY_H_ + +#include +#include + +#include // kMaxVectorSize + +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +// Defines the different kinds of transforms, and heuristics to choose between +// them. +// `AcStrategy` represents what transform should be used, and which sub-block of +// that transform we are currently in. Note that DCT4x4 is applied on all four +// 4x4 sub-blocks of an 8x8 block. +// `AcStrategyImage` defines which strategy should be used for each 8x8 block +// of the image. The highest 4 bits represent the strategy to be used, the +// lowest 4 represent the index of the block inside that strategy. + +namespace jxl { + +class AcStrategy { + public: + // Extremal values for the number of blocks/coefficients of a single strategy. + static constexpr size_t kMaxCoeffBlocks = 32; + static constexpr size_t kMaxBlockDim = kBlockDim * kMaxCoeffBlocks; + // Maximum number of coefficients in a block. Guaranteed to be a multiple of + // the vector size. + static constexpr size_t kMaxCoeffArea = kMaxBlockDim * kMaxBlockDim; + static_assert((kMaxCoeffArea * sizeof(float)) % hwy::kMaxVectorSize == 0, + "Coefficient area is not a multiple of vector size"); + + // Raw strategy types. + enum Type : uint32_t { + // Regular block size DCT + DCT = 0, + // Encode pixels without transforming + IDENTITY = 1, + // Use 2-by-2 DCT + DCT2X2 = 2, + // Use 4-by-4 DCT + DCT4X4 = 3, + // Use 16-by-16 DCT + DCT16X16 = 4, + // Use 32-by-32 DCT + DCT32X32 = 5, + // Use 16-by-8 DCT + DCT16X8 = 6, + // Use 8-by-16 DCT + DCT8X16 = 7, + // Use 32-by-8 DCT + DCT32X8 = 8, + // Use 8-by-32 DCT + DCT8X32 = 9, + // Use 32-by-16 DCT + DCT32X16 = 10, + // Use 16-by-32 DCT + DCT16X32 = 11, + // 4x8 and 8x4 DCT + DCT4X8 = 12, + DCT8X4 = 13, + // Corner-DCT. + AFV0 = 14, + AFV1 = 15, + AFV2 = 16, + AFV3 = 17, + // Larger DCTs + DCT64X64 = 18, + DCT64X32 = 19, + DCT32X64 = 20, + DCT128X128 = 21, + DCT128X64 = 22, + DCT64X128 = 23, + DCT256X256 = 24, + DCT256X128 = 25, + DCT128X256 = 26, + // Marker for num of valid strategies. + kNumValidStrategies + }; + + static constexpr uint32_t TypeBit(const Type type) { + return 1u << static_cast(type); + } + + // Returns true if this block is the first 8x8 block (i.e. top-left) of a + // possibly multi-block strategy. + JXL_INLINE bool IsFirstBlock() const { return is_first_; } + + JXL_INLINE bool IsMultiblock() const { + constexpr uint32_t bits = + TypeBit(Type::DCT16X16) | TypeBit(Type::DCT32X32) | + TypeBit(Type::DCT16X8) | TypeBit(Type::DCT8X16) | + TypeBit(Type::DCT32X8) | TypeBit(Type::DCT8X32) | + TypeBit(Type::DCT16X32) | TypeBit(Type::DCT32X16) | + TypeBit(Type::DCT32X64) | TypeBit(Type::DCT64X32) | + TypeBit(Type::DCT64X64) | TypeBit(DCT64X128) | TypeBit(DCT128X64) | + TypeBit(DCT128X128) | TypeBit(DCT128X256) | TypeBit(DCT256X128) | + TypeBit(DCT256X256); + JXL_DASSERT(Strategy() < kNumValidStrategies); + return ((1u << static_cast(Strategy())) & bits) != 0; + } + + // Returns the raw strategy value. Should only be used for tokenization. + JXL_INLINE uint8_t RawStrategy() const { + return static_cast(strategy_); + } + + JXL_INLINE Type Strategy() const { return strategy_; } + + // Inverse check + static JXL_INLINE constexpr bool IsRawStrategyValid(int raw_strategy) { + return raw_strategy < static_cast(kNumValidStrategies) && + raw_strategy >= 0; + } + static JXL_INLINE AcStrategy FromRawStrategy(uint8_t raw_strategy) { + return FromRawStrategy(static_cast(raw_strategy)); + } + static JXL_INLINE AcStrategy FromRawStrategy(Type raw_strategy) { + JXL_DASSERT(IsRawStrategyValid(static_cast(raw_strategy))); + return AcStrategy(raw_strategy, /*is_first=*/true); + } + + // "Natural order" means the order of increasing of "anisotropic" frequency of + // continuous version of DCT basis. + // Round-trip, for any given strategy s: + // X = NaturalCoeffOrder(s)[NaturalCoeffOrderLutN(s)[X]] + // X = NaturalCoeffOrderLut(s)[NaturalCoeffOrderN(s)[X]] + JXL_INLINE const coeff_order_t* NaturalCoeffOrder() const { + return CoeffOrder()->order + + CoeffOrderAndLut::kOffset[RawStrategy()] * kDCTBlockSize; + } + + JXL_INLINE const coeff_order_t* NaturalCoeffOrderLut() const { + return CoeffOrder()->lut + + CoeffOrderAndLut::kOffset[RawStrategy()] * kDCTBlockSize; + } + + // Number of 8x8 blocks that this strategy will cover. 0 for non-top-left + // blocks inside a multi-block transform. + JXL_INLINE size_t covered_blocks_x() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 1, 2, 1, + 4, 2, 4, 1, 1, 1, 1, 1, 1, + 8, 4, 8, 16, 8, 16, 32, 16, 32}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t covered_blocks_y() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 2, 1, 4, + 1, 4, 2, 1, 1, 1, 1, 1, 1, + 8, 8, 4, 16, 16, 8, 32, 32, 16}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t log2_covered_blocks() const { + static constexpr uint8_t kLut[] = {0, 0, 0, 0, 2, 4, 1, 1, 2, + 2, 3, 3, 0, 0, 0, 0, 0, 0, + 6, 5, 5, 8, 7, 7, 10, 9, 9}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + struct CoeffOrderAndLut { + // Those offsets get multiplied by kDCTBlockSize. + // TODO(veluca): reduce this array by merging together the same order type. + static constexpr size_t kOffset[kNumValidStrategies + 1] = { + 0, 1, 2, 3, 4, 8, 24, 26, 28, 32, 36, 44, 52, 53, + 54, 55, 56, 57, 58, 122, 154, 186, 442, 570, 698, 1722, 2234, 2746, + }; + static constexpr size_t kTotalTableSize = + kOffset[kNumValidStrategies] * kDCTBlockSize; + coeff_order_t order[kTotalTableSize]; + coeff_order_t lut[kTotalTableSize]; + + private: + CoeffOrderAndLut(); + friend class AcStrategy; + }; + + private: + friend class AcStrategyRow; + JXL_INLINE AcStrategy(Type strategy, bool is_first) + : strategy_(strategy), is_first_(is_first) { + JXL_DASSERT(IsMultiblock() || is_first == true); + } + + Type strategy_; + bool is_first_; + + static const CoeffOrderAndLut* CoeffOrder(); +}; + +// Class to use a certain row of the AC strategy. +class AcStrategyRow { + public: + explicit AcStrategyRow(const uint8_t* row) : row_(row) {} + AcStrategy operator[](size_t x) const { + return AcStrategy(static_cast(row_[x] >> 1), row_[x] & 1); + } + + private: + const uint8_t* JXL_RESTRICT row_; +}; + +class AcStrategyImage { + public: + AcStrategyImage() = default; + AcStrategyImage(size_t xsize, size_t ysize); + AcStrategyImage(AcStrategyImage&&) = default; + AcStrategyImage& operator=(AcStrategyImage&&) = default; + + void FillDCT8(const Rect& rect) { + FillPlane((static_cast(AcStrategy::Type::DCT) << 1) | 1, + &layers_, rect); + } + void FillDCT8() { FillDCT8(Rect(layers_)); } + + void FillInvalid() { FillImage(INVALID, &layers_); } + + void Set(size_t x, size_t y, AcStrategy::Type type) { +#if JXL_ENABLE_ASSERT + AcStrategy acs = AcStrategy::FromRawStrategy(type); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(y + acs.covered_blocks_y() <= layers_.ysize()); + JXL_ASSERT(x + acs.covered_blocks_x() <= layers_.xsize()); + JXL_CHECK(SetNoBoundsCheck(x, y, type, /*check=*/false)); + } + + Status SetNoBoundsCheck(size_t x, size_t y, AcStrategy::Type type, + bool check = true) { + AcStrategy acs = AcStrategy::FromRawStrategy(type); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + size_t pos = (y + iy) * stride_ + x + ix; + if (check && row_[pos] != INVALID) { + return JXL_FAILURE("Invalid AC strategy: block overlap"); + } + row_[pos] = + (static_cast(type) << 1) | ((iy | ix) == 0 ? 1 : 0); + } + } + return true; + } + + bool IsValid(size_t x, size_t y) { return row_[y * stride_ + x] != INVALID; } + + AcStrategyRow ConstRow(size_t y, size_t x_prefix = 0) const { + return AcStrategyRow(layers_.ConstRow(y) + x_prefix); + } + + AcStrategyRow ConstRow(const Rect& rect, size_t y) const { + return ConstRow(rect.y0() + y, rect.x0()); + } + + size_t PixelsPerRow() const { return layers_.PixelsPerRow(); } + + size_t xsize() const { return layers_.xsize(); } + size_t ysize() const { return layers_.ysize(); } + + // Count the number of blocks of a given type. + size_t CountBlocks(AcStrategy::Type type) const; + + private: + ImageB layers_; + uint8_t* JXL_RESTRICT row_; + size_t stride_; + + // A value that does not represent a valid combined AC strategy + // value. Used as a sentinel. + static constexpr uint8_t INVALID = 0xFF; +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_STRATEGY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc b/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc new file mode 100644 index 000000000000..5990eebee7b0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc @@ -0,0 +1,234 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/ac_strategy.h" + +#include + +#include +#include +#include // HWY_ALIGN_MAX +#include +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_transforms.h" + +namespace jxl { +namespace { + +// Test that DCT -> IDCT is a noop. +class AcStrategyRoundtrip : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* coeffs = scratch_space + AcStrategy::kMaxCoeffArea; + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + + for (size_t i = 0; i < std::min(1024u, 64u << acs.log2_covered_blocks()); + i++) { + float* input = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(input, AcStrategy::kMaxCoeffArea, 0); + input[i] = 0.2f; + TransformFromPixels(type, input, acs.covered_blocks_x() * 8, coeffs, + scratch_space); + ASSERT_NEAR(coeffs[0], 0.2 / (64 << acs.log2_covered_blocks()), 1e-6) + << " i = " << i; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], j == i ? 0.2f : 0, 2e-6) + << "j = " << j << " i = " << i << " acs " << type; + } + } + // Test DC. + std::fill_n(idct, AcStrategy::kMaxCoeffArea, 0); + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + float* dc = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs); + DCFromLowestFrequencies(type, coeffs, idct, acs.covered_blocks_x() * 8); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], dc[j], 1e-6) + << "j = " << j << " x = " << x << " y = " << y << " acs " << type; + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtrip, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtrip, Test) { Run(); } + +// Test that DC(2x2) -> DCT coefficients -> IDCT -> downsampled IDCT is a noop. +class AcStrategyRoundtripDownsample + : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* coeffs = scratch_space + AcStrategy::kMaxCoeffArea; + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + float* dc = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs); + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64.0f; + ASSERT_NEAR(sum, dc[dy * 8 * acs.covered_blocks_x() + dx], 1e-6) + << "acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtripDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtripDownsample, Test) { Run(); } + +// Test that IDCT(block with zeros in the non-topleft corner) -> downsampled +// IDCT is the same as IDCT -> DC(2x2) of the same block. +class AcStrategyDownsample : public ::hwy::TestWithParamTargetAndT { + protected: + void Run() { + const AcStrategy::Type type = static_cast(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + size_t cx = acs.covered_blocks_y(); + size_t cy = acs.covered_blocks_x(); + CoefficientLayout(&cy, &cx); + + auto mem = hwy::AllocateAligned(4 * AcStrategy::kMaxCoeffArea); + float* scratch_space = mem.get(); + float* idct = scratch_space + AcStrategy::kMaxCoeffArea; + float* idct_acs_downsampled = idct + AcStrategy::kMaxCoeffArea; + + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + float* coeffs = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + DCFromLowestFrequencies(type, coeffs, idct_acs_downsampled, + acs.covered_blocks_x() * 8); + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64; + ASSERT_NEAR( + sum, idct_acs_downsampled[dy * 8 * acs.covered_blocks_x() + dx], + 1e-6) + << " acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyDownsample, Test) { Run(); } + +class AcStrategyTargetTest : public ::hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(AcStrategyTargetTest); + +TEST_P(AcStrategyTargetTest, RoundtripAFVDCT) { + HWY_ALIGN_MAX float idct[16]; + for (size_t i = 0; i < 16; i++) { + HWY_ALIGN_MAX float pixels[16] = {}; + pixels[i] = 1; + HWY_ALIGN_MAX float coeffs[16] = {}; + + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, idct); + for (size_t j = 0; j < 16; j++) { + EXPECT_NEAR(idct[j], pixels[j], 1e-6); + } + } +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFV) { + const AcStrategy::Type type = AcStrategy::Type::AFV0; + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + HWY_ALIGN_MAX float scratch_space[64] = {}; + for (size_t i = 0; i < 1 << 14; i++) { + TransformToPixels(type, coeffs, pixels, 8, scratch_space); + TransformFromPixels(type, pixels, 8, coeffs, scratch_space); + } + EXPECT_NEAR(pixels[0], 0.0, 1E-6); +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFVDCT) { + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + for (size_t i = 0; i < 1 << 14; i++) { + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, pixels); + } + EXPECT_NEAR(pixels[0], 1.0, 1E-6); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/adaptive_reconstruction_test.cc b/third_party/jpeg-xl/lib/jxl/adaptive_reconstruction_test.cc new file mode 100644 index 000000000000..d157a572b117 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/adaptive_reconstruction_test.cc @@ -0,0 +1,193 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_reconstruct.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +const size_t xsize = 16; +const size_t ysize = 8; + +void GenerateFlat(const float background, const float foreground, + std::vector* images) { + for (size_t c = 0; c < Image3F::kNumPlanes; ++c) { + Image3F in(xsize, ysize); + // Plane c = foreground, all others = background. + for (size_t y = 0; y < ysize; ++y) { + float* rows[3] = {in.PlaneRow(0, y), in.PlaneRow(1, y), + in.PlaneRow(2, y)}; + for (size_t x = 0; x < xsize; ++x) { + rows[0][x] = rows[1][x] = rows[2][x] = background; + rows[c][x] = foreground; + } + } + images->push_back(std::move(in)); + } +} + +// Single foreground point at any position in any channel +void GeneratePoints(const float background, const float foreground, + std::vector* images) { + for (size_t c = 0; c < Image3F::kNumPlanes; ++c) { + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + Image3F in(xsize, ysize); + FillImage(background, &in); + in.PlaneRow(c, y)[x] = foreground; + images->push_back(std::move(in)); + } + } + } +} + +void GenerateHorzEdges(const float background, const float foreground, + std::vector* images) { + for (size_t c = 0; c < Image3F::kNumPlanes; ++c) { + // Begin of foreground rows + for (size_t y = 1; y < ysize; ++y) { + Image3F in(xsize, ysize); + FillImage(background, &in); + for (size_t iy = y; iy < ysize; ++iy) { + std::fill(in.PlaneRow(c, iy), in.PlaneRow(c, iy) + xsize, foreground); + } + images->push_back(std::move(in)); + } + } +} + +void GenerateVertEdges(const float background, const float foreground, + std::vector* images) { + for (size_t c = 0; c < Image3F::kNumPlanes; ++c) { + // Begin of foreground columns + for (size_t x = 1; x < xsize; ++x) { + Image3F in(xsize, ysize); + FillImage(background, &in); + for (size_t iy = 0; iy < ysize; ++iy) { + float* JXL_RESTRICT row = in.PlaneRow(c, iy); + for (size_t ix = x; ix < xsize; ++ix) { + row[ix] = foreground; + } + } + images->push_back(std::move(in)); + } + } +} + +void DumpTestImage(const char* name, const Image3F& img) { + fprintf(stderr, "Image %s:\n", name); + for (size_t y = 0; y < img.ysize(); ++y) { + const float* row_x = img.ConstPlaneRow(0, y); + const float* row_y = img.ConstPlaneRow(1, y); + const float* row_b = img.ConstPlaneRow(2, y); + for (size_t x = 0; x < img.xsize(); ++x) { + fprintf(stderr, "%5.1f|%5.1f|%5.1f ", row_x[x], row_y[x], row_b[x]); + } + fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); +} + +// Ensures input remains unchanged by filter - verifies the edge-preserving +// nature of the filter because inputs are piecewise constant. +void EnsureUnchanged(const float background, const float foreground, + uint32_t epf_iters) { + std::vector images; + GenerateFlat(background, foreground, &images); + GeneratePoints(background, foreground, &images); + GenerateHorzEdges(background, foreground, &images); + GenerateVertEdges(background, foreground, &images); + + CodecMetadata metadata; + JXL_CHECK(metadata.size.Set(xsize, ysize)); + metadata.m.xyb_encoded = false; + FrameHeader frame_header(&metadata); + // Ensure no CT is applied + frame_header.color_transform = ColorTransform::kNone; + LoopFilter& lf = frame_header.loop_filter; + lf.gab = false; + lf.epf_iters = epf_iters; + FrameDimensions frame_dim = frame_header.ToFrameDimensions(); + + jxl::PassesDecoderState state; + JXL_CHECK( + jxl::InitializePassesSharedState(frame_header, &state.shared_storage)); + state.Init(); + state.InitForAC(/*pool=*/nullptr); + + state.filter_weights.Init(lf, frame_dim); + FillImage(-0.5f, &state.filter_weights.sigma); + + for (size_t idx_image = 0; idx_image < images.size(); ++idx_image) { + const Image3F& in = images[idx_image]; + state.decoded = CopyImage(in); + + ImageBundle out(&metadata.m); + out.SetFromImage(CopyImage(in), ColorEncoding::LinearSRGB()); + FillImage(-99.f, out.color()); // Initialized with garbage. + Image3F padded = PadImageMirror(in, 2 * kBlockDim, 0); + // Call with `force_fir` set to true to force to apply filters to all of the + // input image. + JXL_CHECK(FinalizeFrameDecoding(&out, &state, /*pool=*/nullptr, + /*force_fir=*/true, + /*skip_blending=*/true)); + +#if JXL_HIGH_PRECISION + VerifyRelativeError(in, *out.color(), 1E-3, 1E-4); +#else + VerifyRelativeError(in, *out.color(), 1E-2, 1E-2); +#endif + if (testing::Test::HasFatalFailure()) { + DumpTestImage("in", in); + DumpTestImage("out", *out.color()); + } + } +} + +} // namespace + +class AdaptiveReconstructionTest : public testing::TestWithParam {}; + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(EPFItersGroup, AdaptiveReconstructionTest, + testing::Values(1, 2, 3), + testing::PrintToStringParamName()); + +TEST_P(AdaptiveReconstructionTest, TestBright) { + EnsureUnchanged(1.0f, 128.0f, GetParam()); +} +TEST_P(AdaptiveReconstructionTest, TestDark) { + EnsureUnchanged(128.0f, 1.0f, GetParam()); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/alpha.cc b/third_party/jpeg-xl/lib/jxl/alpha.cc new file mode 100644 index 000000000000..0b3eece16b61 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha.cc @@ -0,0 +1,119 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/alpha.h" + +#include + +namespace jxl { + +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied) { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + out.r[x] = (fg.r[x] + bg.r[x] * (1.f - fg.a[x])); + out.g[x] = (fg.g[x] + bg.g[x] * (1.f - fg.a[x])); + out.b[x] = (fg.b[x] + bg.b[x] * (1.f - fg.a[x])); + out.a[x] = (1.f - (1.f - fg.a[x]) * (1.f - bg.a[x])); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + const float new_a = 1.f - (1.f - fg.a[x]) * (1.f - bg.a[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out.r[x] = + (fg.r[x] * fg.a[x] + bg.r[x] * bg.a[x] * (1.f - fg.a[x])) * rnew_a; + out.g[x] = + (fg.g[x] * fg.a[x] + bg.g[x] * bg.a[x] * (1.f - fg.a[x])) * rnew_a; + out.b[x] = + (fg.b[x] * fg.a[x] + bg.b[x] * bg.a[x] * (1.f - fg.a[x])) * rnew_a; + out.a[x] = new_a; + } + } +} +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied) { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = (fg[x] + bg[x] * (1.f - fga[x])); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + const float new_a = 1.f - (1.f - fga[x]) * (1.f - bga[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out[x] = (fg[x] * fga[x] + bg[x] * bga[x] * (1.f - fga[x])) * rnew_a; + } + } +} + +void PerformAlphaWeightedAdd(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + out.r[x] = bg.r[x] + fg.r[x] * fg.a[x]; + out.g[x] = bg.g[x] + fg.g[x] * fg.a[x]; + out.b[x] = bg.b[x] + fg.b[x] * fg.a[x]; + out.a[x] = bg.a[x]; + } +} +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] + fg[x] * fga[x]; + } +} + +void PerformMulBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + out.r[x] = bg.r[x] * fg.r[x]; + out.g[x] = bg.g[x] * fg.g[x]; + out.b[x] = bg.b[x] * fg.b[x]; + out.a[x] = bg.a[x] * fg.a[x]; + } +} +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] * fg[x]; + } +} + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = 1.f / std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/alpha.h b/third_party/jpeg-xl/lib/jxl/alpha.h new file mode 100644 index 000000000000..71f531176eac --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha.h @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ALPHA_H_ +#define LIB_JXL_ALPHA_H_ + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// A very small value to avoid divisions by zero when converting to +// unpremultiplied alpha. Page 21 of the technical introduction to OpenEXR +// (https://www.openexr.com/documentation/TechnicalIntroduction.pdf) recommends +// "a power of two" that is "less than half of the smallest positive 16-bit +// floating-point value". That smallest value happens to be the denormal number +// 2^-24, so 2^-26 should be a good choice. +static constexpr float kSmallAlpha = 1.f / (1u << 26u); + +struct AlphaBlendingInputLayer { + const float* r; + const float* g; + const float* b; + const float* a; +}; + +struct AlphaBlendingOutput { + float* r; + float* g; + float* b; + float* a; +}; + +// Note: The pointers in `out` are allowed to alias those in `bg` or `fg`. +// No pointer shall be null. +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied); +// Single plane alpha blending +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied); + +void PerformAlphaWeightedAdd(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels); +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels); + +void PerformMulBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels); +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels); + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); + +} // namespace jxl + +#endif // LIB_JXL_ALPHA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/alpha_test.cc b/third_party/jpeg-xl/lib/jxl/alpha_test.cc new file mode 100644 index 000000000000..6f2f9aa78cf7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha_test.cc @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/alpha.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::FloatNear; + +TEST(AlphaTest, BlendingWithNonPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/false); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(77.2f, .05f), FloatNear(83.0f, .05f), + FloatNear(90.6f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); +} + +TEST(AlphaTest, BlendingWithPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/true); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(101.5f, .05f), FloatNear(105.1f, .05f), + FloatNear(114.8f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); +} + +TEST(AlphaTest, AlphaWeightedAdd) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 1.f / 4; + float out_rgb[3]; + float out_a; + PerformAlphaWeightedAdd( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(100.f + 25.f / 4, .05f), + FloatNear(110.f + 21.f / 4, .05f), + FloatNear(120.f + 23.f / 4, .05f))); + EXPECT_EQ(out_a, bg_a); +} + +TEST(AlphaTest, Mul) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 1.f / 4; + float out_rgb[3]; + float out_a; + PerformMulBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(100.f * 25.f, .05f), + FloatNear(110.f * 21.f, .05f), + FloatNear(120.f * 23.f, .05f))); + EXPECT_NEAR(out_a, bg_a * fg_a, 1e-5); +} + +TEST(AlphaTest, PremultiplyAndUnpremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {120, 130, 140, 150}; + float g[] = {124, 134, 144, 154}; + float b[] = {127, 137, 147, 157}; + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT( + r, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(130 * 63.f / 255, 1e-5f), + FloatNear(140 * 127.f / 255, 1e-5f), 150)); + EXPECT_THAT( + g, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(134 * 63.f / 255, 1e-5f), + FloatNear(144 * 127.f / 255, 1e-5f), 154)); + EXPECT_THAT( + b, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(137 * 63.f / 255, 1e-5f), + FloatNear(147 * 127.f / 255, 1e-5f), 157)); + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(120, 1e-4f), FloatNear(130, 1e-4f), + FloatNear(140, 1e-4f), FloatNear(150, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(124, 1e-4f), FloatNear(134, 1e-4f), + FloatNear(144, 1e-4f), FloatNear(154, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(127, 1e-4f), FloatNear(137, 1e-4f), + FloatNear(147, 1e-4f), FloatNear(157, 1e-4f))); +} + +TEST(AlphaTest, UnpremultiplyAndPremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {50, 60, 70, 80}; + float g[] = {54, 64, 74, 84}; + float b[] = {57, 67, 77, 87}; + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(_, FloatNear(60 * 255.f / 63, 1e-4f), + FloatNear(70 * 255.f / 127, 1e-4f), 80)); + EXPECT_THAT(g, ElementsAre(_, FloatNear(64 * 255.f / 63, 1e-4f), + FloatNear(74 * 255.f / 127, 1e-4f), 84)); + EXPECT_THAT(b, ElementsAre(_, FloatNear(67 * 255.f / 63, 1e-4f), + FloatNear(77 * 255.f / 127, 1e-4f), 87)); + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(50, 1e-4f), FloatNear(60, 1e-4f), + FloatNear(70, 1e-4f), FloatNear(80, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(54, 1e-4f), FloatNear(64, 1e-4f), + FloatNear(74, 1e-4f), FloatNear(84, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(57, 1e-4f), FloatNear(67, 1e-4f), + FloatNear(77, 1e-4f), FloatNear(87, 1e-4f))); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_common.cc b/third_party/jpeg-xl/lib/jxl/ans_common.cc new file mode 100644 index 000000000000..12e2b84344d6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/ans_common.h" + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +std::vector CreateFlatHistogram(int length, int total_count) { + JXL_ASSERT(length > 0); + JXL_ASSERT(length <= total_count); + const int count = total_count / length; + std::vector result(length, count); + const int rem_counts = total_count % length; + for (int i = 0; i < rem_counts; ++i) { + ++result[i]; + } + return result; +} + +// First, all trailing non-occuring symbols are removed from the distribution; +// if this leaves the distribution empty, a dummy symbol with max weight is +// added. This ensures that the resulting distribution sums to total table size. +// Then, `entry_size` is chosen to be the largest power of two so that +// `table_size` = ANS_TAB_SIZE/`entry_size` is at least as big as the +// distribution size. +// Note that each entry will only ever contain two different symbols, and +// consecutive ranges of offsets, which allows us to use a compact +// representation. +// Each entry is initialized with only the (symbol=i, offset) pairs; then +// positions for which the entry overflows (i.e. distribution[i] > entry_size) +// or is not full are computed, and put into a stack in increasing order. +// Missing symbols in the distribution are padded with 0 (because `table_size` +// >= number of symbols). The `cutoff` value for each entry is initialized to +// the number of occupied slots in that entry (i.e. `distributions[i]`). While +// the overflowing-symbol stack is not empty (which implies that the +// underflowing-symbol stack also is not), the top overfull and underfull +// positions are popped from the stack; the empty slots in the underfull entry +// are then filled with as many slots as needed from the overfull entry; such +// slots are placed after the slots in the overfull entry, and `offsets[1]` is +// computed accordingly. The formerly underfull entry is thus now neither +// underfull nor overfull, and represents exactly two symbols. The overfull +// entry might be either overfull or underfull, and is pushed into the +// corresponding stack. +void InitAliasTable(std::vector distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a) { + while (!distribution.empty() && distribution.back() == 0) { + distribution.pop_back(); + } + // Ensure that a valid table is always returned, even for an empty + // alphabet. Otherwise, a specially-crafted stream might crash the + // decoder. + if (distribution.empty()) { + distribution.emplace_back(range); + } + const size_t table_size = 1 << log_alpha_size; +#if JXL_ENABLE_ASSERT + int sum = std::accumulate(distribution.begin(), distribution.end(), 0); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(static_cast(sum) == range); + // range must be a power of two + JXL_ASSERT((range & (range - 1)) == 0); + JXL_ASSERT(distribution.size() <= table_size); + JXL_ASSERT(table_size <= range); + const uint32_t entry_size = range >> log_alpha_size; // this is exact + // Special case for single-symbol distributions, that ensures that the state + // does not change when decoding from such a distribution. Note that, since we + // hardcode offset0 == 0, it is not straightforward (if at all possible) to + // fix the general case to produce this result. + for (size_t sym = 0; sym < distribution.size(); sym++) { + if (distribution[sym] == ANS_TAB_SIZE) { + for (size_t i = 0; i < table_size; i++) { + a[i].right_value = sym; + a[i].cutoff = 0; + a[i].offsets1 = entry_size * i; + a[i].freq0 = 0; + a[i].freq1_xor_freq0 = ANS_TAB_SIZE; + } + return; + } + } + std::vector underfull_posn; + std::vector overfull_posn; + std::vector cutoffs(1 << log_alpha_size); + // Initialize entries. + for (size_t i = 0; i < distribution.size(); i++) { + cutoffs[i] = distribution[i]; + if (cutoffs[i] > entry_size) { + overfull_posn.push_back(i); + } else if (cutoffs[i] < entry_size) { + underfull_posn.push_back(i); + } + } + for (uint32_t i = distribution.size(); i < table_size; i++) { + cutoffs[i] = 0; + underfull_posn.push_back(i); + } + // Reassign overflow/underflow values. + while (!overfull_posn.empty()) { + uint32_t overfull_i = overfull_posn.back(); + overfull_posn.pop_back(); + JXL_ASSERT(!underfull_posn.empty()); + uint32_t underfull_i = underfull_posn.back(); + underfull_posn.pop_back(); + uint32_t underfull_by = entry_size - cutoffs[underfull_i]; + cutoffs[overfull_i] -= underfull_by; + // overfull positions have their original symbols + a[underfull_i].right_value = overfull_i; + a[underfull_i].offsets1 = cutoffs[overfull_i]; + // Slots in the right part of entry underfull_i were taken from the end + // of the symbols in entry overfull_i. + if (cutoffs[overfull_i] < entry_size) { + underfull_posn.push_back(overfull_i); + } else if (cutoffs[overfull_i] > entry_size) { + overfull_posn.push_back(overfull_i); + } + } + for (uint32_t i = 0; i < table_size; i++) { + // cutoffs[i] is properly initialized but the clang-analyzer doesn't infer + // it since it is partially initialized across two for-loops. + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + if (cutoffs[i] == entry_size) { + a[i].right_value = i; + a[i].offsets1 = 0; + a[i].cutoff = 0; + } else { + // Note that, if cutoff is not equal to entry_size, + // a[i].offsets1 was initialized with (overfull cutoff) - + // (entry_size - a[i].cutoff). Thus, subtracting + // a[i].cutoff cannot make it negative. + a[i].offsets1 -= cutoffs[i]; + a[i].cutoff = cutoffs[i]; + } + const size_t freq0 = i < distribution.size() ? distribution[i] : 0; + const size_t i1 = a[i].right_value; + const size_t freq1 = i1 < distribution.size() ? distribution[i1] : 0; + a[i].freq0 = static_cast(freq0); + a[i].freq1_xor_freq0 = static_cast(freq1 ^ freq0); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_common.h b/third_party/jpeg-xl/lib/jxl/ans_common.h new file mode 100644 index 000000000000..ed8de099100c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common.h @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ANS_COMMON_H_ +#define LIB_JXL_ANS_COMMON_H_ + +#include + +#include +#include // Prefetch +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Returns the precision (number of bits) that should be used to store +// a histogram count such that Log2Floor(count) == logcount. +static JXL_INLINE uint32_t GetPopulationCountPrecision(uint32_t logcount, + uint32_t shift) { + int32_t r = std::min( + logcount, int(shift) - int((ANS_LOG_TAB_SIZE - logcount) >> 1)); + if (r < 0) return 0; + return r; +} + +// Returns a histogram where the counts are positive, differ by at most 1, +// and add up to total_count. The bigger counts (if any) are at the beginning +// of the histogram. +std::vector CreateFlatHistogram(int length, int total_count); + +// An alias table implements a mapping from the [0, ANS_TAB_SIZE) range into +// the [0, ANS_MAX_ALPHABET_SIZE) range, satisfying the following conditions: +// - each symbol occurs as many times as specified by any valid distribution +// of frequencies of the symbols. A valid distribution here is an array of +// ANS_MAX_ALPHABET_SIZE that contains numbers in the range [0, ANS_TAB_SIZE], +// and whose sum is ANS_TAB_SIZE. +// - lookups can be done in constant time, and also return how many smaller +// input values map into the same symbol, according to some well-defined order +// of input values. +// - the space used by the alias table is given by a small constant times the +// index of the largest symbol with nonzero probability in the distribution. +// Each of the entries in the table covers a range of `entry_size` values in the +// [0, ANS_TAB_SIZE) range; consecutive entries represent consecutive +// sub-ranges. In the range covered by entry `i`, the first `cutoff` values map +// to symbol `i`, while the others map to symbol `right_value`. +// +// TODO(veluca): consider making the order used for computing offsets easier to +// define - it is currently defined by the algorithm to compute the alias table. +// Beware of breaking the implicit assumption that symbols that come after the +// cutoff value should have an offset at least as big as the cutoff. + +struct AliasTable { + struct Symbol { + size_t value; + size_t offset; + size_t freq; + }; + +// Working set size matters here (~64 tables x 256 entries). +// offsets0 is always zero (beginning of [0] side among the same symbol). +// offsets1 is an offset of (pos >= cutoff) side decremented by cutoff. +#pragma pack(push, 1) + struct Entry { + uint8_t cutoff; // < kEntrySizeMinus1 when used by ANS. + uint8_t right_value; // < alphabet size. + uint16_t freq0; + + // Only used if `greater` (see Lookup) + uint16_t offsets1; // <= ANS_TAB_SIZE + uint16_t freq1_xor_freq0; // for branchless ternary in Lookup + }; +#pragma pack(pop) + + // Dividing `value` by `entry_size` determines `i`, the entry which is + // responsible for the input. If the remainder is below `cutoff`, then the + // mapped symbol is `i`; since `offsets[0]` stores the number of occurences of + // `i` "before" the start of this entry, the offset of the input will be + // `offsets[0] + remainder`. If the remainder is above cutoff, the mapped + // symbol is `right_value`; since `offsets[1]` stores the number of occurences + // of `right_value` "before" this entry, minus the `cutoff` value, the input + // offset is then `remainder + offsets[1]`. + static JXL_INLINE Symbol Lookup(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size, + size_t entry_size_minus_1) { + const size_t i = value >> log_entry_size; + const size_t pos = value & entry_size_minus_1; + +#if JXL_BYTE_ORDER_LITTLE + uint64_t entry; + memcpy(&entry, &table[i].cutoff, sizeof(entry)); + const size_t cutoff = entry & 0xFF; // = MOVZX + const size_t right_value = (entry >> 8) & 0xFF; // = MOVZX + const size_t freq0 = (entry >> 16) & 0xFFFF; +#else + // Generates multiple loads with complex addressing. + const size_t cutoff = table[i].cutoff; + const size_t right_value = table[i].right_value; + const size_t freq0 = table[i].freq0; +#endif + + const bool greater = pos >= cutoff; + +#if JXL_BYTE_ORDER_LITTLE + const uint64_t conditional = greater ? entry : 0; // = CMOV + const size_t offsets1_or_0 = (conditional >> 32) & 0xFFFF; + const size_t freq1_xor_freq0_or_0 = conditional >> 48; +#else + const size_t offsets1_or_0 = greater ? table[i].offsets1 : 0; + const size_t freq1_xor_freq0_or_0 = greater ? table[i].freq1_xor_freq0 : 0; +#endif + + // WARNING: moving this code may interfere with CMOV heuristics. + Symbol s; + s.value = greater ? right_value : i; + s.offset = offsets1_or_0 + pos; + s.freq = freq0 ^ freq1_xor_freq0_or_0; // = greater ? freq1 : freq0 + // XOR avoids implementation-defined conversion from unsigned to signed. + // Alternatives considered: BEXTR is 2 cycles on HSW, SET+shift causes + // spills, simple ternary has a long dependency chain. + + return s; + } + + static HWY_INLINE void Prefetch(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size) { + const size_t i = value >> log_entry_size; + hwy::Prefetch(table + i); + } +}; + +// Computes an alias table for a given distribution. +void InitAliasTable(std::vector distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a); + +} // namespace jxl + +#endif // LIB_JXL_ANS_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ans_common_test.cc b/third_party/jpeg-xl/lib/jxl/ans_common_test.cc new file mode 100644 index 000000000000..4f9e1b8bb3ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/ans_common.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/ans_params.h" + +namespace jxl { +namespace { + +void VerifyAliasDistribution(const std::vector& distribution, + uint32_t range) { + constexpr size_t log_alpha_size = 8; + AliasTable::Entry table[1 << log_alpha_size]; + InitAliasTable(distribution, range, log_alpha_size, table); + std::vector> offsets(distribution.size()); + for (uint32_t i = 0; i < range; i++) { + AliasTable::Symbol s = AliasTable::Lookup( + table, i, ANS_LOG_TAB_SIZE - 8, (1 << (ANS_LOG_TAB_SIZE - 8)) - 1); + offsets[s.value].push_back(s.offset); + } + for (uint32_t i = 0; i < distribution.size(); i++) { + ASSERT_EQ(distribution[i], offsets[i].size()); + std::sort(offsets[i].begin(), offsets[i].end()); + for (uint32_t j = 0; j < offsets[i].size(); j++) { + ASSERT_EQ(offsets[i][j], j); + } + } +} + +TEST(ANSCommonTest, AliasDistributionSmoke) { + VerifyAliasDistribution({ANS_TAB_SIZE / 2, ANS_TAB_SIZE / 2}, ANS_TAB_SIZE); + VerifyAliasDistribution({ANS_TAB_SIZE}, ANS_TAB_SIZE); + VerifyAliasDistribution({0, 0, 0, ANS_TAB_SIZE, 0}, ANS_TAB_SIZE); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_params.h b/third_party/jpeg-xl/lib/jxl/ans_params.h new file mode 100644 index 000000000000..43a5b7c096f5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_params.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ANS_PARAMS_H_ +#define LIB_JXL_ANS_PARAMS_H_ + +// Common parameters that are needed for both the ANS entropy encoding and +// decoding methods. + +#include +#include + +namespace jxl { + +// TODO(veluca): decide if 12 is the best constant here (valid range is up to +// 16). This requires recomputing the Huffman tables in {enc,dec}_ans.cc +// 14 gives a 0.2% improvement at d1 and makes d8 slightly worse. This is +// likely not worth the increase in encoder complexity. +#define ANS_LOG_TAB_SIZE 12u +#define ANS_TAB_SIZE (1 << ANS_LOG_TAB_SIZE) +#define ANS_TAB_MASK (ANS_TAB_SIZE - 1) + +// Largest possible symbol to be encoded by either ANS or prefix coding. +#define PREFIX_MAX_ALPHABET_SIZE 4096 +#define ANS_MAX_ALPHABET_SIZE 256 + +// Max number of bits for prefix coding. +#define PREFIX_MAX_BITS 15 + +#define ANS_SIGNATURE 0x13 // Initial state, used as CRC. + +} // namespace jxl + +#endif // LIB_JXL_ANS_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ans_test.cc b/third_party/jpeg-xl/lib/jxl/ans_test.cc new file mode 100644 index 000000000000..b28de62f837a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_test.cc @@ -0,0 +1,199 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { +namespace { + +void RoundtripTestcase(int n_histograms, int alphabet_size, + const std::vector& input_values) { + constexpr uint16_t kMagic1 = 0x9e33; + constexpr uint16_t kMagic2 = 0x8b04; + + BitWriter writer; + // Space for magic bytes. + BitWriter::Allotment allotment_magic1(&writer, 16); + writer.Write(16, kMagic1); + ReclaimAndCharge(&writer, &allotment_magic1, 0, nullptr); + + std::vector context_map; + EntropyEncodingData codes; + std::vector> input_values_vec; + input_values_vec.push_back(input_values); + + BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec, + &codes, &context_map, &writer, 0, nullptr); + WriteTokens(input_values_vec[0], codes, context_map, &writer, 0, nullptr); + + // Magic bytes + padding + BitWriter::Allotment allotment_magic2(&writer, 24); + writer.Write(16, kMagic2); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment_magic2, 0, nullptr); + + // We do not truncate the output. Reading past the end reads out zeroes + // anyway. + BitReader br(writer.GetSpan()); + + ASSERT_EQ(br.ReadBits(16), kMagic1); + + std::vector dec_context_map; + ANSCode decoded_codes; + ASSERT_TRUE( + DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map)); + ASSERT_EQ(dec_context_map, context_map); + ANSSymbolReader reader(&decoded_codes, &br); + + for (const Token& symbol : input_values) { + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value); + } + ASSERT_TRUE(reader.CheckANSFinalState()); + + ASSERT_EQ(br.ReadBits(16), kMagic2); + EXPECT_TRUE(br.Close()); +} + +TEST(ANSTest, EmptyRoundtrip) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector()); +} + +TEST(ANSTest, SingleSymbolRoundtrip) { + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}}); + } + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, + std::vector(1024, {0, i})); + } +} + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +constexpr size_t kReps = 10; +#else +constexpr size_t kReps = 100; +#endif + +void RoundtripRandomStream(int alphabet_size, size_t reps = kReps, + size_t num = 1 << 18) { + constexpr int kNumHistograms = 3; + std::mt19937_64 rng; + for (size_t i = 0; i < reps; i++) { + std::vector symbols; + for (size_t j = 0; j < num; j++) { + int context = std::uniform_int_distribution<>(0, kNumHistograms - 1)(rng); + int value = std::uniform_int_distribution<>(0, alphabet_size - 1)(rng); + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms, alphabet_size, symbols); + } +} + +void RoundtripRandomUnbalancedStream(int alphabet_size) { + constexpr int kNumHistograms = 3; + constexpr int kPrecision = 1 << 10; + std::mt19937_64 rng; + for (int i = 0; i < 100; i++) { + std::vector distributions[kNumHistograms]; + for (int j = 0; j < kNumHistograms; j++) { + distributions[j].resize(kPrecision); + int symbol = 0; + int remaining = 1; + for (int k = 0; k < kPrecision; k++) { + if (remaining == 0) { + if (symbol < alphabet_size - 1) symbol++; + // There is no meaning behind this distribution: it's anything that + // will create a nonuniform distribution and won't have too few + // symbols usually. Also we want different distributions we get to be + // sufficiently dissimilar. + remaining = + std::uniform_int_distribution<>(0, (kPrecision - k) / 1)(rng); + } + distributions[j][k] = symbol; + remaining--; + } + } + std::vector symbols; + for (int j = 0; j < 1 << 18; j++) { + int context = std::uniform_int_distribution<>(0, kNumHistograms - 1)(rng); + int value = distributions[context][std::uniform_int_distribution<>( + 0, kPrecision - 1)(rng)]; + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols); + } +} + +TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); } + +TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); } + +TEST(ANSTest, RandomStreamRoundtripBig) { + RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) { + RoundtripRandomUnbalancedStream(3); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { + RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, UintConfigRoundtrip) { + for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { + std::vector uint_config, uint_config_dec; + for (size_t i = 0; i < log_alpha_size; i++) { + for (size_t j = 0; j <= i; j++) { + for (size_t k = 0; k <= i - j; k++) { + uint_config.emplace_back(i, j, k); + } + } + } + uint_config.emplace_back(log_alpha_size, 0, 0); + uint_config_dec.resize(uint_config.size()); + BitWriter writer; + BitWriter::Allotment allotment(&writer, 10 * uint_config.size()); + EncodeUintConfigs(uint_config, &writer, log_alpha_size); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br)); + EXPECT_TRUE(br.Close()); + for (size_t i = 0; i < uint_config.size(); i++) { + EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token); + EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token); + EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token); + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/aux_out.cc b/third_party/jpeg-xl/lib/jxl/aux_out.cc new file mode 100644 index 000000000000..d0b9d97aa383 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/aux_out.cc @@ -0,0 +1,105 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/aux_out.h" + +#include + +#include // accumulate + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +void AuxOut::Print(size_t num_inputs) const { + if (num_inputs == 0) return; + + LayerTotals all_layers; + for (size_t i = 0; i < layers.size(); ++i) { + all_layers.Assimilate(layers[i]); + } + + printf("Average butteraugli iters: %10.2f\n", + num_butteraugli_iters * 1.0 / num_inputs); + + for (size_t i = 0; i < layers.size(); ++i) { + if (layers[i].total_bits != 0) { + printf("Total layer bits %-10s\t", LayerName(i)); + printf("%10f%%", 100.0 * layers[i].total_bits / all_layers.total_bits); + layers[i].Print(num_inputs); + } + } + printf("Total image size "); + all_layers.Print(num_inputs); + + const uint32_t dc_pred_total = + std::accumulate(dc_pred_usage.begin(), dc_pred_usage.end(), 0u); + const uint32_t dc_pred_total_xb = + std::accumulate(dc_pred_usage_xb.begin(), dc_pred_usage_xb.end(), 0u); + if (dc_pred_total + dc_pred_total_xb != 0) { + printf("\nDC pred Y XB:\n"); + for (size_t i = 0; i < dc_pred_usage.size(); ++i) { + printf(" %6u (%5.2f%%) %6u (%5.2f%%)\n", dc_pred_usage[i], + 100.0 * dc_pred_usage[i] / dc_pred_total, dc_pred_usage_xb[i], + 100.0 * dc_pred_usage_xb[i] / dc_pred_total_xb); + } + } + + size_t total_blocks = 0; + size_t total_positions = 0; + if (total_blocks != 0 && total_positions != 0) { + printf("\n\t\t Blocks\t\tPositions\t\t\tBlocks/Position\n"); + printf(" Total:\t\t %7zu\t\t %7zu \t\t\t%10f%%\n\n", total_blocks, + total_positions, 100.0 * total_blocks / total_positions); + } +} + +void AuxOut::DumpCoeffImage(const char* label, + const Image3S& coeff_image) const { + JXL_ASSERT(coeff_image.xsize() % 64 == 0); + Image3S reshuffled(coeff_image.xsize() / 8, coeff_image.ysize() * 8); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < coeff_image.ysize(); y++) { + for (size_t x = 0; x < coeff_image.xsize(); x += 64) { + for (size_t i = 0; i < 64; i++) { + reshuffled.PlaneRow(c, 8 * y + i / 8)[x / 8 + i % 8] = + coeff_image.PlaneRow(c, y)[x + i]; + } + } + } + } + DumpImage(label, reshuffled); +} + +void ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, + BitWriter::Allotment* JXL_RESTRICT allotment, + size_t layer, AuxOut* JXL_RESTRICT aux_out) { + size_t used_bits, unused_bits; + allotment->PrivateReclaim(writer, &used_bits, &unused_bits); + +#if 0 + printf("Layer %s bits: max %zu used %zu unused %zu\n", LayerName(layer), + allotment->MaxBits(), used_bits, unused_bits); +#endif + + // This may be a nested call with aux_out == null. Whenever we know that + // aux_out is null, we can call ReclaimUnused directly. + if (aux_out != nullptr) { + aux_out->layers[layer].total_bits += used_bits; + aux_out->layers[layer].histogram_bits += allotment->HistogramBits(); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/aux_out.h b/third_party/jpeg-xl/lib/jxl/aux_out.h new file mode 100644 index 000000000000..30c77d044bed --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/aux_out.h @@ -0,0 +1,320 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_AUX_OUT_H_ +#define LIB_JXL_AUX_OUT_H_ + +// Optional output information for debugging and analyzing size usage. + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jxl_inspection.h" + +namespace jxl { + +// For LayerName and AuxOut::layers[] index. Order does not matter. +enum { + kLayerHeader = 0, + kLayerTOC, + kLayerNoise, + kLayerQuant, + kLayerDequantTables, + kLayerOrder, + kLayerDC, + kLayerControlFields, + kLayerAC, + kLayerACTokens, + kLayerDictionary, + kLayerDots, + kLayerSplines, + kLayerLossless, + kLayerModularGlobal, + kLayerModularDcGroup, + kLayerModularAcGroup, + kLayerModularTree, + kLayerAlpha, + kLayerDepth, + kLayerExtraChannels, + kNumImageLayers +}; + +static inline const char* LayerName(size_t layer) { + switch (layer) { + case kLayerHeader: + return "headers"; + case kLayerTOC: + return "TOC"; + case kLayerNoise: + return "noise"; + case kLayerQuant: + return "quantizer"; + case kLayerDequantTables: + return "quant tables"; + case kLayerOrder: + return "order"; + case kLayerDC: + return "DC"; + case kLayerControlFields: + return "ControlFields"; + case kLayerAC: + return "AC"; + case kLayerACTokens: + return "ACTokens"; + case kLayerDictionary: + return "dictionary"; + case kLayerDots: + return "dots"; + case kLayerSplines: + return "splines"; + case kLayerLossless: + return "lossless"; + case kLayerModularGlobal: + return "modularGlobal"; + case kLayerModularDcGroup: + return "modularDcGroup"; + case kLayerModularAcGroup: + return "modularAcGroup"; + case kLayerModularTree: + return "modularTree"; + case kLayerAlpha: + return "alpha"; + case kLayerDepth: + return "depth"; + case kLayerExtraChannels: + return "extra channels"; + default: + JXL_ABORT("Invalid layer %zu\n", layer); + } +} + +// Statistics gathered during compression or decompression. +struct AuxOut { + private: + struct LayerTotals { + void Assimilate(const LayerTotals& victim) { + num_clustered_histograms += victim.num_clustered_histograms; + histogram_bits += victim.histogram_bits; + extra_bits += victim.extra_bits; + total_bits += victim.total_bits; + clustered_entropy += victim.clustered_entropy; + } + void Print(size_t num_inputs) const { + printf("%10zd", total_bits); + if (histogram_bits != 0) { + printf(" [c/i:%6.2f | hst:%8zd | ex:%8zd | h+c+e:%12.3f", + num_clustered_histograms * 1.0 / num_inputs, histogram_bits >> 3, + extra_bits >> 3, + (histogram_bits + clustered_entropy + extra_bits) / 8.0); + printf("]"); + } + printf("\n"); + } + size_t num_clustered_histograms = 0; + size_t extra_bits = 0; + + // Set via BitsWritten below + size_t histogram_bits = 0; + size_t total_bits = 0; + + double clustered_entropy = 0.0; + }; + + public: + AuxOut() = default; + AuxOut(const AuxOut&) = default; + + void Assimilate(const AuxOut& victim) { + for (size_t i = 0; i < layers.size(); ++i) { + layers[i].Assimilate(victim.layers[i]); + } + num_blocks += victim.num_blocks; + num_dct2_blocks += victim.num_dct2_blocks; + num_dct4_blocks += victim.num_dct4_blocks; + num_dct4x8_blocks += victim.num_dct4x8_blocks; + num_afv_blocks += victim.num_afv_blocks; + num_dct8_blocks += victim.num_dct8_blocks; + num_dct8x16_blocks += victim.num_dct8x16_blocks; + num_dct8x32_blocks += victim.num_dct8x32_blocks; + num_dct16_blocks += victim.num_dct16_blocks; + num_dct16x32_blocks += victim.num_dct16x32_blocks; + num_dct32_blocks += victim.num_dct32_blocks; + num_butteraugli_iters += victim.num_butteraugli_iters; + for (size_t i = 0; i < dc_pred_usage.size(); ++i) { + dc_pred_usage[i] += victim.dc_pred_usage[i]; + dc_pred_usage_xb[i] += victim.dc_pred_usage_xb[i]; + } + } + + void Print(size_t num_inputs) const; + + template + void DumpImage(const char* label, const Image3& image) const { + if (!dump_image) return; + if (debug_prefix.empty()) return; + std::ostringstream pathname; + pathname << debug_prefix << label << ".png"; + CodecInOut io; + // Always save to 16-bit png. + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(ConvertToFloat(image), io.metadata.m.color_encoding); + (void)dump_image(io, pathname.str()); + } + template + void DumpImage(const char* label, const Plane& image) { + DumpImage(label, + Image3(CopyImage(image), CopyImage(image), CopyImage(image))); + } + + template + void DumpXybImage(const char* label, const Image3& image) const { + if (!dump_image) return; + if (debug_prefix.empty()) return; + std::ostringstream pathname; + pathname << debug_prefix << label << ".png"; + + Image3F linear(image.xsize(), image.ysize()); + OpsinParams opsin_params; + opsin_params.Init(kDefaultIntensityTarget); + OpsinToLinear(image, Rect(linear), nullptr, &linear, opsin_params); + + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(std::move(linear), io.metadata.m.color_encoding); + + (void)dump_image(io, pathname.str()); + } + + // Normalizes all the channels to range 0-1, creating a false-color image + // which allows seeing the information from non-RGB channels in an RGB debug + // image. + template + void DumpImageNormalized(const char* label, const Image3& image) const { + std::array min; + std::array max; + Image3MinMax(image, &min, &max); + Image3B normalized(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + float mul = min[c] == max[c] ? 0 : (1.0f / (max[c] - min[c])); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row_in = image.ConstPlaneRow(c, y); + uint8_t* JXL_RESTRICT row_out = normalized.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = static_cast((row_in[x] - min[c]) * mul); + } + } + } + DumpImage(label, normalized); + } + + template + void DumpPlaneNormalized(const char* label, const Plane& image) const { + T min; + T max; + ImageMinMax(image, &min, &max); + Image3B normalized(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + float mul = min == max ? 0 : (1.0f / (max - min)); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row_in = image.ConstRow(y); + uint8_t* JXL_RESTRICT row_out = normalized.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = static_cast((row_in[x] - min) * mul); + } + } + } + DumpImage(label, normalized); + } + + // This dumps coefficients as a 16-bit PNG with coefficients of a block placed + // in the area that would contain that block in a normal image. To view the + // resulting image manually, rescale intensities by using: + // $ convert -auto-level IMAGE.PNG - | display - + void DumpCoeffImage(const char* label, const Image3S& coeff_image) const; + + void SetInspectorImage3F(const jxl::InspectorImage3F& inspector) { + inspector_image3f_ = inspector; + } + + // Allows hooking intermediate data inspection into various places of the + // processing pipeline. Returns true iff processing should proceed. + bool InspectImage3F(const char* label, const Image3F& image) { + if (inspector_image3f_ != nullptr) { + return inspector_image3f_(label, image); + } + return true; + } + + std::array layers; + size_t num_blocks = 0; + + // Number of blocks that use larger DCT (set by ac_strategy). + size_t num_dct2_blocks = 0; + size_t num_dct4_blocks = 0; + size_t num_dct4x8_blocks = 0; + size_t num_afv_blocks = 0; + size_t num_dct8_blocks = 0; + size_t num_dct8x16_blocks = 0; + size_t num_dct8x32_blocks = 0; + size_t num_dct16_blocks = 0; + size_t num_dct16x32_blocks = 0; + size_t num_dct32_blocks = 0; + + std::array dc_pred_usage = {0}; + std::array dc_pred_usage_xb = {0}; + + int num_butteraugli_iters = 0; + + // If not empty, additional debugging information (e.g. debug images) is + // saved in files with this prefix. + std::string debug_prefix; + + // By how much the decoded image was downsampled relative to the encoded + // image. + size_t downsampling = 1; + + jxl::InspectorImage3F inspector_image3f_; + + std::function dump_image = + nullptr; +}; + +// Used to skip image creation if they won't be written to debug directory. +static inline bool WantDebugOutput(const AuxOut* aux_out) { + // Need valid pointer and filename. + return aux_out != nullptr && !aux_out->debug_prefix.empty(); +} + +} // namespace jxl + +#endif // LIB_JXL_AUX_OUT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/aux_out_fwd.h b/third_party/jpeg-xl/lib/jxl/aux_out_fwd.h new file mode 100644 index 000000000000..5f7639aec7e3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/aux_out_fwd.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_AUX_OUT_FWD_H_ +#define LIB_JXL_AUX_OUT_FWD_H_ + +#include + +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; + +// Helper function that ensures the `bits_written` are charged to `layer` in +// `aux_out`. Example usage: +// BitWriter::Allotment allotment(&writer, max_bits); +// writer.Write(..); writer.Write(..); +// ReclaimAndCharge(&writer, &allotment, layer, aux_out); +void ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, + BitWriter::Allotment* JXL_RESTRICT allotment, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +} // namespace jxl + +#endif // LIB_JXL_AUX_OUT_FWD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/arch_macros.h b/third_party/jpeg-xl/lib/jxl/base/arch_macros.h new file mode 100644 index 000000000000..3adae2c835a3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/arch_macros.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_ARCH_MACROS_H_ +#define LIB_JXL_BASE_ARCH_MACROS_H_ + +// Defines the JXL_ARCH_* macros. + +namespace jxl { + +#if defined(__x86_64__) || defined(_M_X64) +#define JXL_ARCH_X64 1 +#else +#define JXL_ARCH_X64 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define JXL_ARCH_PPC 1 +#else +#define JXL_ARCH_PPC 0 +#endif + +#if defined(__aarch64__) || defined(__arm__) +#define JXL_ARCH_ARM 1 +#else +#define JXL_ARCH_ARM 0 +#endif + +} // namespace jxl + +#endif // LIB_JXL_BASE_ARCH_MACROS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/bits.h b/third_party/jpeg-xl/lib/jxl/base/bits.h new file mode 100644 index 000000000000..9f15daffdb56 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/bits.h @@ -0,0 +1,156 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_BITS_H_ +#define LIB_JXL_BASE_BITS_H_ + +// Specialized instructions for processing register-sized bit arrays. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +#if JXL_COMPILER_MSVC +#include +#endif + +#include +#include + +namespace jxl { + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +constexpr bool IsSigned() { + return T(0) > T(-1); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanReverse(&index, x); + return 31 - index; +#else + return static_cast(__builtin_clz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanReverse64(&index, x); + return 63 - index; +#else // JXL_ARCH_X64 + // _BitScanReverse64 not available + uint32_t msb = static_cast(x >> 32u); + unsigned long index; + if (msb == 0) { + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast(__builtin_clzll(x)); +#endif +} +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(const T x) { + static_assert(!IsSigned(), "Num0BitsAboveMS1Bit_Nonzero: use unsigned"); + return Num0BitsAboveMS1Bit_Nonzero(SizeTag(), x); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanForward(&index, x); + return index; +#else + return static_cast(__builtin_ctz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanForward64(&index, x); + return index; +#else // JXL_ARCH_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + unsigned long index; + if (lsb == 0) { + uint32_t msb = static_cast(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast(__builtin_ctzll(x)); +#endif +} +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit_Nonzero(T x) { + static_assert(!IsSigned(), "Num0BitsBelowLS1Bit_Nonzero: use unsigned"); + return Num0BitsBelowLS1Bit_Nonzero(SizeTag(), x); +} + +// Returns bit width for x == 0. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsAboveMS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns bit width for x == 0. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsBelowLS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded down. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t FloorLog2Nonzero(const T x) { + return (sizeof(T) * 8 - 1) ^ Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded up. +template +static JXL_INLINE JXL_MAYBE_UNUSED size_t CeilLog2Nonzero(const T x) { + const size_t floor_log2 = FloorLog2Nonzero(x); + if ((x & (x - 1)) == 0) return floor_log2; // power of two + return floor_log2 + 1; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_BITS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/byte_order.h b/third_party/jpeg-xl/lib/jxl/base/byte_order.h new file mode 100644 index 000000000000..5c5b9969a547 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/byte_order.h @@ -0,0 +1,292 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_BYTE_ORDER_H_ +#define LIB_JXL_BASE_BYTE_ORDER_H_ + +#include +#include // memcpy + +#include "lib/jxl/base/compiler_specific.h" + +#if JXL_COMPILER_MSVC +#include // _byteswap_* +#endif + +#if (defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) +#define JXL_BYTE_ORDER_LITTLE 1 +#else +// This means that we don't know that the byte order is little endian, in +// this case we use endian-neutral code that works for both little- and +// big-endian. +#define JXL_BYTE_ORDER_LITTLE 0 +#endif + +// Returns whether the system is little-endian (least-significant byte first). +#if JXL_BYTE_ORDER_LITTLE +static constexpr bool IsLittleEndian() { return true; } +#else +static inline bool IsLittleEndian() { + const uint32_t multibyte = 1; + uint8_t byte; + memcpy(&byte, &multibyte, 1); + return byte == 1; +} +#endif + +#if JXL_COMPILER_MSVC +#define JXL_BSWAP32(x) _byteswap_ulong(x) +#define JXL_BSWAP64(x) _byteswap_uint64(x) +#else +#define JXL_BSWAP32(x) __builtin_bswap32(x) +#define JXL_BSWAP64(x) __builtin_bswap64(x) +#endif + +static JXL_INLINE uint32_t LoadBE16(const uint8_t* p) { + const uint32_t byte1 = p[0]; + const uint32_t byte0 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadLE16(const uint8_t* p) { + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadBE24(const uint8_t* p) { + const uint32_t byte2 = p[0]; + const uint32_t byte1 = p[1]; + const uint32_t byte0 = p[2]; + return (byte2 << 16) | (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadLE24(const uint8_t* p) { + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + const uint32_t byte2 = p[2]; + return (byte2 << 16) | (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadBE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t big; + memcpy(&big, p, 4); + return JXL_BSWAP32(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte3 = p[0]; + const uint32_t byte2 = p[1]; + const uint32_t byte1 = p[2]; + const uint32_t byte0 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadBE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t big; + memcpy(&big, p, 8); + return JXL_BSWAP64(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte7 = p[0]; + const uint64_t byte6 = p[1]; + const uint64_t byte5 = p[2]; + const uint64_t byte4 = p[3]; + const uint64_t byte3 = p[4]; + const uint64_t byte2 = p[5]; + const uint64_t byte1 = p[6]; + const uint64_t byte0 = p[7]; + return (byte7 << 56ull) | (byte6 << 48ull) | (byte5 << 40ull) | + (byte4 << 32ull) | (byte3 << 24ull) | (byte2 << 16ull) | + (byte1 << 8ull) | byte0; +#endif +} + +static JXL_INLINE uint32_t LoadLE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t little; + memcpy(&little, p, 4); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + const uint32_t byte2 = p[2]; + const uint32_t byte3 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadLE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t little; + memcpy(&little, p, 8); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte0 = p[0]; + const uint64_t byte1 = p[1]; + const uint64_t byte2 = p[2]; + const uint64_t byte3 = p[3]; + const uint64_t byte4 = p[4]; + const uint64_t byte5 = p[5]; + const uint64_t byte6 = p[6]; + const uint64_t byte7 = p[7]; + return (byte7 << 56) | (byte6 << 48) | (byte5 << 40) | (byte4 << 32) | + (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE void StoreBE16(const uint32_t native, uint8_t* p) { + p[0] = (native >> 8) & 0xFF; + p[1] = native & 0xFF; +} + +static JXL_INLINE void StoreLE16(const uint32_t native, uint8_t* p) { + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +} + +static JXL_INLINE void StoreBE24(const uint32_t native, uint8_t* p) { + p[0] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[2] = native & 0xFF; +} + +static JXL_INLINE void StoreLE24(const uint32_t native, uint8_t* p) { + p[2] = (native >> 24) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +} + +static JXL_INLINE void StoreBE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t big = JXL_BSWAP32(native); + memcpy(p, &big, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 24; + p[1] = (native >> 16) & 0xFF; + p[2] = (native >> 8) & 0xFF; + p[3] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreBE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t big = JXL_BSWAP64(native); + memcpy(p, &big, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 56ull; + p[1] = (native >> 48ull) & 0xFF; + p[2] = (native >> 40ull) & 0xFF; + p[3] = (native >> 32ull) & 0xFF; + p[4] = (native >> 24ull) & 0xFF; + p[5] = (native >> 16ull) & 0xFF; + p[6] = (native >> 8ull) & 0xFF; + p[7] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t little = native; + memcpy(p, &little, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[3] = native >> 24; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t little = native; + memcpy(p, &little, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[7] = native >> 56; + p[6] = (native >> 48) & 0xFF; + p[5] = (native >> 40) & 0xFF; + p[4] = (native >> 32) & 0xFF; + p[3] = (native >> 24) & 0xFF; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +// Big/Little Endian order. +struct OrderBE {}; +struct OrderLE {}; + +// Wrappers for calling from generic code. +static JXL_INLINE void Store16(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE16(native, p); +} + +static JXL_INLINE void Store16(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE16(native, p); +} + +static JXL_INLINE void Store24(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE24(native, p); +} + +static JXL_INLINE void Store24(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE24(native, p); +} +static JXL_INLINE void Store32(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE32(native, p); +} + +static JXL_INLINE void Store32(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE32(native, p); +} + +static JXL_INLINE uint32_t Load16(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE16(p); +} + +static JXL_INLINE uint32_t Load16(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE16(p); +} + +static JXL_INLINE uint32_t Load24(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE24(p); +} + +static JXL_INLINE uint32_t Load24(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE24(p); +} +static JXL_INLINE uint32_t Load32(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE32(p); +} + +static JXL_INLINE uint32_t Load32(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE32(p); +} + +#endif // LIB_JXL_BASE_BYTE_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/cache_aligned.cc b/third_party/jpeg-xl/lib/jxl/base/cache_aligned.cc new file mode 100644 index 000000000000..a16d759cba50 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/cache_aligned.cc @@ -0,0 +1,163 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/cache_aligned.h" + +#include +#include + +// Disabled: slower than malloc + alignment. +#define JXL_USE_MMAP 0 + +#if JXL_USE_MMAP +#include +#endif + +#include // std::max +#include +#include // kMaxVectorSize +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace { + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t allocated_size; + uint8_t left_padding[hwy::kMaxVectorSize]; +}; +#pragma pack(pop) + +std::atomic num_allocations{0}; +std::atomic bytes_in_use{0}; +std::atomic max_bytes_in_use{0}; + +} // namespace + +// Avoids linker errors in pre-C++17 builds. +constexpr size_t CacheAligned::kPointerSize; +constexpr size_t CacheAligned::kCacheLineSize; +constexpr size_t CacheAligned::kAlignment; +constexpr size_t CacheAligned::kAlias; + +void CacheAligned::PrintStats() { + printf("Allocations: %zu (max bytes in use: %E)\n", + size_t(num_allocations.load(std::memory_order_relaxed)), + double(max_bytes_in_use.load(std::memory_order_relaxed))); +} + +size_t CacheAligned::NextOffset() { + static std::atomic next{0}; + constexpr uint32_t kGroups = CacheAligned::kAlias / CacheAligned::kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + return CacheAligned::kAlignment * group; +} + +void* CacheAligned::Allocate(const size_t payload_size, size_t offset) { + JXL_ASSERT(payload_size <= std::numeric_limits::max() / 2); + JXL_ASSERT((offset % kAlignment == 0) && offset <= kAlias); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset | |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + offset = kAlignment; // = round_up(sizeof(AllocationHeader), kAlignment) + static_assert(sizeof(AllocationHeader) <= kAlignment, "Else: round up"); + } + +#if JXL_USE_MMAP + const size_t allocated_size = offset + payload_size; + const int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE; + void* allocated = + mmap(nullptr, allocated_size, PROT_READ | PROT_WRITE, flags, -1, 0); + if (allocated == MAP_FAILED) return nullptr; + const uintptr_t aligned = reinterpret_cast(allocated); +#else + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated = malloc(allocated_size); + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); +#endif + +#if 0 + // No effect. + uintptr_t page_aligned = reinterpret_cast(allocated); + page_aligned &= ~(4096 - 1); + if (madvise(reinterpret_cast(page_aligned), allocated_size, + MADV_WILLNEED) != 0) { + JXL_NOTIFY_ERROR("madvise failed"); + } +#elif 0 + // INCREASES both first and subsequent decode times. + if (mlock(allocated, allocated_size) != 0) { + JXL_NOTIFY_ERROR("mlock failed"); + } +#endif + + // Update statistics (#allocations and max bytes in use) + num_allocations.fetch_add(1, std::memory_order_relaxed); + const uint64_t prev_bytes = + bytes_in_use.fetch_add(allocated_size, std::memory_order_acq_rel); + uint64_t expected_max = max_bytes_in_use.load(std::memory_order_acquire); + for (;;) { + const uint64_t desired = + std::max(expected_max, prev_bytes + allocated_size); + if (max_bytes_in_use.compare_exchange_strong(expected_max, desired, + std::memory_order_acq_rel)) { + break; + } + } + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for use by Free(). + AllocationHeader* header = reinterpret_cast(payload) - 1; + header->allocated = allocated; + header->allocated_size = allocated_size; + + return JXL_ASSUME_ALIGNED(reinterpret_cast(payload), 64); +} + +void CacheAligned::Free(const void* aligned_pointer) { + if (aligned_pointer == nullptr) { + return; + } + const uintptr_t payload = reinterpret_cast(aligned_pointer); + JXL_ASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + // Subtract (2's complement negation). + bytes_in_use.fetch_add(~header->allocated_size + 1, + std::memory_order_acq_rel); + +#if JXL_USE_MMAP + munmap(header->allocated, header->allocated_size); +#else + free(header->allocated); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/cache_aligned.h b/third_party/jpeg-xl/lib/jxl/base/cache_aligned.h new file mode 100644 index 000000000000..fe70c21fbea0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/cache_aligned.h @@ -0,0 +1,83 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_CACHE_ALIGNED_H_ +#define LIB_JXL_BASE_CACHE_ALIGNED_H_ + +// Memory allocator with support for alignment + misalignment. + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Functions that depend on the cache line size. +class CacheAligned { + public: + static void PrintStats(); + + static constexpr size_t kPointerSize = sizeof(void*); + static constexpr size_t kCacheLineSize = 64; + // To avoid RFOs, match L2 fill size (pairs of lines). + static constexpr size_t kAlignment = 2 * kCacheLineSize; + // Minimum multiple for which cache set conflicts and/or loads blocked by + // preceding stores can occur. + static constexpr size_t kAlias = 2048; + + // Returns a 'random' (cyclical) offset suitable for Allocate. + static size_t NextOffset(); + + // Returns null or memory whose address is congruent to `offset` (mod kAlias). + // This reduces cache conflicts and load/store stalls, especially with large + // allocations that would otherwise have similar alignments. At least + // `payload_size` (which can be zero) bytes will be accessible. + static void* Allocate(size_t payload_size, size_t offset); + + static void* Allocate(const size_t payload_size) { + return Allocate(payload_size, NextOffset()); + } + + static void Free(const void* aligned_pointer); +}; + +// Avoids the need for a function pointer (deleter) in CacheAlignedUniquePtr. +struct CacheAlignedDeleter { + void operator()(uint8_t* aligned_pointer) const { + return CacheAligned::Free(aligned_pointer); + } +}; + +using CacheAlignedUniquePtr = std::unique_ptr; + +// Does not invoke constructors. +static inline CacheAlignedUniquePtr AllocateArray(const size_t bytes) { + return CacheAlignedUniquePtr( + static_cast(CacheAligned::Allocate(bytes)), + CacheAlignedDeleter()); +} + +static inline CacheAlignedUniquePtr AllocateArray(const size_t bytes, + const size_t offset) { + return CacheAlignedUniquePtr( + static_cast(CacheAligned::Allocate(bytes, offset)), + CacheAlignedDeleter()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_CACHE_ALIGNED_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h b/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h new file mode 100644 index 000000000000..135de43de9e3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_COMPILER_SPECIFIC_H_ +#define LIB_JXL_BASE_COMPILER_SPECIFIC_H_ + +// Macros for compiler version + nonstandard keywords, e.g. __builtin_expect. + +#include + +// #if is shorter and safer than #ifdef. *_VERSION are zero if not detected, +// otherwise 100 * major + minor version. Note that other packages check for +// #ifdef COMPILER_MSVC, so we cannot use that same name. + +#ifdef _MSC_VER +#define JXL_COMPILER_MSVC _MSC_VER +#else +#define JXL_COMPILER_MSVC 0 +#endif + +#ifdef __GNUC__ +#define JXL_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define JXL_COMPILER_GCC 0 +#endif + +#ifdef __clang__ +#define JXL_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +// Clang pretends to be GCC for compatibility. +#undef JXL_COMPILER_GCC +#define JXL_COMPILER_GCC 0 +#else +#define JXL_COMPILER_CLANG 0 +#endif + +#if JXL_COMPILER_MSVC +#define JXL_RESTRICT __restrict +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_RESTRICT __restrict__ +#else +#define JXL_RESTRICT +#endif + +#if JXL_COMPILER_MSVC +#define JXL_INLINE __forceinline +#define JXL_NOINLINE __declspec(noinline) +#else +#define JXL_INLINE inline __attribute__((always_inline)) +#define JXL_NOINLINE __attribute__((noinline)) +#endif + +#if JXL_COMPILER_MSVC +#define JXL_NORETURN __declspec(noreturn) +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_NORETURN __attribute__((noreturn)) +#endif + +#if JXL_COMPILER_MSVC +#define JXL_UNREACHABLE __assume(false) +#elif JXL_COMPILER_CLANG || JXL_COMPILER_GCC >= 405 +#define JXL_UNREACHABLE __builtin_unreachable() +#else +#define JXL_UNREACHABLE +#endif + +#if JXL_COMPILER_MSVC +#define JXL_MAYBE_UNUSED +#else +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define JXL_MAYBE_UNUSED __attribute__((unused)) +#endif + +#if JXL_COMPILER_MSVC +// Unsupported, __assume is not the same. +#define JXL_LIKELY(expr) expr +#define JXL_UNLIKELY(expr) expr +#else +#define JXL_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define JXL_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#endif + +#if JXL_COMPILER_MSVC +#include + +#pragma intrinsic(_ReadWriteBarrier) +#define JXL_COMPILER_FENCE _ReadWriteBarrier() +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_COMPILER_FENCE asm volatile("" : : : "memory") +#else +#define JXL_COMPILER_FENCE +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if JXL_COMPILER_CLANG +// Early versions of Clang did not support __builtin_assume_aligned. +#define JXL_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif JXL_COMPILER_GCC +#define JXL_HAS_ASSUME_ALIGNED 1 +#else +#define JXL_HAS_ASSUME_ALIGNED 0 +#endif + +#if JXL_HAS_ASSUME_ALIGNED +#define JXL_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define JXL_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +#ifdef __has_attribute +#define JXL_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define JXL_HAVE_ATTRIBUTE(x) 0 +#endif + +// Raises warnings if the function return value is unused. Should appear as the +// first part of a function definition/declaration. +#if JXL_HAVE_ATTRIBUTE(nodiscard) +#define JXL_MUST_USE_RESULT [[nodiscard]] +#elif JXL_COMPILER_CLANG && JXL_HAVE_ATTRIBUTE(warn_unused_result) +#define JXL_MUST_USE_RESULT __attribute__((warn_unused_result)) +#else +#define JXL_MUST_USE_RESULT +#endif + +#if JXL_HAVE_ATTRIBUTE(__format__) +#define JXL_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define JXL_FORMAT(idx_fmt, idx_arg) +#endif + +#if JXL_COMPILER_MSVC +using ssize_t = intptr_t; +#endif + +#endif // LIB_JXL_BASE_COMPILER_SPECIFIC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/data_parallel.cc b/third_party/jpeg-xl/lib/jxl/base/data_parallel.cc new file mode 100644 index 000000000000..fb1692df208a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/data_parallel.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/data_parallel.h" + +#define DATA_PARALLEL_TRACE 0 + +#if DATA_PARALLEL_TRACE +#include + +#include "lib/jxl/base/time.h" +#endif // DATA_PARALLEL_TRACE + +namespace jxl { + +// static +JxlParallelRetCode ThreadPool::SequentialRunnerStatic( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + JxlParallelRetCode init_ret = (*init)(jpegxl_opaque, 1); + if (init_ret != 0) return init_ret; + + for (uint32_t i = start_range; i < end_range; i++) { + (*func)(jpegxl_opaque, i, 0); + } + return 0; +} + +#if DATA_PARALLEL_TRACE +void TraceRunBegin(const char* /*caller*/, double* t0) { *t0 = Now(); } + +void TraceRunEnd(const char* caller, double t0) { + const double elapsed = Now() - t0; + fprintf(stderr, "%27s: %5.1f ms\n", caller, elapsed * 1E3); +} +#else +void TraceRunBegin(const char* /*caller*/, double* /*t0*/) {} +void TraceRunEnd(const char* /*caller*/, double /*t0*/) {} +#endif + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/data_parallel.h b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h new file mode 100644 index 000000000000..9a47a59fdaf9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h @@ -0,0 +1,170 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_DATA_PARALLEL_H_ +#define LIB_JXL_BASE_DATA_PARALLEL_H_ + +// Portable, low-overhead C++11 ThreadPool alternative to OpenMP for +// data-parallel computations. + +#include +#include + +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +class ThreadPool { + public: + // Use this type as an InitFunc to skip the initialization step in Run(). + // When this is used the return value of Run() is always true and does not + // need to be checked. + struct SkipInit {}; + + ThreadPool(JxlParallelRunner runner, void* runner_opaque) + : runner_(runner ? runner : &ThreadPool::SequentialRunnerStatic), + runner_opaque_(runner ? runner_opaque : static_cast(this)) {} + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Runs init_func(num_threads) followed by data_func(task, thread) on worker + // thread(s) for every task in [begin, end). init_func() must return a Status + // indicating whether the initialization succeeded. + // "thread" is an integer smaller than num_threads. + // Not thread-safe - no two calls to Run may overlap. + // Subsequent calls will reuse the same threads. + // + // Precondition: begin <= end. + template + Status Run(uint32_t begin, uint32_t end, const InitFunc& init_func, + const DataFunc& data_func, const char* caller = "") { + JXL_ASSERT(begin <= end); + if (begin == end) return true; + RunCallState call_state(init_func, data_func); + // The runner_ uses the C convention and returns 0 in case of error, so we + // convert it to an Status. + return (*runner_)(runner_opaque_, static_cast(&call_state), + &call_state.CallInitFunc, &call_state.CallDataFunc, begin, + end) == 0; + } + + // Specialization that returns bool when SkipInit is used. + template + bool Run(uint32_t begin, uint32_t end, const SkipInit /* tag */, + const DataFunc& data_func, const char* caller = "") { + return Run(begin, end, ReturnTrueInit, data_func, caller); + } + + private: + static Status ReturnTrueInit(size_t num_threads) { return true; } + + // class holding the state of a Run() call to pass to the runner_ as an + // opaque_jpegxl pointer. + template + class RunCallState final { + public: + RunCallState(const InitFunc& init_func, const DataFunc& data_func) + : init_func_(init_func), data_func_(data_func) {} + + // JxlParallelRunInit interface. + static int CallInitFunc(void* jpegxl_opaque, size_t num_threads) { + const auto* self = + static_cast*>(jpegxl_opaque); + // Returns -1 when the internal init function returns false Status to + // indicate an error. + return self->init_func_(num_threads) ? 0 : -1; + } + + // JxlParallelRunFunction interface. + static void CallDataFunc(void* jpegxl_opaque, uint32_t value, + size_t thread_id) { + const auto* self = + static_cast*>(jpegxl_opaque); + return self->data_func_(value, thread_id); + } + + private: + const InitFunc& init_func_; + const DataFunc& data_func_; + }; + + // Default JxlParallelRunner used when no runner is provided by the + // caller. This runner doesn't use any threading and thread_id is always 0. + static JxlParallelRetCode SequentialRunnerStatic( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range); + + // The caller supplied runner function and its opaque void*. + const JxlParallelRunner runner_; + void* const runner_opaque_; +}; + +void TraceRunBegin(const char* caller, double* t0); +void TraceRunEnd(const char* caller, double t0); + +// TODO(deymo): Convert the return value to a Status when not using SkipInit. +template +bool RunOnPool(ThreadPool* pool, const uint32_t begin, const uint32_t end, + const InitFunc& init_func, const DataFunc& data_func, + const char* caller) { + Status ret = true; + double t0; + TraceRunBegin(caller, &t0); + if (pool == nullptr) { + ThreadPool default_pool(nullptr, nullptr); + ret = default_pool.Run(begin, end, init_func, data_func, caller); + } else { + ret = pool->Run(begin, end, init_func, data_func, caller); + } + TraceRunEnd(caller, t0); + return ret; +} + +// Accelerates multiple unsigned 32-bit divisions with the same divisor by +// precomputing a multiplier. This is useful for splitting a contiguous range of +// indices (the task index) into 2D indices. Exhaustively tested on dividends +// up to 4M with non-power of two divisors up to 2K. +class Divider { + public: + // "d" is the divisor (what to divide by). + explicit Divider(const uint32_t d) : shift_(FloorLog2Nonzero(d)) { + // Power of two divisors (including 1) are not supported because it is more + // efficient to special-case them at a higher level. + JXL_ASSERT((d & (d - 1)) != 0); + + // ceil_log2 = floor_log2 + 1 because we ruled out powers of two above. + const uint64_t next_pow2 = 1ULL << (shift_ + 1); + + mul_ = ((next_pow2 - d) << 32) / d + 1; + } + + // "n" is the numerator (what is being divided). + inline uint32_t operator()(const uint32_t n) const { + // Algorithm from "Division by Invariant Integers using Multiplication". + // Its "sh1" is hardcoded to 1 because we don't need to handle d=1. + const uint32_t hi = (uint64_t(mul_) * n) >> 32; + return (hi + ((n - hi) >> 1)) >> shift_; + } + + private: + uint32_t mul_; + const int shift_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_DATA_PARALLEL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.cc b/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.cc new file mode 100644 index 000000000000..4414571d1700 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.cc @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/descriptive_statistics.h" + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +void Stats::Assimilate(const Stats& other) { + const int64_t total_n = n_ + other.n_; + if (total_n == 0) return; // Nothing to do; prevents div by zero. + + min_ = std::min(min_, other.min_); + max_ = std::max(max_, other.max_); + + product_ *= other.product_; + + const double product_n = n_ * other.n_; + const double n2 = n_ * n_; + const double other_n2 = other.n_ * other.n_; + // Warning: multiplying int64 can overflow here. + const double total_n2 = static_cast(total_n) * total_n; + const double total_n3 = static_cast(total_n2) * total_n; + // Precompute reciprocal for speed - used at least twice. + const double inv_total_n = 1.0 / total_n; + const double inv_total_n2 = 1.0 / total_n2; + + const double delta = other.m1_ - m1_; + const double delta2 = delta * delta; + const double delta3 = delta * delta2; + const double delta4 = delta2 * delta2; + + m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n; + + const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n; + + const double new_m3 = + m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 + + 3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n; + + m4_ += other.m4_ + + delta4 * product_n * (n2 - product_n + other_n2) / total_n3 + + 6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 + + 4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n; + + m2_ = new_m2; + m3_ = new_m3; + n_ = total_n; +} + +std::string Stats::ToString(int exclude) const { + if (Count() == 0) return std::string("(none)"); + + char buf[300]; + size_t pos = 0; + int ret; // snprintf - bytes written or negative for error. + + if ((exclude & kNoCount) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%6zu ", + static_cast(Count())); + JXL_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMeanSD) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%9.6f SD=%8.5f ", Mean(), + StandardDeviation()); + JXL_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMinMax) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5f Max=%8.5f ", Min(), + Max()); + JXL_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoSkewKurt) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ", + Skewness(), Kurtosis()); + JXL_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoGeomean) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ", + GeometricMean()); + JXL_ASSERT(ret > 0); + pos += ret; + } + + JXL_ASSERT(pos < sizeof(buf)); + return buf; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.h b/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.h new file mode 100644 index 000000000000..9e70a2facea6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/descriptive_statistics.h @@ -0,0 +1,135 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_DESCRIPTIVE_STATISTICS_H_ +#define LIB_JXL_BASE_DESCRIPTIVE_STATISTICS_H_ + +// For analyzing the range/distribution of scalars. + +#include + +#include +#include +#include + +namespace jxl { + +// Descriptive statistics of a variable (4 moments). +class Stats { + public: + void Notify(const float x) { + ++n_; + + min_ = std::min(min_, x); + max_ = std::max(max_, x); + + product_ *= x; + + // Online moments. Reference: https://goo.gl/9ha694 + const double d = x - m1_; + const double d_div_n = d / n_; + const double d2n1_div_n = d * (n_ - 1) * d_div_n; + const int64_t n_poly = n_ * n_ - 3 * n_ + 3; + m1_ += d_div_n; + m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_); + m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_); + m2_ += d2n1_div_n; + } + + void Assimilate(const Stats& other); + + int64_t Count() const { return n_; } + + float Min() const { return min_; } + float Max() const { return max_; } + + double GeometricMean() const { + return n_ == 0 ? 0.0 : pow(product_, 1.0 / n_); + } + + double Mean() const { return m1_; } + // Same as Mu2. Assumes n_ is large. + double SampleVariance() const { + return n_ == 0 ? 0.0 : m2_ / static_cast(n_); + } + // Unbiased estimator for population variance even for smaller n_. + double Variance() const { + if (n_ == 0) return 0.0; + if (n_ == 1) return m2_; + return m2_ / static_cast(n_ - 1); + } + double StandardDeviation() const { return std::sqrt(Variance()); } + // Near zero for normal distributions; if positive on a unimodal distribution, + // the right tail is fatter. Assumes n_ is large. + double SampleSkewness() const { + if (std::abs(m2_) < 1E-7) return 0.0; + return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Skewness() const { + if (n_ == 0) return 0.0; + const double biased = SampleSkewness(); + const double r = (n_ - 1.0) / n_; + return biased * std::pow(r, 1.5); + } + // Near zero for normal distributions; smaller values indicate fewer/smaller + // outliers and larger indicates more/larger outliers. Assumes n_ is large. + double SampleKurtosis() const { + if (std::abs(m2_) < 1E-7) return 0.0; + return m4_ * n_ / (m2_ * m2_); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Kurtosis() const { + if (n_ == 0) return 0.0; + const double biased = SampleKurtosis(); + const double r = (n_ - 1.0) / n_; + return biased * r * r; + } + + // Central moments, useful for "method of moments"-based parameter estimation + // of a mixture of two Gaussians. Assumes Count() != 0. + double Mu1() const { return m1_; } + double Mu2() const { return m2_ / static_cast(n_); } + double Mu3() const { return m3_ / static_cast(n_); } + double Mu4() const { return m4_ / static_cast(n_); } + + // Which statistics to EXCLUDE in ToString + enum { + kNoCount = 1, + kNoMeanSD = 2, + kNoMinMax = 4, + kNoSkewKurt = 8, + kNoGeomean = 16 + }; + + std::string ToString(int exclude = 0) const; + + private: + int64_t n_ = 0; // signed for faster conversion + safe subtraction + + float min_ = 1E30f; + float max_ = -1E30f; + + double product_ = 1.0; + + // Moments + double m1_ = 0.0; + double m2_ = 0.0; + double m3_ = 0.0; + double m4_ = 0.0; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_DESCRIPTIVE_STATISTICS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/file_io.h b/third_party/jpeg-xl/lib/jxl/base/file_io.h new file mode 100644 index 000000000000..c3529f82b6f9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/file_io.h @@ -0,0 +1,121 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_FILE_IO_H_ +#define LIB_JXL_BASE_FILE_IO_H_ + +// Helper functions for reading/writing files. + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Returns extension including the dot, or empty string if none. Assumes +// filename is not a hidden file (e.g. ".bashrc"). May be called with a pathname +// if the filename contains a dot and/or no other path component does. +static inline std::string Extension(const std::string& filename) { + const size_t pos = filename.rfind('.'); + if (pos == std::string::npos) return std::string(); + return filename.substr(pos); +} + +// RAII, ensures files are closed even when returning early. +class FileWrapper { + public: + FileWrapper(const FileWrapper& other) = delete; + FileWrapper& operator=(const FileWrapper& other) = delete; + + explicit FileWrapper(const std::string& pathname, const char* mode) + : file_(fopen(pathname.c_str(), mode)) {} + + ~FileWrapper() { + if (file_ != nullptr) { + const int err = fclose(file_); + JXL_CHECK(err == 0); + } + } + + // We intend to use FileWrapper as a replacement of FILE. + // NOLINTNEXTLINE(google-explicit-constructor) + operator FILE*() const { return file_; } + + private: + FILE* const file_; +}; + +template +static inline Status ReadFile(const std::string& pathname, + ContainerType* JXL_RESTRICT bytes) { + FileWrapper f(pathname, "rb"); + if (f == nullptr) return JXL_FAILURE("Failed to open file for reading"); + + // Ensure it is a regular file +#ifdef _WIN32 + struct __stat64 s = {}; + const int err = _stat64(pathname.c_str(), &s); + const bool is_file = (s.st_mode & S_IFREG) != 0; +#else + struct stat s = {}; + const int err = stat(pathname.c_str(), &s); + const bool is_file = S_ISREG(s.st_mode); +#endif + if (err != 0) return JXL_FAILURE("Failed to obtain file status"); + if (!is_file) return JXL_FAILURE("Not a file"); + + // Get size of file in bytes + const int64_t size = s.st_size; + if (size <= 0) return JXL_FAILURE("Empty or invalid file size"); + bytes->resize(static_cast(size)); + + size_t pos = 0; + while (pos < bytes->size()) { + // Needed in case ContainerType is std::string, whose data() is const. + char* bytes_writable = reinterpret_cast(&(*bytes)[0]); + const size_t bytes_read = + fread(bytes_writable + pos, 1, bytes->size() - pos, f); + if (bytes_read == 0) return JXL_FAILURE("Failed to read"); + pos += bytes_read; + } + JXL_ASSERT(pos == bytes->size()); + return true; +} + +template +static inline Status WriteFile(const ContainerType& bytes, + const std::string& pathname) { + FileWrapper f(pathname, "wb"); + if (f == nullptr) return JXL_FAILURE("Failed to open file for writing"); + + size_t pos = 0; + while (pos < bytes.size()) { + const size_t bytes_written = + fwrite(bytes.data() + pos, 1, bytes.size() - pos, f); + if (bytes_written == 0) return JXL_FAILURE("Failed to write"); + pos += bytes_written; + } + JXL_ASSERT(pos == bytes.size()); + + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_FILE_IO_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/iaca.h b/third_party/jpeg-xl/lib/jxl/base/iaca.h new file mode 100644 index 000000000000..aad6f5d4e4d2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/iaca.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_IACA_H_ +#define LIB_JXL_BASE_IACA_H_ + +#include "lib/jxl/base/compiler_specific.h" + +// IACA (Intel's Code Analyzer) analyzes instruction latencies, but only for +// code between special markers. These functions embed such markers in an +// executable, but only for reading via IACA - they deliberately trigger a +// crash if executed to ensure they are removed in normal builds. + +#ifndef JXL_IACA_ENABLED +#define JXL_IACA_ENABLED 0 +#endif + +namespace jxl { + +// Call before the region of interest. +static JXL_INLINE void BeginIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // UD2 "instruction" raises an invalid opcode exception. + ".byte 0x0F, 0x0B\n\t" + // Magic sequence recognized by IACA (MOV + addr32 fs:NOP). This actually + // clobbers EBX, but we don't care because the code won't be run, and we + // want IACA to observe the same code the compiler would have generated + // without this marker. + "movl $111, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Call after the region of interest. +static JXL_INLINE void EndIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // See above. + "movl $222, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + // UD2 + ".byte 0x0F, 0x0B\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Add to a scope to mark a region. +struct ScopeIACA { + JXL_INLINE ScopeIACA() { BeginIACA(); } + JXL_INLINE ~ScopeIACA() { EndIACA(); } +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_IACA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/os_macros.h b/third_party/jpeg-xl/lib/jxl/base/os_macros.h new file mode 100644 index 000000000000..8dd1b535fcae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/os_macros.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_OS_MACROS_H_ +#define LIB_JXL_BASE_OS_MACROS_H_ + +// Defines the JXL_OS_* macros. + +#if defined(_WIN32) || defined(_WIN64) +#define JXL_OS_WIN 1 +#else +#define JXL_OS_WIN 0 +#endif + +#ifdef __linux__ +#define JXL_OS_LINUX 1 +#else +#define JXL_OS_LINUX 0 +#endif + +#ifdef __MACH__ +#define JXL_OS_MAC 1 +#else +#define JXL_OS_MAC 0 +#endif + +#ifdef __FreeBSD__ +#define JXL_OS_FREEBSD 1 +#else +#define JXL_OS_FREEBSD 0 +#endif + +#ifdef __HAIKU__ +#define JXL_OS_HAIKU 1 +#else +#define JXL_OS_HAIKU 0 +#endif + +#endif // LIB_JXL_BASE_OS_MACROS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/override.h b/third_party/jpeg-xl/lib/jxl/base/override.h new file mode 100644 index 000000000000..3b1f0b8b4dde --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/override.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_OVERRIDE_H_ +#define LIB_JXL_BASE_OVERRIDE_H_ + +// 'Trool' for command line arguments: force enable/disable, or use default. + +namespace jxl { + +// No effect if kDefault, otherwise forces a feature (typically a FrameHeader +// flag) on or off. +enum class Override : int { kOn = 1, kOff = 0, kDefault = -1 }; + +static inline Override OverrideFromBool(bool flag) { + return flag ? Override::kOn : Override::kOff; +} + +static inline bool ApplyOverride(Override o, bool default_condition) { + if (o == Override::kOn) return true; + if (o == Override::kOff) return false; + return default_condition; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_OVERRIDE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/padded_bytes.cc b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.cc new file mode 100644 index 000000000000..055d10418c56 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.cc @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/padded_bytes.h" + +namespace jxl { + +void PaddedBytes::IncreaseCapacityTo(size_t capacity) { + JXL_ASSERT(capacity > capacity_); + + size_t new_capacity = std::max(capacity, 3 * capacity_ / 2); + new_capacity = std::max(64, new_capacity); + + // BitWriter writes up to 7 bytes past the end. + CacheAlignedUniquePtr new_data = AllocateArray(new_capacity + 8); + if (new_data == nullptr) { + // Allocation failed, discard all data to ensure this is noticed. + size_ = capacity_ = 0; + return; + } + + if (data_ == nullptr) { + // First allocation: ensure first byte is initialized (won't be copied). + new_data[0] = 0; + } else { + // Subsequent resize: copy existing data to new location. + memcpy(new_data.get(), data_.get(), size_); + // Ensure that the first new byte is initialized, to allow write_bits to + // safely append to the newly-resized PaddedBytes. + new_data[size_] = 0; + } + + capacity_ = new_capacity; + std::swap(new_data, data_); +} + +void PaddedBytes::assign(const uint8_t* new_begin, const uint8_t* new_end) { + JXL_DASSERT(new_begin <= new_end); + const size_t new_size = static_cast(new_end - new_begin); + + // memcpy requires non-overlapping ranges, and resizing might invalidate the + // new range. Neither happens if the new range is completely to the left or + // right of the _allocated_ range (irrespective of size_). + const uint8_t* allocated_end = begin() + capacity_; + const bool outside = new_end <= begin() || new_begin >= allocated_end; + if (outside) { + resize(new_size); // grow or shrink + memcpy(data(), new_begin, new_size); + return; + } + + // There is overlap. The new size cannot be larger because we own the memory + // and the new range cannot include anything outside the allocated range. + JXL_ASSERT(new_size <= capacity_); + + // memmove allows overlap and capacity_ is sufficient. + memmove(data(), new_begin, new_size); + size_ = new_size; // shrink +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h new file mode 100644 index 000000000000..b40b119431ac --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/padded_bytes.h @@ -0,0 +1,204 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_PADDED_BYTES_H_ +#define LIB_JXL_BASE_PADDED_BYTES_H_ + +// std::vector replacement with padding to reduce bounds checks in WriteBits + +#include +#include +#include // memcpy + +#include // max +#include +#include // swap + +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Provides a subset of the std::vector interface with some differences: +// - allows BitWriter to write 64 bits at a time without bounds checking; +// - ONLY zero-initializes the first byte (required by BitWriter); +// - ensures cache-line alignment. +class PaddedBytes { + public: + // Required for output params. + PaddedBytes() : size_(0), capacity_(0) {} + + explicit PaddedBytes(size_t size) : size_(size), capacity_(0) { + if (size != 0) IncreaseCapacityTo(size); + } + + PaddedBytes(size_t size, uint8_t value) : size_(size), capacity_(0) { + if (size != 0) { + IncreaseCapacityTo(size); + } + if (size_ != 0) { + memset(data(), value, size); + } + } + + PaddedBytes(const PaddedBytes& other) : size_(other.size_), capacity_(0) { + if (size_ != 0) IncreaseCapacityTo(size_); + if (data() != nullptr) memcpy(data(), other.data(), size_); + } + PaddedBytes& operator=(const PaddedBytes& other) { + // Self-assignment is safe. + resize(other.size()); + if (data() != nullptr) memmove(data(), other.data(), size_); + return *this; + } + + // default is not OK - need to set other.size_ to 0! + PaddedBytes(PaddedBytes&& other) noexcept + : size_(other.size_), + capacity_(other.capacity_), + data_(std::move(other.data_)) { + other.size_ = other.capacity_ = 0; + } + PaddedBytes& operator=(PaddedBytes&& other) noexcept { + size_ = other.size_; + capacity_ = other.capacity_; + data_ = std::move(other.data_); + + if (&other != this) { + other.size_ = other.capacity_ = 0; + } + return *this; + } + + void swap(PaddedBytes& other) { + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + std::swap(data_, other.data_); + } + + void reserve(size_t capacity) { + if (capacity > capacity_) IncreaseCapacityTo(capacity); + } + + // NOTE: unlike vector, this does not initialize the new data! + // However, we guarantee that write_bits can safely append after + // the resize, as we zero-initialize the first new byte of data. + // If size < capacity(), does not invalidate the memory. + void resize(size_t size) { + if (size > capacity_) IncreaseCapacityTo(size); + size_ = (data() == nullptr) ? 0 : size; + } + + // resize(size) plus explicit initialization of the new data with `value`. + void resize(size_t size, uint8_t value) { + size_t old_size = size_; + resize(size); + if (size_ > old_size) { + memset(data() + old_size, value, size_ - old_size); + } + } + + // Amortized constant complexity due to exponential growth. + void push_back(uint8_t x) { + if (size_ == capacity_) { + IncreaseCapacityTo(capacity_ + 1); + if (data() == nullptr) return; + } + + data_[size_++] = x; + } + + size_t size() const { return size_; } + size_t capacity() const { return capacity_; } + + uint8_t* data() { return data_.get(); } + const uint8_t* data() const { return data_.get(); } + + // std::vector operations implemented in terms of the public interface above. + + void clear() { resize(0); } + bool empty() const { return size() == 0; } + + void assign(std::initializer_list il) { + resize(il.size()); + memcpy(data(), il.begin(), il.size()); + } + + // Replaces data() with [new_begin, new_end); potentially reallocates. + void assign(const uint8_t* new_begin, const uint8_t* new_end); + + uint8_t* begin() { return data(); } + const uint8_t* begin() const { return data(); } + uint8_t* end() { return begin() + size(); } + const uint8_t* end() const { return begin() + size(); } + + uint8_t& operator[](const size_t i) { + BoundsCheck(i); + return data()[i]; + } + const uint8_t& operator[](const size_t i) const { + BoundsCheck(i); + return data()[i]; + } + + uint8_t& back() { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + const uint8_t& back() const { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + + template + void append(const T& other) { + append(reinterpret_cast(other.data()), + reinterpret_cast(other.data()) + other.size()); + } + + void append(const uint8_t* begin, const uint8_t* end) { + size_t old_size = size(); + resize(size() + (end - begin)); + memcpy(data() + old_size, begin, end - begin); + } + + private: + void BoundsCheck(size_t i) const { + // <= is safe due to padding and required by BitWriter. + JXL_ASSERT(i <= size()); + } + + // Copies existing data to newly allocated "data_". If allocation fails, + // data() == nullptr and size_ = capacity_ = 0. + // The new capacity will be at least 1.5 times the old capacity. This ensures + // that we avoid quadratic behaviour. + void IncreaseCapacityTo(size_t capacity); + + size_t size_; + size_t capacity_; + CacheAlignedUniquePtr data_; +}; + +template +static inline void Append(const T& s, PaddedBytes* out, + size_t* JXL_RESTRICT byte_pos) { + memcpy(out->data() + *byte_pos, s.data(), s.size()); + *byte_pos += s.size(); + JXL_CHECK(*byte_pos <= out->size()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_PADDED_BYTES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/profiler.h b/third_party/jpeg-xl/lib/jxl/base/profiler.h new file mode 100644 index 000000000000..c78b2a395588 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/profiler.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_PROFILER_H_ +#define LIB_JXL_BASE_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// To use the profiler you must define PROFILER_ENABLED and link against the +// libjxl_profiler library. + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif // PROFILER_ENABLED + +#if PROFILER_ENABLED + +#include "lib/profiler/profiler.h" + +#else // !PROFILER_ENABLED + +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() + +#endif // PROFILER_ENABLED + +#endif // LIB_JXL_BASE_PROFILER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/robust_statistics.h b/third_party/jpeg-xl/lib/jxl/base/robust_statistics.h new file mode 100644 index 000000000000..ef31d4f042a5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/robust_statistics.h @@ -0,0 +1,369 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_ROBUST_STATISTICS_H_ +#define LIB_JXL_BASE_ROBUST_STATISTICS_H_ + +// Robust statistics: Mode, Median, MedianAbsoluteDeviation. + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +template +T Geomean(const T* items, size_t count) { + double product = 1.0; + for (size_t i = 0; i < count; ++i) { + product *= items[i]; + } + return static_cast(std::pow(product, 1.0 / count)); +} + +// Round up for integers +template ::is_integer>::type* = nullptr> +inline T Half(T x) { + return (x + 1) / 2; +} + +// Mul is faster than div. +template ::is_integer>::type* = nullptr> +inline T Half(T x) { + return x * 0.5; +} + +// Returns the median value. Side effect: values <= median will appear before, +// values >= median after the middle index. +// Guarantees average speed O(num_values). +template +T Median(T* samples, const size_t num_samples) { + JXL_ASSERT(num_samples != 0); + std::nth_element(samples, samples + num_samples / 2, samples + num_samples); + T result = samples[num_samples / 2]; + // If even size, find largest element in the partially sorted vector to + // use as second element to average with + if ((num_samples & 1) == 0) { + T biggest = *std::max_element(samples, samples + num_samples / 2); + result = Half(result + biggest); + } + return result; +} + +template +T Median(std::vector* samples) { + return Median(samples->data(), samples->size()); +} + +template +static inline T Median3(const T a, const T b, const T c) { + return std::max(std::min(a, b), std::min(c, std::max(a, b))); +} + +template +static inline T Median5(const T a, const T b, const T c, const T d, const T e) { + return Median3(e, std::max(std::min(a, b), std::min(c, d)), + std::min(std::max(a, b), std::max(c, d))); +} + +// Returns a robust measure of variability. +template +T MedianAbsoluteDeviation(const T* samples, const size_t num_samples, + const T median) { + JXL_ASSERT(num_samples != 0); + std::vector abs_deviations; + abs_deviations.reserve(num_samples); + for (size_t i = 0; i < num_samples; ++i) { + abs_deviations.push_back(std::abs(samples[i] - median)); + } + return Median(&abs_deviations); +} + +template +T MedianAbsoluteDeviation(const std::vector& samples, const T median) { + return MedianAbsoluteDeviation(samples.data(), samples.size(), median); +} + +// Half{Range/Sample}Mode are implementations of "Robust estimators of the mode +// and skewness of continuous data". The mode is less affected by outliers in +// highly-skewed distributions than the median. + +// Robust estimator of the mode for data given as sorted values. +// O(N*logN), N=num_values. +class HalfSampleMode { + public: + // Returns mode. "sorted" must be in ascending order. + template + T operator()(const T* const JXL_RESTRICT sorted, + const size_t num_values) const { + int64_t center = num_values / 2; + int64_t width = num_values; + + // Zoom in on modal intervals of decreasing width. Stop before we reach + // width=1, i.e. single values, for which there is no "slope". + while (width > 2) { + // Round up so we can still reach the outer edges of odd widths. + width = Half(width); + + center = CenterOfIntervalWithMinSlope(sorted, num_values, center, width); + } + + return sorted[center]; // mode := middle value in modal interval. + } + + private: + // Returns center of the densest region [c-radius, c+radius]. + template + static JXL_INLINE int64_t CenterOfIntervalWithMinSlope( + const T* JXL_RESTRICT sorted, const int64_t total_values, + const int64_t center, const int64_t width) { + const int64_t radius = Half(width); + + auto compute_slope = [radius, total_values, sorted]( + int64_t c, int64_t* actual_center = nullptr) { + // For symmetry, check 2*radius+1 values, i.e. [min, max]. + const int64_t min = std::max(c - radius, int64_t(0)); + const int64_t max = std::min(c + radius, total_values - 1); + JXL_ASSERT(min < max); + JXL_ASSERT(sorted[min] <= + sorted[max] + std::numeric_limits::epsilon()); + const float dx = max - min + 1; + const float slope = (sorted[max] - sorted[min]) / dx; + + if (actual_center != nullptr) { + // c may be out of bounds, so return center of the clamped bounds. + *actual_center = Half(min + max); + } + return slope; + }; + + // First find min_slope for all centers. + float min_slope = std::numeric_limits::max(); + for (int64_t c = center - radius; c <= center + radius; ++c) { + min_slope = std::min(min_slope, compute_slope(c)); + } + + // Candidates := centers with slope ~= min_slope. + std::vector candidates; + for (int64_t c = center - radius; c <= center + radius; ++c) { + int64_t actual_center; + const float slope = compute_slope(c, &actual_center); + if (slope <= min_slope * 1.001f) { + candidates.push_back(actual_center); + } + } + + // Keep the median. + JXL_ASSERT(!candidates.empty()); + if (candidates.size() == 1) return candidates[0]; + return Median(&candidates); + } +}; + +// Robust estimator of the mode for data given as a CDF. +// O(N*logN), N=num_bins. +class HalfRangeMode { + public: + // Returns mode expressed as a histogram bin index. "cdf" must be weakly + // monotonically increasing, e.g. from std::partial_sum. + int operator()(const uint32_t* JXL_RESTRICT cdf, + const size_t num_bins) const { + int center = num_bins / 2; + int width = num_bins; + + // Zoom in on modal intervals of decreasing width. Stop before we reach + // width=1, i.e. original bins, because those are noisy. + while (width > 2) { + // Round up so we can still reach the outer edges of odd widths. + width = Half(width); + + center = CenterOfIntervalWithMaxDensity(cdf, num_bins, center, width); + } + + return center; // mode := midpoint of modal interval. + } + + private: + // Returns center of the densest interval [c-radius, c+radius]. + static JXL_INLINE int CenterOfIntervalWithMaxDensity( + const uint32_t* JXL_RESTRICT cdf, const int total_bins, const int center, + const int width) { + const int radius = Half(width); + + auto compute_density = [radius, total_bins, cdf]( + int c, int* actual_center = nullptr) { + // For symmetry, check 2*radius+1 bins, i.e. [min, max]. + const int min = std::max(c - radius, 1); // for -1 below + const int max = std::min(c + radius, total_bins - 1); + JXL_ASSERT(min < max); + JXL_ASSERT(cdf[min] <= cdf[max - 1]); + const int num_bins = max - min + 1; + // Sum over [min, max] == CDF(max) - CDF(min-1). + const float density = float(cdf[max] - cdf[min - 1]) / num_bins; + + if (actual_center != nullptr) { + // c may be out of bounds, so take center of the clamped bounds. + *actual_center = Half(min + max); + } + return density; + }; + + // First find max_density for all centers. + float max_density = 0.0f; + for (int c = center - radius; c <= center + radius; ++c) { + max_density = std::max(max_density, compute_density(c)); + } + + // Candidates := centers with density ~= max_density. + std::vector candidates; + for (int c = center - radius; c <= center + radius; ++c) { + int actual_center; + const float density = compute_density(c, &actual_center); + if (density >= max_density * 0.999f) { + candidates.push_back(actual_center); + } + } + + // Keep the median. + JXL_ASSERT(!candidates.empty()); + if (candidates.size() == 1) return candidates[0]; + return Median(&candidates); + } +}; + +// Sorts integral values in ascending order. About 3x faster than std::sort for +// input distributions with very few unique values. +template +void CountingSort(T* begin, T* end) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair; + std::vector unique; + for (const T* p = begin; p != end; ++p) { + const T value = *p; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique& u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(*p, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* JXL_RESTRICT p = begin; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + JXL_ASSERT(p == end); +} + +struct Bivariate { + Bivariate(float x, float y) : x(x), y(y) {} + float x; + float y; +}; + +class Line { + public: + constexpr Line(const float slope, const float intercept) + : slope_(slope), intercept_(intercept) {} + + constexpr float slope() const { return slope_; } + constexpr float intercept() const { return intercept_; } + + // Robust line fit using Siegel's repeated-median algorithm. + explicit Line(const std::vector& points) { + const size_t N = points.size(); + // This straightforward N^2 implementation is OK for small N. + JXL_ASSERT(N < 10 * 1000); + + // One for every point i. + std::vector medians; + medians.reserve(N); + + // One for every j != i. Never cleared to avoid reallocation. + std::vector slopes(N - 1); + + for (size_t i = 0; i < N; ++i) { + // Index within slopes[] (avoids the hole where j == i). + size_t idx_slope = 0; + + for (size_t j = 0; j < N; ++j) { + if (j == i) continue; + + const float dy = points[j].y - points[i].y; + const float dx = points[j].x - points[i].x; + JXL_ASSERT(std::abs(dx) > 1E-7f); // x must be distinct + slopes[idx_slope++] = dy / dx; + } + JXL_ASSERT(idx_slope == N - 1); + + const float median = Median(&slopes); + medians.push_back(median); + } + + slope_ = Median(&medians); + + // Solve for intercept, overwriting medians[]. + for (size_t i = 0; i < N; ++i) { + medians[i] = points[i].y - slope_ * points[i].x; + } + intercept_ = Median(&medians); + } + + constexpr float operator()(float x) const { return x * slope_ + intercept_; } + + private: + float slope_; + float intercept_; +}; + +static inline void EvaluateQuality(const Line& line, + const std::vector& points, + float* JXL_RESTRICT max_l1, + float* JXL_RESTRICT median_abs_deviation) { + // For computing median_abs_deviation. + std::vector abs_deviations; + abs_deviations.reserve(points.size()); + + *max_l1 = 0.0f; + for (const Bivariate& point : points) { + const float l1 = std::abs(line(point.x) - point.y); + *max_l1 = std::max(*max_l1, l1); + abs_deviations.push_back(l1); + } + + *median_abs_deviation = Median(&abs_deviations); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_ROBUST_STATISTICS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/span.h b/third_party/jpeg-xl/lib/jxl/base/span.h new file mode 100644 index 000000000000..12e58e20ab92 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/span.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_SPAN_H_ +#define LIB_JXL_BASE_SPAN_H_ + +// Span (array view) is a non-owning container that provides cheap "cut" +// operations and could be used as "ArrayLike" data source for PaddedBytes. + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +template +class Span { + public: + constexpr Span() noexcept : Span(nullptr, 0) {} + + constexpr Span(T* array, size_t length) noexcept + : ptr_(array), len_(length) {} + + template + explicit constexpr Span(T (&a)[N]) noexcept : Span(a, N) {} + + template + explicit constexpr Span(const ArrayLike& other) noexcept + : Span(reinterpret_cast(other.data()), other.size()) { + static_assert(sizeof(*other.data()) == sizeof(T), + "Incompatible type of source."); + } + + constexpr T* data() const noexcept { return ptr_; } + + constexpr size_t size() const noexcept { return len_; } + + constexpr T& operator[](size_t i) const noexcept { + // MSVC 2015 accepts this as constexpr, but not ptr_[i] + return *(data() + i); + } + + void remove_prefix(size_t n) noexcept { + JXL_ASSERT(size() >= n); + ptr_ += n; + len_ -= n; + } + + private: + T* ptr_; + size_t len_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_SPAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/status.cc b/third_party/jpeg-xl/lib/jxl/base/status.cc new file mode 100644 index 000000000000..3927b95bda84 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/status.cc @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/status.h" + +#include +#include +#include + +#include + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +namespace jxl { + +bool Debug(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + return false; +} + +bool Abort() { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // defined(*_SANITIZER) + +#if JXL_COMPILER_MSVC + __debugbreak(); + abort(); +#else + __builtin_trap(); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/status.h b/third_party/jpeg-xl/lib/jxl/base/status.h new file mode 100644 index 000000000000..5c675a8f6a32 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/status.h @@ -0,0 +1,300 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_STATUS_H_ +#define LIB_JXL_BASE_STATUS_H_ + +// Error handling: Status return type + helper macros. + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Uncomment to abort when JXL_FAILURE or JXL_STATUS with a fatal error is +// reached: +// #define JXL_CRASH_ON_ERROR + +#ifndef JXL_ENABLE_ASSERT +#define JXL_ENABLE_ASSERT 1 +#endif + +#ifndef JXL_ENABLE_CHECK +#define JXL_ENABLE_CHECK 1 +#endif + +// Pass -DJXL_DEBUG_ON_ERROR at compile time to print debug messages when a +// function returns JXL_FAILURE or calls JXL_NOTIFY_ERROR. Note that this is +// irrelevant if you also pass -DJXL_CRASH_ON_ERROR. +#if defined(JXL_DEBUG_ON_ERROR) || defined(JXL_CRASH_ON_ERROR) +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR +#ifdef NDEBUG +#define JXL_DEBUG_ON_ERROR 0 +#else // NDEBUG +#define JXL_DEBUG_ON_ERROR 1 +#endif // NDEBUG +#endif // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR + +// Pass -DJXL_DEBUG_ON_ALL_ERROR at compile time to print debug messages on +// all error (fatal and non-fatal) status. This implies JXL_DEBUG_ON_ERROR. +#if defined(JXL_DEBUG_ON_ALL_ERROR) +#undef JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 1 +// JXL_DEBUG_ON_ALL_ERROR implies JXL_DEBUG_ON_ERROR too. +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 0 +#endif // JXL_DEBUG_ON_ALL_ERROR + +// The Verbose level for the library +#ifndef JXL_DEBUG_V_LEVEL +#define JXL_DEBUG_V_LEVEL 0 +#endif // JXL_DEBUG_V_LEVEL + +// Print a debug message on standard error. You should use the JXL_DEBUG macro +// instead of calling Debug directly. This function returns false, so it can be +// used as a return value in JXL_FAILURE. +JXL_FORMAT(1, 2) +bool Debug(const char* format, ...); + +// Print a debug message on standard error if "enabled" is true. "enabled" is +// normally a macro that evaluates to 0 or 1 at compile time, so the Debug +// function is never called and optimized out in release builds. Note that the +// arguments are compiled but not evaluated when enabled is false. The format +// string must be a explicit string in the call, for example: +// JXL_DEBUG(JXL_DEBUG_MYMODULE, "my module message: %d", some_var); +// Add a header at the top of your module's .cc or .h file (depending on whether +// you have JXL_DEBUG calls from the .h as well) like this: +// #ifndef JXL_DEBUG_MYMODULE +// #define JXL_DEBUG_MYMODULE 0 +// #endif JXL_DEBUG_MYMODULE +#define JXL_DEBUG(enabled, format, ...) \ + do { \ + if (enabled) { \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + } \ + } while (0) + +// JXL_DEBUG version that prints the debug message if the global verbose level +// defined at compile time by JXL_DEBUG_V_LEVEL is greater or equal than the +// passed level. +#define JXL_DEBUG_V(level, format, ...) \ + JXL_DEBUG(level <= JXL_DEBUG_V_LEVEL, format, ##__VA_ARGS__) + +// Warnings (via JXL_WARNING) are enabled by default in debug builds (opt and +// debug). +#ifdef JXL_DEBUG_WARNING +#undef JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#else // JXL_DEBUG_WARNING +#ifdef NDEBUG +#define JXL_DEBUG_WARNING 0 +#else // JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#endif // NDEBUG +#endif // JXL_DEBUG_WARNING +#define JXL_WARNING(format, ...) \ + JXL_DEBUG(JXL_DEBUG_WARNING, format, ##__VA_ARGS__) + +// Exits the program after printing a stack trace when possible. +JXL_NORETURN bool Abort(); + +// Exits the program after printing file/line plus a formatted string. +#define JXL_ABORT(format, ...) \ + (::jxl::Debug(("%s:%d: JXL_ABORT: " format "\n"), __FILE__, __LINE__, \ + ##__VA_ARGS__), \ + ::jxl::Abort()) + +// Does not guarantee running the code, use only for debug mode checks. +#if JXL_ENABLE_ASSERT +#define JXL_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(true, "JXL_ASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_ASSERT(condition) \ + do { \ + } while (0) +#endif + +// Define JXL_IS_DEBUG_BUILD that denotes asan, msan and other debug builds, +// but not opt or release. +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || \ + defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) || \ + defined(__clang_analyzer__) +#define JXL_IS_DEBUG_BUILD 1 +#else +#define JXL_IS_DEBUG_BUILD 0 +#endif + +// Same as above, but only runs in debug builds (builds where NDEBUG is not +// defined). This is useful for slower asserts that we want to run more rarely +// than usual. These will run on asan, msan and other debug builds, but not in +// opt or release. +#if JXL_IS_DEBUG_BUILD +#define JXL_DASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(true, "JXL_DASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_DASSERT(condition) \ + do { \ + } while (0) +#endif + +// Always runs the condition, so can be used for non-debug calls. +#if JXL_ENABLE_CHECK +#define JXL_CHECK(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(true, "JXL_CHECK: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_CHECK(condition) \ + do { \ + (void)(condition); \ + } while (0) +#endif + +// A jxl::Status value from a StatusCode or Status which prints a debug message +// when enabled. +#define JXL_STATUS(status, format, ...) \ + ::jxl::StatusMessage(::jxl::Status(status), "%s:%d: " format "\n", __FILE__, \ + __LINE__, ##__VA_ARGS__) + +// Notify of an error but discard the resulting Status value. This is only +// useful for debug builds or when building with JXL_CRASH_ON_ERROR. +#define JXL_NOTIFY_ERROR(format, ...) \ + (void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_ERROR: " format, \ + ##__VA_ARGS__) + +// An error Status with a message. The JXL_STATUS() macro will return a Status +// object with a kGenericError code, but the comma operator helps with +// clang-tidy inference and potentially with optimizations. +#define JXL_FAILURE(format, ...) \ + ((void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_FAILURE: " format, \ + ##__VA_ARGS__), \ + ::jxl::Status(::jxl::StatusCode::kGenericError)) + +// Always evaluates the status exactly once, so can be used for non-debug calls. +// Returns from the current context if the passed Status expression is an error +// (fatal or non-fatal). The return value is the passed Status. +#define JXL_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + (void)::jxl::StatusMessage( \ + jxl_return_if_error_status, \ + "%s:%d: JXL_RETURN_IF_ERROR code=%d: %s\n", __FILE__, __LINE__, \ + static_cast(jxl_return_if_error_status.code()), #status); \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +// As above, but without calling StatusMessage. Intended for bundles (see +// fields.h), which have numerous call sites (-> relevant for code size) and do +// not want to generate excessive messages when decoding partial headers. +#define JXL_QUIET_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +enum class StatusCode : int32_t { + // Non-fatal errors (negative values). + kNotEnoughBytes = -1, + + // The only non-error status code. + kOk = 0, + + // Fatal-errors (positive values) + kGenericError = 1, +}; + +// Drop-in replacement for bool that raises compiler warnings if not used +// after being returned from a function. Example: +// Status LoadFile(...) { return true; } is more compact than +// bool JXL_MUST_USE_RESULT LoadFile(...) { return true; } +// In case of error, the status can carry an extra error code in its value which +// is split between fatal and non-fatal error codes. +class JXL_MUST_USE_RESULT Status { + public: + // We want implicit constructor from bool to allow returning "true" or "false" + // on a function when using Status. "true" means kOk while "false" means a + // generic fatal error. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(bool ok) + : code_(ok ? StatusCode::kOk : StatusCode::kGenericError) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(StatusCode code) : code_(code) {} + + // We also want implicit cast to bool to check for return values of functions. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator bool() const { return code_ == StatusCode::kOk; } + + constexpr StatusCode code() const { return code_; } + + // Returns whether the status code is a fatal error. + constexpr bool IsFatalError() const { + return static_cast(code_) > 0; + } + + private: + StatusCode code_; +}; + +// Helper function to create a Status and print the debug message or abort when +// needed. +inline JXL_FORMAT(2, 3) Status + StatusMessage(const Status status, const char* format, ...) { + // This block will be optimized out when JXL_DEBUG_ON_ERROR and + // JXL_DEBUG_ON_ALL_ERROR are both disabled. + if ((JXL_DEBUG_ON_ERROR && status.IsFatalError()) || + (JXL_DEBUG_ON_ALL_ERROR && !status)) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + } +#ifdef JXL_CRASH_ON_ERROR + // JXL_CRASH_ON_ERROR means to Abort() only on non-fatal errors. + if (status.IsFatalError()) { + Abort(); + } +#endif // JXL_CRASH_ON_ERROR + return status; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_STATUS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/thread_pool_internal.h b/third_party/jpeg-xl/lib/jxl/base/thread_pool_internal.h new file mode 100644 index 000000000000..307ae06deb3d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/thread_pool_internal.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ +#define LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ + +#include + +#include + +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/threads/thread_parallel_runner_internal.h" + +namespace jxl { + +// Helper class to pass an internal ThreadPool-like object using threads. This +// is only suitable for tests or tools that access the internal API of JPEG XL. +// In other cases the caller will provide a JxlParallelRunner() for handling +// this. This class uses jpegxl::ThreadParallelRunner (from jpegxl_threads +// library). For interface details check jpegxl::ThreadParallelRunner. +class ThreadPoolInternal : public ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + // "num_worker_threads" defaults to one per hyperthread. If zero, all tasks + // run on the main thread. + explicit ThreadPoolInternal( + int num_worker_threads = std::thread::hardware_concurrency()) + : ThreadPool(&jpegxl::ThreadParallelRunner::Runner, + static_cast(&runner_)), + runner_(num_worker_threads) {} + + ThreadPoolInternal(const ThreadPoolInternal&) = delete; + ThreadPoolInternal& operator&(const ThreadPoolInternal&) = delete; + + size_t NumThreads() const { return runner_.NumThreads(); } + size_t NumWorkerThreads() const { return runner_.NumWorkerThreads(); } + + template + void RunOnEachThread(const Func& func) { + runner_.RunOnEachThread(func); + } + + private: + jpegxl::ThreadParallelRunner runner_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_THREAD_POOL_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/time.cc b/third_party/jpeg-xl/lib/jxl/base/time.cc new file mode 100644 index 000000000000..6ef2295bf7e6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/time.cc @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/time.h" + +#include +#include +#include + +#include + +#include "lib/jxl/base/os_macros.h" // for JXL_OS_* + +#if JXL_OS_WIN +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include +#endif // JXL_OS_WIN + +#if JXL_OS_MAC +#include +#include +#endif // JXL_OS_MAC + +#if JXL_OS_HAIKU +#include +#endif // JXL_OS_HAIKU + +namespace jxl { + +double Now() { +#if JXL_OS_WIN + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + LARGE_INTEGER freq; + (void)QueryPerformanceFrequency(&freq); + return double(counter.QuadPart) / freq.QuadPart; +#elif JXL_OS_MAC + const auto t = mach_absolute_time(); + // On OSX/iOS platform the elapsed time is cpu time unit + // We have to query the time base information to convert it back + // See https://developer.apple.com/library/mac/qa/qa1398/_index.html + static mach_timebase_info_data_t timebase; + if (timebase.denom == 0) { + (void)mach_timebase_info(&timebase); + } + return double(t) * timebase.numer / timebase.denom * 1E-9; +#elif JXL_OS_HAIKU + return double(system_time_nsecs()) * 1E-9; +#else + timespec t; + clock_gettime(CLOCK_MONOTONIC, &t); + return t.tv_sec + t.tv_nsec * 1E-9; +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/time.h b/third_party/jpeg-xl/lib/jxl/base/time.h new file mode 100644 index 000000000000..8dc6f3657f39 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/time.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BASE_TIME_H_ +#define LIB_JXL_BASE_TIME_H_ + +// OS-specific function for timing. + +namespace jxl { + +// Returns current time [seconds] from a monotonic clock with unspecified +// starting point - only suitable for computing elapsed time. +double Now(); + +} // namespace jxl + +#endif // LIB_JXL_BASE_TIME_H_ diff --git a/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc b/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc new file mode 100644 index 000000000000..88cee95095c3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc @@ -0,0 +1,269 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { +namespace { + +TEST(BitReaderTest, ExtendsWithZeroes) { + for (size_t size = 4; size < 32; ++size) { + std::vector data(size, 0xff); + + for (size_t n_bytes = 0; n_bytes < size; n_bytes++) { + BitReader br(Span(data.data(), n_bytes)); + // Read all the bits + for (size_t i = 0; i < n_bytes * kBitsPerByte; i++) { + ASSERT_EQ(br.ReadBits(1), 1) << "n_bytes=" << n_bytes << " i=" << i; + } + + // PEEK more than the declared size - all will be zero. Cannot consume. + for (size_t i = 0; i < BitReader::kMaxBitsPerCall; i++) { + ASSERT_EQ(br.PeekBits(i), 0) + << "size=" << size << "n_bytes=" << n_bytes << " i=" << i; + } + + EXPECT_TRUE(br.Close()); + } + } +} + +struct Symbol { + uint32_t num_bits; + uint32_t value; +}; + +// Reading from output gives the same values. +TEST(BitReaderTest, TestRoundTrip) { + ThreadPoolInternal pool(8); + pool.Run(0, 1000, ThreadPool::SkipInit(), + [](const int task, const int /* thread */) { + constexpr size_t kMaxBits = 8000; + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + + std::vector symbols; + symbols.reserve(1000); + + std::mt19937 rng(55537 + 129 * task); + std::uniform_int_distribution<> dist(1, 32); // closed interval + + for (;;) { + const uint32_t num_bits = dist(rng); + if (writer.BitsWritten() + num_bits > kMaxBits) break; + const uint32_t value = rng() >> (32 - num_bits); + symbols.push_back({num_bits, value}); + writer.Write(num_bits, value); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + for (const Symbol& s : symbols) { + EXPECT_EQ(s.value, reader.ReadBits(s.num_bits)); + } + EXPECT_TRUE(reader.Close()); + }); +} + +// SkipBits is the same as reading that many bits. +TEST(BitReaderTest, TestSkip) { + ThreadPoolInternal pool(8); + pool.Run( + 0, 96, ThreadPool::SkipInit(), + [](const int task, const int /* thread */) { + constexpr size_t kSize = 100; + + for (size_t skip = 0; skip < 128; ++skip) { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kSize * kBitsPerByte); + // Start with "task" 1-bits. + for (int i = 0; i < task; ++i) { + writer.Write(1, 1); + } + + // Write 0-bits that we will skip over + for (size_t i = 0; i < skip; ++i) { + writer.Write(1, 0); + } + + // Write terminator bits '101' + writer.Write(3, 5); + EXPECT_EQ(task + skip + 3, writer.BitsWritten()); + writer.ZeroPadToByte(); + AuxOut aux_out; + ReclaimAndCharge(&writer, &allotment, 0, &aux_out); + EXPECT_LT(aux_out.layers[0].total_bits, kSize * 8); + + BitReader reader1(writer.GetSpan()); + BitReader reader2(writer.GetSpan()); + // Verify initial 1-bits + for (int i = 0; i < task; ++i) { + EXPECT_EQ(1, reader1.ReadBits(1)); + EXPECT_EQ(1, reader2.ReadBits(1)); + } + + // SkipBits or manually read "skip" bits + reader1.SkipBits(skip); + for (size_t i = 0; i < skip; ++i) { + EXPECT_EQ(0, reader2.ReadBits(1)) << " skip=" << skip << " i=" << i; + } + EXPECT_EQ(reader1.TotalBitsConsumed(), reader2.TotalBitsConsumed()); + + // Ensure both readers see the terminator bits. + EXPECT_EQ(5, reader1.ReadBits(3)); + EXPECT_EQ(5, reader2.ReadBits(3)); + + EXPECT_TRUE(reader1.Close()); + EXPECT_TRUE(reader2.Close()); + } + }); +} + +// Verifies byte order and different groupings of bits. +TEST(BitReaderTest, TestOrder) { + constexpr size_t kMaxBits = 16; + + // u(1) - bits written into LSBs of first byte + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 1); + } + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 0); + } + for (size_t i = 0; i < 6; ++i) { + writer.Write(1, 1); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x1F, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xFC, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(8) - get bytes in the same order + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(8, 0xF8); + writer.Write(8, 0x3F); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xF8, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x3F, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(16) - little-endian bytes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(16, 0xF83F); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x3F, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xF8, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // Non-byte-aligned, mixed sizes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(1, 1); + writer.Write(3, 6); + writer.Write(8, 0xDB); + writer.Write(4, 8); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xBD, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x8D, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } +} + +TEST(BitReaderTest, TotalCountersTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader(Span(buf, sizeof(buf))); + + EXPECT_EQ(sizeof(buf), reader.TotalBytes()); + EXPECT_EQ(0, reader.TotalBitsConsumed()); + reader.ReadFixedBits<1>(); + EXPECT_EQ(1, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<10>(); + EXPECT_EQ(11, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<4>(); + EXPECT_EQ(15, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<1>(); + EXPECT_EQ(16, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<16>(); + EXPECT_EQ(32, reader.TotalBitsConsumed()); + + EXPECT_TRUE(reader.Close()); +} + +TEST(BitReaderTest, MoveTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader2; + { + BitReader reader1(Span(buf, sizeof(buf))); + + EXPECT_EQ(0, reader1.TotalBitsConsumed()); + reader1.ReadFixedBits<16>(); + EXPECT_EQ(16, reader1.TotalBitsConsumed()); + + reader2 = std::move(reader1); + // From this point reader1 is invalid, but can continue to access reader2 + // and we don't need to call Close() on reader1. + } + + EXPECT_EQ(16, reader2.TotalBitsConsumed()); + EXPECT_EQ(3U, reader2.ReadFixedBits<8>()); + EXPECT_EQ(24, reader2.TotalBitsConsumed()); + + EXPECT_TRUE(reader2.Close()); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/bits_test.cc b/third_party/jpeg-xl/lib/jxl/bits_test.cc new file mode 100644 index 000000000000..132bcb5535c6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/bits_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/bits.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(BitsTest, TestNumZeroBits) { + // Zero input is well-defined. + EXPECT_EQ(32, Num0BitsAboveMS1Bit(0u)); + EXPECT_EQ(64, Num0BitsAboveMS1Bit(0ull)); + EXPECT_EQ(32, Num0BitsBelowLS1Bit(0u)); + EXPECT_EQ(64, Num0BitsBelowLS1Bit(0ull)); + + EXPECT_EQ(31, Num0BitsAboveMS1Bit(1u)); + EXPECT_EQ(30, Num0BitsAboveMS1Bit(2u)); + EXPECT_EQ(63, Num0BitsAboveMS1Bit(1ull)); + EXPECT_EQ(62, Num0BitsAboveMS1Bit(2ull)); + + EXPECT_EQ(0, Num0BitsBelowLS1Bit(1u)); + EXPECT_EQ(0, Num0BitsBelowLS1Bit(1ull)); + EXPECT_EQ(1, Num0BitsBelowLS1Bit(2u)); + EXPECT_EQ(1, Num0BitsBelowLS1Bit(2ull)); + + EXPECT_EQ(0, Num0BitsAboveMS1Bit(0x80000000u)); + EXPECT_EQ(0, Num0BitsAboveMS1Bit(0x8000000000000000ull)); + EXPECT_EQ(31, Num0BitsBelowLS1Bit(0x80000000u)); + EXPECT_EQ(63, Num0BitsBelowLS1Bit(0x8000000000000000ull)); +} + +TEST(BitsTest, TestFloorLog2) { + // for input = [1, 7] + const int expected[7] = {0, 1, 1, 2, 2, 2, 2}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(31, FloorLog2Nonzero(0x80000000u)); + EXPECT_EQ(31, FloorLog2Nonzero(0x80000001u)); + EXPECT_EQ(31, FloorLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31, FloorLog2Nonzero(0x80000000ull)); + EXPECT_EQ(31, FloorLog2Nonzero(0x80000001ull)); + EXPECT_EQ(31, FloorLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63, FloorLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(63, FloorLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(63, FloorLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +TEST(BitsTest, TestCeilLog2) { + // for input = [1, 7] + const int expected[7] = {0, 1, 2, 2, 3, 3, 3}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(31, CeilLog2Nonzero(0x80000000u)); + EXPECT_EQ(32, CeilLog2Nonzero(0x80000001u)); + EXPECT_EQ(32, CeilLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31, CeilLog2Nonzero(0x80000000ull)); + EXPECT_EQ(32, CeilLog2Nonzero(0x80000001ull)); + EXPECT_EQ(32, CeilLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63, CeilLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(64, CeilLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(64, CeilLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/blending.cc b/third_party/jpeg-xl/lib/jxl/blending.cc new file mode 100644 index 000000000000..a3f408d1c6c1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending.cc @@ -0,0 +1,375 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/blending.h" + +#include "lib/jxl/alpha.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +Status ImageBlender::PrepareBlending(PassesDecoderState* dec_state, + ImageBundle* foreground, + ImageBundle* output) { + const PassesSharedState& state = *dec_state->shared; + // No need to blend anything in this case. + if (!(state.frame_header.frame_type == FrameType::kRegularFrame || + state.frame_header.frame_type == FrameType::kSkipProgressive)) { + *output = std::move(*foreground); + done_ = true; + return true; + } + info_ = state.frame_header.blending_info; + bool replace_all = (info_.mode == BlendMode::kReplace); + // This value should be 0 if there is no alpha channel. + first_alpha_ = 0; + const std::vector& extra_channels = + state.metadata->m.extra_channel_info; + for (size_t i = 0; i < extra_channels.size(); i++) { + if (extra_channels[i].type == jxl::ExtraChannel::kAlpha) { + first_alpha_ = i; + break; + } + } + + ec_info_ = &state.frame_header.extra_channel_blending_info; + if (info_.mode != BlendMode::kReplace && + info_.alpha_channel != first_alpha_) { + return JXL_FAILURE( + "Blending using non-first alpha channel not yet implemented"); + } + for (const auto& ec_i : *ec_info_) { + if (ec_i.mode != BlendMode::kReplace) { + replace_all = false; + } + if (info_.source != ec_i.source) + return JXL_FAILURE("Blending from different sources not yet implemented"); + } + + // Replace the full frame: nothing to do. + if (!state.frame_header.custom_size_or_origin && replace_all) { + *output = std::move(*foreground); + done_ = true; + return true; + } + + size_t image_xsize = state.frame_header.nonserialized_metadata->xsize(); + size_t image_ysize = state.frame_header.nonserialized_metadata->ysize(); + + if ((dec_state->pre_color_transform_frame.xsize() != 0) && + ((image_xsize != foreground->xsize()) || + (image_ysize != foreground->ysize()))) { + // Extra channels are going to be resized. Make a copy. + if (foreground->HasExtraChannels()) { + dec_state->pre_color_transform_ec.clear(); + for (const auto& ec : foreground->extra_channels()) { + dec_state->pre_color_transform_ec.emplace_back(CopyImage(ec)); + } + } + } + + // the rect in the canvas that needs to be updated + cropbox_ = Rect(0, 0, image_xsize, image_ysize); + // the rect of this frame that overlaps with the canvas + overlap_ = cropbox_; + // Image to write to. + if (state.frame_header.custom_size_or_origin) { + o_ = foreground->origin; + int x0 = (o_.x0 >= 0 ? o_.x0 : 0); + int y0 = (o_.y0 >= 0 ? o_.y0 : 0); + int xsize = foreground->xsize(); + if (o_.x0 < 0) xsize += o_.x0; + int ysize = foreground->ysize(); + if (o_.y0 < 0) ysize += o_.y0; + xsize = Clamp1(xsize, 0, (int)cropbox_.xsize() - x0); + ysize = Clamp1(ysize, 0, (int)cropbox_.ysize() - y0); + cropbox_ = Rect(x0, y0, xsize, ysize); + x0 = (o_.x0 < 0 ? -o_.x0 : 0); + y0 = (o_.y0 < 0 ? -o_.y0 : 0); + overlap_ = Rect(x0, y0, xsize, ysize); + } + if (overlap_.xsize() == image_xsize && overlap_.ysize() == image_ysize && + replace_all) { + // frame is larger than image and fully replaces it, this is OK, just need + // to crop + *output = foreground->Copy(); + output->RemoveColor(); + std::vector* ec = nullptr; + size_t num_ec = 0; + if (foreground->HasExtraChannels()) { + num_ec = foreground->extra_channels().size(); + ec = &output->extra_channels(); + ec->clear(); + } + Image3F croppedcolor(image_xsize, image_ysize); + Rect crop(-foreground->origin.x0, -foreground->origin.y0, image_xsize, + image_ysize); + CopyImageTo(crop, *foreground->color(), &croppedcolor); + output->SetFromImage(std::move(croppedcolor), foreground->c_current()); + for (size_t i = 0; i < num_ec; i++) { + const auto& ec_meta = foreground->metadata()->extra_channel_info[i]; + if (ec_meta.dim_shift != 0) { + return JXL_FAILURE( + "Blending of downsampled extra channels is not yet implemented"); + } + ImageF cropped_ec(image_xsize, image_ysize); + CopyImageTo(crop, foreground->extra_channels()[i], &cropped_ec); + ec->push_back(std::move(cropped_ec)); + } + done_ = true; + return true; + } + + ImageBundle& bg = *state.reference_frames[info_.source].frame; + if (bg.xsize() == 0 && bg.ysize() == 0) { + // there is no background, assume it to be all zeroes + ImageBundle empty(foreground->metadata()); + Image3F color(image_xsize, image_ysize); + ZeroFillImage(&color); + empty.SetFromImage(std::move(color), foreground->c_current()); + if (foreground->HasExtraChannels()) { + std::vector ec; + for (const auto& ec_meta : foreground->metadata()->extra_channel_info) { + ImageF eci(ec_meta.Size(image_xsize), ec_meta.Size(image_ysize)); + ZeroFillImage(&eci); + ec.push_back(std::move(eci)); + } + empty.SetExtraChannels(std::move(ec)); + } + bg = std::move(empty); + } else if (state.reference_frames[info_.source].ib_is_in_xyb) { + return JXL_FAILURE( + "Trying to blend XYB reference frame %i and non-XYB frame", + info_.source); + } + + if (bg.xsize() != image_xsize || bg.ysize() != image_ysize || + bg.origin.x0 != 0 || bg.origin.y0 != 0) { + return JXL_FAILURE("Trying to use a %zux%zu crop as a background", + bg.xsize(), bg.ysize()); + } + if (state.metadata->m.xyb_encoded) { + if (!dec_state->output_encoding_info.color_encoding_is_original) { + return JXL_FAILURE("Blending in unsupported color space"); + } + } + + if (!overlap_.IsInside(*foreground)) { + return JXL_FAILURE("Trying to use a %zux%zu crop as a foreground", + foreground->xsize(), foreground->ysize()); + } + + if (!cropbox_.IsInside(bg)) { + return JXL_FAILURE( + "Trying blend %zux%zu to (%zu,%zu), but background is %zux%zu", + cropbox_.xsize(), cropbox_.ysize(), cropbox_.x0(), cropbox_.y0(), + bg.xsize(), bg.ysize()); + } + + if (foreground->HasExtraChannels()) { + for (const auto& ec_meta : foreground->metadata()->extra_channel_info) { + if (ec_meta.dim_shift != 0) { + return JXL_FAILURE( + "Blending of downsampled extra channels is not yet implemented"); + } + } + for (const auto& ec : foreground->extra_channels()) { + if (!overlap_.IsInside(ec)) { + return JXL_FAILURE("Trying to use a %zux%zu crop as a foreground", + foreground->xsize(), foreground->ysize()); + } + } + } + + dest_ = output; + if (state.frame_header.CanBeReferenced() && + &bg == &state.reference_frames[state.frame_header.save_as_reference] + .storage) { + *dest_ = std::move(bg); + } else { + *dest_ = bg.Copy(); + } + + return true; +} + +ImageBlender::RectBlender ImageBlender::PrepareRect( + const Rect& rect, const ImageBundle& foreground) const { + if (done_) return RectBlender(true); + RectBlender blender(false); + blender.info_ = info_; + blender.dest_ = dest_; + blender.ec_info_ = ec_info_; + blender.first_alpha_ = first_alpha_; + + blender.current_overlap_ = rect.Intersection(overlap_); + if (blender.current_overlap_.xsize() == 0 || + blender.current_overlap_.ysize() == 0) { + blender.done_ = true; + return blender; + } + + blender.current_cropbox_ = + Rect(o_.x0 + blender.current_overlap_.x0(), + o_.y0 + blender.current_overlap_.y0(), + blender.current_overlap_.xsize(), blender.current_overlap_.ysize()); + Image3F cropped_foreground(blender.current_overlap_.xsize(), + blender.current_overlap_.ysize()); + CopyImageTo(blender.current_overlap_, foreground.color(), + &cropped_foreground); + blender.foreground_ = ImageBundle(dest_->metadata()); + blender.foreground_.SetFromImage(std::move(cropped_foreground), + foreground.c_current()); + const auto& eci = foreground.metadata()->extra_channel_info; + if (!eci.empty()) { + std::vector ec; + for (size_t i = 0; i < eci.size(); ++i) { + ImageF ec_image(eci[i].Size(blender.current_overlap_.xsize()), + eci[i].Size(blender.current_overlap_.ysize())); + CopyImageTo(blender.current_overlap_, foreground.extra_channels()[i], + &ec_image); + ec.push_back(std::move(ec_image)); + } + blender.foreground_.SetExtraChannels(std::move(ec)); + } + + // Turn current_overlap_ from being relative to the full foreground to being + // relative to the rect. + blender.current_overlap_ = + Rect(blender.current_overlap_.x0() - rect.x0(), + blender.current_overlap_.y0() - rect.y0(), + blender.current_overlap_.xsize(), blender.current_overlap_.ysize()); + + return blender; +} + +Status ImageBlender::RectBlender::DoBlending(size_t y) const { + if (done_ || y < current_overlap_.y0() || + y >= current_overlap_.y0() + current_overlap_.ysize()) { + return true; + } + y -= current_overlap_.y0(); + Rect cropbox_row = current_cropbox_.Line(y); + Rect overlap_row = Rect(0, y, current_overlap_.xsize(), 1); + + // Blend extra channels first so that we use the pre-blending alpha. + for (size_t i = 0; i < ec_info_->size(); i++) { + if (i == first_alpha_) continue; + if ((*ec_info_)[i].mode == BlendMode::kAdd) { + AddTo(overlap_row, foreground_.extra_channels()[i], cropbox_row, + &dest_->extra_channels()[i]); + } else if ((*ec_info_)[i].mode == BlendMode::kBlend) { + if ((*ec_info_)[i].alpha_channel != first_alpha_) + return JXL_FAILURE("Not implemented: blending using non-first alpha"); + bool is_premultiplied = foreground_.AlphaIsPremultiplied(); + const float* JXL_RESTRICT a1 = + overlap_row.ConstRow(foreground_.alpha(), 0); + float* JXL_RESTRICT p1 = overlap_row.Row(&dest_->extra_channels()[i], 0); + const float* JXL_RESTRICT a = cropbox_row.ConstRow(*dest_->alpha(), 0); + float* JXL_RESTRICT p = cropbox_row.Row(&dest_->extra_channels()[i], 0); + PerformAlphaBlending(p, a, p1, a1, p, cropbox_row.xsize(), + is_premultiplied); + } else if ((*ec_info_)[i].mode == BlendMode::kAlphaWeightedAdd) { + if ((*ec_info_)[i].alpha_channel != first_alpha_) + return JXL_FAILURE("Not implemented: blending using non-first alpha"); + const float* JXL_RESTRICT a1 = + overlap_row.ConstRow(foreground_.alpha(), 0); + float* JXL_RESTRICT p1 = overlap_row.Row(&dest_->extra_channels()[i], 0); + float* JXL_RESTRICT p = cropbox_row.Row(&dest_->extra_channels()[i], 0); + PerformAlphaWeightedAdd(p, p1, a1, p, cropbox_row.xsize()); + } else if ((*ec_info_)[i].mode == BlendMode::kMul) { + if ((*ec_info_)[i].alpha_channel != first_alpha_) + return JXL_FAILURE("Not implemented: blending using non-first alpha"); + float* JXL_RESTRICT p1 = overlap_row.Row(&dest_->extra_channels()[i], 0); + float* JXL_RESTRICT p = cropbox_row.Row(&dest_->extra_channels()[i], 0); + PerformMulBlending(p, p1, p, cropbox_row.xsize()); + } else if ((*ec_info_)[i].mode == BlendMode::kReplace) { + CopyImageTo(overlap_row, foreground_.extra_channels()[i], cropbox_row, + &dest_->extra_channels()[i]); + } else { + return JXL_FAILURE("Blend mode not implemented for extra channel %zu", i); + } + } + + if (info_.mode == BlendMode::kAdd) { + for (int p = 0; p < 3; p++) { + AddTo(overlap_row, foreground_.color().Plane(p), cropbox_row, + &dest_->color()->Plane(p)); + } + if (foreground_.HasAlpha()) { + AddTo(overlap_row, foreground_.alpha(), cropbox_row, dest_->alpha()); + } + } else if (info_.mode == BlendMode::kBlend + // blend without alpha is just replace + && foreground_.HasAlpha()) { + bool is_premultiplied = foreground_.AlphaIsPremultiplied(); + // Foreground. + const float* JXL_RESTRICT a1 = overlap_row.ConstRow(foreground_.alpha(), 0); + const float* JXL_RESTRICT r1 = + overlap_row.ConstRow(foreground_.color().Plane(0), 0); + const float* JXL_RESTRICT g1 = + overlap_row.ConstRow(foreground_.color().Plane(1), 0); + const float* JXL_RESTRICT b1 = + overlap_row.ConstRow(foreground_.color().Plane(2), 0); + // Background & destination. + float* JXL_RESTRICT a = cropbox_row.Row(dest_->alpha(), 0); + float* JXL_RESTRICT r = cropbox_row.Row(&dest_->color()->Plane(0), 0); + float* JXL_RESTRICT g = cropbox_row.Row(&dest_->color()->Plane(1), 0); + float* JXL_RESTRICT b = cropbox_row.Row(&dest_->color()->Plane(2), 0); + PerformAlphaBlending(/*bg=*/{r, g, b, a}, /*fg=*/{r1, g1, b1, a1}, + /*out=*/{r, g, b, a}, cropbox_row.xsize(), + is_premultiplied); + } else if (info_.mode == BlendMode::kAlphaWeightedAdd) { + // Foreground. + const float* JXL_RESTRICT a1 = overlap_row.ConstRow(foreground_.alpha(), 0); + const float* JXL_RESTRICT r1 = + overlap_row.ConstRow(foreground_.color().Plane(0), 0); + const float* JXL_RESTRICT g1 = + overlap_row.ConstRow(foreground_.color().Plane(1), 0); + const float* JXL_RESTRICT b1 = + overlap_row.ConstRow(foreground_.color().Plane(2), 0); + // Background & destination. + float* JXL_RESTRICT a = cropbox_row.Row(dest_->alpha(), 0); + float* JXL_RESTRICT r = cropbox_row.Row(&dest_->color()->Plane(0), 0); + float* JXL_RESTRICT g = cropbox_row.Row(&dest_->color()->Plane(1), 0); + float* JXL_RESTRICT b = cropbox_row.Row(&dest_->color()->Plane(2), 0); + PerformAlphaWeightedAdd(/*bg=*/{r, g, b, a}, /*fg=*/{r1, g1, b1, a1}, + /*out=*/{r, g, b, a}, cropbox_row.xsize()); + } else if (info_.mode == BlendMode::kMul) { + // Foreground. + const float* JXL_RESTRICT a1 = overlap_row.ConstRow(foreground_.alpha(), 0); + const float* JXL_RESTRICT r1 = + overlap_row.ConstRow(foreground_.color().Plane(0), 0); + const float* JXL_RESTRICT g1 = + overlap_row.ConstRow(foreground_.color().Plane(1), 0); + const float* JXL_RESTRICT b1 = + overlap_row.ConstRow(foreground_.color().Plane(2), 0); + // Background & destination. + float* JXL_RESTRICT a = cropbox_row.Row(dest_->alpha(), 0); + float* JXL_RESTRICT r = cropbox_row.Row(&dest_->color()->Plane(0), 0); + float* JXL_RESTRICT g = cropbox_row.Row(&dest_->color()->Plane(1), 0); + float* JXL_RESTRICT b = cropbox_row.Row(&dest_->color()->Plane(2), 0); + PerformMulBlending(/*bg=*/{r, g, b, a}, /*fg=*/{r1, g1, b1, a1}, + /*out=*/{r, g, b, a}, cropbox_row.xsize()); + } else { // kReplace + CopyImageTo(overlap_row, foreground_.color(), cropbox_row, dest_->color()); + if (foreground_.HasAlpha()) { + CopyImageTo(overlap_row, foreground_.alpha(), cropbox_row, + dest_->alpha()); + } + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/blending.h b/third_party/jpeg-xl/lib/jxl/blending.h new file mode 100644 index 000000000000..7c0eb377cd4a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_BLENDING_H_ +#define LIB_JXL_BLENDING_H_ +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class ImageBlender { + public: + class RectBlender { + public: + // Does the blending for a given row of the rect passed to + // ImageBlender::PrepareRect. It is safe to have parallel calls to + // DoBlending. + Status DoBlending(size_t y) const; + + // If this returns true, then nothing needs to be done for this rect and + // DoBlending can be skipped (but does not have to). + bool done() const { return done_; } + + private: + friend class ImageBlender; + explicit RectBlender(bool done) : done_(done) {} + + bool done_; + Rect current_overlap_; + Rect current_cropbox_; + ImageBundle foreground_; + BlendingInfo info_; + ImageBundle* dest_; + const std::vector* ec_info_; + size_t first_alpha_; + }; + + Status PrepareBlending(PassesDecoderState* dec_state, ImageBundle* foreground, + ImageBundle* output); + // rect is relative to foreground. + RectBlender PrepareRect(const Rect& rect, + const ImageBundle& foreground) const; + + // If this returns true, then it is not necessary to call further methods on + // this ImageBlender to achieve blending, although it is not forbidden either + // (those methods will just return immediately in that case). + bool done() const { return done_; } + + private: + BlendingInfo info_; + // Destination, as well as background before DoBlending is called. + ImageBundle* dest_; + Rect cropbox_; + Rect overlap_; + bool done_ = false; + const std::vector* ec_info_; + size_t first_alpha_; + FrameOrigin o_{}; +}; + +} // namespace jxl + +#endif // LIB_JXL_BLENDING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/blending_test.cc b/third_party/jpeg-xl/lib/jxl/blending_test.cc new file mode 100644 index 000000000000..fec88d083217 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/blending.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +using ::testing::SizeIs; + +TEST(BlendingTest, Crops) { + ThreadPool* pool = nullptr; + + const PaddedBytes compressed = + ReadTestData("jxl/blending/cropped_traffic_light.jxl"); + DecompressParams dparams; + CodecInOut decoded; + ASSERT_TRUE(DecodeFile(dparams, compressed, &decoded, pool)); + ASSERT_THAT(decoded.frames, SizeIs(4)); + + int i = 0; + for (const ImageBundle& ib : decoded.frames) { + std::ostringstream filename; + filename << "jxl/blending/cropped_traffic_light_frame-" << i << ".png"; + const PaddedBytes compressed_frame = ReadTestData(filename.str()); + CodecInOut frame; + ASSERT_TRUE(SetFromBytes(Span(compressed_frame), &frame)); + EXPECT_TRUE(SamePixels(ib.color(), *frame.Main().color())); + ++i; + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc new file mode 100644 index 000000000000..f808fb864602 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc @@ -0,0 +1,2153 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) +// +// The physical architecture of butteraugli is based on the following naming +// convention: +// * Opsin - dynamics of the photosensitive chemicals in the retina +// with their immediate electrical processing +// * Xyb - hybrid opponent/trichromatic color space +// x is roughly red-subtract-green. +// y is yellow. +// b is blue. +// Xyb values are computed from Opsin mixing, not directly from rgb. +// * Mask - for visual masking +// * Hf - color modeling for spatially high-frequency features +// * Lf - color modeling for spatially low-frequency features +// * Diffmap - to cluster and build an image of error between the images +// * Blur - to hold the smoothing code + +#include "lib/jxl/butteraugli/butteraugli.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/butteraugli/butteraugli.cc" +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#if PROFILER_ENABLED +#include "lib/jxl/base/time.h" +#endif // PROFILER_ENABLED +#include "lib/jxl/convolve.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image_ops.h" + +#ifndef JXL_BUTTERAUGLI_ONCE +#define JXL_BUTTERAUGLI_ONCE + +namespace jxl { + +std::vector ComputeKernel(float sigma) { + const float m = 2.25; // Accuracy increases when m is increased. + const double scaler = -1.0 / (2.0 * sigma * sigma); + const int diff = std::max(1, m * std::fabs(sigma)); + std::vector kernel(2 * diff + 1); + for (int i = -diff; i <= diff; ++i) { + kernel[i + diff] = std::exp(scaler * i * i); + } + return kernel; +} + +void ConvolveBorderColumn(const ImageF& in, const std::vector& kernel, + const size_t x, float* BUTTERAUGLI_RESTRICT row_out) { + const size_t offset = kernel.size() / 2; + int minx = x < offset ? 0 : x - offset; + int maxx = std::min(in.xsize() - 1, x + offset); + float weight = 0.0f; + for (int j = minx; j <= maxx; ++j) { + weight += kernel[j - x + offset]; + } + float scale = 1.0f / weight; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y); + float sum = 0.0f; + for (int j = minx; j <= maxx; ++j) { + sum += row_in[j] * kernel[j - x + offset]; + } + row_out[y] = sum * scale; + } +} + +// Computes a horizontal convolution and transposes the result. +void ConvolutionWithTranspose(const ImageF& in, + const std::vector& kernel, + ImageF* BUTTERAUGLI_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(out->xsize() == in.ysize()); + JXL_CHECK(out->ysize() == in.xsize()); + const size_t len = kernel.size(); + const size_t offset = len / 2; + float weight_no_border = 0.0f; + for (size_t j = 0; j < len; ++j) { + weight_no_border += kernel[j]; + } + const float scale_no_border = 1.0f / weight_no_border; + const size_t border1 = std::min(in.xsize(), offset); + const size_t border2 = in.xsize() > offset ? in.xsize() - offset : 0; + std::vector scaled_kernel(len / 2 + 1); + for (size_t i = 0; i <= len / 2; ++i) { + scaled_kernel[i] = kernel[i] * scale_no_border; + } + + // middle + switch (len) { +#if 1 // speed-optimized version + case 7: { + PROFILER_ZONE("conv7"); + const float sk0 = scaled_kernel[0]; + const float sk1 = scaled_kernel[1]; + const float sk2 = scaled_kernel[2]; + const float sk3 = scaled_kernel[3]; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + const float sum0 = (row_in[0] + row_in[6]) * sk0; + const float sum1 = (row_in[1] + row_in[5]) * sk1; + const float sum2 = (row_in[2] + row_in[4]) * sk2; + const float sum = (row_in[3]) * sk3 + sum0 + sum1 + sum2; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum; + } + } + } break; + case 13: { + PROFILER_ZONE("conv15"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[12]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[11]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[10]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[9]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[8]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[7]) * scaled_kernel[5]; + const float sum = (row_in[6]) * scaled_kernel[6]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 15: { + PROFILER_ZONE("conv15"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[14]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[13]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[12]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[11]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[10]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[9]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[8]) * scaled_kernel[6]; + const float sum = (row_in[7]) * scaled_kernel[7]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 25: { + PROFILER_ZONE("conv25"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[24]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[23]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[22]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[21]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[20]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[19]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[18]) * scaled_kernel[6]; + sum3 += (row_in[7] + row_in[17]) * scaled_kernel[7]; + sum0 += (row_in[8] + row_in[16]) * scaled_kernel[8]; + sum1 += (row_in[9] + row_in[15]) * scaled_kernel[9]; + sum2 += (row_in[10] + row_in[14]) * scaled_kernel[10]; + sum3 += (row_in[11] + row_in[13]) * scaled_kernel[11]; + const float sum = (row_in[12]) * scaled_kernel[12]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 33: { + PROFILER_ZONE("conv33"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[32]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[31]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[30]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[29]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[28]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[27]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[26]) * scaled_kernel[6]; + sum3 += (row_in[7] + row_in[25]) * scaled_kernel[7]; + sum0 += (row_in[8] + row_in[24]) * scaled_kernel[8]; + sum1 += (row_in[9] + row_in[23]) * scaled_kernel[9]; + sum2 += (row_in[10] + row_in[22]) * scaled_kernel[10]; + sum3 += (row_in[11] + row_in[21]) * scaled_kernel[11]; + sum0 += (row_in[12] + row_in[20]) * scaled_kernel[12]; + sum1 += (row_in[13] + row_in[19]) * scaled_kernel[13]; + sum2 += (row_in[14] + row_in[18]) * scaled_kernel[14]; + sum3 += (row_in[15] + row_in[17]) * scaled_kernel[15]; + const float sum = (row_in[16]) * scaled_kernel[16]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 37: { + PROFILER_ZONE("conv37"); + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[36]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[35]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[34]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[33]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[32]) * scaled_kernel[4]; + sum0 += (row_in[5] + row_in[31]) * scaled_kernel[5]; + sum0 += (row_in[6] + row_in[30]) * scaled_kernel[6]; + sum0 += (row_in[7] + row_in[29]) * scaled_kernel[7]; + sum0 += (row_in[8] + row_in[28]) * scaled_kernel[8]; + sum1 += (row_in[9] + row_in[27]) * scaled_kernel[9]; + sum2 += (row_in[10] + row_in[26]) * scaled_kernel[10]; + sum3 += (row_in[11] + row_in[25]) * scaled_kernel[11]; + sum0 += (row_in[12] + row_in[24]) * scaled_kernel[12]; + sum1 += (row_in[13] + row_in[23]) * scaled_kernel[13]; + sum2 += (row_in[14] + row_in[22]) * scaled_kernel[14]; + sum3 += (row_in[15] + row_in[21]) * scaled_kernel[15]; + sum0 += (row_in[16] + row_in[20]) * scaled_kernel[16]; + sum1 += (row_in[17] + row_in[19]) * scaled_kernel[17]; + const float sum = (row_in[18]) * scaled_kernel[18]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + default: + printf("Warning: Unexpected kernel size! %zu\n", len); +#else + default: +#endif + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y); + for (size_t x = border1; x < border2; ++x) { + const int d = x - offset; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + float sum = 0.0f; + size_t j; + for (j = 0; j <= len / 2; ++j) { + sum += row_in[d + j] * scaled_kernel[j]; + } + for (; j < len; ++j) { + sum += row_in[d + j] * scaled_kernel[len - 1 - j]; + } + row_out[y] = sum; + } + } + } + // left border + for (size_t x = 0; x < border1; ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } + + // right border + for (size_t x = border2; x < in.xsize(); ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } +} + +// Separate horizontal and vertical (next function) convolution passes. +void BlurHorizontalConv(const ImageF& in, const intptr_t xbegin, + const intptr_t xend, const intptr_t ybegin, + const intptr_t yend, const std::vector& kernel, + ImageF* out) { + if (xbegin >= xend || ybegin >= yend) return; + const intptr_t xsize = in.xsize(); + const intptr_t ysize = in.ysize(); + JXL_ASSERT(0 <= xbegin && xend <= xsize); + JXL_ASSERT(0 <= ybegin && yend <= ysize); + (void)xsize; + (void)ysize; + const intptr_t radius = kernel.size() / 2; + + for (intptr_t y = ybegin; y < yend; ++y) { + float* JXL_RESTRICT row_out = out->Row(y); + for (intptr_t x = xbegin; x < xend; ++x) { + float sum = 0.0f; + float sum_weights = 0.0f; + const float* JXL_RESTRICT row_in = in.Row(y); + for (intptr_t ix = -radius; ix <= radius; ++ix) { + const intptr_t in_x = x + ix; + if (in_x < 0 || in_x >= xsize) continue; + const float weight_x = kernel[ix + radius]; + sum += row_in[in_x] * weight_x; + sum_weights += weight_x; + } + row_out[x] = sum / sum_weights; + } + } +} + +void BlurVerticalConv(const ImageF& in, const intptr_t xbegin, + const intptr_t xend, const intptr_t ybegin, + const intptr_t yend, const std::vector& kernel, + ImageF* out) { + if (xbegin >= xend || ybegin >= yend) return; + const intptr_t xsize = in.xsize(); + const intptr_t ysize = in.ysize(); + JXL_ASSERT(0 <= xbegin && xend <= xsize); + JXL_ASSERT(0 <= ybegin && yend <= ysize); + (void)xsize; + const intptr_t radius = kernel.size() / 2; + for (intptr_t y = ybegin; y < yend; ++y) { + float* JXL_RESTRICT row_out = out->Row(y); + for (intptr_t x = xbegin; x < xend; ++x) { + float sum = 0.0f; + float sum_weights = 0.0f; + for (intptr_t iy = -radius; iy <= radius; ++iy) { + const intptr_t in_y = y + iy; + if (in_y < 0 || in_y >= ysize) continue; + const float weight_y = kernel[iy + radius]; + sum += in.ConstRow(in_y)[x] * weight_y; + sum_weights += weight_y; + } + row_out[x] = sum / sum_weights; + } + } +} + +// A blur somewhat similar to a 2D Gaussian blur. +// See: https://en.wikipedia.org/wiki/Gaussian_blur +// +// This is a bottleneck because the sigma can be quite large (>7). We can use +// gauss_blur.cc (runtime independent of sigma, closer to a 4*sigma truncated +// Gaussian and our 2.25 in ComputeKernel), but its boundary conditions are +// zero-valued. This leads to noticeable differences at the edges of diffmaps. +// We retain a special case for 5x5 kernels (even faster than gauss_blur), +// optionally use gauss_blur followed by fixup of the borders for large images, +// or fall back to the previous truncated FIR followed by a transpose. +void Blur(const ImageF& in, float sigma, const ButteraugliParams& params, + BlurTemp* temp, ImageF* out) { + std::vector kernel = ComputeKernel(sigma); + // Separable5 does an in-place convolution, so this fast path is not safe if + // in aliases out. + if (kernel.size() == 5 && &in != out) { + float sum_weights = 0.0f; + for (const float w : kernel) { + sum_weights += w; + } + const float scale = 1.0f / sum_weights; + const float w0 = kernel[2] * scale; + const float w1 = kernel[1] * scale; + const float w2 = kernel[0] * scale; + const WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + }; + Separable5(in, Rect(in), weights, /*pool=*/nullptr, out); + return; + } + + const bool fast_gauss = params.approximate_border; + const bool kBorderFixup = fast_gauss && false; + // Fast+fixup is actually slower for small images that are all border. + const bool too_small_for_fast_gauss = + kBorderFixup && + in.xsize() * in.ysize() < 9 * kernel.size() * kernel.size(); + // If fast gaussian is disabled, use previous transposed convolution. + if (!fast_gauss || too_small_for_fast_gauss) { + ImageF* JXL_RESTRICT temp_t = temp->GetTransposed(in); + ConvolutionWithTranspose(in, kernel, temp_t); + ConvolutionWithTranspose(*temp_t, kernel, out); + return; + } + auto rg = CreateRecursiveGaussian(sigma); + ImageF* JXL_RESTRICT temp_ = temp->Get(in); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, temp_, out); + + if (kBorderFixup) { + // Produce rg_radius extra pixels around each border + const intptr_t rg_radius = rg->radius; + const intptr_t radius = kernel.size() / 2; + const intptr_t xsize = in.xsize(); + const intptr_t ysize = in.ysize(); + const intptr_t yend_top = std::min(rg_radius + radius, ysize); + const intptr_t ybegin_bottom = + std::max(intptr_t(0), ysize - rg_radius - radius); + // Top (requires radius extra for the vertical pass) + BlurHorizontalConv(in, 0, xsize, 0, yend_top, kernel, temp_); + // Bottom + BlurHorizontalConv(in, 0, xsize, ybegin_bottom, ysize, kernel, temp_); + // Left/right columns between top and bottom + const intptr_t xbegin_right = std::max(intptr_t(0), xsize - rg_radius); + const intptr_t xend_left = std::min(rg_radius, xsize); + BlurHorizontalConv(in, 0, xend_left, yend_top, ybegin_bottom, kernel, + temp_); + BlurHorizontalConv(in, xbegin_right, xsize, yend_top, ybegin_bottom, kernel, + temp_); + + // Entire left/right columns + BlurVerticalConv(*temp_, 0, xend_left, 0, ysize, kernel, out); + BlurVerticalConv(*temp_, xbegin_right, xsize, 0, ysize, kernel, out); + // Top/bottom between left/right + const intptr_t ybegin_bottom2 = std::max(intptr_t(0), ysize - rg_radius); + const intptr_t yend_top2 = std::min(rg_radius, ysize); + BlurVerticalConv(*temp_, xend_left, xbegin_right, 0, yend_top2, kernel, + out); + BlurVerticalConv(*temp_, xend_left, xbegin_right, ybegin_bottom2, ysize, + kernel, out); + } +} + +// Allows PaddedMaltaUnit to call either function via overloading. +struct MaltaTagLF {}; +struct MaltaTag {}; + +} // namespace jxl + +#endif // JXL_BUTTERAUGLI_ONCE + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +template +HWY_INLINE V MaximumClamp(D d, V v, double kMaxVal) { + static const double kMul = 0.724216145665; + const V mul = Set(d, kMul); + const V maxval = Set(d, kMaxVal); + // If greater than maxval or less than -maxval, replace with if_*. + const V if_pos = MulAdd(v - maxval, mul, maxval); + const V if_neg = MulSub(v + maxval, mul, maxval); + const V pos_or_v = IfThenElse(v >= maxval, if_pos, v); + return IfThenElse(v < Neg(maxval), if_neg, pos_or_v); +} + +// Make area around zero less important (remove it). +template +HWY_INLINE V RemoveRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(x > w, x - w, IfThenElseZero(x < Neg(w), x + w)); +} + +// Make area around zero more important (2x it until the limit). +template +HWY_INLINE V AmplifyRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(x > w, x + w, IfThenElse(x < Neg(w), x - w, x + x)); +} + +// XybLowFreqToVals converts from low-frequency XYB space to the 'vals' space. +// Vals space can be converted to L2-norm space (Euclidean and normalized) +// through visual masking. +template +HWY_INLINE void XybLowFreqToVals(const D d, const V& x, const V& y, + const V& b_arg, V* HWY_RESTRICT valx, + V* HWY_RESTRICT valy, V* HWY_RESTRICT valb) { + static const double xmuli = 32.2217497012; + static const double ymuli = 13.7697791434; + static const double bmuli = 47.504615728; + static const double y_to_b_muli = -0.362267051518; + const V xmul = Set(d, xmuli); + const V ymul = Set(d, ymuli); + const V bmul = Set(d, bmuli); + const V y_to_b_mul = Set(d, y_to_b_muli); + const V b = MulAdd(y_to_b_mul, y, b_arg); + *valb = b * bmul; + *valx = x * xmul; + *valy = y * ymul; +} + +void SuppressXByY(const ImageF& in_x, const ImageF& in_y, const double yw, + ImageF* HWY_RESTRICT out) { + JXL_DASSERT(SameSize(in_x, in_y) && SameSize(in_x, *out)); + const size_t xsize = in_x.xsize(); + const size_t ysize = in_x.ysize(); + + const HWY_FULL(float) d; + static const double s = 0.653020556257; + const auto sv = Set(d, s); + const auto one_minus_s = Set(d, 1.0 - s); + const auto ywv = Set(d, yw); + + for (size_t y = 0; y < ysize; ++y) { + const float* HWY_RESTRICT row_x = in_x.ConstRow(y); + const float* HWY_RESTRICT row_y = in_y.ConstRow(y); + float* HWY_RESTRICT row_out = out->Row(y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto vx = Load(d, row_x + x); + const auto vy = Load(d, row_y + x); + const auto scaler = MulAdd(ywv / MulAdd(vy, vy, ywv), one_minus_s, sv); + Store(scaler * vx, d, row_out + x); + } + } +} + +static void SeparateFrequencies(size_t xsize, size_t ysize, + const ButteraugliParams& params, + BlurTemp* blur_temp, const Image3F& xyb, + PsychoImage& ps) { + PROFILER_FUNC; + const HWY_FULL(float) d; + + // Extract lf ... + static const double kSigmaLf = 7.15593339443; + static const double kSigmaHf = 3.22489901262; + static const double kSigmaUhf = 1.56416327805; + ps.mf = Image3F(xsize, ysize); + ps.hf[0] = ImageF(xsize, ysize); + ps.hf[1] = ImageF(xsize, ysize); + ps.lf = Image3F(xyb.xsize(), xyb.ysize()); + ps.mf = Image3F(xyb.xsize(), xyb.ysize()); + for (int i = 0; i < 3; ++i) { + Blur(xyb.Plane(i), kSigmaLf, params, blur_temp, &ps.lf.Plane(i)); + + // ... and keep everything else in mf. + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_xyb = xyb.PlaneRow(i, y); + const float* BUTTERAUGLI_RESTRICT row_lf = ps.lf.ConstPlaneRow(i, y); + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(i, y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto mf = Load(d, row_xyb + x) - Load(d, row_lf + x); + Store(mf, d, row_mf + x); + } + } + if (i == 2) { + Blur(ps.mf.Plane(i), kSigmaHf, params, blur_temp, &ps.mf.Plane(i)); + break; + } + // Divide mf into mf and hf. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(i, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[i].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + Store(Load(d, row_mf + x), d, row_hf + x); + } + } + Blur(ps.mf.Plane(i), kSigmaHf, params, blur_temp, &ps.mf.Plane(i)); + static const double kRemoveMfRange = 0.29; + static const double kAddMfRange = 0.1; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Load(d, row_hf + x) - mf; + mf = RemoveRangeAroundZero(d, kRemoveMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = ps.mf.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Load(d, row_hf + x) - mf; + + mf = AmplifyRangeAroundZero(d, kAddMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } + } + + // Temporarily used as output of SuppressXByY + ps.uhf[0] = ImageF(xsize, ysize); + ps.uhf[1] = ImageF(xsize, ysize); + + // Suppress red-green by intensity change in the high freq channels. + static const double suppress = 46.0; + SuppressXByY(ps.hf[0], ps.hf[1], suppress, &ps.uhf[0]); + // hf is the SuppressXByY output, uhf will be written below. + ps.hf[0].Swap(ps.uhf[0]); + + for (int i = 0; i < 2; ++i) { + // Divide hf into hf and uhf. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[i].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[i].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_uhf[x] = row_hf[x]; + } + } + Blur(ps.hf[i], kSigmaUhf, params, blur_temp, &ps.hf[i]); + static const double kRemoveHfRange = 1.5; + static const double kAddHfRange = 0.132; + static const double kRemoveUhfRange = 0.04; + static const double kMaxclampHf = 28.4691806922; + static const double kMaxclampUhf = 5.19175294647; + static double kMulYHf = 2.155; + static double kMulYUhf = 2.69313763794; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + auto uhf = Load(d, row_uhf + x) - hf; + hf = RemoveRangeAroundZero(d, kRemoveHfRange, hf); + uhf = RemoveRangeAroundZero(d, kRemoveUhfRange, uhf); + Store(hf, d, row_hf + x); + Store(uhf, d, row_uhf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = ps.uhf[1].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = ps.hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + hf = MaximumClamp(d, hf, kMaxclampHf); + + auto uhf = Load(d, row_uhf + x) - hf; + uhf = MaximumClamp(d, uhf, kMaxclampUhf); + uhf *= Set(d, kMulYUhf); + Store(uhf, d, row_uhf + x); + + hf *= Set(d, kMulYHf); + hf = AmplifyRangeAroundZero(d, kAddHfRange, hf); + Store(hf, d, row_hf + x); + } + } + } + } + // Modify range around zero code only concerns the high frequency + // planes and only the X and Y channels. + // Convert low freq xyb to vals space so that we can do a simple squared sum + // diff on the low frequencies later. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_x = ps.lf.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_y = ps.lf.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_b = ps.lf.PlaneRow(2, y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto valx = Undefined(d); + auto valy = Undefined(d); + auto valb = Undefined(d); + XybLowFreqToVals(d, Load(d, row_x + x), Load(d, row_y + x), + Load(d, row_b + x), &valx, &valy, &valb); + Store(valx, d, row_x + x); + Store(valy, d, row_y + x); + Store(valb, d, row_b + x); + } + } +} + +template +Vec MaltaUnit(MaltaTagLF /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = LoadU(df, d - 4) + LoadU(df, d - 2) + center + + LoadU(df, d + 2) + LoadU(df, d + 4); + // Will return this, sum of all line kernels + auto retval = sum_yconst * sum_yconst; + { + // y grows, x constant + auto sum = LoadU(df, d - xs3 - xs) + LoadU(df, d - xs - xs) + center + + LoadU(df, d + xs + xs) + LoadU(df, d + xs3 + xs); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = LoadU(df, d - xs3 - 3) + LoadU(df, d - xs - xs - 2) + center + + LoadU(df, d + xs + xs + 2) + LoadU(df, d + xs3 + 3); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = LoadU(df, d - xs3 + 3) + LoadU(df, d - xs - xs + 2) + center + + LoadU(df, d + xs + xs - 2) + LoadU(df, d + xs3 - 3); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = LoadU(df, d - xs3 - xs + 1) + LoadU(df, d - xs - xs + 1) + + center + LoadU(df, d + xs + xs - 1) + + LoadU(df, d + xs3 + xs - 1); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = LoadU(df, d - xs3 - xs - 1) + LoadU(df, d - xs - xs - 1) + + center + LoadU(df, d + xs + xs + 1) + + LoadU(df, d + xs3 + xs + 1); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = LoadU(df, d - 4 - xs) + LoadU(df, d - 2 - xs) + center + + LoadU(df, d + 2 + xs) + LoadU(df, d + 4 + xs); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = LoadU(df, d - 4 + xs) + LoadU(df, d - 2 + xs) + center + + LoadU(df, d + 2 - xs) + LoadU(df, d + 4 - xs); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = LoadU(df, d - xs3 - 2) + LoadU(df, d - xs - xs - 1) + center + + LoadU(df, d + xs + xs + 1) + LoadU(df, d + xs3 + 2); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = LoadU(df, d - xs3 + 2) + LoadU(df, d - xs - xs + 1) + center + + LoadU(df, d + xs + xs - 1) + LoadU(df, d + xs3 - 2); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__*______ + 4____0____ + 5______*__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - xs - 3) + LoadU(df, d - xs - 2) + center + + LoadU(df, d + xs + 2) + LoadU(df, d + xs + xs + 3); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3______*__ + 4____0____ + 5__*______ + 6_*_______ + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - xs + 3) + LoadU(df, d - xs + 2) + center + + LoadU(df, d + xs - 2) + LoadU(df, d + xs + xs - 3); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2________* + 3______*__ + 4____0____ + 5__*______ + 6*________ + 7_________ + 8_________ */ + + auto sum = LoadU(df, d + xs + xs - 4) + LoadU(df, d + xs - 2) + center + + LoadU(df, d - xs + 2) + LoadU(df, d - xs - xs + 4); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2*________ + 3__*______ + 4____0____ + 5______*__ + 6________* + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - xs - 4) + LoadU(df, d - xs - 2) + center + + LoadU(df, d + xs + 2) + LoadU(df, d + xs + xs + 4); + retval = MulAdd(sum, sum, retval); + } + { + /* 0__*______ + 1_________ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7_________ + 8______*__ */ + auto sum = LoadU(df, d - xs3 - xs - 2) + LoadU(df, d - xs - xs - 1) + + center + LoadU(df, d + xs + xs + 1) + + LoadU(df, d + xs3 + xs + 2); + retval = MulAdd(sum, sum, retval); + } + { + /* 0______*__ + 1_________ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7_________ + 8__*______ */ + auto sum = LoadU(df, d - xs3 - xs + 2) + LoadU(df, d - xs - xs + 1) + + center + LoadU(df, d + xs + xs - 1) + + LoadU(df, d + xs3 + xs - 2); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +template +Vec MaltaUnit(MaltaTag /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = LoadU(df, d - 4) + LoadU(df, d - 3) + + LoadU(df, d - 2) + LoadU(df, d - 1) + center + + LoadU(df, d + 1) + LoadU(df, d + 2) + + LoadU(df, d + 3) + LoadU(df, d + 4); + // Will return this, sum of all line kernels + auto retval = sum_yconst * sum_yconst; + + { + // y grows, x constant + auto sum = LoadU(df, d - xs3 - xs) + LoadU(df, d - xs3) + + LoadU(df, d - xs - xs) + LoadU(df, d - xs) + center + + LoadU(df, d + xs) + LoadU(df, d + xs + xs) + LoadU(df, d + xs3) + + LoadU(df, d + xs3 + xs); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = LoadU(df, d - xs3 - 3) + LoadU(df, d - xs - xs - 2) + + LoadU(df, d - xs - 1) + center + LoadU(df, d + xs + 1) + + LoadU(df, d + xs + xs + 2) + LoadU(df, d + xs3 + 3); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = LoadU(df, d - xs3 + 3) + LoadU(df, d - xs - xs + 2) + + LoadU(df, d - xs + 1) + center + LoadU(df, d + xs - 1) + + LoadU(df, d + xs + xs - 2) + LoadU(df, d + xs3 - 3); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = LoadU(df, d - xs3 - xs + 1) + LoadU(df, d - xs3 + 1) + + LoadU(df, d - xs - xs + 1) + LoadU(df, d - xs) + center + + LoadU(df, d + xs) + LoadU(df, d + xs + xs - 1) + + LoadU(df, d + xs3 - 1) + LoadU(df, d + xs3 + xs - 1); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = LoadU(df, d - xs3 - xs - 1) + LoadU(df, d - xs3 - 1) + + LoadU(df, d - xs - xs - 1) + LoadU(df, d - xs) + center + + LoadU(df, d + xs) + LoadU(df, d + xs + xs + 1) + + LoadU(df, d + xs3 + 1) + LoadU(df, d + xs3 + xs + 1); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = LoadU(df, d - 4 - xs) + LoadU(df, d - 3 - xs) + + LoadU(df, d - 2 - xs) + LoadU(df, d - 1) + center + + LoadU(df, d + 1) + LoadU(df, d + 2 + xs) + + LoadU(df, d + 3 + xs) + LoadU(df, d + 4 + xs); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = LoadU(df, d - 4 + xs) + LoadU(df, d - 3 + xs) + + LoadU(df, d - 2 + xs) + LoadU(df, d - 1) + center + + LoadU(df, d + 1) + LoadU(df, d + 2 - xs) + + LoadU(df, d + 3 - xs) + LoadU(df, d + 4 - xs); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3___*_____ + 4____0____ + 5_____*___ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = LoadU(df, d - xs3 - 2) + LoadU(df, d - xs - xs - 1) + + LoadU(df, d - xs - 1) + center + LoadU(df, d + xs + 1) + + LoadU(df, d + xs + xs + 1) + LoadU(df, d + xs3 + 2); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_____*___ + 4____0____ + 5___*_____ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = LoadU(df, d - xs3 + 2) + LoadU(df, d - xs - xs + 1) + + LoadU(df, d - xs + 1) + center + LoadU(df, d + xs - 1) + + LoadU(df, d + xs + xs - 1) + LoadU(df, d + xs3 - 2); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__**_____ + 4____0____ + 5_____**__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - xs - 3) + LoadU(df, d - xs - 2) + + LoadU(df, d - xs - 1) + center + LoadU(df, d + xs + 1) + + LoadU(df, d + xs + 2) + LoadU(df, d + xs + xs + 3); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3_____**__ + 4____0____ + 5__**_____ + 6_*_______ + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - xs + 3) + LoadU(df, d - xs + 2) + + LoadU(df, d - xs + 1) + center + LoadU(df, d + xs - 1) + + LoadU(df, d + xs - 2) + LoadU(df, d + xs + xs - 3); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3______*** + 4___*0*___ + 5***______ + 6_________ + 7_________ + 8_________ */ + + auto sum = LoadU(df, d + xs - 4) + LoadU(df, d + xs - 3) + + LoadU(df, d + xs - 2) + LoadU(df, d - 1) + center + + LoadU(df, d + 1) + LoadU(df, d - xs + 2) + + LoadU(df, d - xs + 3) + LoadU(df, d - xs + 4); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3***______ + 4___*0*___ + 5______*** + 6_________ + 7_________ + 8_________ */ + auto sum = LoadU(df, d - xs - 4) + LoadU(df, d - xs - 3) + + LoadU(df, d - xs - 2) + LoadU(df, d - 1) + center + + LoadU(df, d + 1) + LoadU(df, d + xs + 2) + + LoadU(df, d + xs + 3) + LoadU(df, d + xs + 4); + retval = MulAdd(sum, sum, retval); + } + { + /* 0___*_____ + 1___*_____ + 2___*_____ + 3____*____ + 4____0____ + 5____*____ + 6_____*___ + 7_____*___ + 8_____*___ */ + auto sum = LoadU(df, d - xs3 - xs - 1) + LoadU(df, d - xs3 - 1) + + LoadU(df, d - xs - xs - 1) + LoadU(df, d - xs) + center + + LoadU(df, d + xs) + LoadU(df, d + xs + xs + 1) + + LoadU(df, d + xs3 + 1) + LoadU(df, d + xs3 + xs + 1); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_____*___ + 1_____*___ + 2____ *___ + 3____*____ + 4____0____ + 5____*____ + 6___*_____ + 7___*_____ + 8___*_____ */ + auto sum = LoadU(df, d - xs3 - xs + 1) + LoadU(df, d - xs3 + 1) + + LoadU(df, d - xs - xs + 1) + LoadU(df, d - xs) + center + + LoadU(df, d + xs) + LoadU(df, d + xs + xs - 1) + + LoadU(df, d + xs3 - 1) + LoadU(df, d + xs3 + xs - 1); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +// Returns MaltaUnit. Avoids bounds-checks when x0 and y0 are known +// to be far enough from the image borders. "diffs" is a packed image. +template +static BUTTERAUGLI_INLINE float PaddedMaltaUnit(const ImageF& diffs, + const size_t x0, + const size_t y0) { + const float* BUTTERAUGLI_RESTRICT d = diffs.ConstRow(y0) + x0; + const HWY_CAPPED(float, 1) df; + if ((x0 >= 4 && y0 >= 4 && x0 < (diffs.xsize() - 4) && + y0 < (diffs.ysize() - 4))) { + return GetLane(MaltaUnit(Tag(), df, d, diffs.PixelsPerRow())); + } + + PROFILER_ZONE("Padded Malta"); + float borderimage[12 * 9]; // round up to 4 + for (int dy = 0; dy < 9; ++dy) { + int y = y0 + dy - 4; + if (y < 0 || static_cast(y) >= diffs.ysize()) { + for (int dx = 0; dx < 12; ++dx) { + borderimage[dy * 12 + dx] = 0.0f; + } + continue; + } + + const float* row_diffs = diffs.ConstRow(y); + for (int dx = 0; dx < 9; ++dx) { + int x = x0 + dx - 4; + if (x < 0 || static_cast(x) >= diffs.xsize()) { + borderimage[dy * 12 + dx] = 0.0f; + } else { + borderimage[dy * 12 + dx] = row_diffs[x]; + } + } + std::fill(borderimage + dy * 12 + 9, borderimage + dy * 12 + 12, 0.0f); + } + return GetLane(MaltaUnit(Tag(), df, &borderimage[4 * 12 + 4], 12)); +} + +template +static void MaltaDiffMapT(const Tag tag, const ImageF& lum0, const ImageF& lum1, + const double w_0gt1, const double w_0lt1, + const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + JXL_DASSERT(SameSize(lum0, lum1) && SameSize(lum0, *diffs)); + const size_t xsize_ = lum0.xsize(); + const size_t ysize_ = lum0.ysize(); + + const float kWeight0 = 0.5; + const float kWeight1 = 0.33; + + const double w_pre0gt1 = mulli * std::sqrt(kWeight0 * w_0gt1) / (len * 2 + 1); + const double w_pre0lt1 = mulli * std::sqrt(kWeight1 * w_0lt1) / (len * 2 + 1); + const float norm2_0gt1 = w_pre0gt1 * norm1; + const float norm2_0lt1 = w_pre0lt1 * norm1; + + for (size_t y = 0; y < ysize_; ++y) { + const float* HWY_RESTRICT row0 = lum0.ConstRow(y); + const float* HWY_RESTRICT row1 = lum1.ConstRow(y); + float* HWY_RESTRICT row_diffs = diffs->Row(y); + for (size_t x = 0; x < xsize_; ++x) { + const float absval = 0.5f * (std::abs(row0[x]) + std::abs(row1[x])); + const float diff = row0[x] - row1[x]; + const float scaler = norm2_0gt1 / (static_cast(norm1) + absval); + + // Primary symmetric quadratic objective. + row_diffs[x] = scaler * diff; + + const float scaler2 = norm2_0lt1 / (static_cast(norm1) + absval); + const double fabs0 = std::fabs(row0[x]); + + // Secondary half-open quadratic objectives. + const double too_small = 0.55 * fabs0; + const double too_big = 1.05 * fabs0; + + if (row0[x] < 0) { + if (row1[x] > -too_small) { + double impact = scaler2 * (row1[x] + too_small); + if (diff < 0) { + row_diffs[x] -= impact; + } else { + row_diffs[x] += impact; + } + } else if (row1[x] < -too_big) { + double impact = scaler2 * (-row1[x] - too_big); + if (diff < 0) { + row_diffs[x] -= impact; + } else { + row_diffs[x] += impact; + } + } + } else { + if (row1[x] < too_small) { + double impact = scaler2 * (too_small - row1[x]); + if (diff < 0) { + row_diffs[x] -= impact; + } else { + row_diffs[x] += impact; + } + } else if (row1[x] > too_big) { + double impact = scaler2 * (row1[x] - too_big); + if (diff < 0) { + row_diffs[x] -= impact; + } else { + row_diffs[x] += impact; + } + } + } + } + } + + size_t y0 = 0; + // Top + for (; y0 < 4; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } + + const HWY_FULL(float) df; + const size_t aligned_x = std::max(size_t(4), Lanes(df)); + const intptr_t stride = diffs->PixelsPerRow(); + + // Middle + for (; y0 < ysize_ - 4; ++y0) { + const float* BUTTERAUGLI_RESTRICT row_in = diffs->ConstRow(y0); + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + size_t x0 = 0; + for (; x0 < aligned_x; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + for (; x0 + Lanes(df) + 4 <= xsize_; x0 += Lanes(df)) { + auto diff = Load(df, row_diff + x0); + diff += MaltaUnit(Tag(), df, row_in + x0, stride); + Store(diff, df, row_diff + x0); + } + + for (; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } + + // Bottom + for (; y0 < ysize_; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->PlaneRow(c, y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit(*diffs, x0, y0); + } + } +} + +// Need non-template wrapper functions for HWY_EXPORT. +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + MaltaDiffMapT(MaltaTag(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac, c); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + MaltaDiffMapT(MaltaTagLF(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac, c); +} + +void DiffPrecompute(const ImageF& xyb, float mul, float bias_arg, ImageF* out) { + PROFILER_FUNC; + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + const float bias = mul * bias_arg; + const float sqrt_bias = sqrt(bias); + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = xyb.Row(y); + float* BUTTERAUGLI_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + // kBias makes sqrt behave more linearly. + row_out[x] = sqrt(mul * std::abs(row_in[x]) + bias) - sqrt_bias; + } + } +} + +// std::log(80.0) / std::log(255.0); +constexpr float kIntensityTargetNormalizationHack = 0.79079917404f; +static const float kInternalGoodQualityThreshold = + 17.1984479671f * kIntensityTargetNormalizationHack; +static const float kGlobalScale = 1.0 / kInternalGoodQualityThreshold; + +void StoreMin3(const float v, float& min0, float& min1, float& min2) { + if (v < min2) { + if (v < min0) { + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas area generally smooth, don't do masking. +void FuzzyErosion(const ImageF& from, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + float min0 = from.Row(y)[x]; + float min1 = 2 * min0; + float min2 = min1; + if (x >= 3) { + float v = from.Row(y)[x - 3]; + StoreMin3(v, min0, min1, min2); + if (y >= 3) { + float v = from.Row(y - 3)[x - 3]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - 3) { + float v = from.Row(y + 3)[x - 3]; + StoreMin3(v, min0, min1, min2); + } + } + if (x < xsize - 3) { + float v = from.Row(y)[x + 3]; + StoreMin3(v, min0, min1, min2); + if (y >= 3) { + float v = from.Row(y - 3)[x + 3]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - 3) { + float v = from.Row(y + 3)[x + 3]; + StoreMin3(v, min0, min1, min2); + } + } + if (y >= 3) { + float v = from.Row(y - 3)[x]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - 3) { + float v = from.Row(y + 3)[x]; + StoreMin3(v, min0, min1, min2); + } + to->Row(y)[x] = (0.45f * min0 + 0.3f * min1 + 0.25f * min2); + } + } +} + +// Compute values of local frequency and dc masking based on the activity +// in the two images. img_diff_ac may be null. +void Mask(const ImageF& mask0, const ImageF& mask1, + const ButteraugliParams& params, BlurTemp* blur_temp, + ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + // Only X and Y components are involved in masking. B's influence + // is considered less important in the high frequency area, and we + // don't model masking from lower frequency signals. + PROFILER_FUNC; + const size_t xsize = mask0.xsize(); + const size_t ysize = mask0.ysize(); + *mask = ImageF(xsize, ysize); + static const float kMul = 6.19424080439; + static const float kBias = 12.61050594197; + static const float kRadius = 2.7; + ImageF diff0(xsize, ysize); + ImageF diff1(xsize, ysize); + ImageF blurred0(xsize, ysize); + ImageF blurred1(xsize, ysize); + DiffPrecompute(mask0, kMul, kBias, &diff0); + DiffPrecompute(mask1, kMul, kBias, &diff1); + Blur(diff0, kRadius, params, blur_temp, &blurred0); + FuzzyErosion(blurred0, &diff0); + Blur(diff1, kRadius, params, blur_temp, &blurred1); + FuzzyErosion(blurred1, &diff1); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + mask->Row(y)[x] = diff1.Row(y)[x]; + if (diff_ac != nullptr) { + static const float kMaskToErrorMul = 10.0; + float diff = blurred0.Row(y)[x] - blurred1.Row(y)[x]; + diff_ac->Row(y)[x] += kMaskToErrorMul * diff * diff; + } + } + } +} + +// `diff_ac` may be null. +void MaskPsychoImage(const PsychoImage& pi0, const PsychoImage& pi1, + const size_t xsize, const size_t ysize, + const ButteraugliParams& params, Image3F* temp, + BlurTemp* blur_temp, ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + ImageF mask0(xsize, ysize); + ImageF mask1(xsize, ysize); + static const float muls[3] = { + 8.75000241361f, + 0.620978104816f, + 0.307585098253f, + }; + // Silly and unoptimized approach here. TODO(jyrki): rework this. + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_y_hf0 = pi0.hf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_hf1 = pi1.hf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_uhf0 = pi0.uhf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_uhf1 = pi1.uhf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_hf0 = pi0.hf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_hf1 = pi1.hf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_uhf0 = pi0.uhf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_uhf1 = pi1.uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row0 = mask0.Row(y); + float* BUTTERAUGLI_RESTRICT row1 = mask1.Row(y); + for (size_t x = 0; x < xsize; ++x) { + float xdiff0 = (row_x_uhf0[x] + row_x_hf0[x]) * muls[0]; + float xdiff1 = (row_x_uhf1[x] + row_x_hf1[x]) * muls[0]; + float ydiff0 = row_y_uhf0[x] * muls[1] + row_y_hf0[x] * muls[2]; + float ydiff1 = row_y_uhf1[x] * muls[1] + row_y_hf1[x] * muls[2]; + row0[x] = xdiff0 * xdiff0 + ydiff0 * ydiff0; + row0[x] = sqrt(row0[x]); + row1[x] = xdiff1 * xdiff1 + ydiff1 * ydiff1; + row1[x] = sqrt(row1[x]); + } + } + Mask(mask0, mask1, params, blur_temp, mask, diff_ac); +} + +double MaskY(double delta) { + static const double offset = 0.829591754942; + static const double scaler = 0.451936922203; + static const double mul = 2.5485944793; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +double MaskDcY(double delta) { + static const double offset = 0.20025578522; + static const double scaler = 3.87449418804; + static const double mul = 0.505054525019; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +inline float MaskColor(const float color[3], const float mask) { + return color[0] * mask + color[1] * mask + color[2] * mask; +} + +// Diffmap := sqrt of sum{diff images by multplied by X and Y/B masks} +void CombineChannelsToDiffmap(const ImageF& mask, const Image3F& block_diff_dc, + const Image3F& block_diff_ac, float xmul, + ImageF* result) { + PROFILER_FUNC; + JXL_CHECK(SameSize(mask, *result)); + size_t xsize = mask.xsize(); + size_t ysize = mask.ysize(); + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_out = result->Row(y); + for (size_t x = 0; x < xsize; ++x) { + float val = mask.Row(y)[x]; + float maskval = MaskY(val); + float dc_maskval = MaskDcY(val); + float diff_dc[3]; + float diff_ac[3]; + for (int i = 0; i < 3; ++i) { + diff_dc[i] = block_diff_dc.PlaneRow(i, y)[x]; + diff_ac[i] = block_diff_ac.PlaneRow(i, y)[x]; + } + diff_ac[0] *= xmul; + diff_dc[0] *= xmul; + row_out[x] = + sqrt(MaskColor(diff_dc, dc_maskval) + MaskColor(diff_ac, maskval)); + } + } +} + +// Adds weighted L2 difference between i0 and i1 to diffmap. +static void L2Diff(const ImageF& i0, const ImageF& i1, const float w, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Load(d, row0 + x) - Load(d, row1 + x); + const auto diff2 = diff * diff; + const auto prev = Load(d, row_diff + x); + Store(MulAdd(diff2, weight, prev), d, row_diff + x); + } + } +} + +// Initializes diffmap to the weighted L2 difference between i0 and i1. +static void SetL2Diff(const ImageF& i0, const ImageF& i1, const float w, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Load(d, row0 + x) - Load(d, row1 + x); + const auto diff2 = diff * diff; + Store(diff2 * weight, d, row_diff + x); + } + } +} + +// i0 is the original image. +// i1 is the deformed copy. +static void L2DiffAsymmetric(const ImageF& i0, const ImageF& i1, float w_0gt1, + float w_0lt1, + Image3F* BUTTERAUGLI_RESTRICT diffmap, size_t c) { + if (w_0gt1 == 0 && w_0lt1 == 0) { + return; + } + + const HWY_FULL(float) d; + const auto vw_0gt1 = Set(d, w_0gt1 * 0.8); + const auto vw_0lt1 = Set(d, w_0lt1 * 0.8); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.Row(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.Row(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->PlaneRow(c, y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto val0 = Load(d, row0 + x); + const auto val1 = Load(d, row1 + x); + + // Primary symmetric quadratic objective. + const auto diff = val0 - val1; + auto total = MulAdd(diff * diff, vw_0gt1, Load(d, row_diff + x)); + + // Secondary half-open quadratic objectives. + const auto fabs0 = Abs(val0); + const auto too_small = Set(d, 0.4) * fabs0; + const auto too_big = fabs0; + + const auto if_neg = + IfThenElse(val1 > Neg(too_small), val1 + too_small, + IfThenElseZero(val1 < Neg(too_big), Neg(val1) - too_big)); + const auto if_pos = + IfThenElse(val1 < too_small, too_small - val1, + IfThenElseZero(val1 > too_big, val1 - too_big)); + const auto v = IfThenElse(val0 < Zero(d), if_neg, if_pos); + total += vw_0lt1 * v * v; + Store(total, d, row_diff + x); + } + } +} + +// A simple HDR compatible gamma function. +template +V Gamma(const DF df, V v) { + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kRetMul = Set(df, 19.245013259874995f * 0.693147180559945f); + const auto kRetAdd = Set(df, -23.16046239805755); + // This should happen rarely, but may lead to a NaN in log, which is + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + v = ZeroIfNegative(v); + + const auto biased = v + Set(df, 9.9710635769299145); + const auto log = FastLog2f(df, biased); + // We could fold this into a custom Log2 polynomial, but there would be + // relatively little gain. + return MulAdd(kRetMul, log, kRetAdd); +} + +template +BUTTERAUGLI_INLINE void OpsinAbsorbance(const DF df, const V& in0, const V& in1, + const V& in2, V* JXL_RESTRICT out0, + V* JXL_RESTRICT out1, + V* JXL_RESTRICT out2) { + // https://en.wikipedia.org/wiki/Photopsin absorbance modeling. + static const double mixi0 = 0.29956550340058319; + static const double mixi1 = 0.63373087833825936; + static const double mixi2 = 0.077705617820981968; + static const double mixi3 = 1.7557483643287353; + static const double mixi4 = 0.22158691104574774; + static const double mixi5 = 0.69391388044116142; + static const double mixi6 = 0.0987313588422; + static const double mixi7 = 1.7557483643287353; + static const double mixi8 = 0.02; + static const double mixi9 = 0.02; + static const double mixi10 = 0.20480129041026129; + static const double mixi11 = 12.226454707163354; + + const V mix0 = Set(df, mixi0); + const V mix1 = Set(df, mixi1); + const V mix2 = Set(df, mixi2); + const V mix3 = Set(df, mixi3); + const V mix4 = Set(df, mixi4); + const V mix5 = Set(df, mixi5); + const V mix6 = Set(df, mixi6); + const V mix7 = Set(df, mixi7); + const V mix8 = Set(df, mixi8); + const V mix9 = Set(df, mixi9); + const V mix10 = Set(df, mixi10); + const V mix11 = Set(df, mixi11); + + *out0 = mix0 * in0 + mix1 * in1 + mix2 * in2 + mix3; + *out1 = mix4 * in0 + mix5 * in1 + mix6 * in2 + mix7; + *out2 = mix8 * in0 + mix9 * in1 + mix10 * in2 + mix11; + + if (Clamp) { + *out0 = Max(*out0, mix3); + *out1 = Max(*out1, mix7); + *out2 = Max(*out2, mix11); + } +} + +// `blurred` is a temporary image used inside this function and not returned. +Image3F OpsinDynamicsImage(const Image3F& rgb, const ButteraugliParams& params, + Image3F* blurred, BlurTemp* blur_temp) { + PROFILER_FUNC; + Image3F xyb(rgb.xsize(), rgb.ysize()); + const double kSigma = 1.2; + Blur(rgb.Plane(0), kSigma, params, blur_temp, &blurred->Plane(0)); + Blur(rgb.Plane(1), kSigma, params, blur_temp, &blurred->Plane(1)); + Blur(rgb.Plane(2), kSigma, params, blur_temp, &blurred->Plane(2)); + const HWY_FULL(float) df; + const auto intensity_target_multiplier = Set(df, params.intensity_target); + for (size_t y = 0; y < rgb.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_r = rgb.ConstPlaneRow(0, y); + const float* BUTTERAUGLI_RESTRICT row_g = rgb.ConstPlaneRow(1, y); + const float* BUTTERAUGLI_RESTRICT row_b = rgb.ConstPlaneRow(2, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_r = + blurred->ConstPlaneRow(0, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_g = + blurred->ConstPlaneRow(1, y); + const float* BUTTERAUGLI_RESTRICT row_blurred_b = + blurred->ConstPlaneRow(2, y); + float* BUTTERAUGLI_RESTRICT row_out_x = xyb.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_out_y = xyb.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_out_b = xyb.PlaneRow(2, y); + const auto min = Set(df, 1e-4f); + for (size_t x = 0; x < rgb.xsize(); x += Lanes(df)) { + auto sensitivity0 = Undefined(df); + auto sensitivity1 = Undefined(df); + auto sensitivity2 = Undefined(df); + { + // Calculate sensitivity based on the smoothed image gamma derivative. + auto pre_mixed0 = Undefined(df); + auto pre_mixed1 = Undefined(df); + auto pre_mixed2 = Undefined(df); + OpsinAbsorbance( + df, Load(df, row_blurred_r + x) * intensity_target_multiplier, + Load(df, row_blurred_g + x) * intensity_target_multiplier, + Load(df, row_blurred_b + x) * intensity_target_multiplier, + &pre_mixed0, &pre_mixed1, &pre_mixed2); + pre_mixed0 = Max(pre_mixed0, min); + pre_mixed1 = Max(pre_mixed1, min); + pre_mixed2 = Max(pre_mixed2, min); + sensitivity0 = Gamma(df, pre_mixed0) / pre_mixed0; + sensitivity1 = Gamma(df, pre_mixed1) / pre_mixed1; + sensitivity2 = Gamma(df, pre_mixed2) / pre_mixed2; + sensitivity0 = Max(sensitivity0, min); + sensitivity1 = Max(sensitivity1, min); + sensitivity2 = Max(sensitivity2, min); + } + auto cur_mixed0 = Undefined(df); + auto cur_mixed1 = Undefined(df); + auto cur_mixed2 = Undefined(df); + OpsinAbsorbance(df, + Load(df, row_r + x) * intensity_target_multiplier, + Load(df, row_g + x) * intensity_target_multiplier, + Load(df, row_b + x) * intensity_target_multiplier, + &cur_mixed0, &cur_mixed1, &cur_mixed2); + cur_mixed0 *= sensitivity0; + cur_mixed1 *= sensitivity1; + cur_mixed2 *= sensitivity2; + // This is a kludge. The negative values should be zeroed away before + // blurring. Ideally there would be no negative values in the first place. + const auto min01 = Set(df, 1.7557483643287353f); + const auto min2 = Set(df, 12.226454707163354f); + cur_mixed0 = Max(cur_mixed0, min01); + cur_mixed1 = Max(cur_mixed1, min01); + cur_mixed2 = Max(cur_mixed2, min2); + + Store(cur_mixed0 - cur_mixed1, df, row_out_x + x); + Store(cur_mixed0 + cur_mixed1, df, row_out_y + x); + Store(cur_mixed2, df, row_out_b + x); + } + } + return xyb; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(SeparateFrequencies); // Local function. +HWY_EXPORT(MaskPsychoImage); // Local function. +HWY_EXPORT(L2DiffAsymmetric); // Local function. +HWY_EXPORT(L2Diff); // Local function. +HWY_EXPORT(SetL2Diff); // Local function. +HWY_EXPORT(CombineChannelsToDiffmap); // Local function. +HWY_EXPORT(MaltaDiffMap); // Local function. +HWY_EXPORT(MaltaDiffMapLF); // Local function. +HWY_EXPORT(OpsinDynamicsImage); // Local function. + +#if BUTTERAUGLI_ENABLE_CHECKS + +static inline bool IsNan(const float x) { + uint32_t bits; + memcpy(&bits, &x, sizeof(bits)); + const uint32_t bitmask_exp = 0x7F800000; + return (bits & bitmask_exp) == bitmask_exp && (bits & 0x7FFFFF); +} + +static inline bool IsNan(const double x) { + uint64_t bits; + memcpy(&bits, &x, sizeof(bits)); + return (0x7ff0000000000001ULL <= bits && bits <= 0x7fffffffffffffffULL) || + (0xfff0000000000001ULL <= bits && bits <= 0xffffffffffffffffULL); +} + +static inline void CheckImage(const ImageF& image, const char* name) { + PROFILER_FUNC; + for (size_t y = 0; y < image.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + if (IsNan(row[x])) { + printf("NAN: Image %s @ %zu,%zu (of %zu,%zu)\n", name, x, y, + image.xsize(), image.ysize()); + exit(1); + } + } + } +} + +#define CHECK_NAN(x, str) \ + do { \ + if (IsNan(x)) { \ + printf("%d: %s\n", __LINE__, str); \ + abort(); \ + } \ + } while (0) + +#define CHECK_IMAGE(image, name) CheckImage(image, name) + +#else // BUTTERAUGLI_ENABLE_CHECKS + +#define CHECK_NAN(x, str) +#define CHECK_IMAGE(image, name) + +#endif // BUTTERAUGLI_ENABLE_CHECKS + +// Calculate a 2x2 subsampled image for purposes of recursive butteraugli at +// multiresolution. +static Image3F SubSample2x(const Image3F& in) { + size_t xs = (in.xsize() + 1) / 2; + size_t ys = (in.ysize() + 1) / 2; + Image3F retval(xs, ys); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ys; ++y) { + for (size_t x = 0; x < xs; ++x) { + retval.PlaneRow(c, y)[x] = 0; + } + } + } + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < in.ysize(); ++y) { + for (size_t x = 0; x < in.xsize(); ++x) { + retval.PlaneRow(c, y / 2)[x / 2] += 0.25f * in.PlaneRow(c, y)[x]; + } + } + if ((in.xsize() & 1) != 0) { + for (size_t y = 0; y < retval.ysize(); ++y) { + size_t last_column = retval.xsize() - 1; + retval.PlaneRow(c, y)[last_column] *= 2.0f; + } + } + if ((in.ysize() & 1) != 0) { + for (size_t x = 0; x < retval.xsize(); ++x) { + size_t last_row = retval.ysize() - 1; + retval.PlaneRow(c, last_row)[x] *= 2.0f; + } + } + } + return retval; +} + +// Supersample src by 2x and add it to dest. +static void AddSupersampled2x(const ImageF& src, float w, ImageF& dest) { + for (size_t y = 0; y < dest.ysize(); ++y) { + for (size_t x = 0; x < dest.xsize(); ++x) { + // There will be less errors from the more averaged images. + // We take it into account to some extent using a scaler. + static const double kHeuristicMixingValue = 0.3; + dest.Row(y)[x] *= 1.0 - kHeuristicMixingValue * w; + dest.Row(y)[x] += w * src.Row(y / 2)[x / 2]; + } + } +} + +Image3F* ButteraugliComparator::Temp() const { + bool was_in_use = temp_in_use_.test_and_set(std::memory_order_acq_rel); + JXL_ASSERT(!was_in_use); + (void)was_in_use; + return &temp_; +} + +void ButteraugliComparator::ReleaseTemp() const { + temp_in_use_.clear(std::memory_order_acq_rel); +} + +ButteraugliComparator::ButteraugliComparator(const Image3F& rgb0, + const ButteraugliParams& params) + : xsize_(rgb0.xsize()), + ysize_(rgb0.ysize()), + params_(params), + temp_(xsize_, ysize_) { + if (xsize_ < 8 || ysize_ < 8) { + return; + } + + Image3F xyb0 = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)(rgb0, params, Temp(), + &blur_temp_); + ReleaseTemp(); + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb0, pi0_); + + // Awful recursive construction of samples of different resolution. + // This is an after-thought and possibly somewhat parallel in + // functionality with the PsychoImage multi-resolution approach. + sub_.reset(new ButteraugliComparator(SubSample2x(rgb0), params)); +} + +void ButteraugliComparator::Mask(ImageF* BUTTERAUGLI_RESTRICT mask) const { + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi0_, xsize_, ysize_, params_, Temp(), &blur_temp_, mask, nullptr); + ReleaseTemp(); +} + +void ButteraugliComparator::Diffmap(const Image3F& rgb1, ImageF& result) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + const Image3F xyb1 = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)( + rgb1, params_, Temp(), &blur_temp_); + ReleaseTemp(); + DiffmapOpsinDynamicsImage(xyb1, result); + if (sub_) { + if (sub_->xsize_ < 8 || sub_->ysize_ < 8) { + return; + } + const Image3F sub_xyb = HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage)( + SubSample2x(rgb1), params_, sub_->Temp(), &sub_->blur_temp_); + sub_->ReleaseTemp(); + ImageF subresult; + sub_->DiffmapOpsinDynamicsImage(sub_xyb, subresult); + AddSupersampled2x(subresult, 0.5, result); + } +} + +void ButteraugliComparator::DiffmapOpsinDynamicsImage(const Image3F& xyb1, + ImageF& result) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + PsychoImage pi1; + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb1, pi1); + result = ImageF(xsize_, ysize_); + DiffmapPsychoImage(pi1, result); +} + +namespace { + +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + PROFILER_FUNC; + const double len = 3.75; + static const double mulli = 0.39905817637; + HWY_DYNAMIC_DISPATCH(MaltaDiffMap) + (lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, diffs, block_diff_ac, c); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + PROFILER_FUNC; + const double len = 3.75; + static const double mulli = 0.611612573796; + HWY_DYNAMIC_DISPATCH(MaltaDiffMapLF) + (lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, diffs, block_diff_ac, c); +} + +} // namespace + +void ButteraugliComparator::DiffmapPsychoImage(const PsychoImage& pi1, + ImageF& diffmap) const { + PROFILER_FUNC; + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&diffmap); + return; + } + + const float hf_asymmetry_ = params_.hf_asymmetry; + const float xmul_ = params_.xmul; + + ImageF diffs(xsize_, ysize_); + Image3F block_diff_ac(xsize_, ysize_); + ZeroFillImage(&block_diff_ac); + static const double wUhfMalta = 1.10039032555; + static const double norm1Uhf = 71.7800275169; + MaltaDiffMap(pi0_.uhf[1], pi1.uhf[1], wUhfMalta * hf_asymmetry_, + wUhfMalta / hf_asymmetry_, norm1Uhf, &diffs, &block_diff_ac, 1); + + static const double wUhfMaltaX = 173.5; + static const double norm1UhfX = 5.0; + MaltaDiffMap(pi0_.uhf[0], pi1.uhf[0], wUhfMaltaX * hf_asymmetry_, + wUhfMaltaX / hf_asymmetry_, norm1UhfX, &diffs, &block_diff_ac, + 0); + + static const double wHfMalta = 18.7237414387; + static const double norm1Hf = 4498534.45232; + MaltaDiffMapLF(pi0_.hf[1], pi1.hf[1], wHfMalta * std::sqrt(hf_asymmetry_), + wHfMalta / std::sqrt(hf_asymmetry_), norm1Hf, &diffs, + &block_diff_ac, 1); + + static const double wHfMaltaX = 6923.99476109; + static const double norm1HfX = 8051.15833247; + MaltaDiffMapLF(pi0_.hf[0], pi1.hf[0], wHfMaltaX * std::sqrt(hf_asymmetry_), + wHfMaltaX / std::sqrt(hf_asymmetry_), norm1HfX, &diffs, + &block_diff_ac, 0); + + static const double wMfMalta = 37.0819870399; + static const double norm1Mf = 130262059.556; + MaltaDiffMapLF(pi0_.mf.Plane(1), pi1.mf.Plane(1), wMfMalta, wMfMalta, norm1Mf, + &diffs, &block_diff_ac, 1); + + static const double wMfMaltaX = 8246.75321353; + static const double norm1MfX = 1009002.70582; + MaltaDiffMapLF(pi0_.mf.Plane(0), pi1.mf.Plane(0), wMfMaltaX, wMfMaltaX, + norm1MfX, &diffs, &block_diff_ac, 0); + + static const double wmul[9] = { + 400.0, + 1.50815703118, + 0, + 2150.0, + 10.6195433239, + 16.2176043152, + 29.2353797994, + 0.844626970982, + 0.703646627719, + }; + Image3F block_diff_dc(xsize_, ysize_); + for (size_t c = 0; c < 3; ++c) { + if (c < 2) { // No blue channel error accumulated at HF. + HWY_DYNAMIC_DISPATCH(L2DiffAsymmetric) + (pi0_.hf[c], pi1.hf[c], wmul[c] * hf_asymmetry_, wmul[c] / hf_asymmetry_, + &block_diff_ac, c); + } + HWY_DYNAMIC_DISPATCH(L2Diff) + (pi0_.mf.Plane(c), pi1.mf.Plane(c), wmul[3 + c], &block_diff_ac, c); + HWY_DYNAMIC_DISPATCH(SetL2Diff) + (pi0_.lf.Plane(c), pi1.lf.Plane(c), wmul[6 + c], &block_diff_dc, c); + } + + ImageF mask; + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi1, xsize_, ysize_, params_, Temp(), &blur_temp_, &mask, + &block_diff_ac.Plane(1)); + ReleaseTemp(); + + HWY_DYNAMIC_DISPATCH(CombineChannelsToDiffmap) + (mask, block_diff_dc, block_diff_ac, xmul_, &diffmap); +} + +double ButteraugliScoreFromDiffmap(const ImageF& diffmap, + const ButteraugliParams* params) { + PROFILER_FUNC; + // In approximate-border mode, skip pixels on the border likely to be affected + // by FastGauss' zero-valued-boundary behavior. The border is about half of + // the largest-diameter kernel (37x37 pixels), but only if the image is big. + size_t border = (params != nullptr && params->approximate_border) ? 8 : 0; + if (diffmap.xsize() <= 2 * border || diffmap.ysize() <= 2 * border) { + border = 0; + } + float retval = 0.0f; + for (size_t y = border; y < diffmap.ysize() - border; ++y) { + const float* BUTTERAUGLI_RESTRICT row = diffmap.ConstRow(y); + for (size_t x = border; x < diffmap.xsize() - border; ++x) { + retval = std::max(retval, row[x]); + } + } + return retval; +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + double hf_asymmetry, double xmul, ImageF& diffmap) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliDiffmap(rgb0, rgb1, params, diffmap); +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap) { + PROFILER_FUNC; + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + if (xsize < 1 || ysize < 1) { + return JXL_FAILURE("Zero-sized image"); + } + if (!SameSize(rgb0, rgb1)) { + return JXL_FAILURE("Size mismatch"); + } + static const int kMax = 8; + if (xsize < kMax || ysize < kMax) { + // Butteraugli values for small (where xsize or ysize is smaller + // than 8 pixels) images are non-sensical, but most likely it is + // less disruptive to try to compute something than just give up. + // Temporarily extend the borders of the image to fit 8 x 8 size. + size_t xborder = xsize < kMax ? (kMax - xsize) / 2 : 0; + size_t yborder = ysize < kMax ? (kMax - ysize) / 2 : 0; + size_t xscaled = std::max(kMax, xsize); + size_t yscaled = std::max(kMax, ysize); + Image3F scaled0(xscaled, yscaled); + Image3F scaled1(xscaled, yscaled); + for (int i = 0; i < 3; ++i) { + for (size_t y = 0; y < yscaled; ++y) { + for (size_t x = 0; x < xscaled; ++x) { + size_t x2 = + std::min(xsize - 1, std::max(0, x - xborder)); + size_t y2 = + std::min(ysize - 1, std::max(0, y - yborder)); + scaled0.PlaneRow(i, y)[x] = rgb0.PlaneRow(i, y2)[x2]; + scaled1.PlaneRow(i, y)[x] = rgb1.PlaneRow(i, y2)[x2]; + } + } + } + ImageF diffmap_scaled; + const bool ok = + ButteraugliDiffmap(scaled0, scaled1, params, diffmap_scaled); + diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + diffmap.Row(y)[x] = diffmap_scaled.Row(y + yborder)[x + xborder]; + } + } + return ok; + } + ButteraugliComparator butteraugli(rgb0, params); + butteraugli.Diffmap(rgb1, diffmap); + return true; +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + float hf_asymmetry, float xmul, ImageF& diffmap, + double& diffvalue) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliInterface(rgb0, rgb1, params, diffmap, diffvalue); +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap, + double& diffvalue) { +#if PROFILER_ENABLED + double t0 = Now(); +#endif + if (!ButteraugliDiffmap(rgb0, rgb1, params, diffmap)) { + return false; + } +#if PROFILER_ENABLED + double t1 = Now(); + const size_t mp = rgb0.xsize() * rgb0.ysize(); + printf("diff MP/s %f\n", mp / (t1 - t0) * 1E-6); +#endif + diffvalue = ButteraugliScoreFromDiffmap(diffmap, ¶ms); + return true; +} + +double ButteraugliFuzzyClass(double score) { + static const double fuzzy_width_up = 4.8; + static const double fuzzy_width_down = 4.8; + static const double m0 = 2.0; + static const double scaler = 0.7777; + double val; + if (score < 1.0) { + // val in [scaler .. 2.0] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_down)); + val -= 1.0; // from [1 .. 2] to [0 .. 1] + val *= 2.0 - scaler; // from [0 .. 1] to [0 .. 2.0 - scaler] + val += scaler; // from [0 .. 2.0 - scaler] to [scaler .. 2.0] + } else { + // val in [0 .. scaler] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_up)); + val *= scaler; + } + return val; +} + +// #define PRINT_OUT_NORMALIZATION + +double ButteraugliFuzzyInverse(double seek) { + double pos = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double range = 1.0; range >= 1e-10; range *= 0.5) { + double cur = ButteraugliFuzzyClass(pos); + if (cur < seek) { + pos -= range; + } else { + pos += range; + } + } +#ifdef PRINT_OUT_NORMALIZATION + if (seek == 1.0) { + fprintf(stderr, "Fuzzy inverse %g\n", pos); + } +#endif + return pos; +} + +#ifdef PRINT_OUT_NORMALIZATION +static double print_out_normalization = ButteraugliFuzzyInverse(1.0); +#endif + +namespace { + +void ScoreToRgb(double score, double good_threshold, double bad_threshold, + float rgb[3]) { + double heatmap[12][3] = { + {0, 0, 0}, {0, 0, 1}, + {0, 1, 1}, {0, 1, 0}, // Good level + {1, 1, 0}, {1, 0, 0}, // Bad level + {1, 0, 1}, {0.5, 0.5, 1.0}, + {1.0, 0.5, 0.5}, // Pastel colors for the very bad quality range. + {1.0, 1.0, 0.5}, {1, 1, 1}, + {1, 1, 1}, // Last color repeated to have a solid range of white. + }; + if (score < good_threshold) { + score = (score / good_threshold) * 0.3; + } else if (score < bad_threshold) { + score = 0.3 + + (score - good_threshold) / (bad_threshold - good_threshold) * 0.15; + } else { + score = 0.45 + (score - bad_threshold) / (bad_threshold * 12) * 0.5; + } + static const int kTableSize = sizeof(heatmap) / sizeof(heatmap[0]); + score = std::min(std::max(score * (kTableSize - 1), 0.0), + kTableSize - 2); + int ix = static_cast(score); + ix = std::min(std::max(0, ix), kTableSize - 2); // Handle NaN + double mix = score - ix; + for (int i = 0; i < 3; ++i) { + double v = mix * heatmap[ix + 1][i] + (1 - mix) * heatmap[ix][i]; + rgb[i] = pow(v, 0.5); + } +} + +} // namespace + +Image3F CreateHeatMapImage(const ImageF& distmap, double good_threshold, + double bad_threshold) { + Image3F heatmap(distmap.xsize(), distmap.ysize()); + for (size_t y = 0; y < distmap.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_distmap = distmap.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_h0 = heatmap.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_h1 = heatmap.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_h2 = heatmap.PlaneRow(2, y); + for (size_t x = 0; x < distmap.xsize(); ++x) { + const float d = row_distmap[x]; + float rgb[3]; + ScoreToRgb(d, good_threshold, bad_threshold, rgb); + row_h0[x] = rgb[0]; + row_h1[x] = rgb[1]; + row_h2[x] = rgb[2]; + } + } + return heatmap; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h new file mode 100644 index 000000000000..68d3ac107452 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h @@ -0,0 +1,229 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) + +#ifndef LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ +#define LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +#define BUTTERAUGLI_ENABLE_CHECKS 0 +#define BUTTERAUGLI_RESTRICT JXL_RESTRICT + +// This is the main interface to butteraugli image similarity +// analysis function. + +namespace jxl { + +struct ButteraugliParams { + // Multiplier for penalizing new HF artifacts more than blurring away + // features. 1.0=neutral. + float hf_asymmetry = 1.0f; + + // Multiplier for the psychovisual difference in the X channel. + float xmul = 1.0f; + + // Number of nits that correspond to 1.0f input values. + float intensity_target = 80.0f; + + bool approximate_border = false; +}; + +// ButteraugliInterface defines the public interface for butteraugli. +// +// It calculates the difference between rgb0 and rgb1. +// +// rgb0 and rgb1 contain the images. rgb0[c][px] and rgb1[c][px] contains +// the red image for c == 0, green for c == 1, blue for c == 2. Location index +// px is calculated as y * xsize + x. +// +// Value of pixels of images rgb0 and rgb1 need to be represented as raw +// intensity. Most image formats store gamma corrected intensity in pixel +// values. This gamma correction has to be removed, by applying the following +// function to values in the 0-1 range: +// butteraugli_val = pow(input_val, gamma); +// A typical value of gamma is 2.2. It is usually stored in the image header. +// Take care not to confuse that value with its inverse. The gamma value should +// be always greater than one. +// Butteraugli does not work as intended if the caller does not perform +// gamma correction. +// +// hf_asymmetry is a multiplier for penalizing new HF artifacts more than +// blurring away features (1.0 -> neutral). +// +// diffmap will contain an image of the size xsize * ysize, containing +// localized differences for values px (indexed with the px the same as rgb0 +// and rgb1). diffvalue will give a global score of similarity. +// +// A diffvalue smaller than kButteraugliGood indicates that images can be +// observed as the same image. +// diffvalue larger than kButteraugliBad indicates that a difference between +// the images can be observed. +// A diffvalue between kButteraugliGood and kButteraugliBad indicates that +// a subtle difference can be observed between the images. +// +// Returns true on success. +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap, + double &diffvalue); + +// Deprecated (calls the previous function) +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + float hf_asymmetry, float xmul, ImageF &diffmap, + double &diffvalue); + +// Converts the butteraugli score into fuzzy class values that are continuous +// at the class boundary. The class boundary location is based on human +// raters, but the slope is arbitrary. Particularly, it does not reflect +// the expectation value of probabilities of the human raters. It is just +// expected that a smoother class boundary will allow for higher-level +// optimization algorithms to work faster. +// +// Returns 2.0 for a perfect match, and 1.0 for 'ok', 0.0 for bad. Because the +// scoring is fuzzy, a butteraugli score of 0.96 would return a class of +// around 1.9. +double ButteraugliFuzzyClass(double score); + +// Input values should be in range 0 (bad) to 2 (good). Use +// kButteraugliNormalization as normalization. +double ButteraugliFuzzyInverse(double seek); + +// Implementation details, don't use anything below or your code will +// break in the future. + +#ifdef _MSC_VER +#define BUTTERAUGLI_INLINE __forceinline +#else +#define BUTTERAUGLI_INLINE inline +#endif + +#ifdef __clang__ +// Early versions of Clang did not support __builtin_assume_aligned. +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif defined(__GNUC__) +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 1 +#else +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 0 +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if BUTTERAUGLI_HAS_ASSUME_ALIGNED +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) \ + __builtin_assume_aligned((ptr), (align)) +#else +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) (ptr) +#endif // BUTTERAUGLI_HAS_ASSUME_ALIGNED + +struct PsychoImage { + ImageF uhf[2]; // XY + ImageF hf[2]; // XY + Image3F mf; // XYB + Image3F lf; // XYB +}; + +// Depending on implementation, Blur either needs a normal or transposed image. +// Hold one or both of them here and only allocate on demand to reduce memory +// usage. +struct BlurTemp { + ImageF *Get(const ImageF &in) { + if (temp.xsize() == 0) { + temp = ImageF(in.xsize(), in.ysize()); + } + return &temp; + } + + ImageF *GetTransposed(const ImageF &in) { + if (transposed_temp.xsize() == 0) { + transposed_temp = ImageF(in.ysize(), in.xsize()); + } + return &transposed_temp; + } + + ImageF temp; + ImageF transposed_temp; +}; + +class ButteraugliComparator { + public: + // Butteraugli is calibrated at xmul = 1.0. We add a multiplier here so that + // we can test the hypothesis that a higher weighing of the X channel would + // improve results at higher Butteraugli values. + ButteraugliComparator(const Image3F &rgb0, const ButteraugliParams ¶ms); + virtual ~ButteraugliComparator() = default; + + // Computes the butteraugli map between the original image given in the + // constructor and the distorted image give here. + void Diffmap(const Image3F &rgb1, ImageF &result) const; + + // Same as above, but OpsinDynamicsImage() was already applied. + void DiffmapOpsinDynamicsImage(const Image3F &xyb1, ImageF &result) const; + + // Same as above, but the frequency decomposition was already applied. + void DiffmapPsychoImage(const PsychoImage &pi1, ImageF &diffmap) const; + + void Mask(ImageF *BUTTERAUGLI_RESTRICT mask) const; + + private: + Image3F *Temp() const; + void ReleaseTemp() const; + + const size_t xsize_; + const size_t ysize_; + ButteraugliParams params_; + PsychoImage pi0_; + + // Shared temporary image storage to reduce the number of allocations; + // obtained via Temp(), must call ReleaseTemp when no longer needed. + mutable Image3F temp_; + mutable std::atomic_flag temp_in_use_ = ATOMIC_FLAG_INIT; + + mutable BlurTemp blur_temp_; + std::unique_ptr sub_; +}; + +// Deprecated. +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + double hf_asymmetry, double xmul, ImageF &diffmap); + +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap); + +double ButteraugliScoreFromDiffmap(const ImageF &diffmap, + const ButteraugliParams *params = nullptr); + +// Generate rgb-representation of the distance between two images. +Image3F CreateHeatMapImage(const ImageF &distmap, double good_threshold, + double bad_threshold); + +} // namespace jxl + +#endif // LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli_test.cc b/third_party/jpeg-xl/lib/jxl/butteraugli_test.cc new file mode 100644 index 000000000000..bab700ebc206 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli_test.cc @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/butteraugli.h" + +#include "gtest/gtest.h" +#include "jxl/butteraugli_cxx.h" +#include "lib/jxl/test_utils.h" + +TEST(ButteraugliTest, Lossless) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, pixels.data(), pixels.size(), + &pixel_format, pixels.data(), pixels.size())); + EXPECT_EQ(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); +} + +TEST(ButteraugliTest, Distmap) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, pixels.data(), pixels.size(), + &pixel_format, pixels.data(), pixels.size())); + EXPECT_EQ(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); + const float* distmap; + uint32_t row_stride; + JxlButteraugliResultGetDistmap(result.get(), &distmap, &row_stride); + for (uint32_t y = 0; y < ysize; y++) { + for (uint32_t x = 0; x < xsize; x++) { + EXPECT_EQ(0.0, distmap[y * row_stride + x]); + } + } +} + +TEST(ButteraugliTest, Distorted) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector orig_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector dist_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + dist_pixels[0] += 128; + + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, orig_pixels.data(), + orig_pixels.size(), &pixel_format, dist_pixels.data(), + dist_pixels.size())); + EXPECT_NE(0.0, JxlButteraugliResultGetDistance(result.get(), 8.0)); +} + +TEST(ButteraugliTest, Api) { + uint32_t xsize = 171; + uint32_t ysize = 219; + std::vector orig_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector dist_pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + dist_pixels[0] += 128; + + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlButteraugliApiPtr api(JxlButteraugliApiCreate(nullptr)); + JxlButteraugliApiSetHFAsymmetry(api.get(), 1.0f); + JxlButteraugliApiSetIntensityTarget(api.get(), 250.0f); + JxlButteraugliResultPtr result(JxlButteraugliCompute( + api.get(), xsize, ysize, &pixel_format, orig_pixels.data(), + orig_pixels.size(), &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance0 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + JxlButteraugliApiSetHFAsymmetry(api.get(), 2.0f); + result.reset(JxlButteraugliCompute(api.get(), xsize, ysize, &pixel_format, + orig_pixels.data(), orig_pixels.size(), + &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance1 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + EXPECT_NE(distance0, distance1); + + JxlButteraugliApiSetIntensityTarget(api.get(), 80.0f); + result.reset(JxlButteraugliCompute(api.get(), xsize, ysize, &pixel_format, + orig_pixels.data(), orig_pixels.size(), + &pixel_format, dist_pixels.data(), + dist_pixels.size())); + double distance2 = JxlButteraugliResultGetDistance(result.get(), 8.0); + + EXPECT_NE(distance1, distance2); +} diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli_wrapper.cc b/third_party/jpeg-xl/lib/jxl/butteraugli_wrapper.cc new file mode 100644 index 000000000000..6c6b8ec12313 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli_wrapper.cc @@ -0,0 +1,216 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +#include "jxl/butteraugli.h" +#include "jxl/parallel_runner.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_butteraugli_pnorm.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/memory_manager_internal.h" + +namespace { + +void SetMetadataFromPixelFormat(const JxlPixelFormat* pixel_format, + jxl::ImageMetadata* metadata) { + uint32_t potential_alpha_bits = 0; + switch (pixel_format->data_type) { + case JXL_TYPE_FLOAT: + metadata->SetFloat32Samples(); + potential_alpha_bits = 16; + break; + case JXL_TYPE_FLOAT16: + metadata->SetFloat16Samples(); + potential_alpha_bits = 16; + break; + case JXL_TYPE_UINT32: + metadata->SetUintSamples(32); + potential_alpha_bits = 16; + break; + case JXL_TYPE_UINT16: + metadata->SetUintSamples(16); + potential_alpha_bits = 16; + break; + case JXL_TYPE_UINT8: + metadata->SetUintSamples(8); + potential_alpha_bits = 8; + break; + case JXL_TYPE_BOOLEAN: + metadata->SetUintSamples(2); + potential_alpha_bits = 2; + break; + } + if (pixel_format->num_channels == 2 || pixel_format->num_channels == 4) { + metadata->SetAlphaBits(potential_alpha_bits); + } +} + +} // namespace + +struct JxlButteraugliResultStruct { + JxlMemoryManager memory_manager; + + jxl::ImageF distmap; + jxl::ButteraugliParams params; +}; + +struct JxlButteraugliApiStruct { + // Multiplier for penalizing new HF artifacts more than blurring away + // features. 1.0=neutral. + float hf_asymmetry = 1.0f; + + // Multiplier for the psychovisual difference in the X channel. + float xmul = 1.0f; + + // Number of nits that correspond to 1.0f input values. + float intensity_target = jxl::kDefaultIntensityTarget; + + bool approximate_border = false; + + JxlMemoryManager memory_manager; + std::unique_ptr thread_pool{nullptr}; +}; + +JxlButteraugliApi* JxlButteraugliApiCreate( + const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlButteraugliApi)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlButteraugliApi* ret = new (alloc) JxlButteraugliApi(); + ret->memory_manager = local_memory_manager; + return ret; +} + +void JxlButteraugliApiSetParallelRunner(JxlButteraugliApi* api, + JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + api->thread_pool = jxl::make_unique(parallel_runner, + parallel_runner_opaque); +} + +void JxlButteraugliApiSetHFAsymmetry(JxlButteraugliApi* api, float v) { + api->hf_asymmetry = v; +} + +void JxlButteraugliApiSetIntensityTarget(JxlButteraugliApi* api, float v) { + api->intensity_target = v; +} + +void JxlButteraugliApiDestroy(JxlButteraugliApi* api) { + if (api) { + // Call destructor directly since custom free function is used. + api->~JxlButteraugliApi(); + jxl::MemoryManagerFree(&api->memory_manager, api); + } +} + +JxlButteraugliResult* JxlButteraugliCompute( + const JxlButteraugliApi* api, uint32_t xsize, uint32_t ysize, + const JxlPixelFormat* pixel_format_orig, const void* buffer_orig, + size_t size_orig, const JxlPixelFormat* pixel_format_dist, + const void* buffer_dist, size_t size_dist) { + jxl::ImageMetadata orig_metadata; + SetMetadataFromPixelFormat(pixel_format_orig, &orig_metadata); + jxl::ImageBundle orig_ib(&orig_metadata); + jxl::ColorEncoding c_current; + if (pixel_format_orig->data_type == JXL_TYPE_FLOAT) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format_orig->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format_orig->num_channels < 3); + } + if (!jxl::BufferToImageBundle(*pixel_format_orig, xsize, ysize, buffer_orig, + size_orig, api->thread_pool.get(), c_current, + &orig_ib)) { + return nullptr; + } + + jxl::ImageMetadata dist_metadata; + SetMetadataFromPixelFormat(pixel_format_dist, &dist_metadata); + jxl::ImageBundle dist_ib(&dist_metadata); + if (pixel_format_dist->data_type == JXL_TYPE_FLOAT) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format_dist->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format_dist->num_channels < 3); + } + if (!jxl::BufferToImageBundle(*pixel_format_dist, xsize, ysize, buffer_dist, + size_dist, api->thread_pool.get(), c_current, + &dist_ib)) { + return nullptr; + } + + void* alloc = jxl::MemoryManagerAlloc(&api->memory_manager, + sizeof(JxlButteraugliResult)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlButteraugliResult* result = new (alloc) JxlButteraugliResult(); + result->memory_manager = api->memory_manager; + result->params.hf_asymmetry = api->hf_asymmetry; + result->params.xmul = api->xmul; + result->params.intensity_target = api->intensity_target; + result->params.approximate_border = api->approximate_border; + jxl::ButteraugliDistance(orig_ib, dist_ib, result->params, &result->distmap, + api->thread_pool.get()); + + return result; +} + +float JxlButteraugliResultGetDistance(const JxlButteraugliResult* result, + float pnorm) { + return static_cast( + jxl::ComputeDistanceP(result->distmap, result->params, pnorm)); +} + +void JxlButteraugliResultGetDistmap(const JxlButteraugliResult* result, + const float** buffer, + uint32_t* row_stride) { + *buffer = result->distmap.Row(0); + *row_stride = result->distmap.PixelsPerRow(); +} + +float JxlButteraugliResultGetMaxDistance(const JxlButteraugliResult* result) { + float max_distance = 0.0; + for (uint32_t y = 0; y < result->distmap.ysize(); y++) { + for (uint32_t x = 0; x < result->distmap.xsize(); x++) { + if (result->distmap.ConstRow(y)[x] > max_distance) { + max_distance = result->distmap.ConstRow(y)[x]; + } + } + } + return max_distance; +} + +void JxlButteraugliResultDestroy(JxlButteraugliResult* result) { + if (result) { + // Call destructor directly since custom free function is used. + result->~JxlButteraugliResult(); + jxl::MemoryManagerFree(&result->memory_manager, result); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/byte_order_test.cc b/third_party/jpeg-xl/lib/jxl/byte_order_test.cc new file mode 100644 index 000000000000..3673893cf30e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/byte_order_test.cc @@ -0,0 +1,62 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/byte_order.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(ByteOrderTest, TestRoundTripBE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreBE16(in, buf); + EXPECT_EQ(in, LoadBE16(buf)); + EXPECT_NE(in, LoadLE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreLE16(in, buf); + EXPECT_EQ(in, LoadLE16(buf)); + EXPECT_NE(in, LoadBE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripBE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreBE32(in, buf); + EXPECT_EQ(in, LoadBE32(buf)); + EXPECT_NE(in, LoadLE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreLE32(in, buf); + EXPECT_EQ(in, LoadLE32(buf)); + EXPECT_NE(in, LoadBE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE64) { + const uint64_t in = 0xFEDCBA9876543210ull; + uint8_t buf[8]; + StoreLE64(in, buf); + EXPECT_EQ(in, LoadLE64(buf)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc new file mode 100644 index 000000000000..8d2c9b7b51a0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/chroma_from_luma.h" + +namespace jxl { + +ColorCorrelationMap::ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB) + : ytox_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)), + ytob_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)) { + ZeroFillImage(&ytox_map); + ZeroFillImage(&ytob_map); + if (!XYB) { + base_correlation_b_ = 0; + } + RecomputeDCFactors(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h new file mode 100644 index 000000000000..b289857a1b20 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h @@ -0,0 +1,160 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +// Tile is the rectangular grid of blocks that share color correlation +// parameters ("factor_x/b" such that residual_b = blue - Y * factor_b). +static constexpr size_t kColorTileDim = 64; + +static_assert(kColorTileDim % kBlockDim == 0, + "Color tile dim should be divisible by block dim"); +static constexpr size_t kColorTileDimInBlocks = kColorTileDim / kBlockDim; + +static_assert(kGroupDimInBlocks % kColorTileDimInBlocks == 0, + "Group dim should be divisible by color tile dim"); + +static constexpr uint8_t kDefaultColorFactor = 84; + +// JPEG DCT coefficients are at most 1024. CfL constants are at most 127, and +// the ratio of two entries in a JPEG quantization table is at most 255. Thus, +// since the CfL denominator is 84, this leaves 12 bits of mantissa to be used. +// For extra caution, we use 11. +static constexpr uint8_t kCFLFixedPointPrecision = 11; + +static constexpr U32Enc kColorFactorDist(Val(kDefaultColorFactor), Val(256), + BitsOffset(2, 8), BitsOffset(258, 12)); + +struct ColorCorrelationMap { + ColorCorrelationMap() = default; + // xsize/ysize are in pixels + // set XYB=false to do something close to no-op cmap (needed for now since + // cmap is mandatory) + ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB = true); + + float YtoXRatio(int32_t x_factor) const { + return base_correlation_x_ + x_factor * color_scale_; + } + + float YtoBRatio(int32_t b_factor) const { + return base_correlation_b_ + b_factor * color_scale_; + } + + Status DecodeDC(BitReader* br) { + if (br->ReadFixedBits<1>() == 1) { + // All default. + return true; + } + SetColorFactor(U32Coder::Read(kColorFactorDist, br)); + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_x_)); + if (std::abs(base_correlation_x_) > 4.0f) { + return JXL_FAILURE("Base X correlation is out of range"); + } + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_b_)); + if (std::abs(base_correlation_b_) > 4.0f) { + return JXL_FAILURE("Base B correlation is out of range"); + } + ytox_dc_ = static_cast(br->ReadFixedBits()) + + std::numeric_limits::min(); + ytob_dc_ = static_cast(br->ReadFixedBits()) + + std::numeric_limits::min(); + RecomputeDCFactors(); + return true; + } + + // We consider a CfL map to be JPEG-reconstruction-compatible if base + // correlation is 0, no DC correlation is used, and we use the default color + // factor. + bool IsJPEGCompatible() const { + return base_correlation_x_ == 0 && base_correlation_b_ == 0 && + ytob_dc_ == 0 && ytox_dc_ == 0 && + color_factor_ == kDefaultColorFactor; + } + + int32_t RatioJPEG(int32_t factor) const { + return factor * (1 << kCFLFixedPointPrecision) / kDefaultColorFactor; + } + + void SetColorFactor(uint32_t factor) { + color_factor_ = factor; + color_scale_ = 1.0f / color_factor_; + RecomputeDCFactors(); + } + + void SetYToBDC(int32_t ytob_dc) { + ytob_dc_ = ytob_dc; + RecomputeDCFactors(); + } + void SetYToXDC(int32_t ytox_dc) { + ytox_dc_ = ytox_dc; + RecomputeDCFactors(); + } + + int32_t GetYToXDC() const { return ytox_dc_; } + int32_t GetYToBDC() const { return ytob_dc_; } + float GetColorFactor() const { return color_factor_; } + float GetBaseCorrelationX() const { return base_correlation_x_; } + float GetBaseCorrelationB() const { return base_correlation_b_; } + + const float* DCFactors() const { return dc_factors_; } + + void RecomputeDCFactors() { + dc_factors_[0] = YtoXRatio(ytox_dc_); + dc_factors_[2] = YtoBRatio(ytob_dc_); + } + + ImageSB ytox_map; + ImageSB ytob_map; + + private: + float dc_factors_[4] = {}; + // range of factor: -1.51 to +1.52 + uint32_t color_factor_ = kDefaultColorFactor; + float color_scale_ = 1.0f / color_factor_; + float base_correlation_x_ = 0.0f; + float base_correlation_b_ = kYToBRatio; + int32_t ytox_dc_ = 0; + int32_t ytob_dc_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_CHROMA_FROM_LUMA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/codec_in_out.h b/third_party/jpeg-xl/lib/jxl/codec_in_out.h new file mode 100644 index 000000000000..603ad2492d19 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/codec_in_out.h @@ -0,0 +1,249 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_CODEC_IN_OUT_H_ +#define LIB_JXL_CODEC_IN_OUT_H_ + +// Holds inputs/outputs for decoding/encoding images. + +#include + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +// Per-channel interval, used to convert between (full-range) external and +// (bounded or unbounded) temp values. See external_image.cc for the definitions +// of temp/external. +struct CodecInterval { + CodecInterval() = default; + constexpr CodecInterval(float min, float max) : min(min), width(max - min) {} + // Defaults for temp. + float min = 0.0f; + float width = 1.0f; +}; + +struct SizeConstraints { + // Upper limit on pixel dimensions/area, enforced by VerifyDimensions + // (called from decoders). Fuzzers set smaller values to limit memory use. + uint32_t dec_max_xsize = 0xFFFFFFFFu; + uint32_t dec_max_ysize = 0xFFFFFFFFu; + uint64_t dec_max_pixels = 0xFFFFFFFFu; // Might be up to ~0ull +}; + +template ::value>::type> +Status VerifyDimensions(const SizeConstraints* constraints, T xs, T ys) { + if (!constraints) return true; + + if (xs == 0 || ys == 0) return JXL_FAILURE("Empty image."); + if (xs > constraints->dec_max_xsize) return JXL_FAILURE("Image too wide."); + if (ys > constraints->dec_max_ysize) return JXL_FAILURE("Image too tall."); + + const uint64_t num_pixels = static_cast(xs) * ys; + if (num_pixels > constraints->dec_max_pixels) { + return JXL_FAILURE("Image too big."); + } + + return true; +} + +using CodecIntervals = std::array; // RGB[A] or Y[A] + +// Allows passing arbitrary metadata to decoders (required for PNM). +class DecoderHints { + public: + // key=color_space, value=Description(c/pp): specify the ColorEncoding of + // the pixels for decoding. Otherwise, if the codec did not obtain an ICC + // profile from the image, assume sRGB. + // + // Strings are taken from the command line, so avoid spaces for convenience. + void Add(const std::string& key, const std::string& value) { + kv_.emplace_back(key, value); + } + + // Calls `func(key, value)` for each key/value in the order they were added, + // returning false immediately if `func` returns false. + template + Status Foreach(const Func& func) const { + for (const KeyValue& kv : kv_) { + Status ok = func(kv.key, kv.value); + if (!ok) { + return JXL_FAILURE("DecoderHints::Foreach returned false"); + } + } + return true; + } + + private: + // Splitting into key/value avoids parsing in each codec. + struct KeyValue { + KeyValue(std::string key, std::string value) + : key(std::move(key)), value(std::move(value)) {} + + std::string key; + std::string value; + }; + + std::vector kv_; +}; + +// Optional text/EXIF metadata. +struct Blobs { + PaddedBytes exif; + PaddedBytes iptc; + PaddedBytes jumbf; + PaddedBytes xmp; +}; + +// For Codec::kJPG, convert between JPEG and pixels or between JPEG and +// quantized DCT coefficients +// For pixel data, the nominal range is 0..1. +enum class DecodeTarget { kPixels, kQuantizedCoeffs }; + +// Holds a preview, a main image or one or more frames, plus the inputs/outputs +// to/from decoding/encoding. +class CodecInOut { + public: + CodecInOut() : preview_frame(&metadata.m) { + frames.reserve(1); + frames.emplace_back(&metadata.m); + } + + // Move-only. + CodecInOut(CodecInOut&&) = default; + CodecInOut& operator=(CodecInOut&&) = default; + + size_t LastStillFrame() const { + JXL_DASSERT(frames.size() > 0); + size_t last = 0; + for (size_t i = 0; i < frames.size(); i++) { + last = i; + if (frames[i].duration > 0) break; + } + return last; + } + + ImageBundle& Main() { return frames[LastStillFrame()]; } + const ImageBundle& Main() const { return frames[LastStillFrame()]; } + + // If c_current.IsGray(), all planes must be identical. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current) { + Main().SetFromImage(std::move(color), c_current); + SetIntensityTarget(this); + SetSize(Main().xsize(), Main().ysize()); + } + + void SetSize(size_t xsize, size_t ysize) { + JXL_CHECK(metadata.size.Set(xsize, ysize)); + } + + void CheckMetadata() const { + JXL_CHECK(metadata.m.bit_depth.bits_per_sample != 0); + JXL_CHECK(!metadata.m.color_encoding.ICC().empty()); + + if (preview_frame.xsize() != 0) preview_frame.VerifyMetadata(); + JXL_CHECK(preview_frame.metadata() == &metadata.m); + + for (const ImageBundle& ib : frames) { + ib.VerifyMetadata(); + JXL_CHECK(ib.metadata() == &metadata.m); + } + } + + size_t xsize() const { return metadata.size.xsize(); } + size_t ysize() const { return metadata.size.ysize(); } + void ShrinkTo(size_t xsize, size_t ysize) { + // preview is unaffected. + for (ImageBundle& ib : frames) { + ib.ShrinkTo(xsize, ysize); + } + SetSize(xsize, ysize); + } + + // Calls TransformTo for each ImageBundle (preview/frames). + Status TransformTo(const ColorEncoding& c_desired, + ThreadPool* pool = nullptr) { + if (metadata.m.have_preview) { + JXL_RETURN_IF_ERROR(preview_frame.TransformTo(c_desired, pool)); + } + for (ImageBundle& ib : frames) { + JXL_RETURN_IF_ERROR(ib.TransformTo(c_desired, pool)); + } + return true; + } + + // -- DECODER INPUT: + + SizeConstraints constraints; + // Used to set c_current for codecs that lack color space metadata. + DecoderHints dec_hints; + // Decode to pixels or keep JPEG as quantized DCT coefficients + DecodeTarget dec_target = DecodeTarget::kPixels; + + // Intended white luminance, in nits (cd/m^2). + // It is used by codecs that do not know the absolute luminance of their + // images. For those codecs, decoders map from white to this luminance. There + // is no other way of knowing the target brightness for those codecs - depends + // on source material. 709 typically targets 100 nits, BT.2100 PQ up to 10K, + // but HDR content is more typically mastered to 4K nits. Codecs that do know + // the absolute luminance of their images will typically ignore it as a + // decoder input. The corresponding decoder output and encoder input is the + // intensity target in the metadata. ALL decoders MUST set that metadata + // appropriately, but it does not have to be identical to this hint. Encoders + // for codecs that do not encode absolute luminance levels should use that + // metadata to decide on what to map to white. Encoders for codecs that *do* + // encode absolute luminance levels may use it to decide on encoding values, + // but not in a way that would affect the range of interpreted luminance. + // + // 0 means that it is up to the codec to decide on a reasonable value to use. + + float target_nits = 0; + + // -- DECODER OUTPUT: + + // Total number of pixels decoded (may differ from #frames * xsize * ysize + // if frames are cropped) + uint64_t dec_pixels = 0; + + // -- DECODER OUTPUT, ENCODER INPUT: + + // Metadata stored into / retrieved from bitstreams. + + Blobs blobs; + + CodecMetadata metadata; // applies to preview and all frames + + // If metadata.have_preview: + ImageBundle preview_frame; + + std::vector frames; // size=1 if !metadata.have_animation + + bool use_sjpeg = false; + // If the image should be written to a JPEG, use this quality for encoding. + size_t jpeg_quality; +}; + +} // namespace jxl + +#endif // LIB_JXL_CODEC_IN_OUT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order.cc b/third_party/jpeg-xl/lib/jxl/coeff_order.cc new file mode 100644 index 000000000000..0a2a09613601 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order.cc @@ -0,0 +1,163 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/coeff_order.h" + +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +void SetDefaultOrder(AcStrategy acs, coeff_order_t* JXL_RESTRICT order) { + PROFILER_FUNC; + const size_t size = + kDCTBlockSize * acs.covered_blocks_x() * acs.covered_blocks_y(); + const coeff_order_t* natural_coeff_order = acs.NaturalCoeffOrder(); + for (size_t k = 0; k < size; ++k) { + order[k] = natural_coeff_order[k]; + } +} + +uint32_t CoeffOrderContext(uint32_t val) { + uint32_t token, nbits, bits; + HybridUintConfig(0, 0, 0).Encode(val, &token, &nbits, &bits); + return std::min(token, kPermutationContexts - 1); +} + +namespace { +Status ReadPermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br, ANSSymbolReader* reader, + const std::vector& context_map) { + std::vector lehmer(size); + // temp space needs to be as large as the next power of 2, so doubling the + // allocated size is enough. + std::vector temp(size * 2); + uint32_t end = + reader->ReadHybridUint(CoeffOrderContext(size), br, context_map) + skip; + if (end > size) { + return JXL_FAILURE("Invalid permutation size"); + } + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + lehmer[i] = + reader->ReadHybridUint(CoeffOrderContext(last), br, context_map); + last = lehmer[i]; + if (lehmer[i] + i >= size) { + return JXL_FAILURE("Invalid lehmer code"); + } + } + if (order == nullptr) return true; + DecodeLehmerCode(lehmer.data(), temp.data(), size, order); + return true; +} + +} // namespace + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br) { + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + ANSSymbolReader reader(&code, br); + JXL_RETURN_IF_ERROR( + ReadPermutation(skip, size, order, br, &reader, context_map)); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +namespace { + +Status DecodeCoeffOrder(AcStrategy acs, coeff_order_t* order, BitReader* br, + ANSSymbolReader* reader, + const std::vector& context_map) { + PROFILER_FUNC; + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + + JXL_RETURN_IF_ERROR( + ReadPermutation(llf, size, order, br, reader, context_map)); + if (order == nullptr) return true; + const coeff_order_t* natural_coeff_order = acs.NaturalCoeffOrder(); + for (size_t k = 0; k < size; ++k) { + order[k] = natural_coeff_order[order[k]]; + } + return true; +} + +} // namespace + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br) { + uint16_t computed = 0; + std::vector context_map; + ANSCode code; + std::unique_ptr reader; + // Bitstream does not have histograms if no coefficient order is used. + if (used_orders != 0) { + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + reader = make_unique(&code, br); + } + uint32_t acs_mask = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if ((used_acs & (1 << o)) == 0) continue; + acs_mask |= 1 << kStrategyOrder[o]; + } + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + bool used = (acs_mask & (1 << ord)) != 0; + if ((used_orders & (1 << ord)) == 0) { + // No need to set the default order if no ACS uses this order. + if (used) { + for (size_t c = 0; c < 3; c++) { + SetDefaultOrder(acs, &order[CoeffOrderOffset(ord, c)]); + } + } + } else { + for (size_t c = 0; c < 3; c++) { + coeff_order_t* dest = used ? &order[CoeffOrderOffset(ord, c)] : nullptr; + JXL_RETURN_IF_ERROR( + DecodeCoeffOrder(acs, dest, br, reader.get(), context_map)); + } + } + } + if (used_orders && !reader->CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order.h b/third_party/jpeg-xl/lib/jxl/coeff_order.h new file mode 100644 index 000000000000..a1dc927ba5fd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order.h @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COEFF_ORDER_H_ +#define LIB_JXL_COEFF_ORDER_H_ + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +// Those offsets get multiplied by kDCTBlockSize. +static constexpr size_t kCoeffOrderOffset[] = { + 0, 1, 2, 3, 4, 5, 6, 10, 14, 18, + 34, 50, 66, 68, 70, 72, 76, 80, 84, 92, + 100, 108, 172, 236, 300, 332, 364, 396, 652, 908, + 1164, 1292, 1420, 1548, 2572, 3596, 4620, 5132, 5644, 6156, +}; +static_assert(3 * kNumOrders + 1 == + sizeof(kCoeffOrderOffset) / sizeof(*kCoeffOrderOffset), + "Update this array when adding or removing order types."); + +static constexpr size_t CoeffOrderOffset(size_t order, size_t c) { + return kCoeffOrderOffset[3 * order + c] * kDCTBlockSize; +} + +static constexpr size_t kCoeffOrderMaxSize = + kCoeffOrderOffset[3 * kNumOrders] * kDCTBlockSize; + +// Mapping from AC strategy to order bucket. Strategies with different natural +// orders must have different buckets. +constexpr uint8_t kStrategyOrder[] = { + 0, 1, 1, 1, 2, 3, 4, 4, 5, 5, 6, 6, 1, 1, + 1, 1, 1, 1, 7, 8, 8, 9, 10, 10, 11, 12, 12, +}; + +static_assert(AcStrategy::kNumValidStrategies == + sizeof(kStrategyOrder) / sizeof(*kStrategyOrder), + "Update this array when adding or removing AC strategies."); + +constexpr uint32_t kPermutationContexts = 8; + +uint32_t CoeffOrderContext(uint32_t val); + +void SetDefaultOrder(AcStrategy acs, coeff_order_t* JXL_RESTRICT order); + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br); + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h b/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h new file mode 100644 index 000000000000..8ae76e839239 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COEFF_ORDER_FWD_H_ +#define LIB_JXL_COEFF_ORDER_FWD_H_ + +// Breaks circular dependency between ac_strategy and coeff_order. + +#include +#include + +#include "base/compiler_specific.h" + +namespace jxl { + +// Needs at least 16 bits. A 32-bit type speeds up DecodeAC by 2% at the cost of +// more memory. +using coeff_order_t = uint32_t; + +// Maximum number of orders to be used. Note that this needs to be multiplied by +// the number of channels. One per "size class" (plus one extra for DCT8), +// shared between transforms of size XxY and of size YxX. +constexpr uint8_t kNumOrders = 13; + +// DCT coefficients are laid out in such a way that the number of rows of +// coefficients is always the smaller coordinate. +JXL_INLINE constexpr size_t CoefficientRows(size_t rows, size_t columns) { + return rows < columns ? rows : columns; +} + +JXL_INLINE constexpr size_t CoefficientColumns(size_t rows, size_t columns) { + return rows < columns ? columns : rows; +} + +JXL_INLINE void CoefficientLayout(size_t* JXL_RESTRICT rows, + size_t* JXL_RESTRICT columns) { + size_t r = *rows; + size_t c = *columns; + *rows = CoefficientRows(r, c); + *columns = CoefficientColumns(r, c); +} + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_FWD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc b/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc new file mode 100644 index 000000000000..7551de672861 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc @@ -0,0 +1,110 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/coeff_order.h" + +#include + +#include +#include // iota +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_coeff_order.h" + +namespace jxl { +namespace { + +void RoundtripPermutation(coeff_order_t* perm, coeff_order_t* out, size_t len, + size_t* size) { + BitWriter writer; + EncodePermutation(perm, 0, len, &writer, 0, nullptr); + writer.ZeroPadToByte(); + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(DecodePermutation(0, len, out, &reader)); + } + ASSERT_TRUE(status); + *size = writer.GetSpan().size(); +} + +enum Permutation { kIdentity, kFewSwaps, kFewSlides, kRandom }; + +constexpr size_t kNumReps = 128; +constexpr size_t kSwaps = 32; + +void TestPermutation(Permutation kind, size_t len) { + std::vector perm(len); + std::iota(perm.begin(), perm.end(), 0); + std::mt19937 rng; + if (kind == kFewSwaps) { + std::uniform_int_distribution dist(0, len - 1); + for (size_t i = 0; i < kSwaps; i++) { + size_t a = dist(rng); + size_t b = dist(rng); + std::swap(perm[a], perm[b]); + } + } + if (kind == kFewSlides) { + std::uniform_int_distribution dist(0, len - 1); + for (size_t i = 0; i < kSwaps; i++) { + size_t a = dist(rng); + size_t b = dist(rng); + size_t from = std::min(a, b); + size_t to = std::max(a, b); + size_t start = perm[from]; + for (size_t j = from; j < to; j++) { + perm[j] = perm[j + 1]; + } + perm[to] = start; + } + } + if (kind == kRandom) { + std::shuffle(perm.begin(), perm.end(), rng); + } + std::vector out(len); + size_t size = 0; + for (size_t i = 0; i < kNumReps; i++) { + RoundtripPermutation(perm.data(), out.data(), len, &size); + for (size_t idx = 0; idx < len; idx++) { + EXPECT_EQ(perm[idx], out[idx]); + } + } + printf("Encoded size: %zu\n", size); +} + +TEST(CoeffOrderTest, IdentitySmall) { TestPermutation(kIdentity, 256); } +TEST(CoeffOrderTest, FewSlidesSmall) { TestPermutation(kFewSlides, 256); } +TEST(CoeffOrderTest, FewSwapsSmall) { TestPermutation(kFewSwaps, 256); } +TEST(CoeffOrderTest, RandomSmall) { TestPermutation(kRandom, 256); } + +TEST(CoeffOrderTest, IdentityMedium) { TestPermutation(kIdentity, 1 << 12); } +TEST(CoeffOrderTest, FewSlidesMedium) { TestPermutation(kFewSlides, 1 << 12); } +TEST(CoeffOrderTest, FewSwapsMedium) { TestPermutation(kFewSwaps, 1 << 12); } +TEST(CoeffOrderTest, RandomMedium) { TestPermutation(kRandom, 1 << 12); } + +TEST(CoeffOrderTest, IdentityBig) { TestPermutation(kIdentity, 1 << 16); } +TEST(CoeffOrderTest, FewSlidesBig) { TestPermutation(kFewSlides, 1 << 16); } +TEST(CoeffOrderTest, FewSwapsBig) { TestPermutation(kFewSwaps, 1 << 16); } +TEST(CoeffOrderTest, RandomBig) { TestPermutation(kRandom, 1 << 16); } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc new file mode 100644 index 000000000000..a552460c93c0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc @@ -0,0 +1,746 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/color_encoding_internal.h" + +#include + +#include +#include + +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/linalg.h" + +namespace jxl { +namespace { + +// Highest reasonable value for the gamma of a transfer curve. +constexpr uint32_t kMaxGamma = 8192; + +// These strings are baked into Description - do not change. + +std::string ToString(ColorSpace color_space) { + switch (color_space) { + case ColorSpace::kRGB: + return "RGB"; + case ColorSpace::kGray: + return "Gra"; + case ColorSpace::kXYB: + return "XYB"; + case ColorSpace::kUnknown: + return "CS?"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid ColorSpace %u", static_cast(color_space)); +} + +std::string ToString(WhitePoint white_point) { + switch (white_point) { + case WhitePoint::kD65: + return "D65"; + case WhitePoint::kCustom: + return "Cst"; + case WhitePoint::kE: + return "EER"; + case WhitePoint::kDCI: + return "DCI"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid WhitePoint %u", static_cast(white_point)); +} + +std::string ToString(Primaries primaries) { + switch (primaries) { + case Primaries::kSRGB: + return "SRG"; + case Primaries::k2100: + return "202"; + case Primaries::kP3: + return "DCI"; + case Primaries::kCustom: + return "Cst"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid Primaries %u", static_cast(primaries)); +} + +std::string ToString(TransferFunction transfer_function) { + switch (transfer_function) { + case TransferFunction::kSRGB: + return "SRG"; + case TransferFunction::kLinear: + return "Lin"; + case TransferFunction::k709: + return "709"; + case TransferFunction::kPQ: + return "PeQ"; + case TransferFunction::kHLG: + return "HLG"; + case TransferFunction::kDCI: + return "DCI"; + case TransferFunction::kUnknown: + return "TF?"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid TransferFunction %u", + static_cast(transfer_function)); +} + +std::string ToString(RenderingIntent rendering_intent) { + switch (rendering_intent) { + case RenderingIntent::kPerceptual: + return "Per"; + case RenderingIntent::kRelative: + return "Rel"; + case RenderingIntent::kSaturation: + return "Sat"; + case RenderingIntent::kAbsolute: + return "Abs"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_ABORT("Invalid RenderingIntent %u", + static_cast(rendering_intent)); +} + +template +Status ParseEnum(const std::string& token, Enum* value) { + std::string str; + for (Enum e : Values()) { + if (ToString(e) == token) { + *value = e; + return true; + } + } + return false; +} + +class Tokenizer { + public: + Tokenizer(const std::string* input, char separator) + : input_(input), separator_(separator) {} + + Status Next(std::string* JXL_RESTRICT next) { + const size_t end = input_->find(separator_, start_); + if (end == std::string::npos) { + *next = input_->substr(start_); // rest of string + } else { + *next = input_->substr(start_, end - start_); + } + if (next->empty()) return JXL_FAILURE("Missing token"); + start_ = end + 1; + return true; + } + + private: + const std::string* const input_; // not owned + const char separator_; + size_t start_ = 0; // of next token +}; + +Status ParseDouble(const std::string& num, double* JXL_RESTRICT d) { + char* end; + errno = 0; + *d = strtod(num.c_str(), &end); + if (*d == 0.0 && end == num.c_str()) { + return JXL_FAILURE("Invalid double: %s", num.c_str()); + } + if (std::isnan(*d)) { + return JXL_FAILURE("Invalid double: %s", num.c_str()); + } + if (errno == ERANGE) { + return JXL_FAILURE("Double out of range: %s", num.c_str()); + } + return true; +} + +Status ParseDouble(Tokenizer* tokenizer, double* JXL_RESTRICT d) { + std::string num; + JXL_RETURN_IF_ERROR(tokenizer->Next(&num)); + return ParseDouble(num, d); +} + +Status ParseColorSpace(Tokenizer* JXL_RESTRICT tokenizer, + ColorEncoding* JXL_RESTRICT c) { + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + ColorSpace cs; + if (ParseEnum(str, &cs)) { + c->SetColorSpace(cs); + return true; + } + + return JXL_FAILURE("Unknown ColorSpace %s", str.c_str()); +} + +Status ParseWhitePoint(Tokenizer* JXL_RESTRICT tokenizer, + ColorEncoding* JXL_RESTRICT c) { + if (c->ImplicitWhitePoint()) return true; + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (ParseEnum(str, &c->white_point)) return true; + + CIExy xy; + Tokenizer xy_tokenizer(&str, ';'); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.x)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.y)); + if (c->SetWhitePoint(xy)) return true; + + return JXL_FAILURE("Invalid white point %s", str.c_str()); +} + +Status ParsePrimaries(Tokenizer* JXL_RESTRICT tokenizer, + ColorEncoding* JXL_RESTRICT c) { + if (!c->HasPrimaries()) return true; + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (ParseEnum(str, &c->primaries)) return true; + + PrimariesCIExy xy; + Tokenizer xy_tokenizer(&str, ';'); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.r.x)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.r.y)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.g.x)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.g.y)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.b.x)); + JXL_RETURN_IF_ERROR(ParseDouble(&xy_tokenizer, &xy.b.y)); + if (c->SetPrimaries(xy)) return true; + + return JXL_FAILURE("Invalid primaries %s", str.c_str()); +} + +Status ParseRenderingIntent(Tokenizer* JXL_RESTRICT tokenizer, + ColorEncoding* JXL_RESTRICT c) { + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + if (ParseEnum(str, &c->rendering_intent)) return true; + + return JXL_FAILURE("Invalid RenderingIntent %s\n", str.c_str()); +} + +Status ParseTransferFunction(Tokenizer* JXL_RESTRICT tokenizer, + ColorEncoding* JXL_RESTRICT c) { + if (c->tf.SetImplicit()) return true; + + std::string str; + JXL_RETURN_IF_ERROR(tokenizer->Next(&str)); + TransferFunction transfer_function; + if (ParseEnum(str, &transfer_function)) { + c->tf.SetTransferFunction(transfer_function); + return true; + } + + if (str[0] == 'g') { + double gamma; + JXL_RETURN_IF_ERROR(ParseDouble(str.substr(1), &gamma)); + if (c->tf.SetGamma(gamma)) return true; + } + + return JXL_FAILURE("Invalid gamma %s", str.c_str()); +} + +static double F64FromCustomxyI32(const int32_t i) { return i * 1E-6; } +static Status F64ToCustomxyI32(const double f, int32_t* JXL_RESTRICT i) { + if (!(-4 <= f && f <= 4)) { + return JXL_FAILURE("F64 out of bounds for CustomxyI32"); + } + *i = static_cast(roundf(f * 1E6)); + return true; +} + +} // namespace + +CIExy Customxy::Get() const { + CIExy xy; + xy.x = F64FromCustomxyI32(x); + xy.y = F64FromCustomxyI32(y); + return xy; +} + +Status Customxy::Set(const CIExy& xy) { + JXL_RETURN_IF_ERROR(F64ToCustomxyI32(xy.x, &x)); + JXL_RETURN_IF_ERROR(F64ToCustomxyI32(xy.y, &y)); + size_t extension_bits, total_bits; + if (!Bundle::CanEncode(*this, &extension_bits, &total_bits)) { + return JXL_FAILURE("Unable to encode XY %f %f", xy.x, xy.y); + } + return true; +} + +bool CustomTransferFunction::SetImplicit() { + if (nonserialized_color_space == ColorSpace::kXYB) { + if (!SetGamma(1.0 / 3)) JXL_ASSERT(false); + return true; + } + return false; +} + +Status CustomTransferFunction::SetGamma(double gamma) { + if (gamma < (1.0f / kMaxGamma) || gamma > 1.0) { + return JXL_FAILURE("Invalid gamma %f", gamma); + } + + have_gamma_ = false; + if (ApproxEq(gamma, 1.0)) { + transfer_function_ = TransferFunction::kLinear; + return true; + } + if (ApproxEq(gamma, 1.0 / 2.6)) { + transfer_function_ = TransferFunction::kDCI; + return true; + } + // Don't translate 0.45.. to kSRGB nor k709 - that might change pixel + // values because those curves also have a linear part. + + have_gamma_ = true; + gamma_ = roundf(gamma * kGammaMul); + transfer_function_ = TransferFunction::kUnknown; + return true; +} + +namespace { + +std::array CreateC2(const Primaries pr, + const TransferFunction tf) { + std::array c2; + + { + ColorEncoding* c_rgb = c2.data() + 0; + c_rgb->SetColorSpace(ColorSpace::kRGB); + c_rgb->white_point = WhitePoint::kD65; + c_rgb->primaries = pr; + c_rgb->tf.SetTransferFunction(tf); + JXL_CHECK(c_rgb->CreateICC()); + } + + { + ColorEncoding* c_gray = c2.data() + 1; + c_gray->SetColorSpace(ColorSpace::kGray); + c_gray->white_point = WhitePoint::kD65; + c_gray->primaries = pr; + c_gray->tf.SetTransferFunction(tf); + JXL_CHECK(c_gray->CreateICC()); + } + + return c2; +} + +} // namespace + +const ColorEncoding& ColorEncoding::SRGB(bool is_gray) { + static std::array c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kSRGB); + return c2[is_gray]; +} +const ColorEncoding& ColorEncoding::LinearSRGB(bool is_gray) { + static std::array c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kLinear); + return c2[is_gray]; +} + +CIExy ColorEncoding::GetWhitePoint() const { + JXL_DASSERT(have_fields_); + CIExy xy; + switch (white_point) { + case WhitePoint::kCustom: + return white_.Get(); + + case WhitePoint::kD65: + xy.x = 0.3127; + xy.y = 0.3290; + return xy; + + case WhitePoint::kDCI: + // From https://ieeexplore.ieee.org/document/7290729 C.2 page 11 + xy.x = 0.314; + xy.y = 0.351; + return xy; + + case WhitePoint::kE: + xy.x = xy.y = 1.0 / 3; + return xy; + } + JXL_ABORT("Invalid WhitePoint %u", static_cast(white_point)); +} + +Status ColorEncoding::SetWhitePoint(const CIExy& xy) { + JXL_DASSERT(have_fields_); + if (xy.x == 0.0 || xy.y == 0.0) { + return JXL_FAILURE("Invalid white point %f %f", xy.x, xy.y); + } + if (ApproxEq(xy.x, 0.3127) && ApproxEq(xy.y, 0.3290)) { + white_point = WhitePoint::kD65; + return true; + } + if (ApproxEq(xy.x, 1.0 / 3) && ApproxEq(xy.y, 1.0 / 3)) { + white_point = WhitePoint::kE; + return true; + } + if (ApproxEq(xy.x, 0.314) && ApproxEq(xy.y, 0.351)) { + white_point = WhitePoint::kDCI; + return true; + } + white_point = WhitePoint::kCustom; + return white_.Set(xy); +} + +PrimariesCIExy ColorEncoding::GetPrimaries() const { + JXL_DASSERT(have_fields_); + JXL_ASSERT(HasPrimaries()); + PrimariesCIExy xy; + switch (primaries) { + case Primaries::kCustom: + xy.r = red_.Get(); + xy.g = green_.Get(); + xy.b = blue_.Get(); + return xy; + + case Primaries::kSRGB: + xy.r.x = 0.639998686; + xy.r.y = 0.330010138; + xy.g.x = 0.300003784; + xy.g.y = 0.600003357; + xy.b.x = 0.150002046; + xy.b.y = 0.059997204; + return xy; + + case Primaries::k2100: + xy.r.x = 0.708; + xy.r.y = 0.292; + xy.g.x = 0.170; + xy.g.y = 0.797; + xy.b.x = 0.131; + xy.b.y = 0.046; + return xy; + + case Primaries::kP3: + xy.r.x = 0.680; + xy.r.y = 0.320; + xy.g.x = 0.265; + xy.g.y = 0.690; + xy.b.x = 0.150; + xy.b.y = 0.060; + return xy; + } + JXL_ABORT("Invalid Primaries %u", static_cast(primaries)); +} + +Status ColorEncoding::SetPrimaries(const PrimariesCIExy& xy) { + JXL_DASSERT(have_fields_); + JXL_ASSERT(HasPrimaries()); + if (xy.r.x == 0.0 || xy.r.y == 0.0 || xy.g.x == 0.0 || xy.g.y == 0.0 || + xy.b.x == 0.0 || xy.b.y == 0.0) { + return JXL_FAILURE("Invalid primaries %f %f %f %f %f %f", xy.r.x, xy.r.y, + xy.g.x, xy.g.y, xy.b.x, xy.b.y); + } + + if (ApproxEq(xy.r.x, 0.64) && ApproxEq(xy.r.y, 0.33) && + ApproxEq(xy.g.x, 0.30) && ApproxEq(xy.g.y, 0.60) && + ApproxEq(xy.b.x, 0.15) && ApproxEq(xy.b.y, 0.06)) { + primaries = Primaries::kSRGB; + return true; + } + + if (ApproxEq(xy.r.x, 0.708) && ApproxEq(xy.r.y, 0.292) && + ApproxEq(xy.g.x, 0.170) && ApproxEq(xy.g.y, 0.797) && + ApproxEq(xy.b.x, 0.131) && ApproxEq(xy.b.y, 0.046)) { + primaries = Primaries::k2100; + return true; + } + if (ApproxEq(xy.r.x, 0.680) && ApproxEq(xy.r.y, 0.320) && + ApproxEq(xy.g.x, 0.265) && ApproxEq(xy.g.y, 0.690) && + ApproxEq(xy.b.x, 0.150) && ApproxEq(xy.b.y, 0.060)) { + primaries = Primaries::kP3; + return true; + } + + primaries = Primaries::kCustom; + JXL_RETURN_IF_ERROR(red_.Set(xy.r)); + JXL_RETURN_IF_ERROR(green_.Set(xy.g)); + JXL_RETURN_IF_ERROR(blue_.Set(xy.b)); + return true; +} + +Status ColorEncoding::CreateICC() { + InternalRemoveICC(); + if (!MaybeCreateProfile(*this, &icc_)) { + return JXL_FAILURE("Failed to create profile from fields"); + } + return true; +} + +std::string Description(const ColorEncoding& c_in) { + // Copy required for Implicit* + ColorEncoding c = c_in; + + std::string d = ToString(c.GetColorSpace()); + + if (!c.ImplicitWhitePoint()) { + d += '_'; + if (c.white_point == WhitePoint::kCustom) { + const CIExy wp = c.GetWhitePoint(); + d += ToString(wp.x) + ';'; + d += ToString(wp.y); + } else { + d += ToString(c.white_point); + } + } + + if (c.HasPrimaries()) { + d += '_'; + if (c.primaries == Primaries::kCustom) { + const PrimariesCIExy pr = c.GetPrimaries(); + d += ToString(pr.r.x) + ';'; + d += ToString(pr.r.y) + ';'; + d += ToString(pr.g.x) + ';'; + d += ToString(pr.g.y) + ';'; + d += ToString(pr.b.x) + ';'; + d += ToString(pr.b.y); + } else { + d += ToString(c.primaries); + } + } + + d += '_'; + d += ToString(c.rendering_intent); + + if (!c.tf.SetImplicit()) { + d += '_'; + if (c.tf.IsGamma()) { + d += 'g'; + d += ToString(c.tf.GetGamma()); + } else { + d += ToString(c.tf.GetTransferFunction()); + } + } + + return d; +} + +Status ParseDescription(const std::string& description, + ColorEncoding* JXL_RESTRICT c) { + Tokenizer tokenizer(&description, '_'); + JXL_RETURN_IF_ERROR(ParseColorSpace(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseWhitePoint(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParsePrimaries(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseRenderingIntent(&tokenizer, c)); + JXL_RETURN_IF_ERROR(ParseTransferFunction(&tokenizer, c)); + return true; +} + +Customxy::Customxy() { Bundle::Init(this); } +Status Customxy::VisitFields(Visitor* JXL_RESTRICT visitor) { + uint32_t ux = PackSigned(x); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &ux)); + x = UnpackSigned(ux); + uint32_t uy = PackSigned(y); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &uy)); + y = UnpackSigned(uy); + return true; +} + +CustomTransferFunction::CustomTransferFunction() { Bundle::Init(this); } +Status CustomTransferFunction::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(!SetImplicit())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_gamma_)); + + if (visitor->Conditional(have_gamma_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(24, kGammaMul, &gamma_)); + if (gamma_ > kGammaMul || gamma_ * kMaxGamma < kGammaMul) { + return JXL_FAILURE("Invalid gamma %u", gamma_); + } + } + + if (visitor->Conditional(!have_gamma_)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(TransferFunction::kSRGB, &transfer_function_)); + } + } + + return true; +} + +ColorEncoding::ColorEncoding() { Bundle::Init(this); } +Status ColorEncoding::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &want_icc_)); + + // Always send even if want_icc_ because this affects decoding. + // We can skip the white point/primaries because they do not. + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(ColorSpace::kRGB, &color_space_)); + + if (visitor->Conditional(!WantICC())) { + // Serialize enums. NOTE: we set the defaults to the most common values so + // ImageMetadata.all_default is true in the common case. + + if (visitor->Conditional(!ImplicitWhitePoint())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(WhitePoint::kD65, &white_point)); + if (visitor->Conditional(white_point == WhitePoint::kCustom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&white_)); + } + } + + if (visitor->Conditional(HasPrimaries())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(Primaries::kSRGB, &primaries)); + if (visitor->Conditional(primaries == Primaries::kCustom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&red_)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&green_)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blue_)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tf)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(RenderingIntent::kRelative, &rendering_intent)); + + // We didn't have ICC, so all fields should be known. + if (color_space_ == ColorSpace::kUnknown || tf.IsUnknown()) { + return JXL_FAILURE( + "No ICC but cs %u and tf %u%s", + static_cast(color_space_), + tf.IsGamma() ? 0 + : static_cast(tf.GetTransferFunction()), + tf.IsGamma() ? "(gamma)" : ""); + } + + JXL_RETURN_IF_ERROR(CreateICC()); + } + + if (WantICC() && visitor->IsReading()) { + // Haven't called SetICC() yet, do nothing. + } else { + if (ICC().empty()) return JXL_FAILURE("Empty ICC"); + } + + return true; +} + +void ConvertInternalToExternalColorEncoding(const ColorEncoding& internal, + JxlColorEncoding* external) { + external->color_space = static_cast(internal.GetColorSpace()); + + external->white_point = static_cast(internal.white_point); + + jxl::CIExy whitepoint = internal.GetWhitePoint(); + external->white_point_xy[0] = whitepoint.x; + external->white_point_xy[1] = whitepoint.y; + + if (external->color_space == JXL_COLOR_SPACE_RGB || + external->color_space == JXL_COLOR_SPACE_UNKNOWN) { + external->primaries = static_cast(internal.primaries); + jxl::PrimariesCIExy primaries = internal.GetPrimaries(); + external->primaries_red_xy[0] = primaries.r.x; + external->primaries_red_xy[1] = primaries.r.y; + external->primaries_green_xy[0] = primaries.g.x; + external->primaries_green_xy[1] = primaries.g.y; + external->primaries_blue_xy[0] = primaries.b.x; + external->primaries_blue_xy[1] = primaries.b.y; + } + + if (internal.tf.IsGamma()) { + external->transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + external->gamma = internal.tf.GetGamma(); + } else { + external->transfer_function = + static_cast(internal.tf.GetTransferFunction()); + external->gamma = 0; + } + + external->rendering_intent = + static_cast(internal.rendering_intent); +} + +/* Chromatic adaptation matrices*/ +static float kBradford[9] = { + 0.8951f, 0.2664f, -0.1614f, -0.7502f, 1.7135f, + 0.0367f, 0.0389f, -0.0685f, 1.0296f, +}; + +static float kBradfordInv[9] = { + 0.9869929f, -0.1470543f, 0.1599627f, 0.4323053f, 0.5183603f, + 0.0492912f, -0.0085287f, 0.0400428f, 0.9684867f, +}; + +// Adapts whitepoint x, y to D50 +Status AdaptToXYZD50(float wx, float wy, float matrix[9]) { + if (wx < 0 || wx > 1 || wy < 0 || wy > 1) { + return JXL_FAILURE("xy color out of range"); + } + + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + float w50[3] = {0.96422f, 1.0f, 0.82521f}; + + float lms[3]; + float lms50[3]; + + MatMul(kBradford, w, 3, 3, 1, lms); + MatMul(kBradford, w50, 3, 3, 1, lms50); + + float a[9] = { + lms50[0] / lms[0], 0, 0, 0, lms50[1] / lms[1], 0, 0, 0, lms50[2] / lms[2], + }; + + float b[9]; + MatMul(a, kBradford, 3, 3, 3, b); + MatMul(kBradfordInv, b, 3, 3, 3, matrix); + + return true; +} + +Status PrimariesToXYZD50(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]) { + if (rx < 0 || rx > 1 || ry < 0 || ry > 1 || gx < 0 || gx > 1 || gy < 0 || + gy > 1 || bx < 0 || bx > 1 || by < 0 || by > 1 || wx < 0 || wx > 1 || + wy < 0 || wy > 1) { + return JXL_FAILURE("xy color out of range"); + } + + float primaries[9] = { + rx, gx, bx, ry, gy, by, 1.0f - rx - ry, 1.0f - gx - gy, 1.0f - bx - by}; + float primaries_inv[9]; + memcpy(primaries_inv, primaries, sizeof(float) * 9); + Inv3x3Matrix(primaries_inv); + + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + float xyz[3]; + MatMul(primaries_inv, w, 3, 3, 1, xyz); + + float a[9] = { + xyz[0], 0, 0, 0, xyz[1], 0, 0, 0, xyz[2], + }; + + float toXYZ[9]; + MatMul(primaries, a, 3, 3, 3, toXYZ); + + float d50[9]; + JXL_RETURN_IF_ERROR(AdaptToXYZD50(wx, wy, d50)); + + MatMul(d50, toXYZ, 3, 3, 3, matrix); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h new file mode 100644 index 000000000000..4c563b466fd6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h @@ -0,0 +1,462 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COLOR_ENCODING_INTERNAL_H_ +#define LIB_JXL_COLOR_ENCODING_INTERNAL_H_ + +// Metadata for color space conversions. + +#include +#include +#include + +#include // std::abs +#include +#include +#include + +#include "jxl/color_encoding.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// (All CIE units are for the standard 1931 2 degree observer) + +// Color space the color pixel data is encoded in. The color pixel data is +// 3-channel in all cases except in case of kGray, where it uses only 1 channel. +// This also determines the amount of channels used in modular encoding. +enum class ColorSpace : uint32_t { + // Trichromatic color data. This also includes CMYK if a kBlack + // ExtraChannelInfo is present. This implies, if there is an ICC profile, that + // the ICC profile uses a 3-channel color space if no kBlack extra channel is + // present, or uses color space 'CMYK' if a kBlack extra channel is present. + kRGB, + // Single-channel data. This implies, if there is an ICC profile, that the ICC + // profile also represents single-channel data and has the appropriate color + // space ('GRAY'). + kGray, + // Like kRGB, but implies fixed values for primaries etc. + kXYB, + // For non-RGB/gray data, e.g. from non-electro-optical sensors. Otherwise + // the same conditions as kRGB apply. + kUnknown +}; + +static inline const char* EnumName(ColorSpace /*unused*/) { + return "ColorSpace"; +} +static inline constexpr uint64_t EnumBits(ColorSpace /*unused*/) { + using CS = ColorSpace; + return MakeBit(CS::kRGB) | MakeBit(CS::kGray) | MakeBit(CS::kXYB) | + MakeBit(CS::kUnknown); +} + +// Values from CICP ColourPrimaries. +enum class WhitePoint : uint32_t { + kD65 = 1, // sRGB/BT.709/Display P3/BT.2020 + kCustom = 2, // Actual values encoded in separate fields + kE = 10, // XYZ + kDCI = 11, // DCI-P3 +}; + +static inline const char* EnumName(WhitePoint /*unused*/) { + return "WhitePoint"; +} +static inline constexpr uint64_t EnumBits(WhitePoint /*unused*/) { + return MakeBit(WhitePoint::kD65) | MakeBit(WhitePoint::kCustom) | + MakeBit(WhitePoint::kE) | MakeBit(WhitePoint::kDCI); +} + +// Values from CICP ColourPrimaries +enum class Primaries : uint32_t { + kSRGB = 1, // Same as BT.709 + kCustom = 2, // Actual values encoded in separate fields + k2100 = 9, // Same as BT.2020 + kP3 = 11, +}; + +static inline const char* EnumName(Primaries /*unused*/) { return "Primaries"; } +static inline constexpr uint64_t EnumBits(Primaries /*unused*/) { + using Pr = Primaries; + return MakeBit(Pr::kSRGB) | MakeBit(Pr::kCustom) | MakeBit(Pr::k2100) | + MakeBit(Pr::kP3); +} + +// Values from CICP TransferCharacteristics +enum TransferFunction : uint32_t { + k709 = 1, + kUnknown = 2, + kLinear = 8, + kSRGB = 13, + kPQ = 16, // from BT.2100 + kDCI = 17, // from SMPTE RP 431-2 reference projector + kHLG = 18, // from BT.2100 +}; + +static inline const char* EnumName(TransferFunction /*unused*/) { + return "TransferFunction"; +} +static inline constexpr uint64_t EnumBits(TransferFunction /*unused*/) { + using TF = TransferFunction; + return MakeBit(TF::k709) | MakeBit(TF::kLinear) | MakeBit(TF::kSRGB) | + MakeBit(TF::kPQ) | MakeBit(TF::kDCI) | MakeBit(TF::kHLG) | + MakeBit(TF::kUnknown); +} + +enum class RenderingIntent : uint32_t { + // Values match ICC sRGB encodings. + kPerceptual = 0, // good for photos, requires a profile with LUT. + kRelative, // good for logos. + kSaturation, // perhaps useful for CG with fully saturated colors. + kAbsolute, // leaves white point unchanged; good for proofing. +}; + +static inline const char* EnumName(RenderingIntent /*unused*/) { + return "RenderingIntent"; +} +static inline constexpr uint64_t EnumBits(RenderingIntent /*unused*/) { + using RI = RenderingIntent; + return MakeBit(RI::kPerceptual) | MakeBit(RI::kRelative) | + MakeBit(RI::kSaturation) | MakeBit(RI::kAbsolute); +} + +// Chromaticity (Y is omitted because it is 1 for primaries/white points) +struct CIExy { + double x = 0.0; + double y = 0.0; +}; + +struct PrimariesCIExy { + CIExy r; + CIExy g; + CIExy b; +}; + +// Serializable form of CIExy. +struct Customxy : public Fields { + Customxy(); + const char* Name() const override { return "Customxy"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + CIExy Get() const; + // Returns false if x or y do not fit in the encoding. + Status Set(const CIExy& xy); + + int32_t x; + int32_t y; +}; + +struct CustomTransferFunction : public Fields { + CustomTransferFunction(); + const char* Name() const override { return "CustomTransferFunction"; } + + // Sets fields and returns true if nonserialized_color_space has an implicit + // transfer function, otherwise leaves fields unchanged and returns false. + bool SetImplicit(); + + // Gamma: only used for PNG inputs + bool IsGamma() const { return have_gamma_; } + double GetGamma() const { + JXL_ASSERT(IsGamma()); + return gamma_ * 1E-7; // (0, 1) + } + Status SetGamma(double gamma); + + TransferFunction GetTransferFunction() const { + JXL_ASSERT(!IsGamma()); + return transfer_function_; + } + void SetTransferFunction(const TransferFunction tf) { + have_gamma_ = false; + transfer_function_ = tf; + } + + bool IsUnknown() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kUnknown); + } + bool IsSRGB() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kSRGB); + } + bool IsLinear() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kLinear); + } + bool IsPQ() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kPQ); + } + bool IsHLG() const { + return !have_gamma_ && (transfer_function_ == TransferFunction::kHLG); + } + bool IsSame(const CustomTransferFunction& other) const { + if (have_gamma_ != other.have_gamma_) return false; + if (have_gamma_) { + if (gamma_ != other.gamma_) return false; + } else { + if (transfer_function_ != other.transfer_function_) return false; + } + return true; + } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields! + ColorSpace nonserialized_color_space = ColorSpace::kRGB; + + private: + static constexpr uint32_t kGammaMul = 10000000; + + bool have_gamma_; + + // OETF exponent to go from linear to gamma-compressed. + uint32_t gamma_; // Only used if have_gamma_. + + // Can be kUnknown. + TransferFunction transfer_function_; // Only used if !have_gamma_. +}; + +// Compact encoding of data required to interpret and translate pixels to a +// known color space. Stored in Metadata. Thread-compatible. +struct ColorEncoding : public Fields { + ColorEncoding(); + const char* Name() const override { return "ColorEncoding"; } + + // Returns ready-to-use color encodings (initialized on-demand). + static const ColorEncoding& SRGB(bool is_gray = false); + static const ColorEncoding& LinearSRGB(bool is_gray = false); + + // Returns true if an ICC profile was successfully created from fields. + // Must be called after modifying fields. Defined in color_management.cc. + Status CreateICC(); + + // Returns non-empty and valid ICC profile, unless: + // - between calling InternalRemoveICC() and CreateICC() in tests; + // - WantICC() == true and SetICC() was not yet called; + // - after a failed call to SetSRGB(), SetICC(), or CreateICC(). + const PaddedBytes& ICC() const { return icc_; } + + // Internal only, do not call except from tests. + void InternalRemoveICC() { icc_.clear(); } + + // Returns true if `icc` is assigned and decoded successfully. If so, + // subsequent WantICC() will return true until DecideIfWantICC() changes it. + // Returning false indicates data has been lost. + Status SetICC(PaddedBytes&& icc) { + if (icc.empty()) return false; + icc_ = std::move(icc); + + if (!SetFieldsFromICC()) { + InternalRemoveICC(); + return false; + } + + want_icc_ = true; + return true; + } + + // Sets the raw ICC profile bytes, without parsing the ICC, and without + // updating the direct fields such as whitepoint, primaries and color + // space. Functions to get and set fields, such as SetWhitePoint, cannot be + // used anymore after this and functions such as IsSRGB return false no matter + // what the contents of the icc profile. + Status SetICCRaw(PaddedBytes&& icc) { + if (icc.empty()) return false; + icc_ = std::move(icc); + + want_icc_ = true; + have_fields_ = false; + return true; + } + + // Returns whether to send the ICC profile in the codestream. + bool WantICC() const { return want_icc_; } + + // Return whether the direct fields are set, if false but ICC is set, only + // raw ICC bytes are known. + bool HaveFields() const { return have_fields_; } + + // Causes WantICC() to return false if ICC() can be reconstructed from fields. + // Defined in color_management.cc. + void DecideIfWantICC(); + + bool IsGray() const { return color_space_ == ColorSpace::kGray; } + size_t Channels() const { return IsGray() ? 1 : 3; } + + // Returns false if the field is invalid and unusable. + bool HasPrimaries() const { + return !IsGray() && color_space_ != ColorSpace::kXYB; + } + + // Returns true after setting the field to a value defined by color_space, + // otherwise false and leaves the field unchanged. + bool ImplicitWhitePoint() { + if (color_space_ == ColorSpace::kXYB) { + white_point = WhitePoint::kD65; + return true; + } + return false; + } + + // Returns whether the color space is known to be sRGB. If a raw unparsed ICC + // profile is set without the fields being set, this returns false, even if + // the content of the ICC profile would match sRGB. + bool IsSRGB() const { + if (!have_fields_) return false; + if (!IsGray() && color_space_ != ColorSpace::kRGB) return false; + if (white_point != WhitePoint::kD65) return false; + if (primaries != Primaries::kSRGB) return false; + if (!tf.IsSRGB()) return false; + return true; + } + + // Returns whether the color space is known to be linear sRGB. If a raw + // unparsed ICC profile is set without the fields being set, this returns + // false, even if the content of the ICC profile would match linear sRGB. + bool IsLinearSRGB() const { + if (!have_fields_) return false; + if (!IsGray() && color_space_ != ColorSpace::kRGB) return false; + if (white_point != WhitePoint::kD65) return false; + if (primaries != Primaries::kSRGB) return false; + if (!tf.IsLinear()) return false; + return true; + } + + Status SetSRGB(const ColorSpace cs, + const RenderingIntent ri = RenderingIntent::kRelative) { + InternalRemoveICC(); + JXL_ASSERT(cs == ColorSpace::kGray || cs == ColorSpace::kRGB); + color_space_ = cs; + white_point = WhitePoint::kD65; + primaries = Primaries::kSRGB; + tf.SetTransferFunction(TransferFunction::kSRGB); + rendering_intent = ri; + return CreateICC(); + } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Accessors ensure tf.nonserialized_color_space is updated at the same time. + ColorSpace GetColorSpace() const { return color_space_; } + void SetColorSpace(const ColorSpace cs) { + color_space_ = cs; + tf.nonserialized_color_space = cs; + } + + CIExy GetWhitePoint() const; + Status SetWhitePoint(const CIExy& xy); + + PrimariesCIExy GetPrimaries() const; + Status SetPrimaries(const PrimariesCIExy& xy); + + // Checks if the color spaces (including white point / primaries) are the + // same, but ignores the transfer function, rendering intent and ICC bytes. + bool SameColorSpace(const ColorEncoding& other) const { + if (color_space_ != other.color_space_) return false; + + if (white_point != other.white_point) return false; + if (white_point == WhitePoint::kCustom) { + if (white_.x != other.white_.x || white_.y != other.white_.y) + return false; + } + + if (HasPrimaries() != other.HasPrimaries()) return false; + if (HasPrimaries()) { + if (primaries != other.primaries) return false; + if (primaries == Primaries::kCustom) { + if (red_.x != other.red_.x || red_.y != other.red_.y) return false; + if (green_.x != other.green_.x || green_.y != other.green_.y) + return false; + if (blue_.x != other.blue_.x || blue_.y != other.blue_.y) return false; + } + } + return true; + } + + // Checks if the color space and transfer function are the same, ignoring + // rendering intent and ICC bytes + bool SameColorEncoding(const ColorEncoding& other) const { + return SameColorSpace(other) && tf.IsSame(other.tf); + } + + mutable bool all_default; + + // Only valid if HaveFields() + WhitePoint white_point; + Primaries primaries; // Only valid if HasPrimaries() + CustomTransferFunction tf; + RenderingIntent rendering_intent; + + private: + // Returns true if all fields have been initialized (possibly to kUnknown). + // Returns false if the ICC profile is invalid or decoding it fails. + // Defined in color_management.cc. + Status SetFieldsFromICC(); + + // If true, the codestream contains an ICC profile and we do not serialize + // fields. Otherwise, fields are serialized and we create an ICC profile. + bool want_icc_; + + // When false, fields such as white_point and tf are invalid and must not be + // used. This occurs after setting a raw bytes-only ICC profile, only the + // ICC bytes may be used. The color_space_ field is still valid. + bool have_fields_ = true; + + PaddedBytes icc_; // Valid ICC profile + + ColorSpace color_space_; // Can be kUnknown + + // Only used if white_point == kCustom. + Customxy white_; + + // Only used if primaries == kCustom. + Customxy red_; + Customxy green_; + Customxy blue_; +}; + +// Returns whether the two inputs are approximately equal. +static inline bool ApproxEq(const double a, const double b, +#if JPEGXL_ENABLE_SKCMS + double max_l1 = 1E-3) { +#else + double max_l1 = 8E-5) { +#endif + // Threshold should be sufficient for ICC's 15-bit fixed-point numbers. + // We have seen differences of 7.1E-5 with lcms2 and 1E-3 with skcms. + return std::abs(a - b) <= max_l1; +} + +// Returns a representation of the ColorEncoding fields (not icc). +// Example description: "RGB_D65_SRG_Rel_Lin" +std::string Description(const ColorEncoding& c); +Status ParseDescription(const std::string& description, + ColorEncoding* JXL_RESTRICT c); + +static inline std::ostream& operator<<(std::ostream& os, + const ColorEncoding& c) { + return os << Description(c); +} + +void ConvertInternalToExternalColorEncoding(const jxl::ColorEncoding& internal, + JxlColorEncoding* external); + +Status PrimariesToXYZD50(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]); +Status AdaptToXYZD50(float wx, float wy, float matrix[9]); + +} // namespace jxl + +#endif // LIB_JXL_COLOR_ENCODING_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc b/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc new file mode 100644 index 000000000000..f0a8ca3911ce --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc @@ -0,0 +1,183 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/color_encoding_internal.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +TEST(ColorEncodingTest, RoundTripAll) { + for (const test::ColorEncodingDescriptor& cdesc : test::AllEncodings()) { + const ColorEncoding c_original = test::ColorEncodingFromDescriptor(cdesc); + // Verify Set(Get) yields the same white point/primaries/gamma. + { + ColorEncoding c; + EXPECT_TRUE(c.SetWhitePoint(c_original.GetWhitePoint())); + EXPECT_EQ(c_original.white_point, c.white_point); + } + { + ColorEncoding c; + EXPECT_TRUE(c.SetPrimaries(c_original.GetPrimaries())); + EXPECT_EQ(c_original.primaries, c.primaries); + } + if (c_original.tf.IsGamma()) { + ColorEncoding c; + EXPECT_TRUE(c.tf.SetGamma(c_original.tf.GetGamma())); + EXPECT_TRUE(c_original.tf.IsSame(c.tf)); + } + + // Verify ParseDescription(Description) yields the same ColorEncoding + { + const std::string description = Description(c_original); + printf("%s\n", description.c_str()); + ColorEncoding c; + EXPECT_TRUE(ParseDescription(description, &c)); + EXPECT_TRUE(c_original.SameColorEncoding(c)); + } + } +} + +// Verify Set(Get) for specific custom values + +TEST(ColorEncodingTest, NanGamma) { + const std::string description = "Gra_2_Per_gnan"; + ColorEncoding c; + EXPECT_FALSE(ParseDescription(description, &c)); +} + +TEST(ColorEncodingTest, CustomWhitePoint) { + ColorEncoding c; + // Nonsensical values + CIExy xy_in; + xy_in.x = 0.8; + xy_in.y = 0.01; + EXPECT_TRUE(c.SetWhitePoint(xy_in)); + const CIExy xy = c.GetWhitePoint(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetWhitePoint(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomPrimaries) { + ColorEncoding c; + PrimariesCIExy xy_in; + // Nonsensical values + xy_in.r.x = -0.01; + xy_in.r.y = 0.2; + xy_in.g.x = 0.4; + xy_in.g.y = 0.401; + xy_in.b.x = 1.1; + xy_in.b.y = -1.2; + EXPECT_TRUE(c.SetPrimaries(xy_in)); + const PrimariesCIExy xy = c.GetPrimaries(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetPrimaries(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomGamma) { + ColorEncoding c; +#ifndef JXL_CRASH_ON_ERROR + EXPECT_FALSE(c.tf.SetGamma(0.0)); + EXPECT_FALSE(c.tf.SetGamma(-1E-6)); + EXPECT_FALSE(c.tf.SetGamma(1.001)); +#endif + EXPECT_TRUE(c.tf.SetGamma(1.0)); + EXPECT_FALSE(c.tf.IsGamma()); + EXPECT_TRUE(c.tf.IsLinear()); + + EXPECT_TRUE(c.tf.SetGamma(0.123)); + EXPECT_TRUE(c.tf.IsGamma()); + const double gamma = c.tf.GetGamma(); + + ColorEncoding c2; + EXPECT_TRUE(c2.tf.SetGamma(gamma)); + EXPECT_TRUE(c.SameColorEncoding(c2)); + EXPECT_TRUE(c2.tf.IsGamma()); +} + +TEST(ColorEncodingTest, InternalExternalConversion) { + ColorEncoding source_internal; + JxlColorEncoding external; + ColorEncoding destination_internal; + + for (int i = 0; i < 100; i++) { + source_internal.SetColorSpace(static_cast(rand() % 4)); + CIExy wp; + wp.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + wp.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetWhitePoint(wp)); + if (source_internal.HasPrimaries()) { + PrimariesCIExy primaries; + primaries.r.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.r.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetPrimaries(primaries)); + } + CustomTransferFunction tf; + EXPECT_TRUE(tf.SetGamma((float(rand()) / float((RAND_MAX)) * 0.5) + 0.25)); + source_internal.tf = tf; + source_internal.rendering_intent = static_cast(rand() % 4); + + ConvertInternalToExternalColorEncoding(source_internal, &external); + EXPECT_TRUE(ConvertExternalToInternalColorEncoding(external, + &destination_internal)); + + EXPECT_EQ(source_internal.GetColorSpace(), + destination_internal.GetColorSpace()); + EXPECT_EQ(source_internal.white_point, destination_internal.white_point); + EXPECT_EQ(source_internal.GetWhitePoint().x, + destination_internal.GetWhitePoint().x); + EXPECT_EQ(source_internal.GetWhitePoint().y, + destination_internal.GetWhitePoint().y); + if (source_internal.HasPrimaries()) { + EXPECT_EQ(source_internal.GetPrimaries().r.x, + destination_internal.GetPrimaries().r.x); + EXPECT_EQ(source_internal.GetPrimaries().r.y, + destination_internal.GetPrimaries().r.y); + EXPECT_EQ(source_internal.GetPrimaries().g.x, + destination_internal.GetPrimaries().g.x); + EXPECT_EQ(source_internal.GetPrimaries().g.y, + destination_internal.GetPrimaries().g.y); + EXPECT_EQ(source_internal.GetPrimaries().b.x, + destination_internal.GetPrimaries().b.x); + EXPECT_EQ(source_internal.GetPrimaries().b.y, + destination_internal.GetPrimaries().b.y); + } + EXPECT_EQ(source_internal.tf.IsGamma(), destination_internal.tf.IsGamma()); + if (source_internal.tf.IsGamma()) { + EXPECT_EQ(source_internal.tf.GetGamma(), + destination_internal.tf.GetGamma()); + } else { + EXPECT_EQ(source_internal.tf.GetTransferFunction(), + destination_internal.tf.GetTransferFunction()); + } + EXPECT_EQ(source_internal.rendering_intent, + destination_internal.rendering_intent); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_management.cc b/third_party/jpeg-xl/lib/jxl/color_management.cc new file mode 100644 index 000000000000..232ee88b3bb6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_management.cc @@ -0,0 +1,527 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defined by build system; this avoids IDE warnings. Must come before +// color_management.h (affects header definitions). +#ifndef JPEGXL_ENABLE_SKCMS +#define JPEGXL_ENABLE_SKCMS 0 +#endif + +#include "lib/jxl/color_management.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/color_management.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/linalg.h" // MatMul, Inv3x3Matrix +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// NOTE: this is only used to provide a reasonable ICC profile that other +// software can read. Our own transforms use ExtraTF instead because that is +// more precise and supports unbounded mode. +std::vector CreateTableCurve(uint32_t N, const ExtraTF tf) { + JXL_ASSERT(N <= 4096); // ICC MFT2 only allows 4K entries + JXL_ASSERT(tf == ExtraTF::kPQ || tf == ExtraTF::kHLG); + // No point using float - LCMS converts to 16-bit for A2B/MFT. + std::vector table(N); + for (uint32_t i = 0; i < N; ++i) { + const float x = static_cast(i) / (N - 1); // 1.0 at index N - 1. + const double dx = static_cast(x); + // LCMS requires EOTF (e.g. 2.4 exponent). + double y = (tf == ExtraTF::kHLG) ? TF_HLG().DisplayFromEncoded(dx) + : TF_PQ().DisplayFromEncoded(dx); + JXL_ASSERT(y >= 0.0); + // Clamp to table range - necessary for HLG. + if (y > 1.0) y = 1.0; + // 1.0 corresponds to table value 0xFFFF. + table[i] = static_cast(roundf(y * 65535.0)); + } + return table; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(CreateTableCurve); // Local function. + +Status CIEXYZFromWhiteCIExy(const CIExy& xy, float XYZ[3]) { + // Target Y = 1. + if (std::abs(xy.y) < 1e-12) return JXL_FAILURE("Y value is too small"); + const float factor = 1 / xy.y; + XYZ[0] = xy.x * factor; + XYZ[1] = 1; + XYZ[2] = (1 - xy.x - xy.y) * factor; + return true; +} + +namespace { + +// NOTE: this is only used to provide a reasonable ICC profile that other +// software can read. Our own transforms use ExtraTF instead because that is +// more precise and supports unbounded mode. +template +std::vector CreateTableCurve(uint32_t N, const Func& func) { + JXL_ASSERT(N <= 4096); // ICC MFT2 only allows 4K entries + // No point using float - LCMS converts to 16-bit for A2B/MFT. + std::vector table(N); + for (uint32_t i = 0; i < N; ++i) { + const float x = static_cast(i) / (N - 1); // 1.0 at index N - 1. + // LCMS requires EOTF (e.g. 2.4 exponent). + double y = func.DisplayFromEncoded(static_cast(x)); + JXL_ASSERT(y >= 0.0); + // Clamp to table range - necessary for HLG. + if (y > 1.0) y = 1.0; + // 1.0 corresponds to table value 0xFFFF. + table[i] = static_cast(roundf(y * 65535.0)); + } + return table; +} + +void ICCComputeMD5(const PaddedBytes& data, uint8_t sum[16]) { + PaddedBytes data64 = data; + data64.push_back(128); + // Add bytes such that ((size + 8) & 63) == 0. + size_t extra = ((64 - ((data64.size() + 8) & 63)) & 63); + data64.resize(data64.size() + extra, 0); + for (uint64_t i = 0; i < 64; i += 8) { + data64.push_back(static_cast(data.size() << 3u) >> i); + } + + static const uint32_t sineparts[64] = { + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, + 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, + 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, + 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, + 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, + 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, + }; + static const uint32_t shift[64] = { + 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, + 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, + 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, + 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, + }; + + uint32_t a0 = 0x67452301, b0 = 0xefcdab89, c0 = 0x98badcfe, d0 = 0x10325476; + + for (size_t i = 0; i < data64.size(); i += 64) { + uint32_t a = a0, b = b0, c = c0, d = d0, f, g; + for (size_t j = 0; j < 64; j++) { + if (j < 16) { + f = (b & c) | ((~b) & d); + g = j; + } else if (j < 32) { + f = (d & b) | ((~d) & c); + g = (5 * j + 1) & 0xf; + } else if (j < 48) { + f = b ^ c ^ d; + g = (3 * j + 5) & 0xf; + } else { + f = c ^ (b | (~d)); + g = (7 * j) & 0xf; + } + uint32_t dg0 = data64[i + g * 4 + 0], dg1 = data64[i + g * 4 + 1], + dg2 = data64[i + g * 4 + 2], dg3 = data64[i + g * 4 + 3]; + uint32_t u = dg0 | (dg1 << 8u) | (dg2 << 16u) | (dg3 << 24u); + f += a + sineparts[j] + u; + a = d; + d = c; + c = b; + b += (f << shift[j]) | (f >> (32u - shift[j])); + } + a0 += a; + b0 += b; + c0 += c; + d0 += d; + } + sum[0] = a0; + sum[1] = a0 >> 8u; + sum[2] = a0 >> 16u; + sum[3] = a0 >> 24u; + sum[4] = b0; + sum[5] = b0 >> 8u; + sum[6] = b0 >> 16u; + sum[7] = b0 >> 24u; + sum[8] = c0; + sum[9] = c0 >> 8u; + sum[10] = c0 >> 16u; + sum[11] = c0 >> 24u; + sum[12] = d0; + sum[13] = d0 >> 8u; + sum[14] = d0 >> 16u; + sum[15] = d0 >> 24u; +} + +Status CreateICCChadMatrix(CIExy w, float result[9]) { + float m[9]; + if (w.y == 0) { // WhitePoint can not be pitch-black. + return JXL_FAILURE("Invalid WhitePoint"); + } + JXL_RETURN_IF_ERROR(AdaptToXYZD50(w.x, w.y, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +// Creates RGB to XYZ matrix given RGB primaries and whitepoint in xy. +Status CreateICCRGBMatrix(CIExy r, CIExy g, CIExy b, CIExy w, float result[9]) { + float m[9]; + JXL_RETURN_IF_ERROR( + PrimariesToXYZD50(r.x, r.y, g.x, g.y, b.x, b.y, w.x, w.y, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +void WriteICCUint32(uint32_t value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + (*icc)[pos + 0] = (value >> 24u) & 255; + (*icc)[pos + 1] = (value >> 16u) & 255; + (*icc)[pos + 2] = (value >> 8u) & 255; + (*icc)[pos + 3] = value & 255; +} + +void WriteICCUint16(uint16_t value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 2) icc->resize(pos + 2); + (*icc)[pos + 0] = (value >> 8u) & 255; + (*icc)[pos + 1] = value & 255; +} + +// Writes a 4-character tag +void WriteICCTag(const char* value, size_t pos, PaddedBytes* JXL_RESTRICT icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + memcpy(icc->data() + pos, value, 4); +} + +Status WriteICCS15Fixed16(float value, size_t pos, + PaddedBytes* JXL_RESTRICT icc) { + // "nextafterf" for 32768.0f towards zero are: + // 32767.998046875, 32767.99609375, 32767.994140625 + // Even the first value works well,... + bool ok = (-32767.995f <= value) && (value <= 32767.995f); + if (!ok) return JXL_FAILURE("ICC value is out of range / NaN"); + int32_t i = value * 65536.0f + 0.5f; + // Use two's complement + uint32_t u = static_cast(i); + WriteICCUint32(u, pos, icc); + return true; +} + +Status CreateICCHeader(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT header) { + // TODO(lode): choose color management engine name, e.g. "skia" if + // integrated in skia. + static const char* kCmm = "jxl "; + + header->resize(128, 0); + + WriteICCUint32(0, 0, header); // size, correct value filled in at end + WriteICCTag(kCmm, 4, header); + WriteICCUint32(0x04300000u, 8, header); + WriteICCTag("mntr", 12, header); + WriteICCTag(c.IsGray() ? "GRAY" : "RGB ", 16, header); + WriteICCTag("XYZ ", 20, header); + + // Three uint32_t's date/time encoding. + // TODO(lode): encode actual date and time, this is a placeholder + uint32_t year = 2019, month = 12, day = 1; + uint32_t hour = 0, minute = 0, second = 0; + WriteICCUint16(year, 24, header); + WriteICCUint16(month, 26, header); + WriteICCUint16(day, 28, header); + WriteICCUint16(hour, 30, header); + WriteICCUint16(minute, 32, header); + WriteICCUint16(second, 34, header); + + WriteICCTag("acsp", 36, header); + WriteICCTag("APPL", 40, header); + WriteICCUint32(0, 44, header); // flags + WriteICCUint32(0, 48, header); // device manufacturer + WriteICCUint32(0, 52, header); // device model + WriteICCUint32(0, 56, header); // device attributes + WriteICCUint32(0, 60, header); // device attributes + WriteICCUint32(static_cast(c.rendering_intent), 64, header); + + // Mandatory D50 white point of profile connection space + WriteICCUint32(0x0000f6d6, 68, header); + WriteICCUint32(0x00010000, 72, header); + WriteICCUint32(0x0000d32d, 76, header); + + WriteICCTag(kCmm, 80, header); + + return true; +} + +void AddToICCTagTable(const char* tag, size_t offset, size_t size, + PaddedBytes* JXL_RESTRICT tagtable, + std::vector* offsets) { + WriteICCTag(tag, tagtable->size(), tagtable); + // writing true offset deferred to later + WriteICCUint32(0, tagtable->size(), tagtable); + offsets->push_back(offset); + WriteICCUint32(size, tagtable->size(), tagtable); +} + +void FinalizeICCTag(PaddedBytes* JXL_RESTRICT tags, size_t* offset, + size_t* size) { + while ((tags->size() & 3) != 0) { + tags->push_back(0); + } + *offset += *size; + *size = tags->size() - *offset; +} + +// The input text must be ASCII, writing other characters to UTF-16 is not +// implemented. +void CreateICCMlucTag(const std::string& text, PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("mluc", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint32(1, tags->size(), tags); + WriteICCUint32(12, tags->size(), tags); + WriteICCTag("enUS", tags->size(), tags); + WriteICCUint32(text.size() * 2, tags->size(), tags); + WriteICCUint32(28, tags->size(), tags); + for (size_t i = 0; i < text.size(); i++) { + tags->push_back(0); // prepend 0 for UTF-16 + tags->push_back(text[i]); + } +} + +Status CreateICCXYZTag(float xyz[3], PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("XYZ ", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 3; ++i) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(xyz[i], tags->size(), tags)); + } + return true; +} + +Status CreateICCChadTag(float chad[9], PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("sf32", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(chad[i], tags->size(), tags)); + } + return true; +} + +void CreateICCCurvCurvTag(const std::vector& curve, + PaddedBytes* JXL_RESTRICT tags) { + size_t pos = tags->size(); + tags->resize(tags->size() + 12 + curve.size() * 2, 0); + WriteICCTag("curv", pos, tags); + WriteICCUint32(0, pos + 4, tags); + WriteICCUint32(curve.size(), pos + 8, tags); + for (size_t i = 0; i < curve.size(); i++) { + WriteICCUint16(curve[i], pos + 12 + i * 2, tags); + } +} + +Status CreateICCCurvParaTag(std::vector params, size_t curve_type, + PaddedBytes* JXL_RESTRICT tags) { + WriteICCTag("para", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint16(curve_type, tags->size(), tags); + WriteICCUint16(0, tags->size(), tags); + for (size_t i = 0; i < params.size(); i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(params[i], tags->size(), tags)); + } + return true; +} +} // namespace + +Status MaybeCreateProfile(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT icc) { + PaddedBytes header, tagtable, tags; + + if (c.GetColorSpace() == ColorSpace::kUnknown || c.tf.IsUnknown()) { + return false; // Not an error + } + + switch (c.GetColorSpace()) { + case ColorSpace::kRGB: + case ColorSpace::kGray: + break; // OK + case ColorSpace::kXYB: + return JXL_FAILURE("XYB ICC not yet implemented"); + default: + return JXL_FAILURE("Invalid CS %u", + static_cast(c.GetColorSpace())); + } + + JXL_RETURN_IF_ERROR(CreateICCHeader(c, &header)); + + std::vector offsets; + // tag count, deferred to later + WriteICCUint32(0, tagtable.size(), &tagtable); + + size_t tag_offset = 0, tag_size = 0; + + CreateICCMlucTag(Description(c), &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("desc", tag_offset, tag_size, &tagtable, &offsets); + + const std::string copyright = + "Copyright 2019 Google LLC, CC-BY-SA 3.0 Unported " + "license(https://creativecommons.org/licenses/by-sa/3.0/legalcode)"; + CreateICCMlucTag(copyright, &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("cprt", tag_offset, tag_size, &tagtable, &offsets); + + // TODO(eustas): isn't it the other way round: gray image has d50 WhitePoint? + if (c.IsGray()) { + float wtpt[3]; + JXL_RETURN_IF_ERROR(CIEXYZFromWhiteCIExy(c.GetWhitePoint(), wtpt)); + JXL_RETURN_IF_ERROR(CreateICCXYZTag(wtpt, &tags)); + } else { + float d50[3] = {0.964203, 1.0, 0.824905}; + JXL_RETURN_IF_ERROR(CreateICCXYZTag(d50, &tags)); + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("wtpt", tag_offset, tag_size, &tagtable, &offsets); + + if (!c.IsGray()) { + // Chromatic adaptation matrix + float chad[9]; + JXL_RETURN_IF_ERROR(CreateICCChadMatrix(c.GetWhitePoint(), chad)); + + const PrimariesCIExy primaries = c.GetPrimaries(); + float m[9]; + JXL_RETURN_IF_ERROR(CreateICCRGBMatrix(primaries.r, primaries.g, + primaries.b, c.GetWhitePoint(), m)); + float r[3] = {m[0], m[3], m[6]}; + float g[3] = {m[1], m[4], m[7]}; + float b[3] = {m[2], m[5], m[8]}; + + JXL_RETURN_IF_ERROR(CreateICCChadTag(chad, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("chad", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(r, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("rXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(g, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("gXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(b, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("bXYZ", tag_offset, tag_size, &tagtable, &offsets); + } + + if (c.tf.IsGamma()) { + float gamma = 1.0 / c.tf.GetGamma(); + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({gamma, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + } else { + switch (c.tf.GetTransferFunction()) { + case TransferFunction::kHLG: + CreateICCCurvCurvTag( + HWY_DYNAMIC_DISPATCH(CreateTableCurve)(4096, ExtraTF::kHLG), &tags); + break; + case TransferFunction::kPQ: + CreateICCCurvCurvTag( + HWY_DYNAMIC_DISPATCH(CreateTableCurve)(4096, ExtraTF::kPQ), &tags); + break; + case TransferFunction::kSRGB: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {2.4, 1.0 / 1.055, 0.055 / 1.055, 1.0 / 12.92, 0.04045}, 3, &tags)); + break; + case TransferFunction::k709: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {1.0 / 0.45, 1.0 / 1.099, 0.099 / 1.099, 1.0 / 4.5, 0.081}, 3, + &tags)); + break; + case TransferFunction::kLinear: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({1.0, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + case TransferFunction::kDCI: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({2.6, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + default: + JXL_ABORT("Unknown TF %d", c.tf.GetTransferFunction()); + } + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + if (c.IsGray()) { + AddToICCTagTable("kTRC", tag_offset, tag_size, &tagtable, &offsets); + } else { + AddToICCTagTable("rTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("gTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("bTRC", tag_offset, tag_size, &tagtable, &offsets); + } + + // Tag count + WriteICCUint32(offsets.size(), 0, &tagtable); + for (size_t i = 0; i < offsets.size(); i++) { + WriteICCUint32(offsets[i] + header.size() + tagtable.size(), 4 + 12 * i + 4, + &tagtable); + } + + // ICC profile size + WriteICCUint32(header.size() + tagtable.size() + tags.size(), 0, &header); + + *icc = header; + icc->append(tagtable); + icc->append(tags); + + // The MD5 checksum must be computed on the profile with profile flags, + // rendering intent, and region of the checksum itself, set to 0. + // TODO(lode): manually verify with a reliable tool that this creates correct + // signature (profile id) for ICC profiles. + PaddedBytes icc_sum = *icc; + memset(icc_sum.data() + 44, 0, 4); + memset(icc_sum.data() + 64, 0, 4); + uint8_t checksum[16]; + ICCComputeMD5(icc_sum, checksum); + + memcpy(icc->data() + 84, checksum, sizeof(checksum)); + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/color_management.h b/third_party/jpeg-xl/lib/jxl/color_management.h new file mode 100644 index 000000000000..ddfa49cf0b5f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_management.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COLOR_MANAGEMENT_H_ +#define LIB_JXL_COLOR_MANAGEMENT_H_ + +// ICC profiles and color space conversions. + +#include +#include + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +enum class ExtraTF { + kNone, + kPQ, + kHLG, + kSRGB, +}; + +Status MaybeCreateProfile(const ColorEncoding& c, + PaddedBytes* JXL_RESTRICT icc); + +Status CIEXYZFromWhiteCIExy(const CIExy& xy, float XYZ[3]); + +} // namespace jxl + +#endif // LIB_JXL_COLOR_MANAGEMENT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/color_management_test.cc b/third_party/jpeg-xl/lib/jxl/color_management_test.cc new file mode 100644 index 000000000000..db65925351a1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_management_test.cc @@ -0,0 +1,246 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/color_management.h" + +#include +#include + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const CIExy& xy) { + return os << "{x=" << xy.x << ", y=" << xy.y << "}"; +} + +std::ostream& operator<<(std::ostream& os, const PrimariesCIExy& primaries) { + return os << "{r=" << primaries.r << ", g=" << primaries.g + << ", b=" << primaries.b << "}"; +} + +namespace { + +using ::testing::ElementsAre; +using ::testing::FloatNear; + +// Small enough to be fast. If changed, must update Generate*. +static constexpr size_t kWidth = 16; + +struct Globals { + // TODO(deymo): Make this a const. + static Globals* GetInstance() { + static Globals ret; + return &ret; + } + + private: + static constexpr size_t kNumThreads = 0; // only have a single row. + + Globals() : pool(kNumThreads) { + in_gray = GenerateGray(); + in_color = GenerateColor(); + out_gray = ImageF(kWidth, 1); + out_color = ImageF(kWidth * 3, 1); + + c_native = ColorEncoding::LinearSRGB(/*is_gray=*/false); + c_gray = ColorEncoding::LinearSRGB(/*is_gray=*/true); + } + + static ImageF GenerateGray() { + ImageF gray(kWidth, 1); + float* JXL_RESTRICT row = gray.Row(0); + // Increasing left to right + for (uint32_t x = 0; x < kWidth; ++x) { + row[x] = x * 1.0f / (kWidth - 1); // [0, 1] + } + return gray; + } + + static ImageF GenerateColor() { + ImageF image(kWidth * 3, 1); + float* JXL_RESTRICT interleaved = image.Row(0); + std::fill(interleaved, interleaved + kWidth * 3, 0.0f); + + // [0, 4): neutral + for (int32_t x = 0; x < 4; ++x) { + interleaved[3 * x + 0] = x * 1.0f / 3; // [0, 1] + interleaved[3 * x + 2] = interleaved[3 * x + 1] = interleaved[3 * x + 0]; + } + + // [4, 13): pure RGB with low/medium/high saturation + for (int32_t c = 0; c < 3; ++c) { + interleaved[3 * (4 + c) + c] = 0.08f + c * 0.01f; + interleaved[3 * (7 + c) + c] = 0.75f + c * 0.01f; + interleaved[3 * (10 + c) + c] = 1.0f; + } + + // [13, 16): impure, not quite saturated RGB + interleaved[3 * 13 + 0] = 0.86f; + interleaved[3 * 13 + 2] = interleaved[3 * 13 + 1] = 0.16f; + interleaved[3 * 14 + 1] = 0.87f; + interleaved[3 * 14 + 2] = interleaved[3 * 14 + 0] = 0.16f; + interleaved[3 * 15 + 2] = 0.88f; + interleaved[3 * 15 + 1] = interleaved[3 * 15 + 0] = 0.16f; + + return image; + } + + public: + ThreadPoolInternal pool; + + // ImageF so we can use VerifyRelativeError; all are interleaved RGB. + ImageF in_gray; + ImageF in_color; + ImageF out_gray; + ImageF out_color; + ColorEncoding c_native; + ColorEncoding c_gray; +}; + +class ColorManagementTest + : public ::testing::TestWithParam { + public: + static void VerifySameFields(const ColorEncoding& c, + const ColorEncoding& c2) { + ASSERT_EQ(c.rendering_intent, c2.rendering_intent); + ASSERT_EQ(c.GetColorSpace(), c2.GetColorSpace()); + ASSERT_EQ(c.white_point, c2.white_point); + if (c.HasPrimaries()) { + ASSERT_EQ(c.primaries, c2.primaries); + } + ASSERT_TRUE(c.tf.IsSame(c2.tf)); + } + + // "Same" pixels after converting g->c_native -> c -> g->c_native. + static void VerifyPixelRoundTrip(const ColorEncoding& c) { + Globals* g = Globals::GetInstance(); + const ColorEncoding& c_native = c.IsGray() ? g->c_gray : g->c_native; + ColorSpaceTransform xform_fwd; + ColorSpaceTransform xform_rev; + ASSERT_TRUE(xform_fwd.Init(c_native, c, kDefaultIntensityTarget, kWidth, + g->pool.NumThreads())); + ASSERT_TRUE(xform_rev.Init(c, c_native, kDefaultIntensityTarget, kWidth, + g->pool.NumThreads())); + + const size_t thread = 0; + const ImageF& in = c.IsGray() ? g->in_gray : g->in_color; + ImageF* JXL_RESTRICT out = c.IsGray() ? &g->out_gray : &g->out_color; + DoColorSpaceTransform(&xform_fwd, thread, in.Row(0), + xform_fwd.BufDst(thread)); + DoColorSpaceTransform(&xform_rev, thread, xform_fwd.BufDst(thread), + out->Row(0)); + +#if JPEGXL_ENABLE_SKCMS + double max_l1 = 7E-4; + double max_rel = 4E-7; +#else + double max_l1 = 5E-5; + // Most are lower; reached 3E-7 with D60 AP0. + double max_rel = 4E-7; +#endif + if (c.IsGray()) max_rel = 2E-5; + VerifyRelativeError(in, *out, max_l1, max_rel); + } +}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(ColorManagementTestInstantiation, + ColorManagementTest, + ::testing::ValuesIn(test::AllEncodings())); + +// Exercises the ColorManagement interface for ALL ColorEncoding synthesizable +// via enums. +TEST_P(ColorManagementTest, VerifyAllProfiles) { + ColorEncoding c = ColorEncodingFromDescriptor(GetParam()); + printf("%s\n", Description(c).c_str()); + + // Can create profile. + ASSERT_TRUE(c.CreateICC()); + + // Can set an equivalent ColorEncoding from the generated ICC profile. + ColorEncoding c3; + ASSERT_TRUE(c3.SetICC(PaddedBytes(c.ICC()))); + VerifySameFields(c, c3); + + VerifyPixelRoundTrip(c); +} + +testing::Matcher CIExyIs(const double x, const double y) { + static constexpr double kMaxError = 1e-4; + return testing::AllOf( + testing::Field(&CIExy::x, testing::DoubleNear(x, kMaxError)), + testing::Field(&CIExy::y, testing::DoubleNear(y, kMaxError))); +} + +testing::Matcher PrimariesAre( + const testing::Matcher& r, const testing::Matcher& g, + const testing::Matcher& b) { + return testing::AllOf(testing::Field(&PrimariesCIExy::r, r), + testing::Field(&PrimariesCIExy::g, g), + testing::Field(&PrimariesCIExy::b, b)); +} + +TEST_F(ColorManagementTest, sRGBChromaticity) { + const ColorEncoding sRGB = ColorEncoding::SRGB(); + EXPECT_THAT(sRGB.GetWhitePoint(), CIExyIs(0.3127, 0.3290)); + EXPECT_THAT(sRGB.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700Chromaticity) { + PaddedBytes icc = ReadTestData("jxl/color_management/sRGB-D2700.icc"); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc))); + + EXPECT_THAT(sRGB_D2700.GetWhitePoint(), CIExyIs(0.45986, 0.41060)); + // The illuminant-relative chromaticities of this profile's primaries are the + // same as for sRGB. It is the PCS-relative chromaticities that would be + // different. + EXPECT_THAT(sRGB_D2700.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700ToSRGB) { + PaddedBytes icc = ReadTestData("jxl/color_management/sRGB-D2700.icc"); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc))); + + ColorSpaceTransform transform; + ASSERT_TRUE(transform.Init(sRGB_D2700, ColorEncoding::SRGB(), + kDefaultIntensityTarget, 1, 1)); + const float sRGB_D2700_values[3] = {0.863, 0.737, 0.490}; + float sRGB_values[3]; + DoColorSpaceTransform(&transform, 0, sRGB_D2700_values, sRGB_values); + EXPECT_THAT(sRGB_values, + ElementsAre(FloatNear(0.914, 1e-3), FloatNear(0.745, 1e-3), + FloatNear(0.601, 1e-3))); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/common.h b/third_party/jpeg-xl/lib/jxl/common.h new file mode 100644 index 000000000000..306e2c4e853c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/common.h @@ -0,0 +1,224 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COMMON_H_ +#define LIB_JXL_COMMON_H_ + +// Shared constants and helper functions. + +#include +#include +#include + +#include // numeric_limits +#include // unique_ptr +#include + +#include "lib/jxl/base/compiler_specific.h" + +#ifndef JXL_HIGH_PRECISION +#define JXL_HIGH_PRECISION 1 +#endif + +namespace jxl { +// Some enums and typedefs used by more than one header file. + +constexpr size_t kBitsPerByte = 8; // more clear than CHAR_BIT + +constexpr inline size_t RoundUpBitsToByteMultiple(size_t bits) { + return (bits + 7) & ~size_t(7); +} + +constexpr inline size_t RoundUpToBlockDim(size_t dim) { + return (dim + 7) & ~size_t(7); +} + +static inline bool JXL_MAYBE_UNUSED SafeAdd(const uint64_t a, const uint64_t b, + uint64_t& sum) { + sum = a + b; + return sum >= a; // no need to check b - either sum >= both or < both. +} + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +constexpr double kPi = 3.14159265358979323846264338327950288; + +// Reasonable default for sRGB, matches common monitors. We map white to this +// many nits (cd/m^2) by default. Butteraugli was tuned for 250 nits, which is +// very close. +static constexpr float kDefaultIntensityTarget = 255; + +template +constexpr T Pi(T multiplier) { + return static_cast(multiplier * kPi); +} + +// Block is the square grid of pixels to which an "energy compaction" +// transformation (e.g. DCT) is applied. Each block has its own AC quantizer. +constexpr size_t kBlockDim = 8; + +constexpr size_t kDCTBlockSize = kBlockDim * kBlockDim; + +// Group is the rectangular grid of blocks that can be decoded in parallel. This +// is different for DC. +// TODO(jon) : signal kDcGroupDimInBlocks and kGroupDim (and make them +// variables), +// allowing powers of two between (say) 64 and 1024 +constexpr size_t kDcGroupDimInBlocks = 256; +constexpr size_t kDcGroupDim = kDcGroupDimInBlocks * kBlockDim; +// 512x512 DC = 4096x4096, enough for a 4K frame (3840x2160) +// (setting it to 256 results in four DC groups of size 256x256, 224x256, +// 256x14, 224x14) +constexpr size_t kGroupDim = 256; +static_assert(kGroupDim % kBlockDim == 0, + "Group dim should be divisible by block dim"); +constexpr size_t kGroupDimInBlocks = kGroupDim / kBlockDim; + +// Maximum number of passes in an image. +constexpr size_t kMaxNumPasses = 11; + +// Maximum number of reference frames. +constexpr size_t kMaxNumReferenceFrames = 3; + +// Dimensions of a frame, in pixels, and other derived dimensions. +// Computed from FrameHeader. +// TODO(veluca): add extra channels. +struct FrameDimensions { + void Set(size_t xsize, size_t ysize, size_t group_size_shift, + size_t max_hshift, size_t max_vshift, bool modular_mode, + size_t upsampling) { + group_dim = (kGroupDim >> 1) << group_size_shift; + static_assert( + kGroupDim == kDcGroupDimInBlocks, + "DC groups (in blocks) and groups (in pixels) have different size"); + xsize_upsampled = xsize; + ysize_upsampled = ysize; + this->xsize = DivCeil(xsize, upsampling); + this->ysize = DivCeil(ysize, upsampling); + xsize_blocks = DivCeil(this->xsize, kBlockDim << max_hshift) << max_hshift; + ysize_blocks = DivCeil(this->ysize, kBlockDim << max_vshift) << max_vshift; + xsize_padded = xsize_blocks * kBlockDim; + ysize_padded = ysize_blocks * kBlockDim; + if (modular_mode) { + // Modular mode doesn't have any padding. + xsize_padded = this->xsize; + ysize_padded = this->ysize; + } + xsize_upsampled_padded = xsize_padded * upsampling; + ysize_upsampled_padded = ysize_padded * upsampling; + xsize_groups = DivCeil(this->xsize, group_dim); + ysize_groups = DivCeil(this->ysize, group_dim); + xsize_dc_groups = DivCeil(xsize_blocks, group_dim); + ysize_dc_groups = DivCeil(ysize_blocks, group_dim); + num_groups = xsize_groups * ysize_groups; + num_dc_groups = xsize_dc_groups * ysize_dc_groups; + } + + // Image size without any upsampling, i.e. original_size / upsampling. + size_t xsize; + size_t ysize; + // Original image size. + size_t xsize_upsampled; + size_t ysize_upsampled; + // Image size after upsampling the padded image. + size_t xsize_upsampled_padded; + size_t ysize_upsampled_padded; + // Image size after padding to a multiple of kBlockDim (if VarDCT mode). + size_t xsize_padded; + size_t ysize_padded; + // Image size in kBlockDim blocks. + size_t xsize_blocks; + size_t ysize_blocks; + // Image size in number of groups. + size_t xsize_groups; + size_t ysize_groups; + // Image size in number of DC groups. + size_t xsize_dc_groups; + size_t ysize_dc_groups; + // Number of AC or DC groups. + size_t num_groups; + size_t num_dc_groups; + // Size of a group. + size_t group_dim; +}; + +// Prior to C++14 (i.e. C++11): provide our own make_unique +#if __cplusplus < 201402L +template +std::unique_ptr make_unique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} +#else +using std::make_unique; +#endif + +template +JXL_INLINE T Clamp1(T val, T low, T hi) { + return val < low ? low : val > hi ? hi : val; +} + +template +JXL_INLINE T ClampToRange(int64_t val) { + return Clamp1(val, std::numeric_limits::min(), + std::numeric_limits::max()); +} + +template +JXL_INLINE T SaturatingMul(int64_t a, int64_t b) { + return ClampToRange(a * b); +} + +template +JXL_INLINE T SaturatingAdd(int64_t a, int64_t b) { + return ClampToRange(a + b); +} + +// Encodes non-negative (X) into (2 * X), negative (-X) into (2 * X - 1) +constexpr uint32_t PackSigned(int32_t value) { + return (static_cast(value) << 1) ^ + ((static_cast(~value) >> 31) - 1); +} + +// Reverse to PackSigned, i.e. UnpackSigned(PackSigned(X)) == X. +constexpr intptr_t UnpackSigned(size_t value) { + return static_cast((value >> 1) ^ (((~value) & 1) - 1)); +} + +// conversion from integer to string. +template +std::string ToString(T n) { + char data[32] = {}; + if (T(0.1) != T(0)) { + // float + snprintf(data, sizeof(data), "%g", static_cast(n)); + } else if (T(-1) > T(0)) { + // unsigned + snprintf(data, sizeof(data), "%llu", static_cast(n)); + } else { + // signed + snprintf(data, sizeof(data), "%lld", static_cast(n)); + } + return data; +} +} // namespace jxl + +#endif // LIB_JXL_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/compressed_dc.cc b/third_party/jpeg-xl/lib/jxl/compressed_dc.cc new file mode 100644 index 000000000000..f1d00e56e551 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/compressed_dc.cc @@ -0,0 +1,321 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/compressed_dc.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc" +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using D = HWY_FULL(float); +using DScalar = HWY_CAPPED(float, 1); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Vec; + +// TODO(veluca): optimize constants. +const float w1 = 0.20345139757231578f; +const float w2 = 0.0334829185968739f; +const float w0 = 1.0f - 4.0f * (w1 + w2); + +template +V MaxWorkaround(V a, V b) { +#if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800 + // Prevents "Do not know how to split the result of this operator" error + return IfThenElse(a > b, a, b); +#else + return Max(a, b); +#endif +} + +template +JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor, + const float* JXL_RESTRICT row_top, + const float* JXL_RESTRICT row, + const float* JXL_RESTRICT row_bottom, + Vec* JXL_RESTRICT mc, + Vec* JXL_RESTRICT sm, + Vec* JXL_RESTRICT gap, size_t x) { + const auto tl = LoadU(d, row_top + x - 1); + const auto tc = Load(d, row_top + x); + const auto tr = LoadU(d, row_top + x + 1); + + const auto ml = LoadU(d, row + x - 1); + *mc = Load(d, row + x); + const auto mr = LoadU(d, row + x + 1); + + const auto bl = LoadU(d, row_bottom + x - 1); + const auto bc = Load(d, row_bottom + x); + const auto br = LoadU(d, row_bottom + x + 1); + + const auto w_center = Set(d, w0); + const auto w_side = Set(d, w1); + const auto w_corner = Set(d, w2); + + const auto corner = tl + tr + bl + br; + const auto side = ml + mr + tc + bc; + *sm = corner * w_corner + side * w_side + *mc * w_center; + + const auto dc_quant = Set(d, dc_factor); + *gap = MaxWorkaround(*gap, Abs((*mc - *sm) / dc_quant)); +} + +template +JXL_INLINE void ComputePixel( + const float* JXL_RESTRICT dc_factors, + const float* JXL_RESTRICT* JXL_RESTRICT rows_top, + const float* JXL_RESTRICT* JXL_RESTRICT rows, + const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom, + float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) { + const D d; + auto mc_x = Undefined(d); + auto mc_y = Undefined(d); + auto mc_b = Undefined(d); + auto sm_x = Undefined(d); + auto sm_y = Undefined(d); + auto sm_b = Undefined(d); + auto gap = Set(d, 0.5f); + ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0], + &mc_x, &sm_x, &gap, x); + ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1], + &mc_y, &sm_y, &gap, x); + ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2], + &mc_b, &sm_b, &gap, x); + auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f)); + factor = ZeroIfNegative(factor); + + auto out = MulAdd(sm_x - mc_x, factor, mc_x); + Store(out, d, out_rows[0] + x); + out = MulAdd(sm_y - mc_y, factor, mc_y); + Store(out, d, out_rows[1] + x); + out = MulAdd(sm_b - mc_b, factor, mc_b); + Store(out, d, out_rows[2] + x); +} + +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + const size_t xsize = dc->xsize(); + const size_t ysize = dc->ysize(); + if (ysize <= 2 || xsize <= 2) return; + + // TODO(veluca): use tile-based processing? + // TODO(veluca): decide if changes to the y channel should be propagated to + // the x and b channels through color correlation. + JXL_ASSERT(w1 + w2 < 0.25f); + + PROFILER_FUNC; + + Image3F smoothed(xsize, ysize); + // Fill in borders that the loop below will not. First and last are unused. + for (size_t c = 0; c < 3; c++) { + for (size_t y : {size_t(0), ysize - 1}) { + memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y), + xsize * sizeof(float)); + } + } + auto process_row = [&](int y, int /*thread*/) { + const float* JXL_RESTRICT rows_top[3]{ + dc->ConstPlaneRow(0, y - 1), + dc->ConstPlaneRow(1, y - 1), + dc->ConstPlaneRow(2, y - 1), + }; + const float* JXL_RESTRICT rows[3] = { + dc->ConstPlaneRow(0, y), + dc->ConstPlaneRow(1, y), + dc->ConstPlaneRow(2, y), + }; + const float* JXL_RESTRICT rows_bottom[3] = { + dc->ConstPlaneRow(0, y + 1), + dc->ConstPlaneRow(1, y + 1), + dc->ConstPlaneRow(2, y + 1), + }; + float* JXL_RESTRICT rows_out[3] = { + smoothed.PlaneRow(0, y), + smoothed.PlaneRow(1, y), + smoothed.PlaneRow(2, y), + }; + for (size_t x : {size_t(0), xsize - 1}) { + for (size_t c = 0; c < 3; c++) { + rows_out[c][x] = rows[c][x]; + } + } + + size_t x = 1; + // First pixels + const size_t N = Lanes(D()); + for (; x < std::min(N, xsize - 1); x++) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + // Full vectors. + for (; x + N <= xsize - 1; x += N) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, x); + } + // Last pixels. + for (; x < xsize - 1; x++) { + ComputePixel(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + }; + RunOnPool(pool, 1, ysize - 1, ThreadPool::SkipInit(), process_row, + "DCSmoothingRow"); + dc->Swap(smoothed); +} + +// DC dequantization. +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + if (chroma_subsampling.Is444()) { + const auto fac_x = Set(df, dc_factors[0] * mul); + const auto fac_y = Set(df, dc_factors[1] * mul); + const auto fac_b = Set(df, dc_factors[2] * mul); + const auto cfl_fac_x = Set(df, cfl_factors[0]); + const auto cfl_fac_b = Set(df, cfl_factors[2]); + for (size_t y = 0; y < r.ysize(); y++) { + float* dec_row_x = r.PlaneRow(dc, 0, y); + float* dec_row_y = r.PlaneRow(dc, 1, y); + float* dec_row_b = r.PlaneRow(dc, 2, y); + const int32_t* quant_row_x = in.channel[1].plane.Row(y); + const int32_t* quant_row_y = in.channel[0].plane.Row(y); + const int32_t* quant_row_b = in.channel[2].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x += Lanes(di)) { + const auto in_q_x = Load(di, quant_row_x + x); + const auto in_q_y = Load(di, quant_row_y + x); + const auto in_q_b = Load(di, quant_row_b + x); + const auto in_x = ConvertTo(df, in_q_x) * fac_x; + const auto in_y = ConvertTo(df, in_q_y) * fac_y; + const auto in_b = ConvertTo(df, in_q_b) * fac_b; + Store(in_y, df, dec_row_y + x); + Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x); + Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x); + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect(r.x0() >> chroma_subsampling.HShift(c), + r.y0() >> chroma_subsampling.VShift(c), + r.xsize() >> chroma_subsampling.HShift(c), + r.ysize() >> chroma_subsampling.VShift(c)); + const auto fac = Set(df, dc_factors[c] * mul); + const Channel& ch = in.channel[c < 2 ? c ^ 1 : c]; + for (size_t y = 0; y < rect.ysize(); y++) { + const int32_t* quant_row = ch.plane.Row(y); + float* row = rect.PlaneRow(dc, c, y); + for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) { + const auto in_q = Load(di, quant_row + x); + const auto in = ConvertTo(df, in_q) * fac; + Store(in, df, row + x); + } + } + } + } + if (bctx.num_dc_ctxs <= 1) { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row = r.Row(quant_dc, y); + memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize()); + } + } else { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row_val = r.Row(quant_dc, y); + const int32_t* quant_row_x = + in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0)); + const int32_t* quant_row_y = + in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1)); + const int32_t* quant_row_b = + in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2)); + for (size_t x = 0; x < r.xsize(); x++) { + int bucket_x = 0, bucket_y = 0, bucket_b = 0; + for (int t : bctx.dc_thresholds[0]) { + if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++; + } + for (int t : bctx.dc_thresholds[1]) { + if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++; + } + for (int t : bctx.dc_thresholds[2]) { + if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++; + } + int bucket = bucket_x; + bucket *= bctx.dc_thresholds[2].size() + 1; + bucket += bucket_b; + bucket *= bctx.dc_thresholds[1].size() + 1; + bucket += bucket_y; + qdc_row_val[x] = bucket; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(DequantDC); +HWY_EXPORT(AdaptiveDCSmoothing); +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(dc_factors, dc, pool); +} + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + return HWY_DYNAMIC_DISPATCH(DequantDC)(r, dc, quant_dc, in, dc_factors, mul, + cfl_factors, chroma_subsampling, bctx); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/compressed_dc.h b/third_party/jpeg-xl/lib/jxl/compressed_dc.h new file mode 100644 index 000000000000..e3d3ad40df04 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/compressed_dc.h @@ -0,0 +1,43 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_COMPRESSED_DC_H_ +#define LIB_JXL_COMPRESSED_DC_H_ + +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/modular_image.h" + +// DC handling functions: encoding and decoding of DC to and from bitstream, and +// related function to initialize the per-group decoder cache. + +namespace jxl { + +// Smooth DC in already-smooth areas, to counteract banding. +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool); + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx); + +} // namespace jxl + +#endif // LIB_JXL_COMPRESSED_DC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc b/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc new file mode 100644 index 000000000000..eb4cf250a730 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/compressed_image_test.cc @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +// Verifies ReconOpsinImage reconstructs with low butteraugli distance. +void RunRGBRoundTrip(float distance, bool fast) { + ThreadPoolInternal pool(4); + + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + JXL_CHECK(SetFromBytes(Span(orig), &io, &pool)); + // This test can only handle a single group. + io.ShrinkTo(std::min(io.xsize(), kGroupDim), std::min(io.ysize(), kGroupDim)); + + Image3F opsin(io.xsize(), io.ysize()); + (void)ToXYB(io.Main(), &pool, &opsin); + opsin = PadImageToMultiple(opsin, kBlockDim); + GaborishInverse(&opsin, 1.0f, &pool); + + CompressParams cparams; + cparams.butteraugli_distance = distance; + if (fast) { + cparams.speed_tier = SpeedTier::kWombat; + } + + JXL_CHECK(io.metadata.size.Set(opsin.xsize(), opsin.ysize())); + FrameHeader frame_header(&io.metadata); + frame_header.color_transform = ColorTransform::kXYB; + frame_header.loop_filter.epf_iters = 0; + + // Use custom weights for Gaborish. + frame_header.loop_filter.gab_custom = true; + frame_header.loop_filter.gab_x_weight1 = 0.11501538179658321f; + frame_header.loop_filter.gab_x_weight2 = 0.089979079587015454f; + frame_header.loop_filter.gab_y_weight1 = 0.11501538179658321f; + frame_header.loop_filter.gab_y_weight2 = 0.089979079587015454f; + frame_header.loop_filter.gab_b_weight1 = 0.11501538179658321f; + frame_header.loop_filter.gab_b_weight2 = 0.089979079587015454f; + + PassesEncoderState enc_state; + JXL_CHECK(InitializePassesSharedState(frame_header, &enc_state.shared)); + + enc_state.shared.quantizer.SetQuant(4.0f, 4.0f, + &enc_state.shared.raw_quant_field); + enc_state.shared.ac_strategy.FillDCT8(); + enc_state.cparams = cparams; + ZeroFillImage(&enc_state.shared.epf_sharpness); + CodecInOut io1; + io1.Main() = RoundtripImage(opsin, &enc_state, &pool); + io1.metadata.m.color_encoding = io1.Main().c_current(); + + EXPECT_LE(ButteraugliDistance(io, io1, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 1.2); +} + +TEST(CompressedImageTest, RGBRoundTrip_1) { RunRGBRoundTrip(1.0, false); } + +TEST(CompressedImageTest, RGBRoundTrip_1_fast) { RunRGBRoundTrip(1.0, true); } + +TEST(CompressedImageTest, RGBRoundTrip_2) { RunRGBRoundTrip(2.0, false); } + +TEST(CompressedImageTest, RGBRoundTrip_2_fast) { RunRGBRoundTrip(2.0, true); } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/convolve-inl.h b/third_party/jpeg-xl/lib/jxl/convolve-inl.h new file mode 100644 index 000000000000..8977156ad4b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve-inl.h @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_CONVOLVE_INL_H_ +#undef LIB_JXL_CONVOLVE_INL_H_ +#else +#define LIB_JXL_CONVOLVE_INL_H_ +#endif + +#include + +#include "lib/jxl/base/status.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightBytes; +#endif +using hwy::HWY_NAMESPACE::Vec; + +// Synthesizes left/right neighbors from a vector of center pixels. +class Neighbors { + public: + using D = HWY_CAPPED(float, 16); + using V = Vec; + + // Returns l[i] == c[Mirror(i - 1)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // ONML'KJII +#elif HWY_TARGET == HWY_SCALAR + return c; // Same (the first mirrored value is the last valid one) +#else // 128 bit + // c = LKJI +#if HWY_ARCH_X86 + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))}; // KJII +#else + const D d; + // TODO(deymo): Figure out if this can be optimized using a single vsri + // instruction to convert LKJI to KJII. + HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2}; // KJII + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 2)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // NMLK'JIIJ +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_ARCH_X86 + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))}; // JIIJ +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1}; // JIIJ + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 3)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // MLKJ'IIJK +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_ARCH_X86 + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))}; // IIJK +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0}; // IIJK + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_CONVOLVE_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/convolve.cc b/third_party/jpeg-xl/lib/jxl/convolve.cc new file mode 100644 index 000000000000..a2c9051eeeb7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve.cc @@ -0,0 +1,1341 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve.cc" +#include +#include + +#include "lib/jxl/common.h" // RoundUpTo +#include "lib/jxl/convolve-inl.h" +#include "lib/jxl/image_ops.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +// Weighted sum of 1x5 pixels around ix, iy with [wx2 wx1 wx0 wx1 wx2]. +template +static float WeightedSumBorder(const ImageF& in, const WrapY wrap_y, + const int64_t ix, const int64_t iy, + const size_t xsize, const size_t ysize, + const float wx0, const float wx1, + const float wx2) { + const WrapMirror wrap_x; + const float* JXL_RESTRICT row = in.ConstRow(wrap_y(iy, ysize)); + const float in_m2 = row[wrap_x(ix - 2, xsize)]; + const float in_p2 = row[wrap_x(ix + 2, xsize)]; + const float in_m1 = row[wrap_x(ix - 1, xsize)]; + const float in_p1 = row[wrap_x(ix + 1, xsize)]; + const float in_00 = row[ix]; + const float sum_2 = wx2 * (in_m2 + in_p2); + const float sum_1 = wx1 * (in_m1 + in_p1); + const float sum_0 = wx0 * in_00; + return sum_2 + sum_1 + sum_0; +} + +template +static V WeightedSum(const ImageF& in, const WrapY wrap_y, const size_t ix, + const int64_t iy, const size_t ysize, const V wx0, + const V wx1, const V wx2) { + const HWY_FULL(float) d; + const float* JXL_RESTRICT center = in.ConstRow(wrap_y(iy, ysize)) + ix; + const auto in_m2 = LoadU(d, center - 2); + const auto in_p2 = LoadU(d, center + 2); + const auto in_m1 = LoadU(d, center - 1); + const auto in_p1 = LoadU(d, center + 1); + const auto in_00 = Load(d, center); + const auto sum_2 = wx2 * (in_m2 + in_p2); + const auto sum_1 = wx1 * (in_m1 + in_p1); + const auto sum_0 = wx0 * in_00; + return sum_2 + sum_1 + sum_0; +} + +// Produces result for one pixel +template +float Symmetric5Border(const ImageF& in, const Rect& rect, const int64_t ix, + const int64_t iy, const WeightsSymmetric5& weights) { + const float w0 = weights.c[0]; + const float w1 = weights.r[0]; + const float w2 = weights.R[0]; + const float w4 = weights.d[0]; + const float w5 = weights.L[0]; + const float w8 = weights.D[0]; + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + float sum0 = WeightedSumBorder(in, wrap_y, ix, iy, xsize, ysize, w0, w1, w2); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 2, xsize, ysize, w2, w5, w8); + float sum1 = + WeightedSumBorder(in, wrap_y, ix, iy + 2, xsize, ysize, w2, w5, w8); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 1, xsize, ysize, w1, w4, w5); + sum1 += WeightedSumBorder(in, wrap_y, ix, iy + 1, xsize, ysize, w1, w4, w5); + + return sum0 + sum1; +} + +// Produces result for one vector's worth of pixels +template +static void Symmetric5Interior(const ImageF& in, const Rect& rect, + const int64_t ix, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const HWY_FULL(float) d; + + const auto w0 = LoadDup128(d, weights.c); + const auto w1 = LoadDup128(d, weights.r); + const auto w2 = LoadDup128(d, weights.R); + const auto w4 = LoadDup128(d, weights.d); + const auto w5 = LoadDup128(d, weights.L); + const auto w8 = LoadDup128(d, weights.D); + + const size_t ysize = rect.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + auto sum0 = WeightedSum(in, wrap_y, ix, iy, ysize, w0, w1, w2); + + sum0 += WeightedSum(in, wrap_y, ix, iy - 2, ysize, w2, w5, w8); + auto sum1 = WeightedSum(in, wrap_y, ix, iy + 2, ysize, w2, w5, w8); + + sum0 += WeightedSum(in, wrap_y, ix, iy - 1, ysize, w1, w4, w5); + sum1 += WeightedSum(in, wrap_y, ix, iy + 1, ysize, w1, w4, w5); + + Store(sum0 + sum1, d, row_out + ix); +} + +template +static void Symmetric5Row(const ImageF& in, const Rect& rect, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const int64_t kRadius = 2; + const size_t xsize = rect.xsize(); + + size_t ix = 0; + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const size_t aligned_x = RoundUpTo(kRadius, N); + for (; ix < std::min(aligned_x, xsize); ++ix) { + row_out[ix] = Symmetric5Border(in, rect, ix, iy, weights); + } + for (; ix + N + kRadius <= xsize; ix += N) { + Symmetric5Interior(in, rect, ix, iy, weights, row_out); + } + for (; ix < xsize; ++ix) { + row_out[ix] = Symmetric5Border(in, rect, ix, iy, weights); + } +} + +static JXL_NOINLINE void Symmetric5BorderRow(const ImageF& in, const Rect& rect, + const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + return Symmetric5Row(in, rect, iy, weights, row_out); +} + +#if HWY_TARGET != HWY_SCALAR + +// Returns indices for SetTableIndices such that TableLookupLanes on the +// rightmost unaligned vector (rightmost sample in its most-significant lane) +// returns the mirrored values, with the mirror outside the last valid sample. +static inline const int32_t* MirrorLanes(const size_t mod) { + const HWY_CAPPED(float, 16) d; + constexpr size_t kN = MaxLanes(d); + + // For mod = `image width mod 16` 0..15: + // last full vec mirrored (mem order) loadedVec mirrorVec idxVec + // 0123456789abcdef| fedcba9876543210 fed..210 012..def 012..def + // 0123456789abcdef|0 0fedcba98765432 0fe..321 234..f00 123..eff + // 0123456789abcdef|01 10fedcba987654 10f..432 456..110 234..ffe + // 0123456789abcdef|012 210fedcba9876 210..543 67..2210 34..ffed + // 0123456789abcdef|0123 3210fedcba98 321..654 8..33210 4..ffedc + // 0123456789abcdef|01234 43210fedcba + // 0123456789abcdef|012345 543210fedc + // 0123456789abcdef|0123456 6543210fe + // 0123456789abcdef|01234567 76543210 + // 0123456789abcdef|012345678 8765432 + // 0123456789abcdef|0123456789 987654 + // 0123456789abcdef|0123456789A A9876 + // 0123456789abcdef|0123456789AB BA98 + // 0123456789abcdef|0123456789ABC CBA + // 0123456789abcdef|0123456789ABCD DC + // 0123456789abcdef|0123456789ABCDE E EDC..10f EED..210 ffe..321 +#if HWY_CAP_GE512 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, // + 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; +#elif HWY_CAP_GE256 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 7, // + 6, 5, 4, 3, 2, 1, 0}; +#else // 128-bit + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3, // + 2, 1, 0}; +#endif + return idx_lanes + kN - 1 - mod; +} + +#endif // HWY_TARGET != HWY_SCALAR + +namespace strategy { + +struct StrategyBase { + using D = HWY_CAPPED(float, 16); + using V = Vec; +}; + +// 3x3 convolution by symmetric kernel with a single scan through the input. +class Symmetric3 : public StrategyBase { + public: + static constexpr int64_t kRadius = 1; + + // Only accesses pixels in [0, xsize). + template + static JXL_INLINE void ConvolveRow(const float* const JXL_RESTRICT row_m, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, + const WeightsSymmetric3& weights, + float* const JXL_RESTRICT row_out) { + const D d; + // t, m, b = top, middle, bottom row; + const float* const JXL_RESTRICT row_t = wrap_row(row_m - stride, stride); + const float* const JXL_RESTRICT row_b = wrap_row(row_m + stride, stride); + + // Must load in advance - compiler doesn't understand LoadDup128 and + // schedules them too late. + const V w0 = LoadDup128(d, weights.c); + const V w1 = LoadDup128(d, weights.r); + const V w2 = LoadDup128(d, weights.d); + + // l, c, r = left, center, right. Leftmost vector: need FirstL1. + { + const V tc = LoadU(d, row_t + 0); + const V mc = LoadU(d, row_m + 0); + const V bc = LoadU(d, row_b + 0); + const V tl = Neighbors::FirstL1(tc); + const V tr = LoadU(d, row_t + 0 + 1); + const V ml = Neighbors::FirstL1(mc); + const V mr = LoadU(d, row_m + 0 + 1); + const V bl = Neighbors::FirstL1(bc); + const V br = LoadU(d, row_b + 0 + 1); + const V conv = + WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + 0); + } + + // Loop as long as we can load enough new values: + const size_t N = Lanes(d); + size_t x = N; + for (; x + N + kRadius <= xsize; x += N) { + const auto conv = ConvolveValid(row_t, row_m, row_b, x, w0, w1, w2); + Store(conv, d, row_out + x); + } + + // For final (partial) vector: + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + + V tr, mr, br; +#if HWY_TARGET == HWY_SCALAR + tr = tc; // Single-lane => mirrored right neighbor = center value. + mr = mc; + br = bc; +#else + if (kSizeModN == 0) { + // The above loop didn't handle the last vector because it needs an + // additional right neighbor (generated via mirroring). + auto mirror = SetTableIndices(d, MirrorLanes(N - 1)); + tr = TableLookupLanes(tc, mirror); + mr = TableLookupLanes(mc, mirror); + br = TableLookupLanes(bc, mirror); + } else { + auto mirror = SetTableIndices(d, MirrorLanes((xsize % N) - 1)); + // Loads last valid value into uppermost lane and mirrors. + tr = TableLookupLanes(LoadU(d, row_t + xsize - N), mirror); + mr = TableLookupLanes(LoadU(d, row_m + xsize - N), mirror); + br = TableLookupLanes(LoadU(d, row_b + xsize - N), mirror); + } +#endif + + const V tl = LoadU(d, row_t + x - 1); + const V ml = LoadU(d, row_m + x - 1); + const V bl = LoadU(d, row_b + x - 1); + const V conv = WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + x); + } + + private: + // Returns sum{x_i * w_i}. + template + static JXL_INLINE V WeightedSum(const V tl, const V tc, const V tr, + const V ml, const V mc, const V mr, + const V bl, const V bc, const V br, + const V w0, const V w1, const V w2) { + const V sum_tb = tc + bc; + + // Faster than 5 mul + 4 FMA. + const V mul0 = mc * w0; + const V sum_lr = ml + mr; + + const V x1 = sum_tb + sum_lr; + const V mul1 = MulAdd(x1, w1, mul0); + + const V sum_t2 = tl + tr; + const V sum_b2 = bl + br; + const V x2 = sum_t2 + sum_b2; + const V mul2 = MulAdd(x2, w2, mul1); + return mul2; + } + + static JXL_INLINE V ConvolveValid(const float* JXL_RESTRICT row_t, + const float* JXL_RESTRICT row_m, + const float* JXL_RESTRICT row_b, + const int64_t x, const V w0, const V w1, + const V w2) { + const D d; + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + const V tl = LoadU(d, row_t + x - 1); + const V tr = LoadU(d, row_t + x + 1); + const V ml = LoadU(d, row_m + x - 1); + const V mr = LoadU(d, row_m + x + 1); + const V bl = LoadU(d, row_b + x - 1); + const V br = LoadU(d, row_b + x + 1); + return WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + } +}; + +// 5x5 convolution by separable kernel with a single scan through the input. +// This is more cache-efficient than separate horizontal/vertical passes, and +// possibly faster (given enough registers) than tiling and/or transposing. +// +// Overview: imagine a 5x5 window around a central pixel. First convolve the +// rows by multiplying the pixels with the corresponding weights from +// WeightsSeparable5.horz[abs(x_offset) * 4]. Then multiply each of these +// intermediate results by the corresponding vertical weight, i.e. +// vert[abs(y_offset) * 4]. Finally, store the sum of these values as the +// convolution result at the position of the central pixel in the output. +// +// Each of these operations uses SIMD vectors. The central pixel and most +// importantly the output are aligned, so neighnoring pixels (e.g. x_offset=1) +// require unaligned loads. Because weights are supplied in identical groups of +// 4, we can use LoadDup128 to load them (slightly faster). +// +// Uses mirrored boundary handling. Until x >= kRadius, the horizontal +// convolution uses Neighbors class to shuffle vectors as if each of its lanes +// had been loaded from the mirrored offset. Similarly, the last full vector to +// write uses mirroring. In the case of scalar vectors, Neighbors is not usable +// and the value is loaded directly. Otherwise, the number of valid pixels +// modulo the vector size enables a small optimization: for smaller offsets, +// a non-mirrored load is sufficient. +class Separable5 : public StrategyBase { + public: + static constexpr int64_t kRadius = 2; + + template + static JXL_INLINE void ConvolveRow(const float* const JXL_RESTRICT row_m, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, + const WeightsSeparable5& weights, + float* const JXL_RESTRICT row_out) { + const D d; + const int64_t neg_stride = -stride; // allows LEA addressing. + const float* const JXL_RESTRICT row_t2 = + wrap_row(row_m + 2 * neg_stride, stride); + const float* const JXL_RESTRICT row_t1 = + wrap_row(row_m + 1 * neg_stride, stride); + const float* const JXL_RESTRICT row_b1 = + wrap_row(row_m + 1 * stride, stride); + const float* const JXL_RESTRICT row_b2 = + wrap_row(row_m + 2 * stride, stride); + + const V wh0 = LoadDup128(d, weights.horz + 0 * 4); + const V wh1 = LoadDup128(d, weights.horz + 1 * 4); + const V wh2 = LoadDup128(d, weights.horz + 2 * 4); + const V wv0 = LoadDup128(d, weights.vert + 0 * 4); + const V wv1 = LoadDup128(d, weights.vert + 1 * 4); + const V wv2 = LoadDup128(d, weights.vert + 2 * 4); + + size_t x = 0; + + // More than one iteration for scalars. + for (; x < kRadius; x += Lanes(d)) { + const V conv0 = HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2) * wv0; + + const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Main loop: load inputs without padding + for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) { + const V conv0 = HorzConvolve(row_m + x, wh0, wh1, wh2) * wv0; + + const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2); + const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2); + const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Last full vector to write (the above loop handled mod >= kRadius) +#if HWY_TARGET == HWY_SCALAR + while (x < xsize) { +#else + if (kSizeModN < kRadius) { +#endif + const V conv0 = + HorzConvolveLast(row_m, x, xsize, wh0, wh1, wh2) * wv0; + + const V conv1t = + HorzConvolveLast(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = + HorzConvolveLast(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = + HorzConvolveLast(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = + HorzConvolveLast(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + Store(conv2, d, row_out + x); + x += Lanes(d); + } + + // If mod = 0, the above vector was the last. + if (kSizeModN != 0) { + for (; x < xsize; ++x) { + float mul = 0.0f; + for (int64_t dy = -kRadius; dy <= kRadius; ++dy) { + const float wy = weights.vert[std::abs(dy) * 4]; + const float* clamped_row = wrap_row(row_m + dy * stride, stride); + for (int64_t dx = -kRadius; dx <= kRadius; ++dx) { + const float wx = weights.horz[std::abs(dx) * 4]; + const int64_t clamped_x = Mirror(x + dx, xsize); + mul += clamped_row[clamped_x] * wx * wy; + } + } + row_out[x] = mul; + } + } + } + + private: + // Same as HorzConvolve for the first/last vector in a row. + static JXL_INLINE V HorzConvolveFirst(const float* const JXL_RESTRICT row, + const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = c * wh0; + +#if HWY_TARGET == HWY_SCALAR + const V l1 = LoadU(d, row + Mirror(x - 1, xsize)); + const V l2 = LoadU(d, row + Mirror(x - 2, xsize)); +#else + (void)xsize; + const V l1 = Neighbors::FirstL1(c); + const V l2 = Neighbors::FirstL2(c); +#endif + + const V r1 = LoadU(d, row + x + 1); + const V r2 = LoadU(d, row + x + 2); + + const V mul1 = MulAdd(l1 + r1, wh1, mul0); + const V mul2 = MulAdd(l2 + r2, wh2, mul1); + return mul2; + } + + template + static JXL_INLINE V HorzConvolveLast(const float* const JXL_RESTRICT row, + const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = c * wh0; + + const V l1 = LoadU(d, row + x - 1); + const V l2 = LoadU(d, row + x - 2); + + V r1, r2; +#if HWY_TARGET == HWY_SCALAR + r1 = LoadU(d, row + Mirror(x + 1, xsize)); + r2 = LoadU(d, row + Mirror(x + 2, xsize)); +#else + const size_t N = Lanes(d); + if (kSizeModN == 0) { + r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2))); + r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1))); + } else { // == 1 + const auto last = LoadU(d, row + xsize - N); + r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r1 = last; + } +#endif + + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = l1 + r1; + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = l2 + r2; + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } + + // Requires kRadius valid pixels before/after pos. + static JXL_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos, + const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, pos); + const V mul0 = c * wh0; + + // Loading anew is faster than combining vectors. + const V l1 = LoadU(d, pos - 1); + const V r1 = LoadU(d, pos + 1); + const V l2 = LoadU(d, pos - 2); + const V r2 = LoadU(d, pos + 2); + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = l1 + r1; + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = l2 + r2; + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } +}; // namespace strategy + +// 7x7 convolution by separable kernel with a single scan through the input. +// Extended version of Separable5, see documentation there. +class Separable7 : public StrategyBase { + public: + static constexpr int64_t kRadius = 3; + + template + static JXL_INLINE void ConvolveRow(const float* const JXL_RESTRICT row_m, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, + const WeightsSeparable7& weights, + float* const JXL_RESTRICT row_out) { + const D d; + const int64_t neg_stride = -stride; // allows LEA addressing. + const float* const JXL_RESTRICT row_t3 = + wrap_row(row_m + 3 * neg_stride, stride); + const float* const JXL_RESTRICT row_t2 = + wrap_row(row_m + 2 * neg_stride, stride); + const float* const JXL_RESTRICT row_t1 = + wrap_row(row_m + 1 * neg_stride, stride); + const float* const JXL_RESTRICT row_b1 = + wrap_row(row_m + 1 * stride, stride); + const float* const JXL_RESTRICT row_b2 = + wrap_row(row_m + 2 * stride, stride); + const float* const JXL_RESTRICT row_b3 = + wrap_row(row_m + 3 * stride, stride); + + const V wh0 = LoadDup128(d, weights.horz + 0 * 4); + const V wh1 = LoadDup128(d, weights.horz + 1 * 4); + const V wh2 = LoadDup128(d, weights.horz + 2 * 4); + const V wh3 = LoadDup128(d, weights.horz + 3 * 4); + const V wv0 = LoadDup128(d, weights.vert + 0 * 4); + const V wv1 = LoadDup128(d, weights.vert + 1 * 4); + const V wv2 = LoadDup128(d, weights.vert + 2 * 4); + const V wv3 = LoadDup128(d, weights.vert + 3 * 4); + + size_t x = 0; + + // More than one iteration for scalars. + for (; x < kRadius; x += Lanes(d)) { + const V conv0 = + HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2, wh3) * wv0; + + const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + + const V conv3t = HorzConvolveFirst(row_t3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3b = HorzConvolveFirst(row_b3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(conv3t + conv3b, wv3, conv2); + + Store(conv3, d, row_out + x); + } + + // Main loop: load inputs without padding + for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) { + const V conv0 = HorzConvolve(row_m + x, wh0, wh1, wh2, wh3) * wv0; + + const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2, wh3); + const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2, wh3); + const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + + const V conv3t = HorzConvolve(row_t3 + x, wh0, wh1, wh2, wh3); + const V conv3b = HorzConvolve(row_b3 + x, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(conv3t + conv3b, wv3, conv2); + + Store(conv3, d, row_out + x); + } + + // Last full vector to write (the above loop handled mod >= kRadius) +#if HWY_TARGET == HWY_SCALAR + while (x < xsize) { +#else + if (kSizeModN < kRadius) { +#endif + const V conv0 = + HorzConvolveLast(row_m, x, xsize, wh0, wh1, wh2, wh3) * + wv0; + + const V conv1t = + HorzConvolveLast(row_t1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1b = + HorzConvolveLast(row_b1, x, xsize, wh0, wh1, wh2, wh3); + const V conv1 = MulAdd(conv1t + conv1b, wv1, conv0); + + const V conv2t = + HorzConvolveLast(row_t2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2b = + HorzConvolveLast(row_b2, x, xsize, wh0, wh1, wh2, wh3); + const V conv2 = MulAdd(conv2t + conv2b, wv2, conv1); + + const V conv3t = + HorzConvolveLast(row_t3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3b = + HorzConvolveLast(row_b3, x, xsize, wh0, wh1, wh2, wh3); + const V conv3 = MulAdd(conv3t + conv3b, wv3, conv2); + + Store(conv3, d, row_out + x); + x += Lanes(d); + } + + // If mod = 0, the above vector was the last. + if (kSizeModN != 0) { + for (; x < xsize; ++x) { + float mul = 0.0f; + for (int64_t dy = -kRadius; dy <= kRadius; ++dy) { + const float wy = weights.vert[std::abs(dy) * 4]; + const float* clamped_row = wrap_row(row_m + dy * stride, stride); + for (int64_t dx = -kRadius; dx <= kRadius; ++dx) { + const float wx = weights.horz[std::abs(dx) * 4]; + const int64_t clamped_x = Mirror(x + dx, xsize); + mul += clamped_row[clamped_x] * wx * wy; + } + } + row_out[x] = mul; + } + } + } + + private: + // Same as HorzConvolve for the first/last vector in a row. + static JXL_INLINE V HorzConvolveFirst(const float* const JXL_RESTRICT row, + const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2, + const V wh3) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = c * wh0; + +#if HWY_TARGET == HWY_SCALAR + const V l1 = LoadU(d, row + Mirror(x - 1, xsize)); + const V l2 = LoadU(d, row + Mirror(x - 2, xsize)); + const V l3 = LoadU(d, row + Mirror(x - 3, xsize)); +#else + (void)xsize; + const V l1 = Neighbors::FirstL1(c); + const V l2 = Neighbors::FirstL2(c); + const V l3 = Neighbors::FirstL3(c); +#endif + + const V r1 = LoadU(d, row + x + 1); + const V r2 = LoadU(d, row + x + 2); + const V r3 = LoadU(d, row + x + 3); + + const V mul1 = MulAdd(l1 + r1, wh1, mul0); + const V mul2 = MulAdd(l2 + r2, wh2, mul1); + const V mul3 = MulAdd(l3 + r3, wh3, mul2); + return mul3; + } + + template + static JXL_INLINE V HorzConvolveLast(const float* const JXL_RESTRICT row, + const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2, + const V wh3) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = c * wh0; + + const V l1 = LoadU(d, row + x - 1); + const V l2 = LoadU(d, row + x - 2); + const V l3 = LoadU(d, row + x - 3); + + V r1, r2, r3; +#if HWY_TARGET == HWY_SCALAR + r1 = LoadU(d, row + Mirror(x + 1, xsize)); + r2 = LoadU(d, row + Mirror(x + 2, xsize)); + r3 = LoadU(d, row + Mirror(x + 3, xsize)); +#else + const size_t N = Lanes(d); + if (kSizeModN == 0) { + r3 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 3))); + r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2))); + r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1))); + } else if (kSizeModN == 1) { + const auto last = LoadU(d, row + xsize - N); + r3 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 2))); + r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r1 = last; + } else /* kSizeModN >= 2 */ { + const auto last = LoadU(d, row + xsize - N); + r3 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r2 = last; + r1 = LoadU(d, row + x + 1); + } +#endif + + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = l1 + r1; + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = l2 + r2; + const V mul2 = MulAdd(sum2, wh2, mul1); + const V sum3 = l3 + r3; + const V mul3 = MulAdd(sum3, wh3, mul2); + return mul3; + } + + // Returns one vector of horizontal convolution results; lane i is the result + // for pixel pos + i. This is the fast path for interior pixels, i.e. kRadius + // valid pixels before/after pos. + static JXL_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos, + const V wh0, const V wh1, const V wh2, + const V wh3) { + const D d; + const V c = LoadU(d, pos); + const V mul0 = c * wh0; + + // TODO(janwas): better to Combine + const V l1 = LoadU(d, pos - 1); + const V r1 = LoadU(d, pos + 1); + const V l2 = LoadU(d, pos - 2); + const V r2 = LoadU(d, pos + 2); + const V l3 = LoadU(d, pos - 3); + const V r3 = LoadU(d, pos + 3); + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = l1 + r1; + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = l2 + r2; + const V mul2 = MulAdd(sum2, wh2, mul1); + const V sum3 = l3 + r3; + const V mul3 = MulAdd(sum3, wh3, mul2); + return mul3; + } +}; // namespace HWY_NAMESPACE + +} // namespace strategy + +// Single entry point for convolution. +// "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it. +template +class ConvolveT { + static constexpr int64_t kRadius = Strategy::kRadius; + using Simd = HWY_CAPPED(float, 16); + + public: + static size_t MinWidth() { +#if HWY_TARGET == HWY_SCALAR + // First/Last use mirrored loads of up to +/- kRadius. + return 2 * kRadius; +#else + return Lanes(Simd()) + kRadius; +#endif + } + + // "Image" is ImageF or Image3F. + template + static void Run(const Image& in, const Rect& rect, const Weights& weights, + ThreadPool* pool, Image* out) { + PROFILER_ZONE("ConvolveT::Run"); + JXL_CHECK(SameSize(rect, *out)); + JXL_CHECK(rect.xsize() >= MinWidth()); + + static_assert(int64_t(kRadius) <= 3, + "Must handle [0, kRadius) and >= kRadius"); + switch (rect.xsize() % Lanes(Simd())) { + case 0: + return RunRows<0>(in, rect, weights, pool, out); + case 1: + return RunRows<1>(in, rect, weights, pool, out); + case 2: + return RunRows<2>(in, rect, weights, pool, out); + default: + return RunRows<3>(in, rect, weights, pool, out); + } + } + + private: + template + static JXL_INLINE void RunRow(const float* JXL_RESTRICT in, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, const Weights& weights, + float* JXL_RESTRICT out) { + Strategy::template ConvolveRow(in, xsize, stride, wrap_row, + weights, out); + } + + template + static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + const WrapRowMirror wrap_row(in, rect.ysize()); + for (int64_t y = ybegin; y < yend; ++y) { + RunRow(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row, + weights, out->Row(y)); + } + } + + // Image3F. + template + static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + for (int64_t y = ybegin; y < yend; ++y) { + for (size_t c = 0; c < 3; ++c) { + const WrapRowMirror wrap_row(in.Plane(c), rect.ysize()); + RunRow(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, + wrap_row, weights, out->PlaneRow(c, y)); + } + } + } + + template + static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + RunOnPool( + pool, ybegin, yend, ThreadPool::SkipInit(), + [&](const int y, int /*thread*/) HWY_ATTR { + RunRow(rect.ConstRow(in, y), rect.xsize(), stride, + WrapRowUnchanged(), weights, out->Row(y)); + }, + "Convolve"); + } + + // Image3F. + template + static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + RunOnPool( + pool, ybegin, yend, ThreadPool::SkipInit(), + [&](const int y, int /*thread*/) HWY_ATTR { + for (size_t c = 0; c < 3; ++c) { + RunRow(rect.ConstPlaneRow(in, c, y), rect.xsize(), + stride, WrapRowUnchanged(), weights, + out->PlaneRow(c, y)); + } + }, + "Convolve3"); + } + + template + static JXL_INLINE void RunRows(const Image& in, const Rect& rect, + const Weights& weights, ThreadPool* pool, + Image* out) { + const int64_t ysize = rect.ysize(); + RunBorderRows(in, rect, 0, std::min(int64_t(kRadius), ysize), + weights, out); + if (ysize > 2 * int64_t(kRadius)) { + RunInteriorRows(in, rect, int64_t(kRadius), + ysize - int64_t(kRadius), weights, pool, out); + } + if (ysize > int64_t(kRadius)) { + RunBorderRows(in, rect, ysize - int64_t(kRadius), ysize, + weights, out); + } + } +}; + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSymmetric3(in, rect, weights, pool, out); +} + +// Symmetric5 is implemented above without ConvolveT. + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable5(in, rect, weights, pool, out); +} +void Separable5_3(const Image3F& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + Image3F* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable5(in, rect, weights, pool, out); +} + +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable7(in, rect, weights, pool, out); +} +void Separable7_3(const Image3F& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + Image3F* out) { + using Conv = ConvolveT; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable7(in, rect, weights, pool, out); +} + +// Semi-vectorized (interior pixels Fonly); called directly like slow::, unlike +// the fully vectorized strategies below. +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + + const size_t ysize = rect.ysize(); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t iy = task; + + if (iy < 2 || iy >= static_cast(ysize) - 2) { + Symmetric5BorderRow(in, rect, iy, weights, out->Row(iy)); + } else { + Symmetric5Row(in, rect, iy, weights, out->Row(iy)); + } + }, + "Symmetric5x5Convolution"); +} + +void Symmetric5_3(const Image3F& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + Image3F* JXL_RESTRICT out) { + PROFILER_FUNC; + + const size_t ysize = rect.ysize(); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t iy = task; + + if (iy < 2 || iy >= ysize - 2) { + for (size_t c = 0; c < 3; ++c) { + Symmetric5BorderRow(in.Plane(c), rect, iy, weights, + out->PlaneRow(c, iy)); + } + } else { + for (size_t c = 0; c < 3; ++c) { + Symmetric5Row(in.Plane(c), rect, iy, weights, + out->PlaneRow(c, iy)); + } + } + }, + "Symmetric5x5Convolution3"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Symmetric3); +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Symmetric3)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Symmetric5); +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(Symmetric5)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Symmetric5_3); +void Symmetric5_3(const Image3F& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + Image3F* JXL_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(Symmetric5_3)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Separable5); +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Separable5)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Separable5_3); +void Separable5_3(const Image3F& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + Image3F* out) { + return HWY_DYNAMIC_DISPATCH(Separable5_3)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Separable7); +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Separable7)(in, rect, weights, pool, out); +} + +HWY_EXPORT(Separable7_3); +void Separable7_3(const Image3F& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + Image3F* out) { + return HWY_DYNAMIC_DISPATCH(Separable7_3)(in, rect, weights, pool, out); +} + +//------------------------------------------------------------------------------ +// Kernels + +// Concentrates energy in low-frequency components (e.g. for antialiasing). +const WeightsSymmetric3& WeightsSymmetric3Lowpass() { + // Computed by research/convolve_weights.py's cubic spline approximations of + // prolate spheroidal wave functions. + constexpr float w0 = 0.36208932f; + constexpr float w1 = 0.12820096f; + constexpr float w2 = 0.03127668f; + static constexpr WeightsSymmetric3 weights = { + {HWY_REP4(w0)}, {HWY_REP4(w1)}, {HWY_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Lowpass() { + constexpr float w0 = 0.41714928f; + constexpr float w1 = 0.25539268f; + constexpr float w2 = 0.03603267f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +const WeightsSymmetric5& WeightsSymmetric5Lowpass() { + static constexpr WeightsSymmetric5 weights = { + {HWY_REP4(0.1740135f)}, {HWY_REP4(0.1065369f)}, {HWY_REP4(0.0150310f)}, + {HWY_REP4(0.0652254f)}, {HWY_REP4(0.0012984f)}, {HWY_REP4(0.0092025f)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian1() { + constexpr float w0 = 0.38774f; + constexpr float w1 = 0.24477f; + constexpr float w2 = 0.06136f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian2() { + constexpr float w0 = 0.250301f; + constexpr float w1 = 0.221461f; + constexpr float w2 = 0.153388f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +//------------------------------------------------------------------------------ +// Slow + +namespace { + +template +float SlowSymmetric3Pixel(const ImageF& in, const int64_t ix, const int64_t iy, + const int64_t xsize, const int64_t ysize, + const WeightsSymmetric3& weights) { + float sum = 0.0f; + + // ix: image; kx: kernel + for (int64_t ky = -1; ky <= 1; ky++) { + const int64_t y = WrapY()(iy + ky, ysize); + const float* JXL_RESTRICT row_in = in.ConstRow(static_cast(y)); + + const float wc = ky == 0 ? weights.c[0] : weights.r[0]; + const float wlr = ky == 0 ? weights.r[0] : weights.d[0]; + + const int64_t xm1 = WrapX()(ix - 1, xsize); + const int64_t xp1 = WrapX()(ix + 1, xsize); + sum += row_in[ix] * wc + (row_in[xm1] + row_in[xp1]) * wlr; + } + return sum; +} + +template +void SlowSymmetric3Row(const ImageF& in, const int64_t iy, const int64_t xsize, + const int64_t ysize, const WeightsSymmetric3& weights, + float* JXL_RESTRICT row_out) { + row_out[0] = + SlowSymmetric3Pixel(in, 0, iy, xsize, ysize, weights); + for (int64_t ix = 1; ix < xsize - 1; ix++) { + row_out[ix] = SlowSymmetric3Pixel(in, ix, iy, xsize, + ysize, weights); + } + { + const int64_t ix = xsize - 1; + row_out[ix] = SlowSymmetric3Pixel(in, ix, iy, xsize, + ysize, weights); + } +} + +} // namespace + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + + const int64_t xsize = static_cast(rect.xsize()); + const int64_t ysize = static_cast(rect.ysize()); + const int64_t kRadius = 1; + + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t iy = task; + float* JXL_RESTRICT out_row = out->Row(static_cast(iy)); + + if (iy < kRadius || iy >= ysize - kRadius) { + SlowSymmetric3Row(in, iy, xsize, ysize, weights, out_row); + } else { + SlowSymmetric3Row(in, iy, xsize, ysize, weights, + out_row); + } + }, + "SlowSymmetric3"); +} + +void SlowSymmetric3(const Image3F& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + Image3F* JXL_RESTRICT out) { + PROFILER_FUNC; + + const int64_t xsize = static_cast(rect.xsize()); + const int64_t ysize = static_cast(rect.ysize()); + const int64_t kRadius = 1; + + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t iy = task; + const size_t oy = static_cast(iy); + + if (iy < kRadius || iy >= ysize - kRadius) { + for (size_t c = 0; c < 3; ++c) { + SlowSymmetric3Row(in.Plane(c), iy, xsize, ysize, + weights, out->PlaneRow(c, oy)); + } + } else { + for (size_t c = 0; c < 3; ++c) { + SlowSymmetric3Row(in.Plane(c), iy, xsize, ysize, + weights, out->PlaneRow(c, oy)); + } + } + }, + "SlowSymmetric3"); +} + +namespace { + +// Separable kernels, any radius. +float SlowSeparablePixel(const ImageF& in, const Rect& rect, const int64_t x, + const int64_t y, const int64_t radius, + const float* JXL_RESTRICT horz_weights, + const float* JXL_RESTRICT vert_weights) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + const WrapMirror wrap; + + float mul = 0.0f; + for (int dy = -radius; dy <= radius; ++dy) { + const float wy = vert_weights[std::abs(dy) * 4]; + const size_t sy = wrap(y + dy, ysize); + JXL_CHECK(sy < ysize); + const float* const JXL_RESTRICT row = rect.ConstRow(in, sy); + for (int dx = -radius; dx <= radius; ++dx) { + const float wx = horz_weights[std::abs(dx) * 4]; + const size_t sx = wrap(x + dx, xsize); + JXL_CHECK(sx < xsize); + mul += row[sx] * wx * wy; + } + } + return mul; +} + +} // namespace + +void SlowSeparable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + PROFILER_FUNC; + const float* horz_weights = &weights.horz[0]; + const float* vert_weights = &weights.vert[0]; + + const size_t ysize = rect.ysize(); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + + float* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row_out[x] = SlowSeparablePixel(in, rect, x, y, /*radius=*/2, + horz_weights, vert_weights); + } + }, + "SlowSeparable5"); +} + +void SlowSeparable5(const Image3F& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + Image3F* out) { + for (size_t c = 0; c < 3; ++c) { + SlowSeparable5(in.Plane(c), rect, weights, pool, &out->Plane(c)); + } +} + +void SlowSeparable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out) { + PROFILER_FUNC; + const float* horz_weights = &weights.horz[0]; + const float* vert_weights = &weights.vert[0]; + + const size_t ysize = rect.ysize(); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + + float* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row_out[x] = SlowSeparablePixel(in, rect, x, y, /*radius=*/3, + horz_weights, vert_weights); + } + }, + "SlowSeparable7"); +} + +void SlowSeparable7(const Image3F& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + Image3F* out) { + for (size_t c = 0; c < 3; ++c) { + SlowSeparable7(in.Plane(c), rect, weights, pool, &out->Plane(c)); + } +} + +void SlowLaplacian5(const ImageF& in, const Rect& rect, ThreadPool* pool, + ImageF* out) { + PROFILER_FUNC; + JXL_CHECK(SameSize(rect, *out)); + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + const WrapMirror wrap; + + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + + const float* const JXL_RESTRICT row_t = + rect.ConstRow(in, wrap(y - 2, ysize)); + const float* const JXL_RESTRICT row_m = rect.ConstRow(in, y); + const float* const JXL_RESTRICT row_b = + rect.ConstRow(in, wrap(y + 2, ysize)); + float* const JXL_RESTRICT row_out = out->Row(y); + + for (int64_t x = 0; static_cast(x) < xsize; ++x) { + const int64_t xm2 = wrap(x - 2, xsize); + const int64_t xp2 = wrap(x + 2, xsize); + float r = 0.0f; + r += /* */ 1.0f * row_t[x]; + r += 1.0f * row_m[xm2] - 4.0f * row_m[x] + 1.0f * row_m[xp2]; + r += /* */ 1.0f * row_b[x]; + row_out[x] = r; + } + }, + "SlowLaplacian5"); +} + +void SlowLaplacian5(const Image3F& in, const Rect& rect, ThreadPool* pool, + Image3F* out) { + for (size_t c = 0; c < 3; ++c) { + SlowLaplacian5(in.Plane(c), rect, pool, &out->Plane(c)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/convolve.h b/third_party/jpeg-xl/lib/jxl/convolve.h new file mode 100644 index 000000000000..27bb7990aa82 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve.h @@ -0,0 +1,140 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_CONVOLVE_H_ +#define LIB_JXL_CONVOLVE_H_ + +// 2D convolution. + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// No valid values outside [0, xsize), but the strategy may still safely load +// the preceding vector, and/or round xsize up to the vector lane count. This +// avoids needing PadImage. +// Requires xsize >= kConvolveLanes + kConvolveMaxRadius. +static constexpr size_t kConvolveMaxRadius = 3; + +// Weights must already be normalized. + +struct WeightsSymmetric3 { + // d r d (each replicated 4x) + // r c r + // d r d + float c[4]; + float r[4]; + float d[4]; +}; + +struct WeightsSymmetric5 { + // The lower-right quadrant is: c r R (each replicated 4x) + // r d L + // R L D + float c[4]; + float r[4]; + float R[4]; + float d[4]; + float D[4]; + float L[4]; +}; + +// Weights for separable 5x5 filters (typically but not necessarily the same +// values for horizontal and vertical directions). The kernel must already be +// normalized, but note that values for negative offsets are omitted, so the +// given values do not sum to 1. +struct WeightsSeparable5 { + // Horizontal 1D, distances 0..2 (each replicated 4x) + float horz[3 * 4]; + float vert[3 * 4]; +}; + +// Weights for separable 7x7 filters (typically but not necessarily the same +// values for horizontal and vertical directions). The kernel must already be +// normalized, but note that values for negative offsets are omitted, so the +// given values do not sum to 1. +// +// NOTE: for >= 7x7 Gaussian kernels, it is faster to use FastGaussian instead, +// at least when images exceed the L1 cache size. +struct WeightsSeparable7 { + // Horizontal 1D, distances 0..3 (each replicated 4x) + float horz[4 * 4]; + float vert[4 * 4]; +}; + +const WeightsSymmetric3& WeightsSymmetric3Lowpass(); +const WeightsSeparable5& WeightsSeparable5Lowpass(); +const WeightsSymmetric5& WeightsSymmetric5Lowpass(); + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); +void SlowSymmetric3(const Image3F& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + Image3F* JXL_RESTRICT out); + +void SlowSeparable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out); +void SlowSeparable5(const Image3F& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + Image3F* out); + +void SlowSeparable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out); +void SlowSeparable7(const Image3F& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + Image3F* out); + +void SlowLaplacian5(const ImageF& in, const Rect& rect, ThreadPool* pool, + ImageF* out); +void SlowLaplacian5(const Image3F& in, const Rect& rect, ThreadPool* pool, + Image3F* out); + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out); + +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); + +void Symmetric5_3(const Image3F& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + Image3F* JXL_RESTRICT out); + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out); + +void Separable5_3(const Image3F& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + Image3F* out); + +void Separable7(const ImageF& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + ImageF* out); + +void Separable7_3(const Image3F& in, const Rect& rect, + const WeightsSeparable7& weights, ThreadPool* pool, + Image3F* out); + +} // namespace jxl + +#endif // LIB_JXL_CONVOLVE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/convolve_test.cc b/third_party/jpeg-xl/lib/jxl/convolve_test.cc new file mode 100644 index 000000000000..bd355c55de13 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_test.cc @@ -0,0 +1,259 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/convolve.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_test.cc" +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +#ifndef JXL_DEBUG_CONVOLVE +#define JXL_DEBUG_CONVOLVE 0 +#endif + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +void TestNeighbors() { + const Neighbors::D d; + const Neighbors::V v = Iota(d, 0); + HWY_ALIGN float actual[hwy::kTestMaxVectorSize / sizeof(float)] = {0}; + + HWY_ALIGN float first_l1[hwy::kTestMaxVectorSize / sizeof(float)] = { + 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; + Store(Neighbors::FirstL1(v), d, actual); + const size_t N = Lanes(d); + EXPECT_EQ(std::vector(first_l1, first_l1 + N), + std::vector(actual, actual + N)); + +#if HWY_TARGET != HWY_SCALAR + HWY_ALIGN float first_l2[hwy::kTestMaxVectorSize / sizeof(float)] = { + 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}; + Store(Neighbors::FirstL2(v), d, actual); + EXPECT_EQ(std::vector(first_l2, first_l2 + N), + std::vector(actual, actual + N)); + + HWY_ALIGN float first_l3[hwy::kTestMaxVectorSize / sizeof(float)] = { + 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + Store(Neighbors::FirstL3(v), d, actual); + EXPECT_EQ(std::vector(first_l3, first_l3 + N), + std::vector(actual, actual + N)); +#endif // HWY_TARGET != HWY_SCALAR +} + +template +void VerifySymmetric3(const size_t xsize, const size_t ysize, ThreadPool* pool, + Random* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(GeneratorRandom(rng, 1.0f), &in); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSymmetric3& weights = WeightsSymmetric3Lowpass(); + Symmetric3(in, rect, weights, pool, &out_expected); + SlowSymmetric3(in, rect, weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +// Ensures Symmetric and Separable give the same result. +template +void VerifySymmetric5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Random* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(GeneratorRandom(rng, 1.0f), &in); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + Separable5(in, Rect(in), WeightsSeparable5Lowpass(), pool, &out_expected); + Symmetric5(in, rect, WeightsSymmetric5Lowpass(), pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +template +void VerifySeparable5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Random* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(GeneratorRandom(rng, 1.0f), &in); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSeparable5& weights = WeightsSeparable5Lowpass(); + Separable5(in, Rect(in), weights, pool, &out_expected); + SlowSeparable5(in, rect, weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +template +void VerifySeparable7(const size_t xsize, const size_t ysize, ThreadPool* pool, + Random* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(GeneratorRandom(rng, 1.0f), &in); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + // Gaussian sigma 1.0 + const WeightsSeparable7 weights = {{HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}, + {HWY_REP4(0.383103f), HWY_REP4(0.241843f), + HWY_REP4(0.060626f), HWY_REP4(0.00598f)}}; + + SlowSeparable7(in, rect, weights, pool, &out_expected); + Separable7(in, Rect(in), weights, pool, &out_actual); + + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f); +} + +// For all xsize/ysize and kernels: +void TestConvolve() { + TestNeighbors(); + + ThreadPoolInternal pool(4); + pool.Run(kConvolveMaxRadius, 40, ThreadPool::SkipInit(), + [](const int task, int /*thread*/) { + const size_t xsize = task; + std::mt19937_64 rng(129 + 13 * xsize); + + ThreadPool* null_pool = nullptr; + ThreadPoolInternal pool3(3); + for (size_t ysize = kConvolveMaxRadius; ysize < 16; ++ysize) { + JXL_DEBUG(JXL_DEBUG_CONVOLVE, + "%zu x %zu (target %d)===============================", + xsize, ysize, HWY_TARGET); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym3------------------"); + VerifySymmetric3(xsize, ysize, null_pool, &rng); + VerifySymmetric3(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym5------------------"); + VerifySymmetric5(xsize, ysize, null_pool, &rng); + VerifySymmetric5(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep5------------------"); + VerifySeparable5(xsize, ysize, null_pool, &rng); + VerifySeparable5(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep7------------------"); + VerifySeparable7(xsize, ysize, null_pool, &rng); + VerifySeparable7(xsize, ysize, &pool3, &rng); + } + }); +} + +// Measures durations, verifies results, prints timings. `unpredictable1` +// must have value 1 (unknown to the compiler to prevent elision). +template +void BenchmarkConv(const char* caption, const Conv& conv, + const hwy::FuncInput unpredictable1) { + const size_t kNumInputs = 1; + const hwy::FuncInput inputs[kNumInputs] = {unpredictable1}; + hwy::Result results[kNumInputs]; + + const size_t kDim = 160; // in+out fit in L2 + ImageF in(kDim, kDim); + ZeroFillImage(&in); + in.Row(kDim / 2)[kDim / 2] = unpredictable1; + ImageF out(kDim, kDim); + + hwy::Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&in, &conv, &out](const hwy::FuncInput input) { + conv(in, &out); + return out.Row(input)[0]; + }, + inputs, kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + for (size_t i = 0; i < num_results; ++i) { + const double seconds = static_cast(results[i].ticks) / + hwy::platform::InvariantTicksPerSecond(); + printf("%12s: %7.2f MP/s (MAD=%4.2f%%)\n", caption, + kDim * kDim * 1E-6 / seconds, + static_cast(results[i].variability) * 100.0); + } +} + +struct ConvSymmetric3 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Symmetric3(in, Rect(in), WeightsSymmetric3Lowpass(), null_pool, out); + } +}; + +struct ConvSeparable5 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Separable5(in, Rect(in), WeightsSeparable5Lowpass(), null_pool, out); + } +}; + +void BenchmarkAll() { +#if 0 // disabled to avoid test timeouts, run manually on demand + const hwy::FuncInput unpredictable1 = time(nullptr) != 1234; + BenchmarkConv("Symmetric3", ConvSymmetric3(), unpredictable1); + BenchmarkConv("Separable5", ConvSeparable5(), unpredictable1); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class ConvolveTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(ConvolveTest); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, TestConvolve); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, BenchmarkAll); + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc b/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc new file mode 100644 index 000000000000..7d2b14edff78 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc @@ -0,0 +1,120 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/data_parallel.h" + +#include "gtest/gtest.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/test_utils.h" + +namespace jxl { +namespace { + +class DataParallelTest : public ::testing::Test { + protected: + // A fake class to verify that DataParallel is properly calling the + // client-provided runner functions. + static int FakeRunner(void* runner_opaque, void* jpegxl_opaque, + JxlParallelRunInit init, JxlParallelRunFunction func, + uint32_t start_range, uint32_t end_range) { + DataParallelTest* self = static_cast(runner_opaque); + self->runner_called_++; + self->jpegxl_opaque_ = jpegxl_opaque; + self->init_ = init; + self->func_ = func; + self->start_range_ = start_range; + self->end_range_ = end_range; + return self->runner_return_; + } + + ThreadPool pool_{&DataParallelTest::FakeRunner, this}; + + // Number of times FakeRunner() was called. + int runner_called_ = 0; + + // Parameters passed to FakeRunner. + void* jpegxl_opaque_ = nullptr; + JxlParallelRunInit init_ = nullptr; + JxlParallelRunFunction func_ = nullptr; + uint32_t start_range_ = -1; + uint32_t end_range_ = -1; + + // Return value that FakeRunner will return. + int runner_return_ = 0; +}; + +// JxlParallelRunInit interface. +typedef int (*JxlParallelRunInit)(); +int TestInit(void* jpegxl_opaque, size_t num_threads) { return 0; } + +} // namespace + +TEST_F(DataParallelTest, RunnerCalledParamenters) { + EXPECT_TRUE(pool_.Run( + 1234, 5678, [](const size_t num_threads) { return true; }, + [](const int task, const int thread) { return; })); + EXPECT_EQ(1, runner_called_); + EXPECT_NE(nullptr, init_); + EXPECT_NE(nullptr, func_); + EXPECT_NE(nullptr, jpegxl_opaque_); + EXPECT_EQ(1234u, start_range_); + EXPECT_EQ(5678u, end_range_); +} + +TEST_F(DataParallelTest, RunnerFailurePropagates) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_FALSE(pool_.Run( + 1234, 5678, [](const size_t num_threads) { return false; }, + [](const int task, const int thread) { return; })); + EXPECT_FALSE(RunOnPool( + nullptr, 1234, 5678, [](const size_t num_threads) { return false; }, + [](const int task, const int thread) { return; }, "Test")); +} + +TEST_F(DataParallelTest, RunnerNotCalledOnEmptyRange) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_TRUE(pool_.Run( + 123, 123, [](const size_t num_threads) { return false; }, + [](const int task, const int thread) { return; })); + EXPECT_TRUE(RunOnPool( + nullptr, 123, 123, [](const size_t num_threads) { return false; }, + [](const int task, const int thread) { return; }, "Test")); + // We don't call the external runner when the range is empty. We don't even + // need to call the init function. + EXPECT_EQ(0, runner_called_); +} + +// The TestDivider is slow when compiled in debug mode. +TEST_F(DataParallelTest, JXL_SLOW_TEST(TestDivider)) { + jxl::ThreadPoolInternal pool(8); + // 1, 2 are powers of two. + pool.Run(3, 2 * 1024, ThreadPool::SkipInit(), + [](const int d, const int thread) { + // powers of two are not supported. + if ((d & (d - 1)) == 0) return; + + const Divider div(d); +#ifdef NDEBUG + const int max_dividend = 4 * 1024 * 1024; +#else + const int max_dividend = 2 * 1024 + 1; +#endif + for (int x = 0; x < max_dividend; ++x) { + const int q = div(x); + ASSERT_EQ(x / d, q) << x << "/" << d; + } + }); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dct-inl.h b/third_party/jpeg-xl/lib/jxl/dct-inl.h new file mode 100644 index 000000000000..775eb1e90bbe --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct-inl.h @@ -0,0 +1,370 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Fast SIMD floating-point (I)DCT, any power of two. + +#if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_INL_H_ +#undef LIB_JXL_DCT_INL_H_ +#else +#define LIB_JXL_DCT_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/dct_block-inl.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/transpose-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +template +struct FVImpl { + using type = HWY_CAPPED(float, SZ); +}; + +template <> +struct FVImpl<0> { + using type = HWY_FULL(float); +}; + +template +using FV = typename FVImpl::type; + +// Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III +// Algorithms, by Siriani M. Perera and Jianhua Liu. + +template +struct CoeffBundle { + static void AddReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV(), ain1 + i * SZ); + auto in2 = Load(FV(), ain2 + (N - i - 1) * SZ); + Store(in1 + in2, FV(), aout + i * SZ); + } + } + static void SubReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV(), ain1 + i * SZ); + auto in2 = Load(FV(), ain2 + (N - i - 1) * SZ); + Store(in1 - in2, FV(), aout + i * SZ); + } + } + static void B(float* JXL_RESTRICT coeff) { + auto sqrt2 = Set(FV(), square_root<2>::value); + auto in1 = Load(FV(), coeff); + auto in2 = Load(FV(), coeff + SZ); + Store(MulAdd(in1, sqrt2, in2), FV(), coeff); + for (size_t i = 1; i + 1 < N; i++) { + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (i + 1) * SZ); + Store(in1 + in2, FV(), coeff + i * SZ); + } + } + static void BTranspose(float* JXL_RESTRICT coeff) { + for (size_t i = N - 1; i > 0; i--) { + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (i - 1) * SZ); + Store(in1 + in2, FV(), coeff + i * SZ); + } + auto sqrt2 = Set(FV(), square_root<2>::value); + auto in1 = Load(FV(), coeff); + Store(in1 * sqrt2, FV(), coeff); + } + // Ideally optimized away by compiler (except the multiply). + static void InverseEvenOdd(const float* JXL_RESTRICT ain, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV(), ain + i * SZ); + Store(in1, FV(), aout + 2 * i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = Load(FV(), ain + i * SZ); + Store(in1, FV(), aout + (2 * (i - N / 2) + 1) * SZ); + } + } + // Ideally optimized away by compiler. + static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = LoadU(FV(), ain + 2 * i * ain_stride); + Store(in1, FV(), aout + i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = LoadU(FV(), ain + (2 * (i - N / 2) + 1) * ain_stride); + Store(in1, FV(), aout + i * SZ); + } + } + // Invoked on full vector. + static void Multiply(float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV(), coeff + (N / 2 + i) * SZ); + auto mul = Set(FV(), WcMultipliers::kMultipliers[i]); + Store(in1 * mul, FV(), coeff + (N / 2 + i) * SZ); + } + } + static void MultiplyAndAdd(const float* JXL_RESTRICT coeff, + float* JXL_RESTRICT out, size_t out_stride) { + for (size_t i = 0; i < N / 2; i++) { + auto mul = Set(FV(), WcMultipliers::kMultipliers[i]); + auto in1 = Load(FV(), coeff + i * SZ); + auto in2 = Load(FV(), coeff + (N / 2 + i) * SZ); + auto out1 = MulAdd(mul, in2, in1); + auto out2 = NegMulAdd(mul, in2, in1); + StoreU(out1, FV(), out + i * out_stride); + StoreU(out2, FV(), out + (N - i - 1) * out_stride); + } + } + template + static void LoadFromBlock(const Block& in, size_t off, + float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N; i++) { + Store(in.LoadPart(FV(), i, off), FV(), coeff + i * SZ); + } + } + template + static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, + const Block& out, size_t off) { + auto mul = Set(FV(), 1.0f / N); + for (size_t i = 0; i < N; i++) { + out.StorePart(FV(), mul * Load(FV(), coeff + i * SZ), i, off); + } + } +}; + +template +struct DCT1DImpl; + +template +struct DCT1DImpl<1, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem) {} +}; + +template +struct DCT1DImpl<2, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem) { + auto in1 = Load(FV(), mem); + auto in2 = Load(FV(), mem + SZ); + Store(in1 + in2, FV(), mem); + Store(in1 - in2, FV(), mem + SZ); + } +}; + +template +struct DCT1DImpl { + void operator()(float* JXL_RESTRICT mem) { + // This is relatively small (4kB with 64-DCT and AVX-512) + HWY_ALIGN float tmp[N * SZ]; + CoeffBundle::AddReverse(mem, mem + N / 2 * SZ, tmp); + DCT1DImpl()(tmp); + CoeffBundle::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ); + CoeffBundle::Multiply(tmp); + DCT1DImpl()(tmp + N / 2 * SZ); + CoeffBundle::B(tmp + N / 2 * SZ); + CoeffBundle::InverseEvenOdd(tmp, mem); + } +}; + +template +struct IDCT1DImpl; + +template +struct IDCT1DImpl<1, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + StoreU(LoadU(FV(), from), FV(), to); + } +}; + +template +struct IDCT1DImpl<2, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + auto in1 = LoadU(FV(), from); + auto in2 = LoadU(FV(), from + from_stride); + StoreU(in1 + in2, FV(), to); + StoreU(in1 - in2, FV(), to + to_stride); + } +}; + +template +struct IDCT1DImpl { + void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + // This is relatively small (4kB with 64-DCT and AVX-512) + HWY_ALIGN float tmp[N * SZ]; + CoeffBundle::ForwardEvenOdd(from, from_stride, tmp); + IDCT1DImpl()(tmp, SZ, tmp, SZ); + CoeffBundle::BTranspose(tmp + N / 2 * SZ); + IDCT1DImpl()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ); + CoeffBundle::MultiplyAndAdd(tmp, to, to_stride); + } +}; + +template +void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV()); + HWY_ALIGN float tmp[N * SZ]; + for (size_t i = 0; i < M; i += Lanes(FV())) { + // TODO(veluca): consider removing the temprorary memory here (as is done in + // IDCT), if it turns out that some compilers don't optimize away the loads + // and this is performance-critical. + CoeffBundle::LoadFromBlock(from, i, tmp); + DCT1DImpl()(tmp); + CoeffBundle::StoreToBlockAndScale(tmp, to, i); + } +} + +template +void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV()); + for (size_t i = 0; i < M; i += Lanes(FV())) { + IDCT1DImpl()(from.Address(0, i), from.Stride(), to.Address(0, i), + to.Stride()); + } +} + +template +struct DCT1D { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return DCT1DWrapper(from, to, M); + } +}; + +template +struct DCT1D MaxLanes(FV<0>()))>::type> { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return NoInlineWrapper(DCT1DWrapper, from, to, M); + } +}; + +template +struct IDCT1D { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return IDCT1DWrapper(from, to, M); + } +}; + +template +struct IDCT1D MaxLanes(FV<0>()))>::type> { + template + void operator()(const FromBlock& from, const ToBlock& to) { + return NoInlineWrapper(IDCT1DWrapper, from, to, + M); + } +}; + +// Computes the in-place NxN transposed-scaled-DCT (tsDCT) of block. +// Requires that block is HWY_ALIGN'ed. +// +// See also DCTSlow, ComputeDCT +template +struct ComputeTransposedScaledDCT { + // scratch_space must be aligned, and should have space for N*N floats. + template + HWY_MAYBE_UNUSED void operator()(const From& from, float* JXL_RESTRICT to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + DCT1D()(from, DCTTo(to, N)); + Transpose::Run(DCTFrom(to, N), DCTTo(block, N)); + DCT1D()(DCTFrom(block, N), DCTTo(to, N)); + } +}; + +// Computes the in-place NxN transposed-scaled-iDCT (tsIDCT)of block. +// Requires that block is HWY_ALIGN'ed. +// +// See also IDCTSlow, ComputeIDCT. + +template +struct ComputeTransposedScaledIDCT { + // scratch_space must be aligned, and should have space for N*N floats. + template + HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + IDCT1D()(DCTFrom(from, N), DCTTo(block, N)); + Transpose::Run(DCTFrom(block, N), DCTTo(from, N)); + IDCT1D()(DCTFrom(from, N), to); + } +}; +// Computes the non-transposed, scaled DCT of a block, that needs to be +// HWY_ALIGN'ed. Used for rectangular blocks. +template +struct ComputeScaledDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template + HWY_MAYBE_UNUSED void operator()(const From& from, float* to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + if (ROWS < COLS) { + DCT1D()(from, DCTTo(block, COLS)); + Transpose::Run(DCTFrom(block, COLS), DCTTo(to, ROWS)); + DCT1D()(DCTFrom(to, ROWS), DCTTo(block, ROWS)); + Transpose::Run(DCTFrom(block, ROWS), DCTTo(to, COLS)); + } else { + DCT1D()(from, DCTTo(to, COLS)); + Transpose::Run(DCTFrom(to, COLS), DCTTo(block, ROWS)); + DCT1D()(DCTFrom(block, ROWS), DCTTo(to, ROWS)); + } + } +}; +// Computes the non-transposed, scaled DCT of a block, that needs to be +// HWY_ALIGN'ed. Used for rectangular blocks. +template +struct ComputeScaledIDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template + HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + // Reverse the steps done in ComputeScaledDCT. + if (ROWS < COLS) { + Transpose::Run(DCTFrom(from, COLS), DCTTo(block, ROWS)); + IDCT1D()(DCTFrom(block, ROWS), DCTTo(from, ROWS)); + Transpose::Run(DCTFrom(from, ROWS), DCTTo(block, COLS)); + IDCT1D()(DCTFrom(block, COLS), to); + } else { + IDCT1D()(DCTFrom(from, ROWS), DCTTo(block, ROWS)); + Transpose::Run(DCTFrom(block, ROWS), DCTTo(from, COLS)); + IDCT1D()(DCTFrom(from, COLS), to); + } + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_DCT_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_block-inl.h b/third_party/jpeg-xl/lib/jxl/dct_block-inl.h new file mode 100644 index 000000000000..785a0bcbe4b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_block-inl.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Adapters for DCT input/output: from/to contiguous blocks or image rows. + +#if defined(LIB_JXL_DCT_BLOCK_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_BLOCK_INL_H_ +#undef LIB_JXL_DCT_BLOCK_INL_H_ +#else +#define LIB_JXL_DCT_BLOCK_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/base/status.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +// Block: (x, y) <-> (N * y + x) +// Lines: (x, y) <-> (stride * y + x) +// +// I.e. Block is a specialization of Lines with fixed stride. +// +// FromXXX should implement Read and Load (Read vector). +// ToXXX should implement Write and Store (Write vector). + +template +using BlockDesc = HWY_CAPPED(float, N); + +// Here and in the following, the SZ template parameter specifies the number of +// values to load/store. Needed because we want to handle 4x4 sub-blocks of +// 16x16 blocks. +class DCTFrom { + public: + DCTFrom(const float* data, size_t stride) : stride_(stride), data_(data) {} + + template + HWY_INLINE Vec LoadPart(D, const size_t row, size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + return LoadU(D(), Address(row, i)); + } + + HWY_INLINE float Read(const size_t row, const size_t i) const { + return *Address(row, i); + } + + constexpr HWY_INLINE const float* Address(const size_t row, + const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + const float* JXL_RESTRICT data_; +}; + +class DCTTo { + public: + DCTTo(float* data, size_t stride) : stride_(stride), data_(data) {} + + template + HWY_INLINE void StorePart(D, const Vec& v, const size_t row, + size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + StoreU(v, D(), Address(row, i)); + } + + HWY_INLINE void Write(float v, const size_t row, const size_t i) const { + *Address(row, i) = v; + } + + constexpr HWY_INLINE float* Address(const size_t row, const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + float* JXL_RESTRICT data_; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DCT_BLOCK_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_for_test.h b/third_party/jpeg-xl/lib/jxl/dct_for_test.h new file mode 100644 index 000000000000..0e4818546723 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_for_test.h @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DCT_FOR_TEST_H_ +#define LIB_JXL_DCT_FOR_TEST_H_ + +// Unoptimized DCT only for use in tests. + +#include // memcpy + +#include +#include + +#include "lib/jxl/common.h" // Pi + +namespace jxl { + +namespace test { +static inline double alpha(int u) { return u == 0 ? 0.7071067811865475 : 1.0; } + +// N-DCT on M columns, divided by sqrt(N). Matches the definition in the spec. +template +void DCT1D(double block[N * M], double out[N * M]) { + std::vector matrix(N * N); + const double scale = std::sqrt(2.0) / N; + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + matrix[N * u + y] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +// N-IDCT on M columns, multiplied by sqrt(N). Matches the definition in the +// spec. +template +void IDCT1D(double block[N * M], double out[N * M]) { + std::vector matrix(N * N); + const double scale = std::sqrt(2.0); + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + // Transpose of DCT matrix. + matrix[N * y + u] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +template +void TransposeBlock(double in[N * M], double out[M * N]) { + for (size_t x = 0; x < N; x++) { + for (size_t y = 0; y < M; y++) { + out[y * N + x] = in[x * M + y]; + } + } +} +} // namespace test + +// Untransposed DCT. +template +void DCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector g(kBlockSize); + test::DCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); + test::DCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); +} + +// Untransposed IDCT. +template +void IDCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector g(kBlockSize); + test::IDCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); + test::IDCT1D(block, g.data()); + test::TransposeBlock(g.data(), block); +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_FOR_TEST_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_scales.cc b/third_party/jpeg-xl/lib/jxl/dct_scales.cc new file mode 100644 index 000000000000..d3f1fb6300ef --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_scales.cc @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dct_scales.h" + +namespace jxl { + +// Definition of constexpr arrays. +constexpr float DCTResampleScales<1, 8>::kScales[]; +constexpr float DCTResampleScales<2, 16>::kScales[]; +constexpr float DCTResampleScales<4, 32>::kScales[]; +constexpr float DCTResampleScales<8, 64>::kScales[]; +constexpr float DCTResampleScales<16, 128>::kScales[]; +constexpr float DCTResampleScales<32, 256>::kScales[]; +constexpr float DCTResampleScales<8, 1>::kScales[]; +constexpr float DCTResampleScales<16, 2>::kScales[]; +constexpr float DCTResampleScales<32, 4>::kScales[]; +constexpr float DCTResampleScales<64, 8>::kScales[]; +constexpr float DCTResampleScales<128, 16>::kScales[]; +constexpr float DCTResampleScales<256, 32>::kScales[]; +constexpr float WcMultipliers<4>::kMultipliers[]; +constexpr float WcMultipliers<8>::kMultipliers[]; +constexpr float WcMultipliers<16>::kMultipliers[]; +constexpr float WcMultipliers<32>::kMultipliers[]; +constexpr float WcMultipliers<64>::kMultipliers[]; +constexpr float WcMultipliers<128>::kMultipliers[]; +constexpr float WcMultipliers<256>::kMultipliers[]; + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dct_scales.h b/third_party/jpeg-xl/lib/jxl/dct_scales.h new file mode 100644 index 000000000000..1e7efe7c9b01 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_scales.h @@ -0,0 +1,399 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DCT_SCALES_H_ +#define LIB_JXL_DCT_SCALES_H_ + +// Scaling factors. + +#include + +namespace jxl { +template +struct square_root { + static constexpr float value = square_root::value * 2; +}; + +template <> +struct square_root<1> { + static constexpr float value = 1.0f; +}; + +template <> +struct square_root<2> { + static constexpr float value = 1.4142135623730951f; +}; + +// For n != 0, the n-th basis function of a N-DCT, evaluated in pixel k, has a +// value of cos((k+1/2) n/(2N) pi). When downsampling by 2x, we average +// the values for pixel k and k+1 to get the value for pixel (k/2), thus we get +// +// [cos((k+1/2) n/N pi) + cos((k+3/2) n/N pi)]/2 = +// cos(n/(2N) pi) cos((k+1) n/N pi) = +// cos(n/(2N) pi) cos(((k/2)+1/2) n/(N/2) pi) +// +// which is exactly the same as the value of pixel k/2 of a N/2-sized DCT, +// except for the cos(n/(2N) pi) scaling factor (which does *not* +// depend on the pixel). Thus, when using the lower-frequency coefficients of a +// DCT-N to compute a DCT-(N/2), they should be scaled by this constant. Scaling +// factors for a DCT-(N/4) etc can then be obtained by successive +// multiplications. The structs below contain the above-mentioned scaling +// factors. +// +// Python code for the tables below: +// +// for i in range(N // 8): +// v = math.cos(i / (2 * N) * math.pi) +// v *= math.cos(i / (N) * math.pi) +// v *= math.cos(i / (N / 2) * math.pi) +// print(v, end=", ") + +template +struct DCTResampleScales; + +template <> +struct DCTResampleScales<8, 1> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<16, 2> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.901764195028874394, + }; +}; + +template <> +struct DCTResampleScales<32, 4> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.974886821136879522, + 0.901764195028874394, + 0.787054918159101335, + }; +}; + +template <> +struct DCTResampleScales<64, 8> { + static constexpr float kScales[] = { + 1.0000000000000000, 0.9936866130906366, 0.9748868211368796, + 0.9440180941651672, 0.9017641950288744, 0.8490574973847023, + 0.7870549181591013, 0.7171081282466044, + }; +}; + +template <> +struct DCTResampleScales<128, 16> { + static constexpr float kScales[] = { + 1.0, + 0.9984194528776054, + 0.9936866130906366, + 0.9858278282666936, + 0.9748868211368796, + 0.9609244059440204, + 0.9440180941651672, + 0.9242615922757944, + 0.9017641950288744, + 0.8766500784429904, + 0.8490574973847023, + 0.8191378932865928, + 0.7870549181591013, + 0.7529833816270532, + 0.7171081282466044, + 0.6796228528314651, + }; +}; + +template <> +struct DCTResampleScales<256, 32> { + static constexpr float kScales[] = { + 1.0, + 0.9996047255830407, + 0.9984194528776054, + 0.9964458326264695, + 0.9936866130906366, + 0.9901456355893141, + 0.9858278282666936, + 0.9807391980963174, + 0.9748868211368796, + 0.9682788310563117, + 0.9609244059440204, + 0.9528337534340876, + 0.9440180941651672, + 0.9344896436056892, + 0.9242615922757944, + 0.913348084400198, + 0.9017641950288744, + 0.8895259056651056, + 0.8766500784429904, + 0.8631544288990163, + 0.8490574973847023, + 0.8343786191696513, + 0.8191378932865928, + 0.8033561501721485, + 0.7870549181591013, + 0.7702563888779096, + 0.7529833816270532, + 0.7352593067735488, + 0.7171081282466044, + 0.6985543251889097, + 0.6796228528314651, + 0.6603391026591464, + }; +}; + +// Inverses of the above. +template <> +struct DCTResampleScales<1, 8> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<2, 16> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.108937353592731823, + }; +}; + +template <> +struct DCTResampleScales<4, 32> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.025760096781116015, + 1.108937353592731823, + 1.270559368765487251, + }; +}; + +template <> +struct DCTResampleScales<8, 64> { + static constexpr float kScales[] = { + 1.0000000000000000, 1.0063534990068217, 1.0257600967811158, + 1.0593017296817173, 1.1089373535927318, 1.1777765381970435, + 1.2705593687654873, 1.3944898413647777, + }; +}; + +template <> +struct DCTResampleScales<16, 128> { + static constexpr float kScales[] = { + 1.0, + 1.0015830492062623, + 1.0063534990068217, + 1.0143759095928793, + 1.0257600967811158, + 1.0406645869480142, + 1.0593017296817173, + 1.0819447744633812, + 1.1089373535927318, + 1.1407059950032632, + 1.1777765381970435, + 1.2207956782315876, + 1.2705593687654873, + 1.3280505578213306, + 1.3944898413647777, + 1.4714043176061107, + }; +}; + +template <> +struct DCTResampleScales<32, 256> { + static constexpr float kScales[] = { + 1.0, + 1.0003954307206069, + 1.0015830492062623, + 1.0035668445360069, + 1.0063534990068217, + 1.009952439375063, + 1.0143759095928793, + 1.0196390660647288, + 1.0257600967811158, + 1.0327603660498115, + 1.0406645869480142, + 1.049501024072585, + 1.0593017296817173, + 1.0701028169146336, + 1.0819447744633812, + 1.0948728278734026, + 1.1089373535927318, + 1.124194353004584, + 1.1407059950032632, + 1.158541237256391, + 1.1777765381970435, + 1.1984966740820495, + 1.2207956782315876, + 1.244777922949508, + 1.2705593687654873, + 1.2982690107339132, + 1.3280505578213306, + 1.3600643892400104, + 1.3944898413647777, + 1.4315278911623237, + 1.4714043176061107, + 1.5143734423314616, + }; +}; + +// Constants for DCT implementation. Generated by the following snippet: +// for i in range(N // 2): +// print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ") +template +struct WcMultipliers; + +template <> +struct WcMultipliers<4> { + static constexpr float kMultipliers[] = { + 0.541196100146197, + 1.3065629648763764, + }; +}; + +template <> +struct WcMultipliers<8> { + static constexpr float kMultipliers[] = { + 0.5097955791041592, + 0.6013448869350453, + 0.8999762231364156, + 2.5629154477415055, + }; +}; + +template <> +struct WcMultipliers<16> { + static constexpr float kMultipliers[] = { + 0.5024192861881557, 0.5224986149396889, 0.5669440348163577, + 0.6468217833599901, 0.7881546234512502, 1.060677685990347, + 1.7224470982383342, 5.101148618689155, + }; +}; + +template <> +struct WcMultipliers<32> { + static constexpr float kMultipliers[] = { + 0.5006029982351963, 0.5054709598975436, 0.5154473099226246, + 0.5310425910897841, 0.5531038960344445, 0.5829349682061339, + 0.6225041230356648, 0.6748083414550057, 0.7445362710022986, + 0.8393496454155268, 0.9725682378619608, 1.1694399334328847, + 1.4841646163141662, 2.057781009953411, 3.407608418468719, + 10.190008123548033, + }; +}; +template <> +struct WcMultipliers<64> { + static constexpr float kMultipliers[] = { + 0.500150636020651, 0.5013584524464084, 0.5037887256810443, + 0.5074711720725553, 0.5124514794082247, 0.5187927131053328, + 0.52657731515427, 0.535909816907992, 0.5469204379855088, + 0.5597698129470802, 0.57465518403266, 0.5918185358574165, + 0.6115573478825099, 0.6342389366884031, 0.6603198078137061, + 0.6903721282002123, 0.7251205223771985, 0.7654941649730891, + 0.8127020908144905, 0.8683447152233481, 0.9345835970364075, + 1.0144082649970547, 1.1120716205797176, 1.233832737976571, + 1.3892939586328277, 1.5939722833856311, 1.8746759800084078, + 2.282050068005162, 2.924628428158216, 4.084611078129248, + 6.796750711673633, 20.373878167231453, + }; +}; +template <> +struct WcMultipliers<128> { + static constexpr float kMultipliers[] = { + 0.5000376519155477, 0.5003390374428216, 0.5009427176380873, + 0.5018505174842379, 0.5030651913013697, 0.5045904432216454, + 0.5064309549285542, 0.5085924210498143, 0.5110815927066812, + 0.5139063298475396, 0.5170756631334912, 0.5205998663018917, + 0.524490540114724, 0.5287607092074876, 0.5334249333971333, + 0.538499435291984, 0.5440022463817783, 0.549953374183236, + 0.5563749934898856, 0.5632916653417023, 0.5707305880121454, + 0.5787218851348208, 0.5872989370937893, 0.5964987630244563, + 0.606362462272146, 0.6169357260050706, 0.6282694319707711, + 0.6404203382416639, 0.6534518953751283, 0.6674352009263413, + 0.6824501259764195, 0.6985866506472291, 0.7159464549705746, + 0.7346448236478627, 0.7548129391165311, 0.776600658233963, + 0.8001798956216941, 0.8257487738627852, 0.8535367510066064, + 0.8838110045596234, 0.9168844461846523, 0.9531258743921193, + 0.9929729612675466, 1.036949040910389, 1.0856850642580145, + 1.1399486751015042, 1.2006832557294167, 1.2690611716991191, + 1.346557628206286, 1.4350550884414341, 1.5369941008524954, + 1.6555965242641195, 1.7952052190778898, 1.961817848571166, + 2.163957818751979, 2.4141600002500763, 2.7316450287739396, + 3.147462191781909, 3.7152427383269746, 4.5362909369693565, + 5.827688377844654, 8.153848602466814, 13.58429025728446, + 40.744688103351834, + }; +}; + +template <> +struct WcMultipliers<256> { + static constexpr float kMultipliers[128] = { + 0.5000094125358878, 0.500084723455784, 0.5002354020255269, + 0.5004615618093246, 0.5007633734146156, 0.5011410648064231, + 0.5015949217281668, 0.502125288230386, 0.5027325673091954, + 0.5034172216566842, 0.5041797745258774, 0.5050208107132756, + 0.5059409776624396, 0.5069409866925212, 0.5080216143561264, + 0.509183703931388, 0.5104281670536573, 0.5117559854927805, + 0.5131682130825206, 0.5146659778093218, 0.516250484068288, + 0.5179230150949777, 0.5196849355823947, 0.5215376944933958, + 0.5234828280796439, 0.52552196311921, 0.5276568203859896, + 0.5298892183652453, 0.5322210772308335, 0.5346544231010253, + 0.537191392591309, 0.5398342376841637, 0.5425853309375497, + 0.545447171055775, 0.5484223888484947, 0.551513753605893, + 0.554724179920619, 0.5580567349898085, 0.5615146464335654, + 0.5651013106696203, 0.5688203018875696, 0.5726753816701664, + 0.5766705093136241, 0.5808098529038624, 0.5850978012111273, + 0.58953897647151, 0.5941382481306648, 0.5989007476325463, + 0.6038318843443582, 0.6089373627182432, 0.614223200800649, + 0.6196957502119484, 0.6253617177319102, 0.6312281886412079, + 0.6373026519855411, 0.6435930279473415, 0.6501076975307724, + 0.6568555347890955, 0.6638459418498757, 0.6710888870233562, + 0.6785949463131795, 0.6863753486870501, 0.6944420255086364, + 0.7028076645818034, 0.7114857693151208, 0.7204907235796304, + 0.7298378629074134, 0.7395435527641373, 0.749625274727372, + 0.7601017215162176, 0.7709929019493761, 0.7823202570613161, + 0.7941067887834509, 0.8063772028037925, 0.8191580674598145, + 0.83247799080191, 0.8463678182968619, 0.860860854031955, + 0.8759931087426972, 0.8918035785352535, 0.9083345588266809, + 0.9256319988042384, 0.9437459026371479, 0.962730784794803, + 0.9826461881778968, 1.0035572754078206, 1.0255355056139732, + 1.048659411496106, 1.0730154944316674, 1.0986992590905857, + 1.1258164135986009, 1.1544842669978943, 1.184833362908442, + 1.217009397314603, 1.2511754798461228, 1.287514812536712, + 1.326233878832723, 1.3675662599582539, 1.411777227500661, + 1.459169302866857, 1.5100890297227016, 1.5649352798258847, + 1.6241695131835794, 1.6883285509131505, 1.7580406092704062, + 1.8340456094306077, 1.9172211551275689, 2.0086161135167564, + 2.1094945286246385, 2.22139377701127, 2.346202662531156, + 2.486267909203593, 2.644541877144861, 2.824791402350551, + 3.0318994541759925, 3.2723115884254845, 3.5547153325075804, + 3.891107790700307, 4.298537526449054, 4.802076008665048, + 5.440166215091329, 6.274908408039339, 7.413566756422303, + 9.058751453879703, 11.644627325175037, 16.300023088031555, + 27.163977662448232, 81.48784219222516, + }; +}; + +// Apply the DCT algorithm-intrinsic constants to DCTResampleScale. +template +constexpr float DCTTotalResampleScale(size_t x) { + return DCTResampleScales::kScales[x]; +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_SCALES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_test.cc b/third_party/jpeg-xl/lib/jxl/dct_test.cc new file mode 100644 index 000000000000..4b974baef090 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_test.cc @@ -0,0 +1,399 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dct_test.cc" +#include +#include +#include + +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/image.h" +#include "lib/jxl/test_utils.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Computes the in-place NxN DCT of block. +// Requires that block is HWY_ALIGN'ed. +// +// Performs ComputeTransposedScaledDCT and then transposes and scales it to +// obtain "vanilla" DCT. +template +void ComputeDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[N * N]; + ComputeTransposedScaledDCT()(DCTFrom(block, N), tmp_block, scratch_space); + + // Untranspose. + Transpose::Run(DCTFrom(tmp_block, N), DCTTo(block, N)); +} + +// Computes the in-place 8x8 iDCT of block. +// Requires that block is HWY_ALIGN'ed. +template +void ComputeIDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[N * N]; + // Untranspose. + Transpose::Run(DCTFrom(block, N), DCTTo(tmp_block, N)); + + ComputeTransposedScaledIDCT()(tmp_block, DCTTo(block, N), scratch_space); +} + +template +void TransposeTestT(float accuracy) { + constexpr size_t kBlockSize = N * N; + HWY_ALIGN float src[kBlockSize]; + DCTTo to_src(src, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + to_src.Write(y * N + x, y, x); + } + } + HWY_ALIGN float dst[kBlockSize]; + Transpose::Run(DCTFrom(src, N), DCTTo(dst, N)); + DCTFrom from_dst(dst, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + float expected = x * N + y; + float actual = from_dst.Read(y, x); + EXPECT_NEAR(expected, actual, accuracy) << "x = " << x << ", y = " << y; + } + } +} + +void TransposeTest() { + TransposeTestT<8>(1e-7f); + TransposeTestT<16>(1e-7f); + TransposeTestT<32>(1e-7f); +} + +template +void ColumnDctRoundtripT(float accuracy) { + constexpr size_t kBlockSize = N * N; + // Though we are only interested in single column result, dct.h has built-in + // limit on minimal number of columns processed. So, to be safe, we do + // regular 8x8 block transformation. On the bright side - we could check all + // 8 basis vectors at once. + HWY_ALIGN float block[kBlockSize]; + DCTTo to(block, N); + DCTFrom from(block, N); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + to.Write((i == j) ? 1.0f : 0.0f, i, j); + } + } + + // Running (I)DCT on the same memory block seems to trigger a compiler bug on + // ARMv7 with clang6. + HWY_ALIGN float tmp[kBlockSize]; + DCTTo to_tmp(tmp, N); + DCTFrom from_tmp(tmp, N); + + DCT1D()(from, to_tmp); + IDCT1D()(from_tmp, to); + + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + float expected = (i == j) ? 1.0f : 0.0f; + float actual = from.Read(i, j); + EXPECT_NEAR(expected, actual, accuracy) << " i=" << i << ", j=" << j; + } + } +} + +void ColumnDctRoundtrip() { + ColumnDctRoundtripT<8>(1e-6f); + ColumnDctRoundtripT<16>(1e-6f); + ColumnDctRoundtripT<32>(1e-6f); +} + +template +void TestDctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + DCTSlow(slow); + ComputeDCT(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy / N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template +void TestIdctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + IDCTSlow(slow); + ComputeIDCT(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy * N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template +void TestInverseT(float accuracy) { + ThreadPoolInternal pool(N < 32 ? 0 : 8); + enum { kBlockSize = N * N }; + RunOnPool( + &pool, 0, kBlockSize, ThreadPool::SkipInit(), + [accuracy](const int task, int /*thread*/) { + const size_t i = static_cast(task); + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + ComputeIDCT(x); + ComputeDCT(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + }, + "TestInverse"); +} + +void InverseTest() { + TestInverseT<8>(1e-6f); + TestInverseT<16>(1e-6f); + TestInverseT<32>(3e-6f); +} + +template +void TestDctTranspose(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + for (size_t j = 0; j < kBlockSize; ++j) { + // We check that = . + // That means (Me_j)_i = (M^\dagger{}e_i)_j + + // x := Me_j + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[j] = 1.0; + ComputeIDCT(x); + // y := M^\dagger{}e_i + HWY_ALIGN float y[kBlockSize] = {0.0f}; + y[i] = 1.0; + ComputeDCT(y); + + EXPECT_NEAR(x[i] / N, y[j] * N, accuracy) << "i = " << i << ", j = " << j; + } + } +} + +template +void TestSlowInverse(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + double x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + DCTSlow(x); + IDCTSlow(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + } +} + +template +void TestRectInverseT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + for (size_t i = 0; i < kBlockSize; ++i) { + HWY_ALIGN float x[kBlockSize] = {0.0f}; + HWY_ALIGN float out[kBlockSize] = {0.0f}; + x[i] = 1.0; + HWY_ALIGN float coeffs[kBlockSize] = {0.0f}; + HWY_ALIGN float scratch_space[kBlockSize * 2]; + + ComputeScaledDCT()(DCTFrom(x, COLS), coeffs, scratch_space); + ComputeScaledIDCT()(coeffs, DCTTo(out, COLS), scratch_space); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(out[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k << " ROWS = " << ROWS + << " COLS = " << COLS; + } + } +} + +void TestRectInverse() { + TestRectInverseT<16, 32>(1e-6f); + TestRectInverseT<8, 32>(1e-6f); + TestRectInverseT<8, 16>(1e-6f); + TestRectInverseT<4, 8>(1e-6f); + TestRectInverseT<2, 4>(1e-6f); + TestRectInverseT<1, 4>(1e-6f); + TestRectInverseT<1, 2>(1e-6f); + + TestRectInverseT<32, 16>(1e-6f); + TestRectInverseT<32, 8>(1e-6f); + TestRectInverseT<16, 8>(1e-6f); + TestRectInverseT<8, 4>(1e-6f); + TestRectInverseT<4, 2>(1e-6f); + TestRectInverseT<4, 1>(1e-6f); + TestRectInverseT<2, 1>(1e-6f); +} + +template +void TestRectTransposeT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + HWY_ALIGN float scratch_space[kBlockSize * 2]; + for (size_t px = 0; px < COLS; ++px) { + for (size_t py = 0; py < ROWS; ++py) { + HWY_ALIGN float x1[kBlockSize] = {0.0f}; + HWY_ALIGN float x2[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs1[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs2[kBlockSize] = {0.0f}; + x1[py * COLS + px] = 1; + x2[px * ROWS + py] = 1; + + constexpr size_t OUT_ROWS = ROWS < COLS ? ROWS : COLS; + constexpr size_t OUT_COLS = ROWS < COLS ? COLS : ROWS; + + ComputeScaledDCT()(DCTFrom(x1, COLS), coeffs1, scratch_space); + ComputeScaledDCT()(DCTFrom(x2, ROWS), coeffs2, scratch_space); + + for (size_t x = 0; x < OUT_COLS; ++x) { + for (size_t y = 0; y < OUT_ROWS; ++y) { + EXPECT_NEAR(coeffs1[y * OUT_COLS + x], coeffs2[y * OUT_COLS + x], + accuracy) + << " px = " << px << ", py = " << py << ", x = " << x + << ", y = " << y; + } + } + } + } +} + +void TestRectTranspose() { + TestRectTransposeT<16, 32>(1e-6f); + TestRectTransposeT<8, 32>(1e-6f); + TestRectTransposeT<8, 16>(1e-6f); + TestRectTransposeT<4, 8>(1e-6f); + TestRectTransposeT<2, 4>(1e-6f); + TestRectTransposeT<1, 4>(1e-6f); + TestRectTransposeT<1, 2>(1e-6f); + + // Identical to 8, 16 + // TestRectTranspose<16, 8>(1e-6f); +} + +void TestDctAccuracyShard(size_t shard) { + if (shard == 0) { + TestDctAccuracy<1>(1.1E-7f); + TestDctAccuracy<2>(1.1E-7f); + TestDctAccuracy<4>(1.1E-7f); + TestDctAccuracy<8>(1.1E-7f); + TestDctAccuracy<16>(1.3E-7f); + } + TestDctAccuracy<32>(1.1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestIdctAccuracyShard(size_t shard) { + if (shard == 0) { + TestIdctAccuracy<1>(1E-7f); + TestIdctAccuracy<2>(1E-7f); + TestIdctAccuracy<4>(1E-7f); + TestIdctAccuracy<8>(1E-7f); + TestIdctAccuracy<16>(1E-7f); + } + TestIdctAccuracy<32>(1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestDctTransposeShard(size_t shard) { + if (shard == 0) { + TestDctTranspose<8>(1E-6f); + TestDctTranspose<16>(1E-6f); + } + TestDctTranspose<32>(3E-6f, 32 * shard, 32 * (shard + 1)); +} + +void TestSlowInverseShard(size_t shard) { + if (shard == 0) { + TestSlowInverse<1>(1E-5f); + TestSlowInverse<2>(1E-5f); + TestSlowInverse<4>(1E-5f); + TestSlowInverse<8>(1E-5f); + TestSlowInverse<16>(1E-5f); + } + TestSlowInverse<32>(1E-5f, 32 * shard, 32 * (shard + 1)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class TransposeTest : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(TransposeTest); + +HWY_EXPORT_AND_TEST_P(TransposeTest, TransposeTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, InverseTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, ColumnDctRoundtrip); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectInverse); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectTranspose); + +// Tests in the DctShardedTest class are sharded for N=32. +class DctShardedTest : public ::hwy::TestWithParamTargetAndT {}; + +std::vector ShardRange(uint32_t n) { +#ifdef JXL_DISABLE_SLOW_TESTS + JXL_ASSERT(n > 6); + std::vector ret = {0, 1, 3, 5, n - 1}; +#else + std::vector ret(n); + std::iota(ret.begin(), ret.end(), 0); +#endif // JXL_DISABLE_SLOW_TESTS + return ret; +} + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(DctShardedTest, + ::testing::ValuesIn(ShardRange(32))); + +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestIdctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctTransposeShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestSlowInverseShard); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dct_util.h b/third_party/jpeg-xl/lib/jxl/dct_util.h new file mode 100644 index 000000000000..21c471c13354 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_util.h @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DCT_UTIL_H_ +#define LIB_JXL_DCT_UTIL_H_ + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +union ACPtr { + int32_t* ptr32; + int16_t* ptr16; + ACPtr() = default; + explicit ACPtr(int16_t* p) : ptr16(p) {} + explicit ACPtr(int32_t* p) : ptr32(p) {} +}; + +union ConstACPtr { + const int32_t* ptr32; + const int16_t* ptr16; + ConstACPtr() = default; + explicit ConstACPtr(const int16_t* p) : ptr16(p) {} + explicit ConstACPtr(const int32_t* p) : ptr32(p) {} +}; + +enum class ACType { k16 = 0, k32 = 1 }; + +class ACImage { + public: + virtual ~ACImage() = default; + virtual ACType Type() const = 0; + virtual ACPtr PlaneRow(size_t c, size_t y, size_t xbase) = 0; + virtual ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const = 0; + virtual size_t PixelsPerRow() const = 0; + virtual void ZeroFill() = 0; + virtual void ZeroFillPlane(size_t c) = 0; + virtual bool IsEmpty() const = 0; +}; + +template +class ACImageT final : public ACImage { + public: + ACImageT() = default; + ACImageT(size_t xsize, size_t ysize) { + static_assert( + std::is_same::value || std::is_same::value, + "ACImage must be either 32- or 16- bit"); + img_ = Image3(xsize, ysize); + } + ACType Type() const override { + return sizeof(T) == 2 ? ACType::k16 : ACType::k32; + } + ACPtr PlaneRow(size_t c, size_t y, size_t xbase) override { + return ACPtr(img_.PlaneRow(c, y) + xbase); + } + ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const override { + return ConstACPtr(img_.PlaneRow(c, y) + xbase); + } + + size_t PixelsPerRow() const override { return img_.PixelsPerRow(); } + + void ZeroFill() override { ZeroFillImage(&img_); } + + void ZeroFillPlane(size_t c) override { ZeroFillImage(&img_.Plane(c)); } + + bool IsEmpty() const override { + return img_.xsize() == 0 || img_.ysize() == 0; + } + + private: + Image3 img_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DCT_UTIL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_ans.cc b/third_party/jpeg-xl/lib/jxl/dec_ans.cc new file mode 100644 index 000000000000..f8f2aa507971 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_ans.cc @@ -0,0 +1,383 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_ans.h" + +#include + +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +// Decodes a number in the range [0..255], by reading 1 - 11 bits. +inline int DecodeVarLenUint8(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast(input->ReadFixedBits<3>()); + if (nbits == 0) { + return 1; + } else { + return static_cast(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +// Decodes a number in the range [0..65535], by reading 1 - 21 bits. +inline int DecodeVarLenUint16(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast(input->ReadFixedBits<4>()); + if (nbits == 0) { + return 1; + } else { + return static_cast(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +Status ReadHistogram(int precision_bits, std::vector* counts, + BitReader* input) { + int simple_code = input->ReadBits(1); + if (simple_code == 1) { + int i; + int symbols[2] = {0}; + int max_symbol = 0; + const int num_symbols = input->ReadBits(1) + 1; + for (i = 0; i < num_symbols; ++i) { + symbols[i] = DecodeVarLenUint8(input); + if (symbols[i] > max_symbol) max_symbol = symbols[i]; + } + counts->resize(max_symbol + 1); + if (num_symbols == 1) { + (*counts)[symbols[0]] = 1 << precision_bits; + } else { + if (symbols[0] == symbols[1]) { // corrupt data + return false; + } + (*counts)[symbols[0]] = input->ReadBits(precision_bits); + (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]]; + } + } else { + int is_flat = input->ReadBits(1); + if (is_flat == 1) { + int alphabet_size = DecodeVarLenUint8(input) + 1; + if (alphabet_size == 0) { + return JXL_FAILURE("Invalid alphabet size for flat histogram."); + } + *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits); + return true; + } + + uint32_t shift; + { + // TODO(veluca): speed up reading with table lookups. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = 0; + for (; log < upper_bound_log; log++) { + if (input->ReadFixedBits<1>() == 0) break; + } + shift = (input->ReadBits(log) | (1 << log)) - 1; + if (shift > ANS_LOG_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid shift value"); + } + } + + int length = DecodeVarLenUint8(input) + 3; + counts->resize(length); + int total_count = 0; + + static const uint8_t huff[128][2] = { + {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + }; + + std::vector logcounts(counts->size()); + int omit_log = -1; + int omit_pos = -1; + // This array remembers which symbols have an RLE length. + std::vector same(counts->size(), 0); + for (size_t i = 0; i < logcounts.size(); ++i) { + input->Refill(); // for PeekFixedBits + Advance + int idx = input->PeekFixedBits<7>(); + input->Consume(huff[idx][0]); + logcounts[i] = huff[idx][1]; + // The RLE symbol. + if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { + int rle_length = DecodeVarLenUint8(input); + same[i] = rle_length + 5; + i += rle_length + 3; + continue; + } + if (logcounts[i] > omit_log) { + omit_log = logcounts[i]; + omit_pos = i; + } + } + // Invalid input, e.g. due to invalid usage of RLE. + if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); + if (static_cast(omit_pos) + 1 < logcounts.size() && + logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid histogram."); + } + int prev = 0; + int numsame = 0; + for (size_t i = 0; i < logcounts.size(); ++i) { + if (same[i]) { + // RLE sequence, let this loop output the same count for the next + // iterations. + numsame = same[i] - 1; + prev = i > 0 ? (*counts)[i - 1] : 0; + } + if (numsame > 0) { + (*counts)[i] = prev; + numsame--; + } else { + int code = logcounts[i]; + // omit_pos may not be negative at this point (checked before). + if (i == static_cast(omit_pos)) { + continue; + } else if (code == 0) { + continue; + } else if (code == 1) { + (*counts)[i] = 1; + } else { + int bitcount = GetPopulationCountPrecision(code - 1, shift); + (*counts)[i] = (1 << (code - 1)) + + (input->ReadBits(bitcount) << (code - 1 - bitcount)); + } + } + total_count += (*counts)[i]; + } + (*counts)[omit_pos] = (1 << precision_bits) - total_count; + if ((*counts)[omit_pos] <= 0) { + // The histogram we've read sums to more than total_count (including at + // least 1 for the omitted value). + return JXL_FAILURE("Invalid histogram count."); + } + } + return true; +} + +} // namespace + +Status DecodeANSCodes(const size_t num_histograms, + const size_t max_alphabet_size, BitReader* in, + ANSCode* result) { + result->degenerate_symbols.resize(num_histograms, -1); + if (result->use_prefix_code) { + JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); + result->huffman_data.resize(num_histograms); + std::vector alphabet_sizes(num_histograms); + for (size_t c = 0; c < num_histograms; c++) { + alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; + if (alphabet_sizes[c] > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); + } + } + for (size_t c = 0; c < num_histograms; c++) { + if (alphabet_sizes[c] > 1) { + if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { + if (!in->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for huffman code"); + } + return JXL_FAILURE( + "Invalid huffman tree number %zu, alphabet size %u", c, + alphabet_sizes[c]); + } + } else { + // 0-bit codes does not requre extension tables. + result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); + } + for (const auto& h : result->huffman_data[c].table_) { + if (h.bits <= kHuffmanTableBits) { + result->UpdateMaxNumBits(c, h.value); + } + } + } + } else { + JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); + result->alias_tables = + AllocateArray(num_histograms * (1 << result->log_alpha_size) * + sizeof(AliasTable::Entry)); + AliasTable::Entry* alias_tables = + reinterpret_cast(result->alias_tables.get()); + for (size_t c = 0; c < num_histograms; ++c) { + std::vector counts; + if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { + return JXL_FAILURE("Invalid histogram bitstream."); + } + if (counts.size() > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %zu", counts.size()); + } + while (!counts.empty() && counts.back() == 0) { + counts.pop_back(); + } + for (size_t s = 0; s < counts.size(); s++) { + if (counts[s] != 0) { + result->UpdateMaxNumBits(c, s); + } + } + // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. + int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); + for (int s = 0; s < degenerate_symbol; ++s) { + if (counts[s] != 0) { + degenerate_symbol = -1; + break; + } + } + result->degenerate_symbols[c] = degenerate_symbol; + InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size, + alias_tables + c * (1 << result->log_alpha_size)); + } + } + return true; +} +Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, + BitReader* br) { + br->Refill(); + size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); + size_t msb_in_token = 0, lsb_in_token = 0; + if (split_exponent != log_alpha_size) { + // otherwise, msb/lsb don't matter. + size_t nbits = CeilLog2Nonzero(split_exponent + 1); + msb_in_token = br->ReadBits(nbits); + if (msb_in_token > split_exponent) { + // This could be invalid here already and we need to check this before + // we use its value to read more bits. + return JXL_FAILURE("Invalid HybridUintConfig"); + } + nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); + lsb_in_token = br->ReadBits(nbits); + } + if (lsb_in_token + msb_in_token > split_exponent) { + return JXL_FAILURE("Invalid HybridUintConfig"); + } + *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); + return true; +} + +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector* uint_config, + BitReader* br) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config->size(); i++) { + JXL_RETURN_IF_ERROR( + DecodeUintConfig(log_alpha_size, &(*uint_config)[i], br)); + } + return true; +} + +LZ77Params::LZ77Params() { Bundle::Init(this); } +Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); + if (!visitor->Conditional(enabled)) return true; + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), + BitsOffset(15, 8), 224, &min_symbol)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), + BitsOffset(8, 9), 3, &min_length)); + return true; +} + +void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { + HybridUintConfig* cfg = &uint_config[ctx]; + // LZ77 symbols use a different uint config. + if (lz77.enabled && lz77.nonserialized_distance_context != ctx && + symbol >= lz77.min_symbol) { + symbol -= lz77.min_symbol; + cfg = &lz77.length_uint_config; + } + size_t split_token = cfg->split_token; + size_t msb_in_token = cfg->msb_in_token; + size_t lsb_in_token = cfg->lsb_in_token; + size_t split_exponent = cfg->split_exponent; + if (symbol < split_token) { + max_num_bits = std::max(max_num_bits, split_exponent); + return; + } + uint32_t n_extra_bits = + split_exponent - (msb_in_token + lsb_in_token) + + ((symbol - split_token) >> (msb_in_token + lsb_in_token)); + size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; + max_num_bits = std::max(max_num_bits, total_bits); +} + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector* context_map, bool disallow_lz77) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); + if (code->lz77.enabled) { + num_contexts++; + JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, + &code->lz77.length_uint_config, br)); + } + if (code->lz77.enabled && disallow_lz77) { + return JXL_FAILURE("Using LZ77 when explicitly disallowed"); + } + size_t num_histograms = 1; + context_map->resize(num_contexts); + if (num_contexts > 1) { + JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); + } + code->lz77.nonserialized_distance_context = context_map->back(); + code->use_prefix_code = br->ReadFixedBits<1>(); + if (code->use_prefix_code) { + code->log_alpha_size = PREFIX_MAX_BITS; + } else { + code->log_alpha_size = br->ReadFixedBits<2>() + 5; + } + code->uint_config.resize(num_histograms); + JXL_RETURN_IF_ERROR( + DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); + const size_t max_alphabet_size = 1 << code->log_alpha_size; + JXL_RETURN_IF_ERROR( + DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); + // When using LZ77, flat codes might result in valid codestreams with + // histograms that potentially allow very large bit counts. + // TODO(veluca): in principle, a valid codestream might contain a histogram + // that could allow very large numbers of bits that is never used during ANS + // decoding. There's no benefit to doing that, though. + if (!code->lz77.enabled && code->max_num_bits > 32) { + // Just emit a warning as there are many opportunities for false positives. + JXL_WARNING("Histogram can represent numbers that are too large: %zu\n", + code->max_num_bits); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_ans.h b/third_party/jpeg-xl/lib/jxl/dec_ans.h new file mode 100644 index 000000000000..c357481eb15e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_ans.h @@ -0,0 +1,392 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_ANS_H_ +#define LIB_JXL_DEC_ANS_H_ + +// Library to decode the ANS population counts from the bit-stream and build a +// decoding table from them. + +#include +#include + +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_huffman.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +class ANSSymbolReader; + +// Experiments show that best performance is typically achieved for a +// split-exponent of 3 or 4. Trend seems to be that '4' is better +// for large-ish pictures, and '3' better for rather small-ish pictures. +// This is plausible - the more special symbols we have, the better +// statistics we need to get a benefit out of them. + +// Our hybrid-encoding scheme has dedicated tokens for the smallest +// (1 << split_exponents) numbers, and for the rest +// encodes (number of bits) + (msb_in_token sub-leading binary digits) + +// (lsb_in_token lowest binary digits) in the token, with the remaining bits +// then being encoded as data. +// +// Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. +// +// Numbers N in [0 .. 15]: +// These get represented as (token=N, bits=''). +// Numbers N >= 16: +// If n is such that 2**n <= N < 2**(n+1), +// and m = N - 2**n is the 'mantissa', +// these get represented as: +// (token=split_token + +// ((n - split_exponent) * 4) + +// (m >> (n - msb_in_token)), +// bits=m & (1 << (n - msb_in_token)) - 1) +// Specifically, we would get: +// N = 0 - 15: (token=N, nbits=0, bits='') +// N = 16 (10000): (token=16, nbits=2, bits='00') +// N = 17 (10001): (token=16, nbits=2, bits='01') +// N = 20 (10100): (token=17, nbits=2, bits='00') +// N = 24 (11000): (token=18, nbits=2, bits='00') +// N = 28 (11100): (token=19, nbits=2, bits='00') +// N = 32 (100000): (token=20, nbits=3, bits='000') +// N = 65535: (token=63, nbits=13, bits='1111111111111') +struct HybridUintConfig { + uint32_t split_exponent; + uint32_t split_token; + uint32_t msb_in_token; + uint32_t lsb_in_token; + JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, + uint32_t* JXL_RESTRICT nbits, + uint32_t* JXL_RESTRICT bits) const { + if (value < split_token) { + *token = value; + *nbits = 0; + *bits = 0; + } else { + uint32_t n = FloorLog2Nonzero(value); + uint32_t m = value - (1 << n); + *token = split_token + + ((n - split_exponent) << (msb_in_token + lsb_in_token)) + + ((m >> (n - msb_in_token)) << lsb_in_token) + + (m & ((1 << lsb_in_token) - 1)); + *nbits = n - msb_in_token - lsb_in_token; + *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); + } + } + + explicit HybridUintConfig(uint32_t split_exponent = 4, + uint32_t msb_in_token = 2, + uint32_t lsb_in_token = 0) + : split_exponent(split_exponent), + split_token(1 << split_exponent), + msb_in_token(msb_in_token), + lsb_in_token(lsb_in_token) { + JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); + } +}; + +struct LZ77Params : public Fields { + LZ77Params(); + const char* Name() const override { return "LZ77Params"; } + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + bool enabled; + + // Symbols above min_symbol use a special hybrid uint encoding and + // represent a length, to be added to min_length. + uint32_t min_symbol; + uint32_t min_length; + + // Not serialized by VisitFields. + HybridUintConfig length_uint_config{0, 0, 0}; + + size_t nonserialized_distance_context; +}; + +static constexpr size_t kWindowSize = 1 << 20; +static constexpr size_t kNumSpecialDistances = 120; +// Table of special distance codes from WebP lossless. +static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { + {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, + {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, + {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, + {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, + {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, + {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, + {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, + {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, + {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, + {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, + {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, + {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, + {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, + {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, + {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; + +struct ANSCode { + CacheAlignedUniquePtr alias_tables; + std::vector huffman_data; + std::vector uint_config; + std::vector degenerate_symbols; + bool use_prefix_code; + uint8_t log_alpha_size; // for ANS. + LZ77Params lz77; + // Maximum number of bits necessary to represent the result of a + // ReadHybridUint call done with this ANSCode. + size_t max_num_bits = 0; + void UpdateMaxNumBits(size_t ctx, size_t symbol); +}; + +class ANSSymbolReader { + public: + // Invalid symbol reader, to be overwritten. + ANSSymbolReader() = default; + ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, + size_t distance_multiplier = 0) + : alias_tables_( + reinterpret_cast(code->alias_tables.get())), + huffman_data_(code->huffman_data.data()), + use_prefix_code_(code->use_prefix_code), + configs(code->uint_config.data()) { + if (!use_prefix_code_) { + state_ = static_cast(br->ReadFixedBits<32>()); + log_alpha_size_ = code->log_alpha_size; + log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; + entry_size_minus_1_ = (1 << log_entry_size_) - 1; + } else { + state_ = (ANS_SIGNATURE << 16u); + } + if (!code->lz77.enabled) return; + // a std::vector incurs unacceptable decoding speed loss because of + // initialization. + lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); + lz77_window_ = reinterpret_cast(lz77_window_storage_.get()); + lz77_ctx_ = code->lz77.nonserialized_distance_context; + lz77_length_uint_ = code->lz77.length_uint_config; + lz77_threshold_ = code->lz77.min_symbol; + lz77_min_length_ = code->lz77.min_length; + num_special_distances_ = + distance_multiplier == 0 ? 0 : kNumSpecialDistances; + for (size_t i = 0; i < num_special_distances_; i++) { + int dist = kSpecialDistances[i][0]; + dist += static_cast(distance_multiplier) * kSpecialDistances[i][1]; + if (dist < 1) dist = 1; + special_distances_[i] = dist; + } + } + + JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + + const AliasTable::Entry* table = + &alias_tables_[histo_idx << log_alpha_size_]; + const AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; + +#if 1 + // Branchless version is about equally fast on SKX. + const uint32_t new_state = + (state_ << 16u) | static_cast(br->PeekFixedBits<16>()); + const bool normalize = state_ < (1u << 16u); + state_ = normalize ? new_state : state_; + br->Consume(normalize ? 16 : 0); +#else + if (JXL_UNLIKELY(state_ < (1u << 16u))) { + state_ = (state_ << 16u) | br->PeekFixedBits<16>(); + br->Consume(16); + } +#endif + const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); + AliasTable::Prefetch(table, next_res, log_entry_size_); + + return symbol.value; + } + + JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + return huffman_data_[histo_idx].ReadSymbol(br); + } + + JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + // TODO(veluca): hoist if in hotter loops. + if (JXL_UNLIKELY(use_prefix_code_)) { + return ReadSymbolHuffWithoutRefill(histo_idx, br); + } + return ReadSymbolANSWithoutRefill(histo_idx, br); + } + + JXL_INLINE size_t ReadSymbol(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + br->Refill(); + return ReadSymbolWithoutRefill(histo_idx, br); + } + + bool CheckANSFinalState() { return state_ == (ANS_SIGNATURE << 16u); } + + template + static JXL_INLINE uint32_t ReadHybridUintConfig( + const HybridUintConfig& config, size_t token, BitReader* br) { + size_t split_token = config.split_token; + size_t msb_in_token = config.msb_in_token; + size_t lsb_in_token = config.lsb_in_token; + size_t split_exponent = config.split_exponent; + // Fast-track version of hybrid integer decoding. + if (token < split_token) return token; + uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + + ((token - split_token) >> (msb_in_token + lsb_in_token)); + // Max amount of bits for ReadBits is 32 and max valid left shift is 29 + // bits. However, for speed no error is propagated here, instead limit the + // nbits size. If nbits > 29, the code stream is invalid, but no error is + // returned. + // Note that in most cases we will emit an error if the histogram allows + // representing numbers that would cause invalid shifts, but we need to + // keep this check as when LZ77 is enabled it might make sense to have an + // histogram that could in principle cause invalid shifts. + nbits &= 31u; + uint32_t low = token & ((1 << lsb_in_token) - 1); + token >>= lsb_in_token; + const size_t bits = br->PeekBits(nbits); + br->Consume(nbits); + size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) + << nbits) | + bits) + << lsb_in_token) | + low; + // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not + // fit uint32_t + return static_cast(ret); + } + + // Takes a *clustered* idx. + size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { + if (JXL_UNLIKELY(num_to_copy_ > 0)) { + size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; + num_to_copy_--; + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + num_to_copy_ = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_; + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + // Distance code. + size_t token = ReadSymbolWithoutRefill(lz77_ctx_, br); + size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], token, br); + if (JXL_LIKELY(distance < num_special_distances_)) { + distance = special_distances_[distance]; + } else { + distance = distance + 1 - num_special_distances_; + } + if (JXL_UNLIKELY(distance > num_decoded_)) { + distance = num_decoded_; + } + if (JXL_UNLIKELY(distance > kWindowSize)) { + distance = kWindowSize; + } + copy_pos_ = num_decoded_ - distance; + if (JXL_UNLIKELY(distance == 0)) { + JXL_DASSERT(lz77_window_ != nullptr); + // distance 0 -> num_decoded_ == copy_pos_ == 0 + size_t to_fill = std::min(num_to_copy_, kWindowSize); + memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); + } + // TODO(eustas): overflow; mark BitReader as unhealthy + if (num_to_copy_ < lz77_min_length_) return 0; + return ReadHybridUintClustered(ctx, br); // will trigger a copy. + } + size_t ret = ReadHybridUintConfig(configs[ctx], token, br); + if (lz77_window_) lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + + JXL_INLINE size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, + const std::vector& context_map) { + return ReadHybridUintClustered(context_map[ctx], br); + } + + // ctx is a *clustered* context! + // This function will modify the ANS state as if `count` symbols have been + // decoded. + bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { + // TODO(veluca): No optimization for Huffman mode yet. + if (use_prefix_code_) return false; + // TODO(eustas): propagate "degenerate_symbol" to simplify this method. + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; + AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + if (symbol.freq != ANS_TAB_SIZE) return false; + if (configs[ctx].split_token <= symbol.value) return false; + if (symbol.value >= lz77_threshold_) return false; + *value = symbol.value; + if (lz77_window_) { + for (size_t i = 0; i < count; i++) { + lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; + } + } + return true; + } + + private: + const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned + const HuffmanDecodingData* huffman_data_; + bool use_prefix_code_; + uint32_t state_ = ANS_SIGNATURE << 16u; + const HybridUintConfig* JXL_RESTRICT configs; + uint32_t log_alpha_size_; + uint32_t log_entry_size_; + uint32_t entry_size_minus_1_; + + // LZ77 structures and constants. + static constexpr size_t kWindowMask = kWindowSize - 1; + CacheAlignedUniquePtr lz77_window_storage_; + uint32_t* lz77_window_ = nullptr; + uint32_t num_decoded_ = 0; + uint32_t num_to_copy_ = 0; + uint32_t copy_pos_ = 0; + uint32_t lz77_ctx_ = 0; + uint32_t lz77_min_length_ = 0; + uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. + HybridUintConfig lz77_length_uint_; + uint32_t special_distances_[kNumSpecialDistances]; + uint32_t num_special_distances_; +}; + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector* context_map, + bool disallow_lz77 = false); + +// Exposed for tests. +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector* uint_config, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_DEC_ANS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h b/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h new file mode 100644 index 000000000000..d4f456881351 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h @@ -0,0 +1,344 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_BIT_READER_H_ +#define LIB_JXL_DEC_BIT_READER_H_ + +// Bounds-checked bit reader; 64-bit buffer with support for deferred refills +// and switching to reading byte-aligned words. + +#include +#include +#include // memcpy + +#ifdef __BMI2__ +#include +#endif + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" + +namespace jxl { + +// Reads bits previously written to memory by BitWriter. Uses unaligned 8-byte +// little-endian loads. +class BitReader { + public: + static constexpr size_t kMaxBitsPerCall = 56; + + // Constructs an invalid BitReader, to be overwritten before usage. + BitReader() + : buf_(0), + bits_in_buf_(0), + next_byte_{nullptr}, + end_minus_8_{nullptr}, + first_byte_(nullptr) {} + BitReader(const BitReader&) = delete; + + // bytes need not be aligned nor padded! + template + explicit BitReader(const ArrayLike& bytes) + : buf_(0), + bits_in_buf_(0), + next_byte_(bytes.data()), + // Assumes first_byte_ >= 8. + end_minus_8_(bytes.data() - 8 + bytes.size()), + first_byte_(bytes.data()) { + Refill(); + } + ~BitReader() { + // Close() must be called before destroying an initialized bit reader. + // Invalid bit readers will have a nullptr in first_byte_. + JXL_ASSERT(close_called_ || !first_byte_); + } + + // Move operator needs to invalidate the other BitReader such that it is + // irrelevant if we call Close() on it or not. + BitReader& operator=(BitReader&& other) noexcept { + // Ensure the current instance was already closed, before we overwrite it + // with other. + JXL_ASSERT(close_called_ || !first_byte_); + + JXL_DASSERT(!other.close_called_); + buf_ = other.buf_; + bits_in_buf_ = other.bits_in_buf_; + next_byte_ = other.next_byte_; + end_minus_8_ = other.end_minus_8_; + first_byte_ = other.first_byte_; + overread_bytes_ = other.overread_bytes_; + close_called_ = other.close_called_; + + other.first_byte_ = nullptr; + other.next_byte_ = nullptr; + return *this; + } + BitReader& operator=(const BitReader& other) = delete; + + // For time-critical reads, refills can be shared by multiple reads. + // Based on variant 4 (plus bounds-checking), see + // fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/ + JXL_INLINE void Refill() { + if (JXL_UNLIKELY(next_byte_ > end_minus_8_)) { + BoundsCheckedRefill(); + } else { + // It's safe to load 64 bits; insert valid (possibly nonzero) bits above + // bits_in_buf_. The shift requires bits_in_buf_ < 64. + buf_ |= LoadLE64(next_byte_) << bits_in_buf_; + + // Advance by bytes fully absorbed into the buffer. + next_byte_ += (63 - bits_in_buf_) >> 3; + + // We absorbed a multiple of 8 bits, so the lower 3 bits of bits_in_buf_ + // must remain unchanged, otherwise the next refill's shifted bits will + // not align with buf_. Set the three upper bits so the result >= 56. + bits_in_buf_ |= 56; + JXL_DASSERT(56 <= bits_in_buf_ && bits_in_buf_ < 64); + } + } + + // Returns the bits that would be returned by Read without calling Advance(). + // It is legal to PEEK at more bits than present in the bitstream (required + // by Huffman), and those bits will be zero. + template + JXL_INLINE uint64_t PeekFixedBits() const { + static_assert(N <= kMaxBitsPerCall, "Reading too many bits in one call."); + JXL_DASSERT(!close_called_); + return buf_ & ((1ULL << N) - 1); + } + + JXL_INLINE uint64_t PeekBits(size_t nbits) const { + JXL_DASSERT(nbits <= kMaxBitsPerCall); + JXL_DASSERT(!close_called_); + + // Slightly faster but requires BMI2. It is infeasible to make the many + // callers reside between begin/end_target, especially because only the + // callers in dec_ans are time-critical. Therefore only enabled if the + // entire binary is compiled for (and thus requires) BMI2. +#if defined(__BMI2__) && defined(__x86_64__) + return _bzhi_u64(buf_, nbits); +#else + const uint64_t mask = (1ULL << nbits) - 1; + return buf_ & mask; +#endif + } + + // Removes bits from the buffer. Need not match the previous Peek size, but + // the buffer must contain at least num_bits (this prevents consuming more + // than the total number of bits). + JXL_INLINE void Consume(size_t num_bits) { + JXL_DASSERT(!close_called_); + JXL_DASSERT(bits_in_buf_ >= num_bits); +#ifdef JXL_CRASH_ON_ERROR + // When JXL_CRASH_ON_ERROR is defined, it is a fatal error to read more bits + // than available in the stream. A non-zero overread_bytes_ implies that + // next_byte_ is already at the end of the stream, so we don't need to + // check that. + JXL_ASSERT(bits_in_buf_ >= num_bits + overread_bytes_ * kBitsPerByte); +#endif + bits_in_buf_ -= num_bits; + buf_ >>= num_bits; + } + + JXL_INLINE uint64_t ReadBits(size_t nbits) { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekBits(nbits); + Consume(nbits); + return bits; + } + + template + JXL_INLINE uint64_t ReadFixedBits() { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekFixedBits(); + Consume(N); + return bits; + } + + // Equivalent to calling ReadFixedBits(1) `skip` times, but much faster. + // `skip` is typically large. + void SkipBits(size_t skip) { + JXL_DASSERT(!close_called_); + // Buffer is large enough - don't zero buf_ below. + if (JXL_UNLIKELY(skip <= bits_in_buf_)) { + Consume(skip); + return; + } + + // First deduct what we can satisfy from the buffer + skip -= bits_in_buf_; + bits_in_buf_ = 0; + // Not enough to call Advance - that may leave some bits in the buffer + // which were previously ABOVE bits_in_buf. + buf_ = 0; + + // Skip whole bytes + const size_t whole_bytes = skip / kBitsPerByte; + next_byte_ += whole_bytes; + skip %= kBitsPerByte; + + Refill(); + Consume(skip); + } + + size_t TotalBitsConsumed() const { + const size_t bytes_read = static_cast(next_byte_ - first_byte_); + return (bytes_read + overread_bytes_) * kBitsPerByte - bits_in_buf_; + } + + Status JumpToByteBoundary() { + const size_t remainder = TotalBitsConsumed() % kBitsPerByte; + if (remainder == 0) return true; + if (JXL_UNLIKELY(ReadBits(kBitsPerByte - remainder) != 0)) { + return JXL_FAILURE("Non-zero padding bits"); + } + return true; + } + + // For interoperability with other bitreaders (for resuming at + // non-byte-aligned positions). + const uint8_t* FirstByte() const { return first_byte_; } + size_t TotalBytes() const { + return static_cast(end_minus_8_ + 8 - first_byte_); + } + + // Returns span of the remaining (unconsumed) bytes, e.g. for passing to + // external decoders such as Brotli. + Span GetSpan() const { + JXL_DASSERT(first_byte_ != nullptr); + JXL_ASSERT(TotalBitsConsumed() % kBitsPerByte == 0); + const size_t offset = TotalBitsConsumed() / kBitsPerByte; // no remainder + JXL_ASSERT(offset <= TotalBytes()); + return Span(first_byte_ + offset, TotalBytes() - offset); + } + + // Returns whether all the bits read so far have been within the input bounds. + // When reading past the EOF, the Read*() and Consume() functions return zeros + // but flag a failure when calling Close() without checking this function. + Status AllReadsWithinBounds() { + // Mark up to which point the user checked the out of bounds condition. If + // the user handles the condition at higher level (e.g. fetch more bytes + // from network, return a custom JXL_FAILURE, ...), Close() should not + // output a debug error (which would break tests with JXL_CRASH_ON_ERROR + // even when legitimately handling the situation at higher level). This is + // used by Bundle::CanRead. + checked_out_of_bounds_bits_ = TotalBitsConsumed(); + if (TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return false; + } + return true; + } + + // Close the bit reader and return whether all the previous reads were + // successful. Close must be called once. + Status Close() { + JXL_DASSERT(!close_called_); + close_called_ = true; + if (!first_byte_) return true; + if (TotalBitsConsumed() > checked_out_of_bounds_bits_ && + TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return JXL_FAILURE("Read more bits than available in the bit_reader"); + } + return true; + } + + private: + // Separate function avoids inlining this relatively cold code into callers. + JXL_NOINLINE void BoundsCheckedRefill() { + PROFILER_FUNC; + const uint8_t* end = end_minus_8_ + 8; + + // Read whole bytes until we have [56, 64) bits (same as LoadLE64) + for (; bits_in_buf_ < 64 - kBitsPerByte; bits_in_buf_ += kBitsPerByte) { + if (next_byte_ >= end) break; + buf_ |= static_cast(*next_byte_++) << bits_in_buf_; + } + JXL_DASSERT(bits_in_buf_ < 64); + + // Add extra bytes as 0 at the end of the stream in the bit_buffer_. If + // these bits are read, Close() will return a failure. + size_t extra_bytes = (63 - bits_in_buf_) / kBitsPerByte; + overread_bytes_ += extra_bytes; + bits_in_buf_ += extra_bytes * kBitsPerByte; + + JXL_DASSERT(bits_in_buf_ < 64); + JXL_DASSERT(bits_in_buf_ >= 56); + } + + JXL_NOINLINE uint32_t BoundsCheckedReadByteAlignedWord() { + if (next_byte_ + 1 < end_minus_8_ + 8) { + uint32_t ret = LoadLE16(next_byte_); + next_byte_ += 2; + return ret; + } + overread_bytes_ += 2; + return 0; + } + + uint64_t buf_; + size_t bits_in_buf_; // [0, 64) + const uint8_t* JXL_RESTRICT next_byte_; + const uint8_t* end_minus_8_; // for refill bounds check + const uint8_t* first_byte_; // for GetSpan + + // Number of bytes past the end that were loaded into the buf_. These bytes + // are not read from memory, but instead assumed 0. It is an error (likely due + // to an invalid stream) to Consume() more bits than specified in the range + // passed to the constructor. + uint64_t overread_bytes_{0}; + bool close_called_{false}; + + uint64_t checked_out_of_bounds_bits_{0}; +}; + +// Closes a BitReader when the BitReaderScopedCloser goes out of scope. When +// closing the bit reader, if the status result was failure it sets this failure +// to the passed variable pointer. Typical usage. +// +// Status ret = true; +// { +// BitReader reader(...); +// BitReaderScopedCloser reader_closer(&reader, &ret); +// +// // ... code that can return errors here ... +// } +// // ... more code that doesn't use the BitReader. +// return ret; + +class BitReaderScopedCloser { + public: + BitReaderScopedCloser(BitReader* reader, Status* status) + : reader_(reader), status_(status) { + JXL_DASSERT(reader_ != nullptr); + JXL_DASSERT(status_ != nullptr); + } + ~BitReaderScopedCloser() { + Status close_ret = reader_->Close(); + if (!close_ret) *status_ = close_ret; + } + BitReaderScopedCloser(const BitReaderScopedCloser&) = delete; + + private: + BitReader* reader_; + Status* status_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_BIT_READER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_cache.h b/third_party/jpeg-xl/lib/jxl/dec_cache.h new file mode 100644 index 000000000000..a339e602ae37 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_cache.h @@ -0,0 +1,360 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_CACHE_H_ +#define LIB_JXL_DEC_CACHE_H_ + +#include + +#include // HWY_ALIGN_MAX + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/dec_upsample.h" +#include "lib/jxl/filters.h" +#include "lib/jxl/image.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +// Per-frame decoder state. All the images here should be accessed through a +// group rect (either with block units or pixel units). +struct PassesDecoderState { + PassesSharedState shared_storage; + // Allows avoiding copies for encoder loop. + const PassesSharedState* JXL_RESTRICT shared = &shared_storage; + + // Upsampler for the current frame. + Upsampler upsampler; + + // DC upsampler + Upsampler dc_upsampler; + + // Storage for RNG output for noise synthesis. + Image3F noise; + + // Storage for pre-color-transform output for displayed + // save_before_color_transform frames. + Image3F pre_color_transform_frame; + // Non-empty (contains originals) if extra-channels were cropped. + std::vector pre_color_transform_ec; + + // For ANS decoding. + std::vector code; + std::vector> context_map; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_dm_multiplier; + float b_dm_multiplier; + + // Decoded image. + Image3F decoded; + std::vector extra_channels; + + // Borders between groups. Only allocated if `decoded` is *not* allocated. + // We also store the extremal borders for simplicity. Horizontal borders are + // stored in an image as wide as the main frame, in top-to-bottom order (top + // border of a group first, followed by the bottom border, followed by top + // border of the next group). Vertical borders are similarly stored. + Image3F borders_horizontal; + Image3F borders_vertical; + + // RGB8 output buffer. If not nullptr, image data will be written to this + // buffer instead of being written to the output ImageBundle. The image data + // is assumed to have the stride given by `rgb_stride`, hence row `i` starts + // at position `i * rgb_stride`. + uint8_t* rgb_output; + size_t rgb_stride = 0; + + // Whether to use int16 float-XYB-to-uint8-srgb conversion. + bool fast_xyb_srgb8_conversion; + + // If true, rgb_output is RGBA using 4 instead of 3 bytes per pixel. + bool rgb_output_is_rgba; + + // Seed for noise, to have different noise per-frame. + size_t noise_seed = 0; + + // Keep track of the transform types used. + std::atomic used_acs{0}; + + // Storage for coefficients if in "accumulate" mode. + std::unique_ptr coefficients = make_unique>(0, 0); + + // Filter application pipeline used by ApplyImageFeatures. One entry is needed + // per thread. + std::vector filter_pipelines; + + // Input weights used by the filters. These are shared from multiple threads + // but are read-only for the filter application. + FilterWeights filter_weights; + + // Manages the status of borders. + GroupBorderAssigner group_border_assigner; + + // TODO(veluca): this should eventually become "iff no global modular + // transform was applied". + bool EagerFinalizeImageRect() const { + return shared->frame_header.chroma_subsampling.Is444() && + shared->frame_header.encoding == FrameEncoding::kVarDCT && + shared->frame_header.nonserialized_metadata->m.extra_channel_info + .empty(); + } + + // Amount of padding that will be accessed, in all directions, outside a rect + // during a call to FinalizeImageRect(). + size_t FinalizeRectPadding() const { + // TODO(veluca): add YCbCr upsampling here too. + size_t padding = shared->frame_header.loop_filter.Padding(); + padding += shared->frame_header.upsampling == 1 ? 0 : 2; + JXL_DASSERT(padding <= kMaxFinalizeRectPadding); + return padding; + } + + // Storage for intermediate data during FinalizeRect steps. + // TODO(veluca): these buffers are larger than strictly necessary. + std::vector filter_input_storage; + std::vector padded_upsampling_input_storage; + std::vector upsampling_input_storage; + // We keep four arrays, one per upsampling level, to reduce memory usage in + // the common case of no upsampling. + std::vector output_pixel_data_storage[4] = {}; + + // Buffer for decoded pixel data for a group. + std::vector group_data; + static constexpr size_t kGroupDataYBorder = kMaxFinalizeRectPadding * 2; + static constexpr size_t kGroupDataXBorder = + RoundUpToBlockDim(kMaxFinalizeRectPadding) * 2 + kBlockDim; + + void EnsureStorage(size_t num_threads) { + // We need one filter_storage per thread, ensure we have at least that many. + if (shared->frame_header.loop_filter.epf_iters != 0 || + shared->frame_header.loop_filter.gab) { + if (filter_pipelines.size() < num_threads) { + filter_pipelines.resize(num_threads); + } + } + // We allocate filter_input_storage unconditionally to ensure that the image + // is allocated if we need it for DC upsampling. + for (size_t _ = filter_input_storage.size(); _ < num_threads; _++) { + // Extra padding along the x dimension to ensure memory accesses don't + // load out-of-bounds pixels. + filter_input_storage.emplace_back( + kApplyImageFeaturesTileDim + 2 * kGroupDataXBorder, + kApplyImageFeaturesTileDim + 2 * kGroupDataYBorder); + } + if (shared->frame_header.upsampling != 1) { + for (size_t _ = upsampling_input_storage.size(); _ < num_threads; _++) { + // At this point, we only need up to 2 pixels of border per side for + // upsampling, but we add an extra border for aligned access. + upsampling_input_storage.emplace_back( + kApplyImageFeaturesTileDim + 2 * kBlockDim, + kApplyImageFeaturesTileDim + 4); + padded_upsampling_input_storage.emplace_back( + kApplyImageFeaturesTileDim + 2 * kBlockDim, + kApplyImageFeaturesTileDim + 4); + } + } + for (size_t _ = group_data.size(); _ < num_threads; _++) { + group_data.emplace_back(kGroupDim + 2 * kGroupDataXBorder, + kGroupDim + 2 * kGroupDataYBorder); +#if MEMORY_SANITIZER + // Avoid errors due to loading vectors on the outermost padding. + ZeroFillImage(&group_data.back()); +#endif + } + if (rgb_output) { + size_t log2_upsampling = CeilLog2Nonzero(shared->frame_header.upsampling); + for (size_t _ = output_pixel_data_storage[log2_upsampling].size(); + _ < num_threads; _++) { + output_pixel_data_storage[log2_upsampling].emplace_back( + kApplyImageFeaturesTileDim << log2_upsampling, + kApplyImageFeaturesTileDim << log2_upsampling); + } + } + } + + // Information for colour conversions. + OutputEncodingInfo output_encoding_info; + + // Initializes decoder-specific structures using information from *shared. + void Init() { + x_dm_multiplier = + std::pow(1 / (1.25f), shared->frame_header.x_qm_scale - 2.0f); + b_dm_multiplier = + std::pow(1 / (1.25f), shared->frame_header.b_qm_scale - 2.0f); + + rgb_output = nullptr; + rgb_output_is_rgba = false; + fast_xyb_srgb8_conversion = false; + used_acs = 0; + + group_border_assigner.Init(shared->frame_dim); + const LoopFilter& lf = shared->frame_header.loop_filter; + filter_weights.Init(lf, shared->frame_dim); + for (auto& fp : filter_pipelines) { + // De-initialize FilterPipelines. + fp.num_filters = 0; + } + } + + // Initialize the decoder state after all of DC is decoded. + void InitForAC(ThreadPool* pool) { + shared_storage.coeff_order_size = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if (((1 << o) & used_acs) == 0) continue; + uint8_t ord = kStrategyOrder[o]; + shared_storage.coeff_order_size = + std::max(kCoeffOrderOffset[3 * (ord + 1)] * kDCTBlockSize, + shared_storage.coeff_order_size); + } + size_t sz = shared_storage.frame_header.passes.num_passes * + shared_storage.coeff_order_size; + if (sz > shared_storage.coeff_orders.size()) { + shared_storage.coeff_orders.resize(sz); + } + if (shared->frame_header.flags & FrameHeader::kNoise) { + noise = Image3F(shared->frame_dim.xsize_upsampled_padded, + shared->frame_dim.ysize_upsampled_padded); + size_t num_x_groups = DivCeil(noise.xsize(), kGroupDim); + size_t num_y_groups = DivCeil(noise.ysize(), kGroupDim); + PROFILER_ZONE("GenerateNoise"); + auto generate_noise = [&](int group_index, int _) { + size_t gx = group_index % num_x_groups; + size_t gy = group_index / num_x_groups; + Rect rect(gx * kGroupDim, gy * kGroupDim, kGroupDim, kGroupDim, + noise.xsize(), noise.ysize()); + RandomImage3(noise_seed + group_index, rect, &noise); + }; + RunOnPool(pool, 0, num_x_groups * num_y_groups, ThreadPool::SkipInit(), + generate_noise, "Generate noise"); + { + PROFILER_ZONE("High pass noise"); + // 4 * (1 - box kernel) + WeightsSymmetric5 weights{{HWY_REP4(-3.84)}, {HWY_REP4(0.16)}, + {HWY_REP4(0.16)}, {HWY_REP4(0.16)}, + {HWY_REP4(0.16)}, {HWY_REP4(0.16)}}; + // TODO(veluca): avoid copy. + // TODO(veluca): avoid having a full copy of the image in main memory. + ImageF noise_tmp(noise.xsize(), noise.ysize()); + for (size_t c = 0; c < 3; c++) { + Symmetric5(noise.Plane(c), Rect(noise), weights, pool, &noise_tmp); + std::swap(noise.Plane(c), noise_tmp); + } + noise_seed += shared->frame_dim.num_groups; + } + } + EnsureBordersStorage(); + if (!EagerFinalizeImageRect()) { + // decoded must be padded to a multiple of kBlockDim rows since the last + // rows may be used by the filters even if they are outside the frame + // dimension. + decoded = Image3F(shared->frame_dim.xsize_padded, + shared->frame_dim.ysize_padded); + } +#if MEMORY_SANITIZER + // Avoid errors due to loading vectors on the outermost padding. + ZeroFillImage(&decoded); +#endif + } + + void EnsureBordersStorage() { + if (!EagerFinalizeImageRect()) return; + size_t padding = FinalizeRectPadding(); + size_t bordery = 2 * padding; + size_t borderx = padding + group_border_assigner.PaddingX(padding); + Rect horizontal = Rect(0, 0, shared->frame_dim.xsize_padded, + bordery * shared->frame_dim.ysize_groups * 2); + if (!SameSize(horizontal, borders_horizontal)) { + borders_horizontal = Image3F(horizontal.xsize(), horizontal.ysize()); + } + Rect vertical = Rect(0, 0, borderx * shared->frame_dim.xsize_groups * 2, + shared->frame_dim.ysize_padded); + if (!SameSize(vertical, borders_vertical)) { + borders_vertical = Image3F(vertical.xsize(), vertical.ysize()); + } + } +}; + +// Temp images required for decoding a single group. Reduces memory allocations +// for large images because we only initialize min(#threads, #groups) instances. +struct GroupDecCache { + void InitOnce(size_t num_passes, size_t used_acs) { + PROFILER_FUNC; + + for (size_t i = 0; i < num_passes; i++) { + if (num_nzeroes[i].xsize() == 0) { + // Allocate enough for a whole group - partial groups on the + // right/bottom border just use a subset. The valid size is passed via + // Rect. + + num_nzeroes[i] = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } + } + size_t max_block_area = 0; + + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + AcStrategy acs = AcStrategy::FromRawStrategy(o); + if ((used_acs & (1 << o)) == 0) continue; + size_t area = + acs.covered_blocks_x() * acs.covered_blocks_y() * kDCTBlockSize; + max_block_area = std::max(area, max_block_area); + } + + if (max_block_area > max_block_area_) { + max_block_area_ = max_block_area; + // We need 3x float blocks for dequantized coefficients and 1x for scratch + // space for transforms. + float_memory_ = hwy::AllocateAligned(max_block_area_ * 4); + // We need 3x int32 or int16 blocks for quantized coefficients. + int32_memory_ = hwy::AllocateAligned(max_block_area_ * 3); + int16_memory_ = hwy::AllocateAligned(max_block_area_ * 3); + } + + dec_group_block = float_memory_.get(); + scratch_space = dec_group_block + max_block_area_ * 3; + dec_group_qblock = int32_memory_.get(); + dec_group_qblock16 = int16_memory_.get(); + } + + // Scratch space used by DecGroupImpl(). + float* dec_group_block; + int32_t* dec_group_qblock; + int16_t* dec_group_qblock16; + + // For TransformToPixels. + float* scratch_space; + // Note that scratch_space is never used at the same time as dec_group_qblock. + // Moreover, only one of dec_group_qblock16 is ever used. + // TODO(veluca): figure out if we can save allocations. + + // AC decoding + Image3I num_nzeroes[kMaxNumPasses]; + + private: + hwy::AlignedFreeUniquePtr float_memory_; + hwy::AlignedFreeUniquePtr int32_memory_; + hwy::AlignedFreeUniquePtr int16_memory_; + size_t max_block_area_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_CACHE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_context_map.cc b/third_party/jpeg-xl/lib/jxl/dec_context_map.cc new file mode 100644 index 000000000000..d81caad6c2bd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_context_map.cc @@ -0,0 +1,114 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_context_map.h" + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/entropy_coder.h" + +namespace jxl { + +namespace { + +void MoveToFront(uint8_t* v, uint8_t index) { + uint8_t value = v[index]; + uint8_t i = index; + for (; i; --i) v[i] = v[i - 1]; + v[0] = value; +} + +void InverseMoveToFrontTransform(uint8_t* v, int v_len) { + uint8_t mtf[256]; + int i; + for (i = 0; i < 256; ++i) { + mtf[i] = static_cast(i); + } + for (i = 0; i < v_len; ++i) { + uint8_t index = v[i]; + v[i] = mtf[index]; + if (index) MoveToFront(mtf, index); + } +} + +bool VerifyContextMap(const std::vector& context_map, + const size_t num_htrees) { + std::vector have_htree(num_htrees); + size_t num_found = 0; + for (const uint8_t htree : context_map) { + if (htree >= num_htrees) { + return JXL_FAILURE("Invalid histogram index in context map."); + } + if (!have_htree[htree]) { + have_htree[htree] = true; + ++num_found; + } + } + if (num_found != num_htrees) { + return JXL_FAILURE("Incomplete context map."); + } + return true; +} + +} // namespace + +bool DecodeContextMap(std::vector* context_map, size_t* num_htrees, + BitReader* input) { + bool is_simple = input->ReadFixedBits<1>(); + if (is_simple) { + int bits_per_entry = input->ReadFixedBits<2>(); + if (bits_per_entry != 0) { + for (size_t i = 0; i < context_map->size(); i++) { + (*context_map)[i] = input->ReadBits(bits_per_entry); + } + } else { + std::fill(context_map->begin(), context_map->end(), 0); + } + } else { + bool use_mtf = input->ReadFixedBits<1>(); + ANSCode code; + std::vector dummy_ctx_map; + // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't + // make sense in non-malicious bitstreams, and could cause a stack overflow + // in malicious bitstreams by making every context map require its own + // context map. + JXL_RETURN_IF_ERROR( + DecodeHistograms(input, 1, &code, &dummy_ctx_map, + /*disallow_lz77=*/context_map->size() <= 2)); + ANSSymbolReader reader(&code, input); + size_t i = 0; + while (i < context_map->size()) { + uint32_t sym = reader.ReadHybridUint(0, input, dummy_ctx_map); + if (sym >= kMaxClusters) { + return JXL_FAILURE("Invalid cluster ID"); + } + (*context_map)[i] = sym; + i++; + } + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid context map"); + } + if (use_mtf) { + InverseMoveToFrontTransform(context_map->data(), context_map->size()); + } + } + *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1; + return VerifyContextMap(*context_map, *num_htrees); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_context_map.h b/third_party/jpeg-xl/lib/jxl/dec_context_map.h new file mode 100644 index 000000000000..0e30d3972e1d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_context_map.h @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_CONTEXT_MAP_H_ +#define LIB_JXL_DEC_CONTEXT_MAP_H_ + +#include +#include + +#include + +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +// Context map uses uint8_t. +constexpr size_t kMaxClusters = 256; + +// Reads the context map from the bit stream. On calling this function, +// context_map->size() must be the number of possible context ids. +// Sets *num_htrees to the number of different histogram ids in +// *context_map. +bool DecodeContextMap(std::vector* context_map, size_t* num_htrees, + BitReader* input); + +} // namespace jxl + +#endif // LIB_JXL_DEC_CONTEXT_MAP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image.cc b/third_party/jpeg-xl/lib/jxl/dec_external_image.cc new file mode 100644 index 000000000000..d5f860b78c89 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image.cc @@ -0,0 +1,449 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_external_image.h" + +#include + +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_external_image.cc" +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +void FloatToU32(const float* in, uint32_t* out, size_t num, float mul, + size_t bits_per_sample) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind du; + size_t vec_num = num; + if (bits_per_sample == 32) { + // Conversion to real 32-bit *unsigned* integers requires more intermediate + // precision that what is given by the usual f32 -> i32 conversion + // instructions, so we run the non-SIMD path for those. + vec_num = 0; + } +#if JXL_IS_DEBUG_BUILD + // Avoid accessing partially-uninitialized vectors with memory sanitizer. + vec_num &= ~(Lanes(d) - 1); +#endif // JXL_IS_DEBUG_BUILD + + const auto one = Set(d, 1.0f); + const auto scale = Set(d, mul); + for (size_t x = 0; x < vec_num; x += Lanes(d)) { + auto v = Load(d, in + x); + // Check for NaNs. + JXL_DASSERT(AllTrue(v == v)); + // Clamp turns NaN to 'min'. + v = Clamp(v, Zero(d), one); + auto i = NearestInt(v * scale); + Store(BitCast(du, i), du, out + x); + } + for (size_t x = vec_num; x < num; x++) { + float v = in[x]; + JXL_DASSERT(!std::isnan(v)); + // Inverted condition grants that NaN is mapped to 0.0f. + v = (v >= 0.0f) ? (v > 1.0f ? mul : (v * mul)) : 0.0f; + out[x] = static_cast(v + 0.5f); + } +} + +void FloatToF16(const float* in, hwy::float16_t* out, size_t num) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind du; + size_t vec_num = num; +#if JXL_IS_DEBUG_BUILD + // Avoid accessing partially-uninitialized vectors with memory sanitizer. + vec_num &= ~(Lanes(d) - 1); +#endif // JXL_IS_DEBUG_BUILD + for (size_t x = 0; x < vec_num; x += Lanes(d)) { + auto v = Load(d, in + x); + auto v16 = DemoteTo(du, v); + Store(v16, du, out + x); + } + if (num != vec_num) { + const HWY_CAPPED(float, 1) d1; + const hwy::HWY_NAMESPACE::Rebind du1; + for (size_t x = vec_num; x < num; x++) { + auto v = Load(d1, in + x); + auto v16 = DemoteTo(du1, v); + Store(v16, du1, out + x); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +// Stores a float in big endian +void StoreBEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreBE32(u, p); +} + +// Stores a float in little endian +void StoreLEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreLE32(u, p); +} + +// The orientation may not be identity. +// TODO(lode): SIMDify where possible +template +void UndoOrientation(jxl::Orientation undo_orientation, const Plane& image, + Plane& out, jxl::ThreadPool* pool) { + const size_t xsize = image.xsize(); + const size_t ysize = image.ysize(); + + if (undo_orientation == Orientation::kFlipHorizontal) { + out = Plane(xsize, ysize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kRotate180) { + out = Plane(xsize, ysize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kFlipVertical) { + out = Plane(xsize, ysize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kTranspose) { + out = Plane(ysize, xsize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[y] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kRotate90) { + out = Plane(ysize, xsize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kAntiTranspose) { + out = Plane(ysize, xsize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation"); + } else if (undo_orientation == Orientation::kRotate270) { + out = Plane(ysize, xsize); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[y] = row_in[x]; + } + }, + "UndoOrientation"); + } +} +} // namespace + +HWY_EXPORT(FloatToU32); +HWY_EXPORT(FloatToF16); + +namespace { + +using StoreFuncType = void(uint32_t value, uint8_t* dest); +template +void StoreUintRow(uint32_t* JXL_RESTRICT* rows_u32, size_t num_channels, + size_t xsize, size_t bytes_per_sample, + uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_u32[c][x], + out + (num_channels * x + c) * bytes_per_sample); + } + } +} + +template +void StoreFloatRow(const float* JXL_RESTRICT* rows_in, size_t num_channels, + size_t xsize, uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_in[c][x], out + (num_channels * x + c) * sizeof(float)); + } + } +} + +void JXL_INLINE Store8(uint32_t value, uint8_t* dest) { *dest = value & 0xff; } + +} // namespace + +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride, + jxl::ThreadPool* pool, void* out_image, + size_t out_size, jxl::Orientation undo_orientation) { + if (bits_per_sample < 1 || bits_per_sample > 32) { + return JXL_FAILURE("Invalid bits_per_sample value."); + } + // TODO(deymo): Implement 1-bit per pixel packed in 8 samples per byte. + if (bits_per_sample == 1) { + return JXL_FAILURE("packed 1-bit per sample is not yet supported"); + } + size_t xsize = ib.xsize(); + size_t ysize = ib.ysize(); + + uint8_t* out = reinterpret_cast(out_image); + + bool want_alpha = num_channels == 2 || num_channels == 4; + size_t color_channels = num_channels <= 2 ? 1 : 3; + + // bytes_per_channel and is only valid for bits_per_sample > 1. + const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t bytes_per_pixel = num_channels * bytes_per_channel; + + const Image3F* color = &ib.color(); + Image3F temp_color; + const ImageF* alpha = ib.HasAlpha() ? &ib.alpha() : nullptr; + ImageF temp_alpha; + + if (undo_orientation != Orientation::kIdentity) { + Image3F transformed; + for (size_t c = 0; c < color_channels; ++c) { + UndoOrientation(undo_orientation, color->Plane(c), transformed.Plane(c), + pool); + } + transformed.Swap(temp_color); + color = &temp_color; + if (ib.HasAlpha()) { + UndoOrientation(undo_orientation, *alpha, temp_alpha, pool); + alpha = &temp_alpha; + } + + xsize = color->xsize(); + ysize = color->ysize(); + } + + if (stride < bytes_per_pixel * xsize) { + return JXL_FAILURE( + "stride is smaller than scanline width in bytes: %zu vs %zu", stride, + bytes_per_pixel * xsize); + } + + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + ImageF ones; + if (want_alpha && !ib.HasAlpha()) { + ones = ImageF(xsize, 1); + FillImage(1.0f, &ones); + } + + if (float_out) { + if (bits_per_sample == 16) { + bool swap_endianness = little_endian != IsLittleEndian(); + Plane f16_cache; + RunOnPool( + pool, 0, static_cast(ysize), + [&](size_t num_threads) { + f16_cache = + Plane(xsize, num_channels * num_threads); + return true; + }, + [&](const int task, int thread) { + const int64_t y = task; + const float* JXL_RESTRICT row_in[4]; + size_t c = 0; + for (; c < color_channels; c++) { + row_in[c] = color->PlaneRow(c, y); + } + if (want_alpha) { + row_in[c++] = ib.HasAlpha() ? alpha->Row(y) : ones.Row(0); + } + JXL_ASSERT(c == num_channels); + hwy::float16_t* JXL_RESTRICT row_f16[4]; + for (size_t r = 0; r < c; r++) { + row_f16[r] = f16_cache.Row(r + thread * num_channels); + HWY_DYNAMIC_DISPATCH(FloatToF16) + (row_in[r], row_f16[r], xsize); + } + // interleave the one scanline + hwy::float16_t* f16_out = &(reinterpret_cast( + out_image))[y * xsize * num_channels]; + for (size_t x = 0; x < xsize; x++) { + for (size_t r = 0; r < c; r++) { + f16_out[x * num_channels + r] = row_f16[r][x]; + } + } + if (swap_endianness) { + uint8_t* u8_out = &(reinterpret_cast( + out_image))[y * xsize * num_channels * 2]; + size_t size = xsize * num_channels * 2; + for (size_t i = 0; i < size; i += 2) { + std::swap(u8_out[i + 0], u8_out[i + 1]); + } + } + }, + "ConvertF16"); + } else if (bits_per_sample == 32) { + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const int64_t y = task; + size_t i = stride * y; + const float* JXL_RESTRICT row_in[4]; + size_t c = 0; + for (; c < color_channels; c++) { + row_in[c] = color->PlaneRow(c, y); + } + if (want_alpha) { + row_in[c++] = ib.HasAlpha() ? alpha->Row(y) : ones.Row(0); + } + JXL_ASSERT(c == num_channels); + if (little_endian) { + StoreFloatRow(row_in, c, xsize, out + i); + } else { + StoreFloatRow(row_in, c, xsize, out + i); + } + }, + "ConvertFloat"); + } else { + return JXL_FAILURE("float other than 16-bit and 32-bit not supported"); + } + } else { + // Multiplier to convert from floating point 0-1 range to the integer + // range. + float mul = (1ull << bits_per_sample) - 1; + Plane u32_cache; + RunOnPool( + pool, 0, static_cast(ysize), + [&](size_t num_threads) { + u32_cache = Plane(xsize, num_channels * num_threads); + return true; + }, + [&](const int task, int thread) { + const int64_t y = task; + size_t i = stride * y; + const float* JXL_RESTRICT row_in[4]; + size_t c = 0; + for (; c < color_channels; c++) { + row_in[c] = color->PlaneRow(c, y); + } + if (want_alpha) { + row_in[c++] = ib.HasAlpha() ? alpha->Row(y) : ones.Row(0); + } + JXL_ASSERT(c == num_channels); + uint32_t* JXL_RESTRICT row_u32[4]; + for (size_t r = 0; r < c; r++) { + row_u32[r] = u32_cache.Row(r + thread * num_channels); + HWY_DYNAMIC_DISPATCH(FloatToU32) + (row_in[r], row_u32[r], xsize, mul, bits_per_sample); + } + // TODO(deymo): add bits_per_sample == 1 case here. + if (bits_per_sample <= 8) { + StoreUintRow(row_u32, c, xsize, 1, out + i); + } else if (bits_per_sample <= 16) { + if (little_endian) { + StoreUintRow(row_u32, c, xsize, 2, out + i); + } else { + StoreUintRow(row_u32, c, xsize, 2, out + i); + } + } else if (bits_per_sample <= 24) { + if (little_endian) { + StoreUintRow(row_u32, c, xsize, 3, out + i); + } else { + StoreUintRow(row_u32, c, xsize, 3, out + i); + } + } else { + if (little_endian) { + StoreUintRow(row_u32, c, xsize, 4, out + i); + } else { + StoreUintRow(row_u32, c, xsize, 4, out + i); + } + } + }, + "ConvertUint"); + } + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image.h b/third_party/jpeg-xl/lib/jxl/dec_external_image.h new file mode 100644 index 000000000000..94f495d6501c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_DEC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include +#include + +#include "jxl/types.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Converts ib to interleaved void* pixel buffer with the given format. +// bits_per_sample: must be 8, 16 or 32, and must be 32 if float_out +// is true. 1 and 32 int are not yet implemented. +// num_channels: must be 1, 2, 3 or 4 for gray, gray+alpha, RGB, RGB+alpha. +// This supports the features needed for the C API and does not perform +// color space conversion. +// TODO(lode): support 1-bit output (bits_per_sample == 1) +// TODO(lode): support rectangle crop. +// stride_out is output scanline size in bytes, must be >= +// output_xsize * bytes_per_pixel. +// undo_orientation is an EXIF orientation to undo. Depending on the +// orientation, the output xsize and ysize are swapped compared to input +// xsize and ysize. +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride_out, + jxl::ThreadPool* thread_pool, void* out_image, + size_t out_size, jxl::Orientation undo_orientation); + +} // namespace jxl + +#endif // LIB_JXL_DEC_EXTERNAL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc b/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc new file mode 100644 index 000000000000..6d7e7f9011ca --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Decoder case, interleaves an internal float image. +void BM_DecExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + size_t num_channels = 4; + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + Image3F color(xsize, ysize); + ZeroFillImage(&color); + ib.SetFromImage(std::move(color), ColorEncoding::SRGB()); + ImageF alpha(xsize, ysize); + ZeroFillImage(&alpha); + ib.SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + const size_t bytes_per_row = xsize * num_channels; + std::vector interleaved(bytes_per_row * ysize); + + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertToExternal( + ib, + /*bits_per_sample=*/8, + /*float_out=*/false, num_channels, JXL_NATIVE_ENDIAN, + /*stride*/ bytes_per_row, + /*thread_pool=*/nullptr, interleaved.data(), interleaved.size(), + /*undo_orientation=*/jxl::Orientation::kIdentity)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_DecExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_file.cc b/third_party/jpeg-xl/lib/jxl/dec_file.cc new file mode 100644 index 000000000000..f33cfe17212f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_file.cc @@ -0,0 +1,192 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_file.h" + +#include + +#include +#include + +#include "jxl/decode.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" + +namespace jxl { +namespace { + +Status DecodeHeaders(BitReader* reader, CodecInOut* io) { + JXL_RETURN_IF_ERROR(ReadSizeHeader(reader, &io->metadata.size)); + + JXL_RETURN_IF_ERROR(ReadImageMetadata(reader, &io->metadata.m)); + + io->metadata.transform_data.nonserialized_xyb_encoded = + io->metadata.m.xyb_encoded; + JXL_RETURN_IF_ERROR(Bundle::Read(reader, &io->metadata.transform_data)); + + return true; +} + +} // namespace + +Status DecodePreview(const DecompressParams& dparams, + const CodecMetadata& metadata, + BitReader* JXL_RESTRICT reader, ThreadPool* pool, + ImageBundle* JXL_RESTRICT preview, uint64_t* dec_pixels, + const SizeConstraints* constraints) { + // No preview present in file. + if (!metadata.m.have_preview) { + if (dparams.preview == Override::kOn) { + return JXL_FAILURE("preview == kOn but no preview present"); + } + return true; + } + + // Have preview; prepare to skip or read it. + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + + if (dparams.preview == Override::kOff) { + JXL_RETURN_IF_ERROR(SkipFrame(metadata, reader, /*is_preview=*/true)); + return true; + } + + // Else: default or kOn => decode preview. + PassesDecoderState dec_state; + JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.Set(metadata.m)); + JXL_RETURN_IF_ERROR(DecodeFrame(dparams, &dec_state, pool, reader, preview, + metadata, constraints, + /*is_preview=*/true)); + if (dec_pixels) { + *dec_pixels += dec_state.shared->frame_dim.xsize_upsampled * + dec_state.shared->frame_dim.ysize_upsampled; + } + return true; +} + +// To avoid the complexity of file I/O and buffering, we assume the bitstream +// is loaded (or for large images/sequences: mapped into) memory. +Status DecodeFile(const DecompressParams& dparams, + const Span file, CodecInOut* JXL_RESTRICT io, + ThreadPool* pool) { + PROFILER_ZONE("DecodeFile uninstrumented"); + + // Marker + JxlSignature signature = JxlSignatureCheck(file.data(), file.size()); + if (signature == JXL_SIG_NOT_ENOUGH_BYTES || signature == JXL_SIG_INVALID) { + return JXL_FAILURE("File does not start with known JPEG XL signature"); + } + + std::unique_ptr jpeg_data = nullptr; + if (dparams.keep_dct) { + if (io->Main().jpeg_data == nullptr) { + return JXL_FAILURE("Caller must set jpeg_data"); + } + jpeg_data = std::move(io->Main().jpeg_data); + } + + Status ret = true; + { + BitReader reader(file); + BitReaderScopedCloser reader_closer(&reader, &ret); + (void)reader.ReadFixedBits<16>(); // skip marker + + { + JXL_RETURN_IF_ERROR(DecodeHeaders(&reader, io)); + size_t xsize = io->metadata.xsize(); + size_t ysize = io->metadata.ysize(); + JXL_RETURN_IF_ERROR(VerifyDimensions(&io->constraints, xsize, ysize)); + } + + if (io->metadata.m.color_encoding.WantICC()) { + PaddedBytes icc; + JXL_RETURN_IF_ERROR(ReadICC(&reader, &icc)); + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetICC(std::move(icc))); + } + // Set ICC profile in jpeg_data. + if (jpeg_data) { + Status res = jpeg::SetJPEGDataFromICC(io->metadata.m.color_encoding.ICC(), + jpeg_data.get()); + if (!res) { + return res; + } + } + + JXL_RETURN_IF_ERROR(DecodePreview(dparams, io->metadata, &reader, pool, + &io->preview_frame, &io->dec_pixels, + &io->constraints)); + + // Only necessary if no ICC and no preview. + JXL_RETURN_IF_ERROR(reader.JumpToByteBoundary()); + if (io->metadata.m.have_animation && dparams.keep_dct) { + return JXL_FAILURE("Cannot decode to JPEG an animation"); + } + + PassesDecoderState dec_state; + JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.Set(io->metadata.m)); + + io->frames.clear(); + Status dec_ok(false); + do { + io->frames.emplace_back(&io->metadata.m); + if (jpeg_data) { + io->frames.back().jpeg_data = std::move(jpeg_data); + } + // Skip frames that are not displayed. + do { + dec_ok = + DecodeFrame(dparams, &dec_state, pool, &reader, &io->frames.back(), + io->metadata, &io->constraints); + if (!dparams.allow_partial_files) { + JXL_RETURN_IF_ERROR(dec_ok); + } else if (!dec_ok) { + io->frames.pop_back(); + break; + } + } while (dec_state.shared->frame_header.frame_type != + FrameType::kRegularFrame && + dec_state.shared->frame_header.frame_type != + FrameType::kSkipProgressive); + io->dec_pixels += io->frames.back().xsize() * io->frames.back().ysize(); + } while (!dec_state.shared->frame_header.is_last && dec_ok); + + if (io->frames.empty()) return JXL_FAILURE("Not enough data."); + + if (dparams.check_decompressed_size && !dparams.allow_partial_files && + dparams.max_downsampling == 1) { + if (reader.TotalBitsConsumed() != file.size() * kBitsPerByte) { + return JXL_FAILURE("DecodeFile reader position not at EOF."); + } + } + // Suppress errors when decoding partial files with DC frames. + if (!reader.AllReadsWithinBounds() && dparams.allow_partial_files) { + (void)reader.Close(); + } + + io->CheckMetadata(); + // reader is closed here. + } + return ret; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_file.h b/third_party/jpeg-xl/lib/jxl/dec_file.h new file mode 100644 index 000000000000..2077ce02c8db --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_file.h @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_FILE_H_ +#define LIB_JXL_DEC_FILE_H_ + +// Top-level interface for JXL decoding. + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/dec_params.h" + +namespace jxl { + +// Decodes the preview image, if present, and stores it in `preview`. +// Must be the first frame in the file. Does nothing if there is no preview +// frame present according to the metadata. +Status DecodePreview(const DecompressParams& dparams, + const CodecMetadata& metadata, + BitReader* JXL_RESTRICT reader, ThreadPool* pool, + ImageBundle* JXL_RESTRICT preview, uint64_t* dec_pixels, + const SizeConstraints* constraints); + +// Implementation detail: currently decodes to linear sRGB. The contract is: +// `io` appears 'identical' (modulo compression artifacts) to the encoder input +// in a color-aware viewer. Note that `io->metadata.m.color_encoding` +// identifies the color space that was passed to the encoder; clients that want +// that same encoding must call `io->TransformTo` afterwards. +Status DecodeFile(const DecompressParams& params, + const Span file, CodecInOut* io, + ThreadPool* pool = nullptr); + +static inline Status DecodeFile(const DecompressParams& params, + const PaddedBytes& file, CodecInOut* io, + ThreadPool* pool = nullptr) { + return DecodeFile(params, Span(file), io, pool); +} + +} // namespace jxl + +#endif // LIB_JXL_DEC_FILE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_frame.cc b/third_party/jpeg-xl/lib/jxl/dec_frame.cc new file mode 100644 index 000000000000..d994ffe90329 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_frame.cc @@ -0,0 +1,907 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_frame.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/dec_reconstruct.h" +#include "lib/jxl/dec_upsample.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/filters.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg, + PassesDecoderState* state, ThreadPool* pool) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader)); + + JXL_RETURN_IF_ERROR( + DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map)); + + JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader)); + + // Pre-compute info for decoding a group. + if (is_jpeg) { + state->shared_storage.quantizer.ClearDCMul(); // Don't dequant DC + } + + state->shared_storage.ac_strategy.FillInvalid(); + return true; +} +} // namespace + +Status DecodeFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame_header) { + JXL_ASSERT(frame_header->nonserialized_metadata != nullptr); + JXL_RETURN_IF_ERROR(ReadFrameHeader(reader, frame_header)); + + if (frame_header->encoding == FrameEncoding::kModular) { + if (frame_header->chroma_subsampling.MaxHShift() != 0 || + frame_header->chroma_subsampling.MaxVShift() != 0) { + return JXL_FAILURE("Chroma subsampling in modular mode is not supported"); + } + } + + return true; +} + +Status SkipFrame(const CodecMetadata& metadata, BitReader* JXL_RESTRICT reader, + bool is_preview) { + FrameHeader header(&metadata); + header.nonserialized_is_preview = is_preview; + JXL_RETURN_IF_ERROR(DecodeFrameHeader(reader, &header)); + + // Read TOC. + std::vector group_offsets; + std::vector group_sizes; + uint64_t groups_total_size; + const bool has_ac_global = true; + const FrameDimensions frame_dim = header.ToFrameDimensions(); + const size_t toc_entries = + NumTocEntries(frame_dim.num_groups, frame_dim.num_dc_groups, + header.passes.num_passes, has_ac_global); + JXL_RETURN_IF_ERROR(ReadGroupOffsets(toc_entries, reader, &group_offsets, + &group_sizes, &groups_total_size)); + + // Pretend all groups are read. + reader->SkipBits(groups_total_size * kBitsPerByte); + if (reader->TotalBitsConsumed() > reader->TotalBytes() * kBitsPerByte) { + return JXL_FAILURE("Group code extends after stream end"); + } + + return true; +} + +static BitReader* GetReaderForSection( + size_t num_groups, size_t num_passes, size_t group_codes_begin, + const std::vector& group_offsets, + const std::vector& group_sizes, BitReader* JXL_RESTRICT reader, + BitReader* JXL_RESTRICT store, size_t index) { + if (num_groups == 1 && num_passes == 1) return reader; + const size_t group_offset = group_codes_begin + group_offsets[index]; + const size_t next_group_offset = + group_codes_begin + group_offsets[index] + group_sizes[index]; + // The order of these variables must be: + // group_codes_begin <= group_offset <= next_group_offset <= file.size() + JXL_DASSERT(group_codes_begin <= group_offset); + JXL_DASSERT(group_offset <= next_group_offset); + JXL_DASSERT(next_group_offset <= reader->TotalBytes()); + const size_t group_size = next_group_offset - group_offset; + const size_t remaining_size = reader->TotalBytes() - group_offset; + const size_t size = std::min(group_size + 8, remaining_size); + *store = + BitReader(Span(reader->FirstByte() + group_offset, size)); + return store; +} + +Status DecodeFrame(const DecompressParams& dparams, + PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + BitReader* JXL_RESTRICT reader, ImageBundle* decoded, + const CodecMetadata& metadata, + const SizeConstraints* constraints, bool is_preview) { + PROFILER_ZONE("DecodeFrame uninstrumented"); + + FrameDecoder frame_decoder(dec_state, metadata, pool); + + frame_decoder.SetFrameSizeLimits(constraints); + + JXL_RETURN_IF_ERROR(frame_decoder.InitFrame( + reader, decoded, is_preview, dparams.allow_partial_files, + dparams.allow_partial_files && dparams.allow_more_progressive_steps)); + + // Handling of progressive decoding. + { + const FrameHeader& frame_header = frame_decoder.GetFrameHeader(); + size_t max_passes = dparams.max_passes; + size_t max_downsampling = std::max( + dparams.max_downsampling >> (frame_header.dc_level * 3), size_t(1)); + // TODO(veluca): deal with downsamplings >= 8. + if (max_downsampling >= 8) { + max_passes = 0; + } else { + for (uint32_t i = 0; i < frame_header.passes.num_downsample; ++i) { + if (max_downsampling >= frame_header.passes.downsample[i] && + max_passes > frame_header.passes.last_pass[i]) { + max_passes = frame_header.passes.last_pass[i] + 1; + } + } + } + // Do not use downsampling for kReferenceOnly frames. + if (frame_header.frame_type == FrameType::kReferenceOnly) { + max_passes = frame_header.passes.num_passes; + } + max_passes = std::min(max_passes, frame_header.passes.num_passes); + frame_decoder.SetMaxPasses(max_passes); + } + + size_t processed_bytes = reader->TotalBitsConsumed() / kBitsPerByte; + + Status close_ok = true; + std::vector> section_readers; + { + std::vector> section_closers; + std::vector section_info; + std::vector section_status; + size_t bytes_to_skip = 0; + for (size_t i = 0; i < frame_decoder.NumSections(); i++) { + size_t b = frame_decoder.SectionOffsets()[i]; + size_t e = b + frame_decoder.SectionSizes()[i]; + bytes_to_skip += e - b; + size_t pos = reader->TotalBitsConsumed() / kBitsPerByte; + if (pos + e <= reader->TotalBytes()) { + auto br = make_unique( + Span(reader->FirstByte() + b + pos, e - b)); + section_info.emplace_back(FrameDecoder::SectionInfo{br.get(), i}); + section_closers.emplace_back( + make_unique(br.get(), &close_ok)); + section_readers.emplace_back(std::move(br)); + } else if (!dparams.allow_partial_files) { + return JXL_FAILURE("Premature end of stream."); + } + } + // Skip over the to-be-decoded sections. + reader->SkipBits(kBitsPerByte * bytes_to_skip); + section_status.resize(section_info.size()); + + JXL_RETURN_IF_ERROR(frame_decoder.ProcessSections( + section_info.data(), section_info.size(), section_status.data())); + + for (size_t i = 0; i < section_status.size(); i++) { + auto s = section_status[i]; + if (s == FrameDecoder::kDone) { + processed_bytes += frame_decoder.SectionSizes()[i]; + continue; + } + if (dparams.allow_more_progressive_steps && s == FrameDecoder::kPartial) { + continue; + } + if (dparams.max_downsampling > 1 && s == FrameDecoder::kSkipped) { + continue; + } + return JXL_FAILURE("Invalid section %zu status: %d", section_info[i].id, + s); + } + } + + JXL_RETURN_IF_ERROR(close_ok); + + JXL_RETURN_IF_ERROR(frame_decoder.FinalizeFrame()); + decoded->SetDecodedBytes(processed_bytes); + return true; +} + +Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview, bool allow_partial_frames, + bool allow_partial_dc_global) { + PROFILER_FUNC; + decoded_ = decoded; + JXL_ASSERT(is_finalized_); + + allow_partial_frames_ = allow_partial_frames; + allow_partial_dc_global_ = allow_partial_dc_global; + + // Reset the dequantization matrices to their default values. + dec_state_->shared_storage.matrices = DequantMatrices(); + + frame_header_.nonserialized_is_preview = is_preview; + JXL_RETURN_IF_ERROR(DecodeFrameHeader(br, &frame_header_)); + frame_dim_ = frame_header_.ToFrameDimensions(); + + const size_t num_passes = frame_header_.passes.num_passes; + const size_t xsize = frame_dim_.xsize; + const size_t ysize = frame_dim_.ysize; + const size_t num_groups = frame_dim_.num_groups; + + // Check validity of frame dimensions. + JXL_RETURN_IF_ERROR(VerifyDimensions(constraints_, xsize, ysize)); + + // If the previous frame was not a kRegularFrame, `decoded` may have different + // dimensions; must reset to avoid errors. + decoded->RemoveColor(); + decoded->ClearExtraChannels(); + + // Read TOC. + uint64_t groups_total_size; + const bool has_ac_global = true; + const size_t toc_entries = NumTocEntries(num_groups, frame_dim_.num_dc_groups, + num_passes, has_ac_global); + JXL_RETURN_IF_ERROR(ReadGroupOffsets(toc_entries, br, §ion_offsets_, + §ion_sizes_, &groups_total_size)); + + JXL_DASSERT((br->TotalBitsConsumed() % kBitsPerByte) == 0); + const size_t group_codes_begin = br->TotalBitsConsumed() / kBitsPerByte; + JXL_DASSERT(!section_offsets_.empty()); + + // Overflow check. + if (group_codes_begin + groups_total_size < group_codes_begin) { + return JXL_FAILURE("Invalid group codes"); + } + + if (!frame_header_.chroma_subsampling.Is444() && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + frame_header_.encoding == FrameEncoding::kVarDCT) { + // TODO(veluca): actually implement this. + return JXL_FAILURE( + "Non-444 chroma subsampling is not supported when adaptive DC " + "smoothing is enabled"); + } + JXL_RETURN_IF_ERROR( + InitializePassesSharedState(frame_header_, &dec_state_->shared_storage)); + dec_state_->Init(); + modular_frame_decoder_.Init(frame_dim_); + + if (decoded->IsJPEG()) { + if (frame_header_.encoding == FrameEncoding::kModular) { + return JXL_FAILURE("Cannot output JPEG from Modular"); + } + jpeg::JPEGData* jpeg_data = decoded->jpeg_data.get(); + if (jpeg_data->components.size() != 1 && + jpeg_data->components.size() != 3) { + return JXL_FAILURE("Invalid number of components"); + } + decoded->jpeg_data->width = frame_dim_.xsize; + decoded->jpeg_data->height = frame_dim_.ysize; + if (jpeg_data->components.size() == 1) { + jpeg_data->components[0].width_in_blocks = frame_dim_.xsize_blocks; + jpeg_data->components[0].height_in_blocks = frame_dim_.ysize_blocks; + } else { + for (size_t c = 0; c < 3; c++) { + jpeg_data->components[c < 2 ? c ^ 1 : c].width_in_blocks = + frame_dim_.xsize_blocks >> + frame_header_.chroma_subsampling.HShift(c); + jpeg_data->components[c < 2 ? c ^ 1 : c].height_in_blocks = + frame_dim_.ysize_blocks >> + frame_header_.chroma_subsampling.VShift(c); + } + } + for (size_t c = 0; c < jpeg_data->components.size(); c++) { + jpeg_data->components[c].h_samp_factor = + 1 << frame_header_.chroma_subsampling.RawHShift(c < 2 ? c ^ 1 : c); + jpeg_data->components[c].v_samp_factor = + 1 << frame_header_.chroma_subsampling.RawVShift(c < 2 ? c ^ 1 : c); + } + for (auto& v : jpeg_data->components) { + v.coeffs.resize(v.width_in_blocks * v.height_in_blocks * + jxl::kDCTBlockSize); + } + } + + dec_state_->upsampler.Init( + frame_header_.upsampling, + frame_header_.nonserialized_metadata->m.transform_data); + + // Clear the state. + decoded_dc_global_ = false; + decoded_ac_global_ = false; + is_finalized_ = false; + finalized_dc_ = false; + decoded_dc_groups_.clear(); + decoded_dc_groups_.resize(frame_dim_.num_dc_groups); + decoded_passes_per_ac_group_.clear(); + decoded_passes_per_ac_group_.resize(frame_dim_.num_groups, 0); + processed_section_.clear(); + processed_section_.resize(section_offsets_.size()); + max_passes_ = frame_header_.passes.num_passes; + num_renders_ = 0; + + return true; +} + +Status FrameDecoder::ProcessDCGlobal(BitReader* br) { + PROFILER_FUNC; + PassesSharedState& shared = dec_state_->shared_storage; + if (shared.frame_header.flags & FrameHeader::kPatches) { + JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode( + br, frame_dim_.xsize_padded, frame_dim_.ysize_padded)); + } + if (shared.frame_header.flags & FrameHeader::kSplines) { + JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode( + br, frame_dim_.xsize * frame_dim_.ysize)); + } + if (shared.frame_header.flags & FrameHeader::kNoise) { + JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params)); + } + + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.DecodeDC(br)); + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + jxl::DecodeGlobalDCInfo(br, decoded_->IsJPEG(), dec_state_, pool_)); + } + Status dec_status = modular_frame_decoder_.DecodeGlobalInfo( + br, frame_header_, allow_partial_dc_global_); + if (dec_status.IsFatalError()) return dec_status; + if (dec_status) { + decoded_dc_global_ = true; + } + return dec_status; +} + +Status FrameDecoder::ProcessDCGroup(size_t dc_group_id, BitReader* br) { + PROFILER_FUNC; + const size_t gx = dc_group_id % frame_dim_.xsize_dc_groups; + const size_t gy = dc_group_id / frame_dim_.xsize_dc_groups; + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.DecodeVarDCTDC(dc_group_id, br, dec_state_)); + } + const Rect mrect(gx * kDcGroupDim, gy * kDcGroupDim, kDcGroupDim, + kDcGroupDim); + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, br, 3, 1000, ModularStreamId::ModularDC(dc_group_id), + /*zerofill=*/false)); + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.DecodeAcMetadata(dc_group_id, br, dec_state_)); + } + decoded_dc_groups_[dc_group_id] = true; + return true; +} + +void FrameDecoder::FinalizeDC() { + // Do Adaptive DC smoothing if enabled. This *must* happen between all the + // ProcessDCGroup and ProcessACGroup. + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + AdaptiveDCSmoothing(dec_state_->shared->quantizer.MulDC(), + &dec_state_->shared_storage.dc_storage, pool_); + } + + finalized_dc_ = true; +} + +void FrameDecoder::AllocateOutput() { + const CodecMetadata& metadata = *frame_header_.nonserialized_metadata; + if (dec_state_->rgb_output == nullptr) { + decoded_->SetFromImage(Image3F(frame_dim_.xsize_upsampled_padded, + frame_dim_.ysize_upsampled_padded), + dec_state_->output_encoding_info.color_encoding); + } + if (metadata.m.num_extra_channels > 0) { + for (size_t i = 0; i < metadata.m.num_extra_channels; i++) { + const auto eci = metadata.m.extra_channel_info[i]; + dec_state_->extra_channels.emplace_back( + eci.Size(frame_dim_.xsize_upsampled_padded), + eci.Size(frame_dim_.ysize_upsampled_padded)); +#if MEMORY_SANITIZER + // Avoid errors due to loading vectors on the outermost padding. + for (size_t y = 0; y < eci.Size(frame_dim_.ysize_upsampled_padded); y++) { + for (size_t x = eci.Size(frame_dim_.xsize_upsampled); + x < eci.Size(frame_dim_.xsize_upsampled_padded); x++) { + dec_state_->extra_channels.back().Row(y)[x] = 0; + } + } +#endif + } + } + decoded_->origin = dec_state_->shared->frame_header.frame_origin; +} + +Status FrameDecoder::ProcessACGlobal(BitReader* br) { + JXL_CHECK(finalized_dc_); + JXL_CHECK(decoded_->HasColor() || dec_state_->rgb_output != nullptr); + + // Decode AC group. + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode( + br, &modular_frame_decoder_)); + + size_t num_histo_bits = + CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups); + dec_state_->shared_storage.num_histograms = + 1 + br->ReadBits(num_histo_bits); + + dec_state_->code.resize(kMaxNumPasses); + dec_state_->context_map.resize(kMaxNumPasses); + // Read coefficient orders and histograms. + size_t max_num_bits_ac = 0; + for (size_t i = 0; + i < dec_state_->shared_storage.frame_header.passes.num_passes; i++) { + uint16_t used_orders = U32Coder::Read(kOrderEnc, br); + JXL_RETURN_IF_ERROR(DecodeCoeffOrders( + used_orders, dec_state_->used_acs, + &dec_state_->shared_storage + .coeff_orders[i * dec_state_->shared_storage.coeff_order_size], + br)); + size_t num_contexts = + dec_state_->shared->num_histograms * + dec_state_->shared_storage.block_ctx_map.NumACContexts(); + JXL_RETURN_IF_ERROR(DecodeHistograms( + br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i])); + // Add extra values to enable the cheat in hot loop of DecodeACVarBlock. + dec_state_->context_map[i].resize( + num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount); + max_num_bits_ac = + std::max(max_num_bits_ac, dec_state_->code[i].max_num_bits); + } + max_num_bits_ac += CeilLog2Nonzero( + dec_state_->shared_storage.frame_header.passes.num_passes); + // 16-bit buffer for decoding to JPEG are not implemented. + // TODO(veluca): figure out the exact limit - 16 should still work with + // 16-bit buffers, but we are excluding it for safety. + bool use_16_bit = max_num_bits_ac < 16 && !decoded_->IsJPEG(); + bool store = frame_header_.passes.num_passes > 1; + size_t xs = store ? kGroupDim * kGroupDim : 0; + size_t ys = store ? frame_dim_.num_groups : 0; + if (use_16_bit) { + dec_state_->coefficients = make_unique>(xs, ys); + } else { + dec_state_->coefficients = make_unique>(xs, ys); + } + if (store) { + dec_state_->coefficients->ZeroFill(); + } + } + + // Set JPEG decoding data. + if (decoded_->IsJPEG()) { + decoded_->color_transform = frame_header_.color_transform; + decoded_->chroma_subsampling = frame_header_.chroma_subsampling; + const std::vector& qe = + dec_state_->shared_storage.matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + auto jpeg_c_map = JpegOrder(frame_header_.color_transform, + decoded_->jpeg_data->components.size() == 1); + for (size_t c = 0; c < 3; c++) { + if (c != 1 && decoded_->jpeg_data->components.size() == 1) { + continue; + } + size_t jpeg_channel = jpeg_c_map[c]; + size_t qpos = decoded_->jpeg_data->components[jpeg_channel].quant_idx; + JXL_CHECK(qpos != decoded_->jpeg_data->quant.size()); + for (size_t x = 0; x < 8; x++) { + for (size_t y = 0; y < 8; y++) { + decoded_->jpeg_data->quant[qpos].values[x * 8 + y] = + (*qe[0].qraw.qtable)[c * 64 + y * 8 + x]; + } + } + } + } + // Set memory buffer for pre-color-transform frame, if needed. + if (frame_header_.needs_color_transform() && + frame_header_.save_before_color_transform) { + dec_state_->pre_color_transform_frame = Image3F( + frame_dim_.xsize_upsampled_padded, frame_dim_.ysize_upsampled_padded); + } else { + // clear pre_color_transform_frame to ensure that previously moved-from + // images are not used. + dec_state_->pre_color_transform_frame = Image3F(); + } + decoded_ac_global_ = true; + return true; +} + +Status FrameDecoder::ProcessACGroup(size_t ac_group_id, + BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, + bool force_draw, bool dc_only) { + PROFILER_ZONE("process_group"); + const size_t gx = ac_group_id % frame_dim_.xsize_groups; + const size_t gy = ac_group_id / frame_dim_.xsize_groups; + const size_t x = gx * frame_dim_.group_dim; + const size_t y = gy * frame_dim_.group_dim; + + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + group_dec_caches_[thread].InitOnce(frame_header_.passes.num_passes, + dec_state_->used_acs); + JXL_RETURN_IF_ERROR(DecodeGroup( + br, num_passes, ac_group_id, dec_state_, &group_dec_caches_[thread], + thread, decoded_, decoded_passes_per_ac_group_[ac_group_id], force_draw, + dc_only)); + } + + // don't limit to image dimensions here (is done in DecodeGroup) + const Rect mrect(x, y, frame_dim_.group_dim, frame_dim_.group_dim); + for (size_t i = 0; i < frame_header_.passes.num_passes; i++) { + int minShift, maxShift; + frame_header_.passes.GetDownsamplingBracket(i, minShift, maxShift); + if (i >= decoded_passes_per_ac_group_[ac_group_id] && + i < decoded_passes_per_ac_group_[ac_group_id] + num_passes) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, br[i - decoded_passes_per_ac_group_[ac_group_id]], minShift, + maxShift, ModularStreamId::ModularAC(ac_group_id, i), + /*zerofill=*/false)); + } else if (i >= decoded_passes_per_ac_group_[ac_group_id] + num_passes && + force_draw) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + mrect, nullptr, minShift, maxShift, + ModularStreamId::ModularAC(ac_group_id, i), /*zerofill=*/true)); + } + } + decoded_passes_per_ac_group_[ac_group_id] += num_passes; + return true; +} + +Status FrameDecoder::ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status) { + if (num == 0) return true; // Nothing to process + std::fill(section_status, section_status + num, SectionStatus::kSkipped); + size_t dc_global_sec = num; + size_t ac_global_sec = num; + std::vector dc_group_sec(frame_dim_.num_dc_groups, num); + std::vector> ac_group_sec( + frame_dim_.num_groups, + std::vector(frame_header_.passes.num_passes, num)); + std::vector num_ac_passes(frame_dim_.num_groups); + if (frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1) { + JXL_ASSERT(num == 1); + JXL_ASSERT(sections[0].id == 0); + if (processed_section_[0] == false) { + processed_section_[0] = true; + ac_group_sec[0].resize(1); + dc_global_sec = ac_global_sec = dc_group_sec[0] = ac_group_sec[0][0] = 0; + num_ac_passes[0] = 1; + } else { + section_status[0] = SectionStatus::kDuplicate; + } + } else { + size_t ac_global_index = frame_dim_.num_dc_groups + 1; + for (size_t i = 0; i < num; i++) { + JXL_ASSERT(sections[i].id < processed_section_.size()); + if (processed_section_[sections[i].id]) { + section_status[i] = SectionStatus::kDuplicate; + continue; + } + if (sections[i].id == 0) { + dc_global_sec = i; + } else if (sections[i].id < ac_global_index) { + dc_group_sec[sections[i].id - 1] = i; + } else if (sections[i].id == ac_global_index) { + ac_global_sec = i; + } else { + size_t ac_idx = sections[i].id - ac_global_index - 1; + size_t acg = ac_idx % frame_dim_.num_groups; + size_t acp = ac_idx / frame_dim_.num_groups; + if (acp >= frame_header_.passes.num_passes) { + return JXL_FAILURE("Invalid section ID"); + } + if (acp >= max_passes_) { + continue; + } + ac_group_sec[acg][acp] = i; + } + processed_section_[sections[i].id] = true; + } + // Count number of new passes per group. + for (size_t g = 0; g < ac_group_sec.size(); g++) { + size_t j = 0; + for (; j + decoded_passes_per_ac_group_[g] < max_passes_; j++) { + if (ac_group_sec[g][j + decoded_passes_per_ac_group_[g]] == num) { + break; + } + } + num_ac_passes[g] = j; + } + } + if (dc_global_sec != num) { + Status dc_global_status = ProcessDCGlobal(sections[dc_global_sec].br); + if (dc_global_status.IsFatalError()) return dc_global_status; + if (dc_global_status) { + section_status[dc_global_sec] = SectionStatus::kDone; + } else { + section_status[dc_global_sec] = SectionStatus::kPartial; + } + } + + std::atomic has_error{false}; + if (decoded_dc_global_) { + RunOnPool( + pool_, 0, dc_group_sec.size(), ThreadPool::SkipInit(), + [this, &dc_group_sec, &num, §ions, §ion_status, &has_error]( + size_t i, size_t thread) { + if (dc_group_sec[i] != num) { + if (!ProcessDCGroup(i, sections[dc_group_sec[i]].br)) { + has_error = true; + } else { + section_status[dc_group_sec[i]] = SectionStatus::kDone; + } + } + }, + "DecodeDCGroup"); + } + if (has_error) return JXL_FAILURE("Error in DC group"); + + if (*std::min_element(decoded_dc_groups_.begin(), decoded_dc_groups_.end()) == + true && + !finalized_dc_) { + FinalizeDC(); + AllocateOutput(); + } + + if (finalized_dc_) dec_state_->EnsureBordersStorage(); + if (finalized_dc_ && ac_global_sec != num && !decoded_ac_global_) { + dec_state_->InitForAC(pool_); + JXL_RETURN_IF_ERROR(ProcessACGlobal(sections[ac_global_sec].br)); + section_status[ac_global_sec] = SectionStatus::kDone; + } + + if (decoded_ac_global_) { + // The decoded image requires padding for filtering. ProcessACGlobal added + // the padding, however when Flush is used, the image is shrunk to the + // output size. Add the padding back here. This is a cheap opeartion + // since the image has the original allocated size. The memory and original + // size are already there, but for safety we require the indicated xsize and + // ysize dimensions match the working area, see PlaneRowBoundsCheck. + decoded_->ShrinkTo(frame_dim_.xsize_upsampled_padded, + frame_dim_.ysize_upsampled_padded); + + // Mark all the AC groups that we received as not complete yet. + for (size_t i = 0; i < ac_group_sec.size(); i++) { + if (num_ac_passes[i] == 0) continue; + dec_state_->group_border_assigner.ClearDone(i); + } + + RunOnPool( + pool_, 0, ac_group_sec.size(), + [this](size_t num_threads) { + PrepareStorage(num_threads, decoded_passes_per_ac_group_.size()); + return true; + }, + [this, &ac_group_sec, &num_ac_passes, &num, §ions, §ion_status, + &has_error](size_t g, size_t thread) { + if (num_ac_passes[g] == 0) { // no new AC pass, nothing to do. + return; + } + (void)num; + size_t first_pass = decoded_passes_per_ac_group_[g]; + BitReader* JXL_RESTRICT readers[kMaxNumPasses]; + for (size_t i = 0; i < num_ac_passes[g]; i++) { + JXL_ASSERT(ac_group_sec[g][first_pass + i] != num); + readers[i] = sections[ac_group_sec[g][first_pass + i]].br; + } + if (!ProcessACGroup(g, readers, num_ac_passes[g], + GetStorageLocation(thread, g), + /*force_draw=*/false, /*dc_only=*/false)) { + has_error = true; + } else { + for (size_t i = 0; i < num_ac_passes[g]; i++) { + section_status[ac_group_sec[g][first_pass + i]] = + SectionStatus::kDone; + } + } + }, + "DecodeGroup"); + } + if (has_error) return JXL_FAILURE("Error in AC group"); + + for (size_t i = 0; i < num; i++) { + if (section_status[i] == SectionStatus::kSkipped || + section_status[i] == SectionStatus::kPartial) { + processed_section_[sections[i].id] = false; + } + } + return true; +} + +Status FrameDecoder::Flush() { + bool has_blending = frame_header_.blending_info.mode != BlendMode::kReplace || + frame_header_.custom_size_or_origin; + for (const auto& blending_info_ec : + frame_header_.extra_channel_blending_info) { + if (blending_info_ec.mode != BlendMode::kReplace) has_blending = true; + } + // No early Flush() if blending is enabled. + if (has_blending && !is_finalized_) { + return false; + } + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + uint32_t completely_decoded_ac_pass = *std::min_element( + decoded_passes_per_ac_group_.begin(), decoded_passes_per_ac_group_.end()); + if (completely_decoded_ac_pass < frame_header_.passes.num_passes) { + // We don't have all AC yet: force a draw of all the missing areas. + dec_state_->dc_upsampler.Init( + /*upsampling=*/8, frame_header_.nonserialized_metadata->transform_data); + // Mark all sections as not complete. + for (size_t i = 0; i < decoded_passes_per_ac_group_.size(); i++) { + if (decoded_passes_per_ac_group_[i] == frame_header_.passes.num_passes) + continue; + dec_state_->group_border_assigner.ClearDone(i); + } + std::atomic has_error{false}; + RunOnPool( + pool_, 0, decoded_passes_per_ac_group_.size(), + [this](size_t num_threads) { + PrepareStorage(num_threads, decoded_passes_per_ac_group_.size()); + return true; + }, + [this, &has_error](size_t g, size_t thread) { + if (decoded_passes_per_ac_group_[g] == + frame_header_.passes.num_passes) { + // This group was drawn already, nothing to do. + return; + } + BitReader* JXL_RESTRICT readers[kMaxNumPasses] = {}; + bool ok = ProcessACGroup( + g, readers, /*num_passes=*/0, GetStorageLocation(thread, g), + /*force_draw=*/true, /*dc_only=*/!decoded_ac_global_); + if (!ok) has_error = true; + }, + "ForceDrawGroup"); + if (has_error) { + return JXL_FAILURE("Drawing groups failed"); + } + } + // TODO(veluca): the rest of this function should be removed once we have full + // support for per-group decoding. + + // undo global modular transforms and copy int pixel buffers to float ones + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.FinalizeDecoding(dec_state_, pool_, decoded_)); + + JXL_RETURN_IF_ERROR(FinalizeFrameDecoding(decoded_, dec_state_, pool_, + /*force_fir=*/false, + /*skip_blending=*/false)); + + num_renders_++; + return true; +} + +Status FrameDecoder::FinalizeFrame() { + if (is_finalized_) { + return JXL_FAILURE("FinalizeFrame called multiple times"); + } + is_finalized_ = true; + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + if (!finalized_dc_) { + // We don't have all of DC: EPF might not behave correctly (and is not + // particularly useful anyway on upsampling results), so we disable it. + dec_state_->shared_storage.frame_header.loop_filter.epf_iters = 0; + } + if ((!decoded_dc_global_ || !decoded_ac_global_ || + *std::min_element(decoded_dc_groups_.begin(), + decoded_dc_groups_.end()) != 1 || + *std::min_element(decoded_passes_per_ac_group_.begin(), + decoded_passes_per_ac_group_.end()) < max_passes_) && + !allow_partial_frames_) { + return JXL_FAILURE( + "FinalizeFrame called before the frame was fully decoded"); + } + + JXL_RETURN_IF_ERROR(Flush()); + + if (dec_state_->shared->frame_header.CanBeReferenced()) { + size_t id = dec_state_->shared->frame_header.save_as_reference; + if (dec_state_->pre_color_transform_frame.xsize() == 0) { + dec_state_->shared_storage.reference_frames[id].storage = + decoded_->Copy(); + } else { + dec_state_->shared_storage.reference_frames[id].storage = + ImageBundle(decoded_->metadata()); + dec_state_->shared_storage.reference_frames[id].storage.SetFromImage( + std::move(dec_state_->pre_color_transform_frame), + decoded_->c_current()); + if (decoded_->HasExtraChannels()) { + const std::vector* ecs = &dec_state_->pre_color_transform_ec; + if (ecs->empty()) ecs = &decoded_->extra_channels(); + std::vector extra_channels; + for (const auto& ec : *ecs) { + extra_channels.push_back(CopyImage(ec)); + } + dec_state_->shared_storage.reference_frames[id] + .storage.SetExtraChannels(std::move(extra_channels)); + } + } + dec_state_->shared_storage.reference_frames[id].frame = + &dec_state_->shared_storage.reference_frames[id].storage; + dec_state_->shared_storage.reference_frames[id].ib_is_in_xyb = + dec_state_->shared->frame_header.save_before_color_transform; + } + if (dec_state_->shared->frame_header.dc_level != 0) { + dec_state_->shared_storage + .dc_frames[dec_state_->shared->frame_header.dc_level - 1] = + std::move(*decoded_->color()); + decoded_->RemoveColor(); + } + if (frame_header_.nonserialized_is_preview) { + // Fix possible larger image size (multiple of kBlockDim) + // TODO(lode): verify if and when that happens. + decoded_->ShrinkTo(frame_dim_.xsize, frame_dim_.ysize); + } else if (!decoded_->IsJPEG()) { + // A kRegularFrame is blended with the other frames, and thus results in a + // coalesced frame of size equal to image dimensions. Other frames are not + // blended, thus their final size is the size that was defined in the + // frame_header. + if (frame_header_.frame_type == kRegularFrame || + frame_header_.frame_type == kSkipProgressive) { + decoded_->ShrinkTo( + dec_state_->shared->frame_header.nonserialized_metadata->xsize(), + dec_state_->shared->frame_header.nonserialized_metadata->ysize()); + } else { + // xsize_upsampled is the actual frame size, after any upsampling has been + // applied. + decoded_->ShrinkTo(frame_dim_.xsize_upsampled, + frame_dim_.ysize_upsampled); + } + } + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_frame.h b/third_party/jpeg-xl/lib/jxl/dec_frame.h new file mode 100644 index 000000000000..21aa7f62f733 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_frame.h @@ -0,0 +1,228 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_FRAME_H_ +#define LIB_JXL_DEC_FRAME_H_ + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// TODO(veluca): remove DecodeFrameHeader once the API migrates to FrameDecoder. + +// `frame_header` must have nonserialized_metadata and +// nonserialized_is_preview set. +Status DecodeFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame_header); + +// Decodes a frame. Groups may be processed in parallel by `pool`. +// See DecodeFile for explanation of c_decoded. +// `io` is only used for reading maximum image size. Also updates +// `dec_state` with the new frame header. +// `metadata` is the metadata that applies to all frames of the codestream +// `decoded->metadata` must already be set and must match metadata.m. +Status DecodeFrame(const DecompressParams& dparams, + PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + BitReader* JXL_RESTRICT reader, ImageBundle* decoded, + const CodecMetadata& metadata, + const SizeConstraints* constraints, bool is_preview = false); + +// Leaves reader in the same state as DecodeFrame would. Used to skip preview. +// Also updates `dec_state` with the new frame header. +Status SkipFrame(const CodecMetadata& metadata, BitReader* JXL_RESTRICT reader, + bool is_preview = false); + +// TODO(veluca): implement "forced drawing". +class FrameDecoder { + public: + // All parameters must outlive the FrameDecoder. + FrameDecoder(PassesDecoderState* dec_state, const CodecMetadata& metadata, + ThreadPool* pool) + : dec_state_(dec_state), pool_(pool), frame_header_(&metadata) {} + + // `constraints` must outlive the FrameDecoder if not null, or stay alive + // until the next call to SetFrameSizeLimits. + void SetFrameSizeLimits(const SizeConstraints* constraints) { + constraints_ = constraints; + } + + // Read FrameHeader and table of contents from the given BitReader. + // Also checks frame dimensions for their limits, and sets the output + // image buffer. + // TODO(veluca): remove the `allow_partial_frames` flag - this should be moved + // on callers. + Status InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview, bool allow_partial_frames, + bool allow_partial_dc_global); + + struct SectionInfo { + BitReader* JXL_RESTRICT br; + size_t id; + }; + + enum SectionStatus { + // Processed correctly. + kDone = 0, + // Skipped because other required sections were not yet processed. + kSkipped = 1, + // Skipped because the section was already processed. + kDuplicate = 2, + // Only partially decoded: the section will need to be processed again. + kPartial = 3, + }; + + // Processes `num` sections; each SectionInfo contains the index + // of the section and a BitReader that only contains the data of the section. + // `section_status` should point to `num` elements, and will be filled with + // information about whether each section was processed or not. + // A section is a part of the encoded file that is indexed by the TOC. + Status ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status); + + // Flushes all the data decoded so far to pixels. + Status Flush(); + + // Runs final operations once a frame data is decoded. + // Must be called exactly once per frame, after all calls to ProcessSections. + Status FinalizeFrame(); + + // Returns offset of this section after the end of the TOC. The end of the TOC + // is the byte position of the bit reader after InitFrame was called. + const std::vector& SectionOffsets() const { + return section_offsets_; + } + const std::vector& SectionSizes() const { return section_sizes_; } + size_t NumSections() const { return section_sizes_.size(); } + + // TODO(veluca): remove once we remove --downsampling flag. + void SetMaxPasses(size_t max_passes) { max_passes_ = max_passes; } + const FrameHeader& GetFrameHeader() const { return frame_header_; } + + // Returns whether a DC image has been decoded, accessible at low resolution + // at passes.shared_storage.dc_storage + bool HasDecodedDC() const { + return frame_header_.encoding == FrameEncoding::kVarDCT && finalized_dc_; + } + + // If the image has default exif orientation and no + // blending, the current frame cannot be referenced by future frames, sets the + // buffer to which uint8 sRGB pixels will be decoded to. + // TODO(veluca): reduce this set of restrictions. + void MaybeSetRGB8OutputBuffer(uint8_t* rgb_output, size_t stride, + bool is_rgba) const { + if (decoded_->metadata()->GetOrientation() != Orientation::kIdentity) { + return; + } + if (frame_header_.blending_info.mode != BlendMode::kReplace || + frame_header_.custom_size_or_origin) { + return; + } + if (frame_header_.CanBeReferenced()) { + return; + } + dec_state_->rgb_output = rgb_output; + dec_state_->rgb_output_is_rgba = is_rgba; + dec_state_->rgb_stride = stride; +#if !JXL_HIGH_PRECISION + if (!is_rgba && decoded_->metadata()->xyb_encoded && + dec_state_->output_encoding_info.color_encoding.IsSRGB() && + dec_state_->output_encoding_info.all_default_opsin && + HasFastXYBTosRGB8() && frame_header_.needs_color_transform()) { + dec_state_->fast_xyb_srgb8_conversion = true; + } +#endif + } + + // Returns true if the rgb output buffer passed by MaybeSetRGB8OutputBuffer + // has been/will be populated by Flush() / FinalizeFrame(). + bool HasRGBBuffer() const { return dec_state_->rgb_output != nullptr; } + + private: + Status ProcessDCGlobal(BitReader* br); + Status ProcessDCGroup(size_t dc_group_id, BitReader* br); + void FinalizeDC(); + void AllocateOutput(); + Status ProcessACGlobal(BitReader* br); + Status ProcessACGroup(size_t ac_group_id, BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, bool force_draw, + bool dc_only); + + // Allocates storage for parallel decoding using up to `num_threads` threads + // of up to `num_tasks` tasks. The value of `thread` passed to + // `GetStorageLocation` must be smaller than the `num_threads` value passed + // here. The value of `task` passed to `GetStorageLocation` must be smaller + // than the value of `num_tasks` passed here. + void PrepareStorage(size_t num_threads, size_t num_tasks) { + size_t storage_size = std::min(num_threads, num_tasks); + if (storage_size > group_dec_caches_.size()) { + group_dec_caches_.resize(storage_size); + } + dec_state_->EnsureStorage(storage_size); + use_task_id_ = num_threads > num_tasks; + } + + size_t GetStorageLocation(size_t thread, size_t task) { + if (use_task_id_) return task; + return thread; + } + + PassesDecoderState* dec_state_; + ThreadPool* pool_; + std::vector section_offsets_; + std::vector section_sizes_; + size_t max_passes_; + // TODO(veluca): figure out the duplication between these and dec_state_. + FrameHeader frame_header_; + FrameDimensions frame_dim_; + ImageBundle* decoded_; + ModularFrameDecoder modular_frame_decoder_; + bool allow_partial_frames_; + bool allow_partial_dc_global_; + + std::vector processed_section_; + std::vector decoded_passes_per_ac_group_; + std::vector decoded_dc_groups_; + bool decoded_dc_global_; + bool decoded_ac_global_; + bool finalized_dc_ = true; + bool is_finalized_ = true; + size_t num_renders_ = 0; + + std::vector group_dec_caches_; + + // Frame size limits. + const SizeConstraints* constraints_ = nullptr; + + // Whether or not the task id should be used for storage indexing, instead of + // the thread id. + bool use_task_id_ = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_FRAME_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.cc b/third_party/jpeg-xl/lib/jxl/dec_group.cc new file mode 100644 index 000000000000..d17f1e3f2471 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group.cc @@ -0,0 +1,869 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_group.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_reconstruct.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" + +#ifndef LIB_JXL_DEC_GROUP_CC +#define LIB_JXL_DEC_GROUP_CC +namespace jxl { + +// Interface for reading groups for DecodeGroupImpl. +class GetBlock { + public: + virtual void StartRow(size_t by) = 0; + virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, + size_t size, size_t log2_covered_blocks, + ACPtr block[3], ACType ac_type) = 0; + virtual ~GetBlock() {} +}; + +// Controls whether DecodeGroupImpl renders to pixels or not. +enum DrawMode { + // Render to pixels. + kDraw = 0, + // Don't render to pixels. + kDontDraw = 1, + // Don't do IDCT or dequantization, but just postprocessing. Used for + // progressive DC. + kOnlyImageFeatures = 2, +}; + +} // namespace jxl +#endif // LIB_JXL_DEC_GROUP_CC + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftRight; + +using D = HWY_FULL(float); +using DU = HWY_FULL(uint32_t); +using DI = HWY_FULL(int32_t); +using DI16 = Rebind; +constexpr D d; +constexpr DI di; +constexpr DI16 di16; + +// TODO(veluca): consider SIMDfying. +void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { + for (size_t x = 0; x < 8; x++) { + for (size_t y = x + 1; y < 8; y++) { + std::swap(block[y * 8 + x], block[x * 8 + y]); + } + } +} + +template +void DequantLane(Vec scaled_dequant_x, Vec scaled_dequant_y, + Vec scaled_dequant_b, + const float* JXL_RESTRICT dequant_matrices, size_t dq_ofs, + size_t size, size_t k, Vec x_cc_mul, Vec b_cc_mul, + const float* JXL_RESTRICT biases, ACPtr qblock[3], + float* JXL_RESTRICT block) { + const auto x_mul = Load(d, dequant_matrices + dq_ofs + k) * scaled_dequant_x; + const auto y_mul = + Load(d, dequant_matrices + dq_ofs + size + k) * scaled_dequant_y; + const auto b_mul = + Load(d, dequant_matrices + dq_ofs + 2 * size + k) * scaled_dequant_b; + + Vec quantized_x_int; + Vec quantized_y_int; + Vec quantized_b_int; + if (ac_type == ACType::k16) { + Rebind di16; + quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); + quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); + quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); + } else { + quantized_x_int = Load(di, qblock[0].ptr32 + k); + quantized_y_int = Load(di, qblock[1].ptr32 + k); + quantized_b_int = Load(di, qblock[2].ptr32 + k); + } + + const auto dequant_x_cc = + AdjustQuantBias(di, 0, quantized_x_int, biases) * x_mul; + const auto dequant_y = + AdjustQuantBias(di, 1, quantized_y_int, biases) * y_mul; + const auto dequant_b_cc = + AdjustQuantBias(di, 2, quantized_b_int, biases) * b_mul; + + const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); + const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); + Store(dequant_x, d, block + k); + Store(dequant_y, d, block + size + k); + Store(dequant_b, d, block + 2 * size + k); +} + +template +void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, + float x_dm_multiplier, float b_dm_multiplier, Vec x_cc_mul, + Vec b_cc_mul, size_t kind, size_t size, + const Quantizer& quantizer, + const float* JXL_RESTRICT dequant_matrices, + size_t covered_blocks, const size_t* sbx, + const float* JXL_RESTRICT* JXL_RESTRICT dc_row, + size_t dc_stride, const float* JXL_RESTRICT biases, + ACPtr qblock[3], float* JXL_RESTRICT block) { + PROFILER_FUNC; + + const auto scaled_dequant_s = inv_global_scale / quant; + + const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); + const auto scaled_dequant_y = Set(d, scaled_dequant_s); + const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); + + const size_t dq_ofs = quantizer.DequantMatrixOffset(kind, 0); + + for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { + DequantLane(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, + dequant_matrices, dq_ofs, size, k, x_cc_mul, b_cc_mul, + biases, qblock, block); + } + for (size_t c = 0; c < 3; c++) { + LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, + block + c * size); + } +} + +Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* JXL_RESTRICT dec_state, + size_t thread, size_t group_idx, ImageBundle* decoded, + DrawMode draw) { + // TODO(veluca): investigate cache usage in this function. + PROFILER_FUNC; + constexpr size_t kGroupDataXBorder = PassesDecoderState::kGroupDataXBorder; + constexpr size_t kGroupDataYBorder = PassesDecoderState::kGroupDataYBorder; + + const Rect block_rect = dec_state->shared->BlockGroupRect(group_idx); + const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; + + const size_t xsize_blocks = block_rect.xsize(); + const size_t ysize_blocks = block_rect.ysize(); + + const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); + + const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); + const float* JXL_RESTRICT dequant_matrices = + dec_state->shared->quantizer.DequantMatrix(0, 0); + + const YCbCrChromaSubsampling& cs = + dec_state->shared->frame_header.chroma_subsampling; + + const size_t idct_stride = dec_state->EagerFinalizeImageRect() + ? dec_state->group_data[thread].PixelsPerRow() + : dec_state->decoded.PixelsPerRow(); + + HWY_ALIGN int32_t scaled_qtable[64 * 3]; + + ACType ac_type = dec_state->coefficients->Type(); + auto dequant_block = (ac_type == ACType::k16 ? DequantBlock + : DequantBlock); + // Whether or not coefficients should be stored for future usage, and/or read + // from past usage. + bool accumulate = !dec_state->coefficients->IsEmpty(); + // Offset of the current block in the group. + size_t offset = 0; + + std::array jpeg_c_map; + std::array dcoff = {}; + + // TODO(veluca): all of this should be done only once per image. + if (decoded->IsJPEG()) { + if (!dec_state->shared->cmap.IsJPEGCompatible()) { + return JXL_FAILURE("The CfL map is not JPEG-compatible"); + } + jpeg_c_map = JpegOrder(dec_state->shared->frame_header.color_transform, + decoded->jpeg_data->components.size() == 1); + const std::vector& qe = + dec_state->shared->matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + for (size_t c = 0; c < 3; c++) { + if (dec_state->shared->frame_header.color_transform == + ColorTransform::kNone) { + dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c]; + } + for (size_t i = 0; i < 64; i++) { + // Transpose the matrix, as it will be used on the transposed block. + int n = qe[0].qraw.qtable->at(64 + i); + int d = qe[0].qraw.qtable->at(64 * c + i); + if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { + return JXL_FAILURE("Invalid JPEG quantization table"); + } + scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = + (1 << kCFLFixedPointPrecision) * n / d; + } + } + } + + size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; + size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; + Rect r[3]; + for (size_t i = 0; i < 3; i++) { + r[i] = + Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], + block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); + } + + for (size_t by = 0; by < ysize_blocks; ++by) { + if (draw == kOnlyImageFeatures) break; + get_block->StartRow(by); + size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; + + const int32_t* JXL_RESTRICT row_quant = + block_rect.ConstRow(dec_state->shared->raw_quant_field, by); + + const float* JXL_RESTRICT dc_rows[3] = { + r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), + r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), + r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), + }; + + const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + + const int8_t* JXL_RESTRICT row_cmap[3] = { + dec_state->shared->cmap.ytox_map.ConstRow(ty), + nullptr, + dec_state->shared->cmap.ytob_map.ConstRow(ty), + }; + + float* JXL_RESTRICT idct_row[3]; + int16_t* JXL_RESTRICT jpeg_row[3]; + for (size_t c = 0; c < 3; c++) { + if (dec_state->EagerFinalizeImageRect()) { + idct_row[c] = dec_state->group_data[thread].PlaneRow( + c, sby[c] * kBlockDim + kGroupDataYBorder) + + kGroupDataXBorder; + } else { + idct_row[c] = + dec_state->decoded.PlaneRow(c, (r[c].y0() + sby[c]) * kBlockDim) + + r[c].x0() * kBlockDim; + } + if (decoded->IsJPEG()) { + auto& component = decoded->jpeg_data->components[jpeg_c_map[c]]; + jpeg_row[c] = + component.coeffs.data() + + (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * + kDCTBlockSize; + } + } + + size_t bx = 0; + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; + auto x_cc_mul = + Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx])); + auto b_cc_mul = + Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx])); + // Increment bx by llf_x because those iterations would otherwise + // immediately continue (!IsFirstBlock). Reduces mispredictions. + for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { + size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; + AcStrategy acs = acs_row[bx]; + const size_t llf_x = acs.covered_blocks_x(); + + // Can only happen in the second or lower rows of a varblock. + if (JXL_UNLIKELY(!acs.IsFirstBlock())) { + bx += llf_x; + continue; + } + PROFILER_ZONE("DecodeGroupImpl inner"); + const size_t log2_covered_blocks = acs.log2_covered_blocks(); + + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + + ACPtr qblock[3]; + if (accumulate) { + for (size_t c = 0; c < 3; c++) { + qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); + } + } else { + // No point in reading from bitstream without accumulating and not + // drawing. + JXL_ASSERT(draw == kDraw); + if (ac_type == ACType::k16) { + memset(group_dec_cache->dec_group_qblock16, 0, + size * 3 * sizeof(int16_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; + } + } else { + memset(group_dec_cache->dec_group_qblock, 0, + size * 3 * sizeof(int32_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; + } + } + } + JXL_RETURN_IF_ERROR(get_block->LoadBlock( + bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); + offset += size; + if (draw == kDontDraw) { + bx += llf_x; + continue; + } + + if (JXL_UNLIKELY(decoded->IsJPEG())) { + if (acs.Strategy() != AcStrategy::Type::DCT) { + return JXL_FAILURE( + "Can only decode to JPEG if only DCT-8 is used."); + } + + HWY_ALIGN int32_t transposed_dct_y[64]; + for (size_t c : {1, 0, 2}) { + if (decoded->jpeg_data->components.size() == 1 && c != 1) { + continue; + } + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + int16_t* JXL_RESTRICT jpeg_pos = + jpeg_row[c] + sbx[c] * kDCTBlockSize; + // JPEG XL is transposed, JPEG is not. + auto transposed_dct = qblock[c].ptr32; + Transpose8x8InPlace(transposed_dct); + // No CfL - no need to store the y block converted to integers. + if (!cs.Is444() || + (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else if (c == 1) { + // Y channel: save for restoring X/B, but nothing else to do. + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + Store(ini, di, transposed_dct_y + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else { + // transposed_dct_y contains the y channel block, transposed. + const auto scale = Set( + di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx])); + const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); + for (int i = 0; i < 64; i += Lanes(d)) { + auto in = Load(di, transposed_dct + i); + auto in_y = Load(di, transposed_dct_y + i); + auto qt = Load(di, scaled_qtable + c * size + i); + auto coeff_scale = + ShiftRight(qt * scale + round); + auto cfl_factor = ShiftRight( + in_y * coeff_scale + round); + StoreU(DemoteTo(di16, in + cfl_factor), di16, jpeg_pos + i); + } + } + jpeg_pos[0] = + Clamp1(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); + } + } else { + HWY_ALIGN float* const block = group_dec_cache->dec_group_block; + // Dequantize and add predictions. + dequant_block( + acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, + dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), + size, dec_state->shared->quantizer, dequant_matrices, + acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, + dc_stride, + dec_state->output_encoding_info.opsin_params.quant_biases, qblock, + block); + + for (size_t c : {1, 0, 2}) { + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + // IDCT + float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; + TransformToPixels(acs.Strategy(), block + c * size, idct_pos, + idct_stride, group_dec_cache->scratch_space); + } + } + bx += llf_x; + } + } + } + if (draw == kDontDraw) { + return true; + } + // No ApplyImageFeatures in JPEG mode, or if using chroma subsampling. It will + // be done after decoding the whole image (this allows it to work on the + // chroma channels too). + if (dec_state->EagerFinalizeImageRect() && !decoded->IsJPEG()) { + // Copy the group borders to the border storage. + size_t xsize_groups = dec_state->shared->frame_dim.xsize_groups; + size_t xsize_padded = dec_state->shared->frame_dim.xsize_padded; + size_t ysize_padded = dec_state->shared->frame_dim.ysize_padded; + size_t gx = group_idx % xsize_groups; + size_t gy = group_idx / xsize_groups; + Image3F* group_data = &dec_state->group_data[thread]; + size_t x0 = block_rect.x0() * kBlockDim; + size_t x1 = (block_rect.x0() + block_rect.xsize()) * kBlockDim; + size_t y0 = block_rect.y0() * kBlockDim; + size_t y1 = (block_rect.y0() + block_rect.ysize()) * kBlockDim; + size_t padding = dec_state->FinalizeRectPadding(); + size_t borderx = dec_state->group_border_assigner.PaddingX(padding); + size_t bordery = padding; + size_t borderx_write = padding + borderx; + size_t bordery_write = padding + bordery; + CopyImageTo( + Rect(kGroupDataXBorder, kGroupDataYBorder, x1 - x0, bordery_write), + *group_data, Rect(x0, (gy * 2) * bordery_write, x1 - x0, bordery_write), + &dec_state->borders_horizontal); + CopyImageTo( + Rect(kGroupDataXBorder, kGroupDataYBorder + y1 - y0 - bordery_write, + x1 - x0, bordery_write), + *group_data, + Rect(x0, (gy * 2 + 1) * bordery_write, x1 - x0, bordery_write), + &dec_state->borders_horizontal); + CopyImageTo( + Rect(kGroupDataXBorder, kGroupDataYBorder, borderx_write, y1 - y0), + *group_data, Rect((gx * 2) * borderx_write, y0, borderx_write, y1 - y0), + &dec_state->borders_vertical); + CopyImageTo(Rect(kGroupDataXBorder + x1 - x0 - borderx_write, + kGroupDataYBorder, borderx_write, y1 - y0), + *group_data, + Rect((gx * 2 + 1) * borderx_write, y0, borderx_write, y1 - y0), + &dec_state->borders_vertical); + Rect fir_rects[GroupBorderAssigner::kMaxToFinalize]; + size_t num_fir_rects = 0; + dec_state->group_border_assigner.GroupDone( + group_idx, dec_state->FinalizeRectPadding(), fir_rects, &num_fir_rects); + for (size_t i = 0; i < num_fir_rects; i++) { + const Rect& r = fir_rects[i]; + // Limits of the area to copy from, in image coordinates. + JXL_DASSERT(r.x0() == 0 || r.x0() >= borderx); + size_t x0src = r.x0() == 0 ? r.x0() : r.x0() - borderx; + size_t x1src = r.x0() + r.xsize() + + (r.x0() + r.xsize() == xsize_padded ? 0 : borderx); + JXL_DASSERT(r.y0() == 0 || r.y0() >= bordery); + size_t y0src = r.y0() == 0 ? r.y0() : r.y0() - bordery; + size_t y1src = r.y0() + r.ysize() + + (r.y0() + r.ysize() == ysize_padded ? 0 : bordery); + // Copy other groups' borders from the border storage. + if (y0src < y0) { + CopyImageTo(Rect(x0src, (gy * 2 - 1) * bordery_write, x1src - x0src, + bordery_write), + dec_state->borders_horizontal, + Rect(kGroupDataXBorder + x0src - x0, + kGroupDataYBorder - bordery_write, x1src - x0src, + bordery_write), + group_data); + } + if (y1src > y1) { + CopyImageTo( + Rect(x0src, (gy * 2 + 2) * bordery_write, x1src - x0src, + bordery_write), + dec_state->borders_horizontal, + Rect(kGroupDataXBorder + x0src - x0, kGroupDataYBorder + y1 - y0, + x1src - x0src, bordery_write), + group_data); + } + if (x0src < x0) { + CopyImageTo( + Rect((gx * 2 - 1) * borderx_write, y0src, borderx_write, + y1src - y0src), + dec_state->borders_vertical, + Rect(kGroupDataXBorder - borderx_write, + kGroupDataYBorder + y0src - y0, borderx_write, y1src - y0src), + group_data); + } + if (x1src > x1) { + CopyImageTo( + Rect((gx * 2 + 2) * borderx_write, y0src, borderx_write, + y1src - y0src), + dec_state->borders_vertical, + Rect(kGroupDataXBorder + x1 - x0, kGroupDataYBorder + y0src - y0, + borderx_write, y1src - y0src), + group_data); + } + Rect group_data_rect(kGroupDataXBorder + r.x0() - x0, + kGroupDataYBorder + r.y0() - y0, r.xsize(), + r.ysize()); + JXL_RETURN_IF_ERROR(FinalizeImageRect(group_data, group_data_rect, {}, + dec_state, thread, decoded, r)); + } + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { +// Decode quantized AC coefficients of DCT blocks. +// LLF components in the output block will not be modified. +template +Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, + int32_t* JXL_RESTRICT row_nzeros, + const int32_t* JXL_RESTRICT row_nzeros_top, + size_t nzeros_stride, size_t c, size_t bx, size_t by, + size_t lbx, AcStrategy acs, + const coeff_order_t* JXL_RESTRICT coeff_order, + BitReader* JXL_RESTRICT br, + ANSSymbolReader* JXL_RESTRICT decoder, + const std::vector& context_map, + const uint8_t* qdc_row, const int32_t* qf_row, + const BlockCtxMap& block_ctx_map, ACPtr block, + size_t shift = 0) { + PROFILER_FUNC; + // Equal to number of LLF coefficients. + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); + + size_t ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &coeff_order[CoeffOrderOffset(ord, c)]; + + size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; + + size_t nzeros = decoder->ReadHybridUint(nzero_ctx, br, context_map); + if (nzeros + covered_blocks > size) { + return JXL_FAILURE("Invalid AC: nzeros too large"); + } + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + row_nzeros[bx + x + y * nzeros_stride] = + (nzeros + covered_blocks - 1) >> log2_covered_blocks; + } + } + + const size_t histo_offset = + ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + + // Skip LLF + { + PROFILER_ZONE("AcDecSkipLLF, reader"); + size_t prev = (nzeros > size / 16 ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + const size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + const size_t u_coeff = decoder->ReadHybridUint(ctx, br, context_map); + // Hand-rolled version of UnpackSigned, shifting before the conversion to + // signed integer to avoid undefined behavior of shifting negative + // numbers. + const size_t magnitude = u_coeff >> 1; + const size_t neg_sign = (~u_coeff) & 1; + const intptr_t coeff = + static_cast((magnitude ^ (neg_sign - 1)) << shift); + if (ac_type == ACType::k16) { + block.ptr16[order[k]] += coeff; + } else { + block.ptr32[order[k]] += coeff; + } + prev = static_cast(u_coeff != 0); + nzeros -= prev; + } + if (JXL_UNLIKELY(nzeros != 0)) { + return JXL_FAILURE( + "Invalid AC: nzeros not 0. Block (%zu, %zu), channel %zu", bx, by, c); + } + } + return true; +} + +// Structs used by DecodeGroupImpl to get a quantized block. +// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row +// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient +// image provided by the encoder. + +struct GetBlockFromBitstream : public GetBlock { + void StartRow(size_t by) override { + qf_row = rect.ConstRow(*qf, by); + for (size_t c = 0; c < 3; c++) { + size_t sby = by >> vshift[c]; + quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); + for (size_t i = 0; i < num_passes; i++) { + row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); + row_nzeros_top[i][c] = + sby == 0 + ? nullptr + : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); + } + } + } + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + auto decode_ac_varblock = ac_type == ACType::k16 + ? DecodeACVarBlock + : DecodeACVarBlock; + for (size_t c : {1, 0, 2}) { + size_t sbx = bx >> hshift[c]; + size_t sby = by >> vshift[c]; + if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { + continue; + } + + for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { + JXL_RETURN_IF_ERROR(decode_ac_varblock( + ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], + row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, + &coeff_orders[pass * coeff_order_size], readers[pass], + &decoders[pass], context_map[pass], quant_dc_row, qf_row, + *block_ctx_map, block[c], shift_for_pass[pass])); + } + } + return true; + } + + Status Init(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, + size_t group_idx, size_t histo_selector_bits, const Rect& rect, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* dec_state, size_t first_pass) { + for (size_t i = 0; i < 3; i++) { + hshift[i] = dec_state->shared->frame_header.chroma_subsampling.HShift(i); + vshift[i] = dec_state->shared->frame_header.chroma_subsampling.VShift(i); + } + this->coeff_order_size = dec_state->shared->coeff_order_size; + this->coeff_orders = + dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; + this->context_map = dec_state->context_map.data() + first_pass; + this->readers = readers; + this->num_passes = num_passes; + this->shift_for_pass = + dec_state->shared->frame_header.passes.shift + first_pass; + this->group_dec_cache = group_dec_cache; + this->rect = rect; + block_ctx_map = &dec_state->shared->block_ctx_map; + qf = &dec_state->shared->raw_quant_field; + quant_dc = &dec_state->shared->quant_dc; + + for (size_t pass = 0; pass < num_passes; pass++) { + // Select which histogram set to use among those of the current pass. + size_t cur_histogram = 0; + if (histo_selector_bits != 0) { + cur_histogram = readers[pass]->ReadBits(histo_selector_bits); + } + if (cur_histogram >= dec_state->shared->num_histograms) { + return JXL_FAILURE("Invalid histogram selector"); + } + ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); + + decoders[pass] = + ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); + } + nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); + for (size_t i = 0; i < num_passes; i++) { + JXL_ASSERT( + nzeros_stride == + static_cast(group_dec_cache->num_nzeroes[i].PixelsPerRow())); + } + return true; + } + + const uint32_t* shift_for_pass = nullptr; // not owned + const coeff_order_t* JXL_RESTRICT coeff_orders; + size_t coeff_order_size; + const std::vector* JXL_RESTRICT context_map; + ANSSymbolReader decoders[kMaxNumPasses]; + BitReader* JXL_RESTRICT* JXL_RESTRICT readers; + size_t num_passes; + size_t ctx_offset[kMaxNumPasses]; + size_t nzeros_stride; + int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; + const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; + GroupDecCache* JXL_RESTRICT group_dec_cache; + const BlockCtxMap* block_ctx_map; + const ImageI* qf; + const ImageB* quant_dc; + const int32_t* qf_row; + const uint8_t* quant_dc_row; + Rect rect; + size_t hshift[3], vshift[3]; +}; + +struct GetBlockFromEncoder : public GetBlock { + void StartRow(size_t by) override {} + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + JXL_DASSERT(ac_type == ACType::k32); + for (size_t c = 0; c < 3; c++) { + // for each pass + for (size_t i = 0; i < quantized_ac->size(); i++) { + for (size_t k = 0; k < size; k++) { + // TODO(veluca): SIMD. + block[c].ptr32[k] += rows[i][c][offset + k]; + } + } + } + offset += size; + return true; + } + + GetBlockFromEncoder(const std::vector>& ac, + size_t group_idx) + : quantized_ac(&ac) { + // TODO(veluca): not supported with chroma subsampling. + for (size_t i = 0; i < quantized_ac->size(); i++) { + JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + } + + const std::vector>* JXL_RESTRICT quantized_ac; + size_t offset = 0; + const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; +}; + +HWY_EXPORT(DecodeGroupImpl); + +} // namespace + +Status DecodeGroup(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only) { + PROFILER_FUNC; + + DrawMode draw = (num_passes + first_pass == + dec_state->shared->frame_header.passes.num_passes) || + force_draw + ? kDraw + : kDontDraw; + + if (draw == kDraw && num_passes == 0 && first_pass == 0) { + // We reuse filter_input_storage here as it is not currently in use. + const Rect src_rect = dec_state->shared->BlockGroupRect(group_idx); + const Rect copy_rect(kBlockDim, 2, src_rect.xsize(), src_rect.ysize()); + CopyImageToWithPadding(src_rect, *dec_state->shared->dc, 2, copy_rect, + &dec_state->filter_input_storage[thread]); + EnsurePaddingInPlace(&dec_state->filter_input_storage[thread], copy_rect, + src_rect, dec_state->shared->frame_dim.xsize_blocks, + dec_state->shared->frame_dim.ysize_blocks, 2, 2); + Image3F* upsampling_dst = &dec_state->decoded; + Rect dst_rect(src_rect.x0() * 8, src_rect.y0() * 8, src_rect.xsize() * 8, + src_rect.ysize() * 8); + if (dec_state->EagerFinalizeImageRect()) { + upsampling_dst = &dec_state->group_data[thread]; + dst_rect = Rect(PassesDecoderState::kGroupDataXBorder, + PassesDecoderState::kGroupDataYBorder, dst_rect.xsize(), + dst_rect.ysize()); + } + dec_state->dc_upsampler.UpsampleRect( + dec_state->filter_input_storage[thread], copy_rect, upsampling_dst, + dst_rect, + static_cast(src_rect.y0()) - + static_cast(copy_rect.y0()), + dec_state->shared->frame_dim.ysize_blocks); + draw = kOnlyImageFeatures; + } + + size_t histo_selector_bits = 0; + if (dc_only) { + JXL_ASSERT(num_passes == 0); + } else { + JXL_ASSERT(dec_state->shared->num_histograms > 0); + histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); + } + + GetBlockFromBitstream get_block; + JXL_RETURN_IF_ERROR( + get_block.Init(readers, num_passes, group_idx, histo_selector_bits, + dec_state->shared->BlockGroupRect(group_idx), + group_dec_cache, dec_state, first_pass)); + + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + &get_block, group_dec_cache, dec_state, thread, group_idx, decoded, + draw)); + + for (size_t pass = 0; pass < num_passes; pass++) { + if (!get_block.decoders[pass].CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + } + return true; +} + +Status DecodeGroupForRoundtrip(const std::vector>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out) { + PROFILER_FUNC; + + GetBlockFromEncoder get_block(ac, group_idx); + group_dec_cache->InitOnce( + /*num_passes=*/0, + /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1); + + return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(&get_block, group_dec_cache, + dec_state, thread, group_idx, + decoded, kDraw); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.h b/third_party/jpeg-xl/lib/jxl/dec_group.h new file mode 100644 index 000000000000..519107883b86 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_GROUP_H_ +#define LIB_JXL_DEC_GROUP_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +Status DecodeGroup(BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only); + +Status DecodeGroupForRoundtrip(const std::vector>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_group_border.cc b/third_party/jpeg-xl/lib/jxl/dec_group_border.cc new file mode 100644 index 000000000000..6b587b58cb28 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group_border.cc @@ -0,0 +1,192 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_group_border.h" + +#include + +namespace jxl { + +void GroupBorderAssigner::Init(const FrameDimensions& frame_dim) { + frame_dim_ = frame_dim; + size_t num_corners = + (frame_dim_.xsize_groups + 1) * (frame_dim_.ysize_groups + 1); + counters_.reset(new std::atomic[num_corners]); + // Initialize counters. + for (size_t y = 0; y < frame_dim_.ysize_groups + 1; y++) { + for (size_t x = 0; x < frame_dim_.xsize_groups + 1; x++) { + // Counters at image borders don't have anything on the other side, we + // pre-fill their value to have more uniform handling afterwards. + uint8_t init_value = 0; + if (x == 0) { + init_value |= kTopLeft | kBottomLeft; + } + if (x == frame_dim_.xsize_groups) { + init_value |= kTopRight | kBottomRight; + } + if (y == 0) { + init_value |= kTopLeft | kTopRight; + } + if (y == frame_dim_.ysize_groups) { + init_value |= kBottomLeft | kBottomRight; + } + counters_[y * (frame_dim_.xsize_groups + 1) + x] = init_value; + } + } +} + +void GroupBorderAssigner::ClearDone(size_t group_id) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + counters_[top_left_idx].fetch_and(~kBottomRight); + counters_[top_right_idx].fetch_and(~kBottomLeft); + counters_[bottom_left_idx].fetch_and(~kTopRight); + counters_[bottom_right_idx].fetch_and(~kTopLeft); +} + +// Looking at each corner between groups, we can guarantee that the four +// involved groups will agree between each other regarding the order in which +// each of the four groups terminated. Thus, the last of the four groups +// gets the responsibility of handling the corner. For borders, every border +// is assigned to its top corner (for vertical borders) or to its left corner +// (for horizontal borders): the order as seen on those corners will decide who +// handles that border. + +void GroupBorderAssigner::GroupDone(size_t group_id, size_t padding, + Rect* rects_to_finalize, + size_t* num_to_finalize) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + Rect block_rect(x * frame_dim_.group_dim / kBlockDim, + y * frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, frame_dim_.xsize_blocks, + frame_dim_.ysize_blocks); + + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + + auto fetch_status = [this](size_t idx, uint8_t bit) { + // Note that the acq-rel semantics of this fetch are actually needed to + // ensure that the pixel data of the group is already written to memory. + size_t status = counters_[idx].fetch_or(bit); + JXL_DASSERT((bit & status) == 0); + return bit | status; + }; + + size_t top_left_status = fetch_status(top_left_idx, kBottomRight); + size_t top_right_status = fetch_status(top_right_idx, kBottomLeft); + size_t bottom_right_status = fetch_status(bottom_right_idx, kTopLeft); + size_t bottom_left_status = fetch_status(bottom_left_idx, kTopRight); + + size_t padx = PaddingX(padding); + size_t pady = padding; + + size_t x1 = block_rect.x0() + block_rect.xsize(); + size_t y1 = block_rect.y0() + block_rect.ysize(); + + bool is_last_group_x = frame_dim_.xsize_groups == x + 1; + bool is_last_group_y = frame_dim_.ysize_groups == y + 1; + + // Start of border of neighbouring group, end of border of this group, start + // of border of this group (on the other side), end of border of next group. + size_t xpos[4] = { + block_rect.x0() == 0 ? 0 : block_rect.x0() * kBlockDim - padx, + block_rect.x0() == 0 ? 0 : block_rect.x0() * kBlockDim + padx, + is_last_group_x ? frame_dim_.xsize_padded : x1 * kBlockDim - padx, + is_last_group_x ? frame_dim_.xsize_padded : x1 * kBlockDim + padx}; + size_t ypos[4] = { + block_rect.y0() == 0 ? 0 : block_rect.y0() * kBlockDim - pady, + block_rect.y0() == 0 ? 0 : block_rect.y0() * kBlockDim + pady, + is_last_group_y ? frame_dim_.ysize_padded : y1 * kBlockDim - pady, + is_last_group_y ? frame_dim_.ysize_padded : y1 * kBlockDim + pady}; + + *num_to_finalize = 0; + auto append_rect = [&](size_t x0, size_t x1, size_t y0, size_t y1) { + Rect rect(xpos[x0], ypos[y0], xpos[x1] - xpos[x0], ypos[y1] - ypos[y0]); + if (rect.xsize() == 0 || rect.ysize() == 0) return; + JXL_DASSERT(*num_to_finalize < kMaxToFinalize); + rects_to_finalize[(*num_to_finalize)++] = rect; + }; + + // Because of how group borders are assigned, it is impossible that we need to + // process the left and right side of some area but not the center area. Thus, + // we compute the first/last part to process in every horizontal strip and + // merge them together. We first collect a mask of what parts should be + // processed. + // We do this horizontally rather than vertically because horizontal borders + // are larger. + bool available_parts_mask[3][3] = {}; // [x][y] + // Center + available_parts_mask[1][1] = true; + // Corners + if (top_left_status == 0xF) available_parts_mask[0][0] = true; + if (top_right_status == 0xF) available_parts_mask[2][0] = true; + if (bottom_right_status == 0xF) available_parts_mask[2][2] = true; + if (bottom_left_status == 0xF) available_parts_mask[0][2] = true; + // Other borders + if (top_left_status & kTopRight) available_parts_mask[1][0] = true; + if (top_left_status & kBottomLeft) available_parts_mask[0][1] = true; + if (top_right_status & kBottomRight) available_parts_mask[2][1] = true; + if (bottom_left_status & kBottomRight) available_parts_mask[1][2] = true; + + // Collect horizontal ranges. + constexpr size_t kNoSegment = 3; + std::pair horizontal_segments[3] = {{kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}}; + for (size_t y = 0; y < 3; y++) { + for (size_t x = 0; x < 3; x++) { + if (!available_parts_mask[x][y]) continue; + JXL_DASSERT(horizontal_segments[y].second == kNoSegment || + horizontal_segments[y].second == x); + JXL_DASSERT((horizontal_segments[y].first == kNoSegment) == + (horizontal_segments[y].second == kNoSegment)); + if (horizontal_segments[y].first == kNoSegment) { + horizontal_segments[y].first = x; + } + horizontal_segments[y].second = x + 1; + } + } + if (horizontal_segments[0] == horizontal_segments[1] && + horizontal_segments[0] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 3); + } else if (horizontal_segments[0] == horizontal_segments[1]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } else if (horizontal_segments[1] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 3); + } else { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_group_border.h b/third_party/jpeg-xl/lib/jxl/dec_group_border.h new file mode 100644 index 000000000000..946b9f6adb30 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group_border.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_GROUP_BORDER_H_ +#define LIB_JXL_DEC_GROUP_BORDER_H_ + +#include + +#include + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +class GroupBorderAssigner { + public: + // Prepare the GroupBorderAssigner to handle a given frame. + void Init(const FrameDimensions& frame_dim); + // Marks a group as done, and returns the (at most 3) rects to run + // FinalizeImageRect on. `block_rect` must be the rect corresponding + // to the given `group_id`, measured in blocks. + void GroupDone(size_t group_id, size_t padding, Rect* rects_to_finalize, + size_t* num_to_finalize); + // Marks a group as not-done, for running re-paints. + void ClearDone(size_t group_id); + + static constexpr size_t kMaxToFinalize = 3; + + // Vectors on ARM NEON are never wider than 4 floats, so rounding to multiples + // of 4 is enough. +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + static constexpr size_t kPaddingXRound = 4; +#else + static constexpr size_t kPaddingXRound = kBlockDim; +#endif + + // Returns the necessary amount of padding for the X axis. + size_t PaddingX(size_t padding) { return RoundUpTo(padding, kPaddingXRound); } + + private: + FrameDimensions frame_dim_; + std::unique_ptr[]> counters_; + + // Constants to identify group positions relative to the corners. + static constexpr uint8_t kTopLeft = 0x01; + static constexpr uint8_t kTopRight = 0x02; + static constexpr uint8_t kBottomRight = 0x04; + static constexpr uint8_t kBottomLeft = 0x08; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_BORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_huffman.cc b/third_party/jpeg-xl/lib/jxl/dec_huffman.cc new file mode 100644 index 000000000000..f934c3ec0aba --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_huffman.cc @@ -0,0 +1,264 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_huffman.h" + +#include /* for memset */ + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static const int kCodeLengthCodes = 18; +static const uint8_t kCodeLengthCodeOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15, +}; +static const uint8_t kDefaultCodeLength = 8; +static const uint8_t kCodeLengthRepeatCode = 16; + +int ReadHuffmanCodeLengths(const uint8_t* code_length_code_lengths, + int num_symbols, uint8_t* code_lengths, + BitReader* br) { + int symbol = 0; + uint8_t prev_code_len = kDefaultCodeLength; + int repeat = 0; + uint8_t repeat_code_len = 0; + int space = 32768; + HuffmanCode table[32]; + + uint16_t counts[16] = {0}; + for (int i = 0; i < kCodeLengthCodes; ++i) { + ++counts[code_length_code_lengths[i]]; + } + if (!BuildHuffmanTable(table, 5, code_length_code_lengths, kCodeLengthCodes, + &counts[0])) { + return 0; + } + + while (symbol < num_symbols && space > 0) { + const HuffmanCode* p = table; + uint8_t code_len; + br->Refill(); + p += br->PeekFixedBits<5>(); + br->Consume(p->bits); + code_len = (uint8_t)p->value; + if (code_len < kCodeLengthRepeatCode) { + repeat = 0; + code_lengths[symbol++] = code_len; + if (code_len != 0) { + prev_code_len = code_len; + space -= 32768u >> code_len; + } + } else { + const int extra_bits = code_len - 14; + int old_repeat; + int repeat_delta; + uint8_t new_len = 0; + if (code_len == kCodeLengthRepeatCode) { + new_len = prev_code_len; + } + if (repeat_code_len != new_len) { + repeat = 0; + repeat_code_len = new_len; + } + old_repeat = repeat; + if (repeat > 0) { + repeat -= 2; + repeat <<= extra_bits; + } + repeat += (int)br->ReadBits(extra_bits) + 3; + repeat_delta = repeat - old_repeat; + if (symbol + repeat_delta > num_symbols) { + return 0; + } + memset(&code_lengths[symbol], repeat_code_len, (size_t)repeat_delta); + symbol += repeat_delta; + if (repeat_code_len != 0) { + space -= repeat_delta << (15 - repeat_code_len); + } + } + } + if (space != 0) { + return 0; + } + memset(&code_lengths[symbol], 0, (size_t)(num_symbols - symbol)); + return true; +} + +static JXL_INLINE bool ReadSimpleCode(size_t alphabet_size, BitReader* br, + HuffmanCode* table) { + size_t max_bits = + (alphabet_size > 1u) ? FloorLog2Nonzero(alphabet_size - 1u) + 1 : 0; + + size_t num_symbols = br->ReadFixedBits<2>() + 1; + + uint16_t symbols[4] = {0}; + for (size_t i = 0; i < num_symbols; ++i) { + uint16_t symbol = br->ReadBits(max_bits); + if (symbol >= alphabet_size) { + return false; + } + symbols[i] = symbol; + } + + for (size_t i = 0; i < num_symbols - 1; ++i) { + for (size_t j = i + 1; j < num_symbols; ++j) { + if (symbols[i] == symbols[j]) return false; + } + } + + // 4 symbols have to option to encode. + if (num_symbols == 4) num_symbols += br->ReadFixedBits<1>(); + + const auto swap_symbols = [&symbols](size_t i, size_t j) { + uint16_t t = symbols[j]; + symbols[j] = symbols[i]; + symbols[i] = t; + }; + + size_t table_size = 1; + switch (num_symbols) { + case 1: + table[0] = {0, symbols[0]}; + break; + case 2: + if (symbols[0] > symbols[1]) swap_symbols(0, 1); + table[0] = {1, symbols[0]}; + table[1] = {1, symbols[1]}; + table_size = 2; + break; + case 3: + if (symbols[1] > symbols[2]) swap_symbols(1, 2); + table[0] = {1, symbols[0]}; + table[2] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[3] = {2, symbols[2]}; + table_size = 4; + break; + case 4: { + for (size_t i = 0; i < 3; ++i) { + for (size_t j = i + 1; j < 4; ++j) { + if (symbols[i] > symbols[j]) swap_symbols(i, j); + } + } + table[0] = {2, symbols[0]}; + table[2] = {2, symbols[1]}; + table[1] = {2, symbols[2]}; + table[3] = {2, symbols[3]}; + table_size = 4; + break; + } + case 5: { + if (symbols[2] > symbols[3]) swap_symbols(2, 3); + table[0] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[2] = {1, symbols[0]}; + table[3] = {3, symbols[2]}; + table[4] = {1, symbols[0]}; + table[5] = {2, symbols[1]}; + table[6] = {1, symbols[0]}; + table[7] = {3, symbols[3]}; + table_size = 8; + break; + } + default: { + // Unreachable. + return false; + } + } + + const uint32_t goal_size = 1u << kHuffmanTableBits; + while (table_size != goal_size) { + memcpy(&table[table_size], &table[0], + (size_t)table_size * sizeof(table[0])); + table_size <<= 1; + } + + return true; +} + +bool HuffmanDecodingData::ReadFromBitStream(size_t alphabet_size, + BitReader* br) { + if (alphabet_size > (1 << PREFIX_MAX_BITS)) return false; + + /* simple_code_or_skip is used as follows: + 1 for simple code; + 0 for no skipping, 2 skips 2 code lengths, 3 skips 3 code lengths */ + uint32_t simple_code_or_skip = br->ReadFixedBits<2>(); + if (simple_code_or_skip == 1u) { + table_.resize(1u << kHuffmanTableBits); + return ReadSimpleCode(alphabet_size, br, table_.data()); + } + + std::vector code_lengths(alphabet_size, 0); + uint8_t code_length_code_lengths[kCodeLengthCodes] = {0}; + int space = 32; + int num_codes = 0; + /* Static Huffman code for the code length code lengths */ + static const HuffmanCode huff[16] = { + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 1}, + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 5}, + }; + for (size_t i = simple_code_or_skip; i < kCodeLengthCodes && space > 0; ++i) { + const int code_len_idx = kCodeLengthCodeOrder[i]; + const HuffmanCode* p = huff; + uint8_t v; + br->Refill(); + p += br->PeekFixedBits<4>(); + br->Consume(p->bits); + v = (uint8_t)p->value; + code_length_code_lengths[code_len_idx] = v; + if (v != 0) { + space -= (32u >> v); + ++num_codes; + } + } + bool ok = (num_codes == 1 || space == 0) && + ReadHuffmanCodeLengths(code_length_code_lengths, alphabet_size, + &code_lengths[0], br); + + if (!ok) return false; + uint16_t counts[16] = {0}; + for (size_t i = 0; i < alphabet_size; ++i) { + ++counts[code_lengths[i]]; + } + table_.resize(alphabet_size + 376); + uint32_t table_size = + BuildHuffmanTable(table_.data(), kHuffmanTableBits, &code_lengths[0], + alphabet_size, &counts[0]); + table_.resize(table_size); + return (table_size > 0); +} + +// Decodes the next Huffman coded symbol from the bit-stream. +uint16_t HuffmanDecodingData::ReadSymbol(BitReader* br) const { + size_t n_bits; + const HuffmanCode* table = table_.data(); + table += br->PeekBits(kHuffmanTableBits); + n_bits = table->bits; + if (n_bits > kHuffmanTableBits) { + br->Consume(kHuffmanTableBits); + n_bits -= kHuffmanTableBits; + table += table->value; + table += br->PeekBits(n_bits); + } + br->Consume(table->bits); + return table->value; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_huffman.h b/third_party/jpeg-xl/lib/jxl/dec_huffman.h new file mode 100644 index 000000000000..333c39172a7c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_huffman.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_HUFFMAN_H_ +#define LIB_JXL_DEC_HUFFMAN_H_ + +#include +#include + +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static constexpr size_t kHuffmanTableBits = 8u; + +struct HuffmanDecodingData { + // Decodes the Huffman code lengths from the bit-stream and fills in the + // pre-allocated table with the corresponding 2-level Huffman decoding table. + // Returns false if the Huffman code lengths can not de decoded. + bool ReadFromBitStream(size_t alphabet_size, BitReader* br); + + uint16_t ReadSymbol(BitReader* br) const; + + std::vector table_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_HUFFMAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_modular.cc b/third_party/jpeg-xl/lib/jxl/dec_modular.cc new file mode 100644 index 000000000000..73a7686dd236 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_modular.cc @@ -0,0 +1,590 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_modular.h" + +#include + +#include + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc" +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; + +void MultiplySum(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const pixel_type* const JXL_RESTRICT row_in_Y, + const float factor, float* const JXL_RESTRICT row_out) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x) + Load(di, row_in_Y + x); + const auto out = ConvertTo(df, in) * factor_v; + Store(out, df, row_out + x); + } +} + +void RgbFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, Image3F* decoded, size_t /*c*/, + size_t y) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + + float* const JXL_RESTRICT row_out_r = decoded->PlaneRow(0, y); + float* const JXL_RESTRICT row_out_g = decoded->PlaneRow(1, y); + float* const JXL_RESTRICT row_out_b = decoded->PlaneRow(2, y); + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = ConvertTo(df, in) * factor_v; + Store(out, df, row_out_r + x); + Store(out, df, row_out_g + x); + Store(out, df, row_out_b + x); + } +} + +// Same signature as RgbFromSingle so we can assign to the same pointer. +void SingleFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, Image3F* decoded, size_t c, + size_t y) { + const HWY_FULL(float) df; + const Rebind di; // assumes pixel_type <= float + + float* const JXL_RESTRICT row_out = decoded->PlaneRow(c, y); + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = ConvertTo(df, in) * factor_v; + Store(out, df, row_out + x); + } +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(MultiplySum); // Local function +HWY_EXPORT(RgbFromSingle); // Local function +HWY_EXPORT(SingleFromSingle); // Local function + +// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int +// back to binary32 float +void int_to_float(const pixel_type* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const size_t xsize, + const int bits, const int exp_bits) { + if (bits == 32) { + JXL_ASSERT(sizeof(pixel_type) == sizeof(float)); + JXL_ASSERT(exp_bits == 8); + memcpy(row_out, row_in, xsize * sizeof(float)); + return; + } + int exp_bias = (1 << (exp_bits - 1)) - 1; + int sign_shift = bits - 1; + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> sign_shift); + f &= (1 << sign_shift) - 1; + if (f == 0) { + row_out[x] = (signbit ? -0.f : 0.f); + continue; + } + int exp = (f >> mant_bits); + int mantissa = (f & ((1 << mant_bits) - 1)); + mantissa <<= mant_shift; + // Try to normalize only if there is space for maneuver. + if (exp == 0 && exp_bits < 8) { + // subnormal number + while ((mantissa & 0x800000) == 0) { + mantissa <<= 1; + exp--; + } + exp++; + // remove leading 1 because it is implicit now + mantissa &= 0x7fffff; + } + exp -= exp_bias; + // broke up the arbitrary float into its parts, now reassemble into + // binary32 + exp += 127; + JXL_ASSERT(exp >= 0); + f = (signbit ? 0x80000000 : 0); + f |= (exp << 23); + f |= mantissa; + memcpy(&row_out[x], &f, 4); + } +} + +Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, + const FrameHeader& frame_header, + bool allow_truncated_group) { + bool decode_color = frame_header.encoding == FrameEncoding::kModular; + const auto& metadata = frame_header.nonserialized_metadata->m; + bool is_gray = metadata.color_encoding.IsGray(); + size_t nb_chans = 3; + if (is_gray && frame_header.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + bool has_tree = reader->ReadBits(1); + if (has_tree) { + size_t tree_size_limit = + 1024 + frame_dim.xsize * frame_dim.ysize * nb_chans; + JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + } + do_color = decode_color; + if (!do_color) nb_chans = 0; + size_t nb_extra = metadata.extra_channel_info.size(); + + bool fp = metadata.bit_depth.floating_point_sample; + + // bits_per_sample is just metadata for XYB images. + if (metadata.bit_depth.bits_per_sample >= 32 && do_color && + frame_header.color_transform != ColorTransform::kXYB) { + if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { + // TODO(lode): does modular support uint32_t? maxval is signed int so + // cannot represent 32 bits. + return JXL_FAILURE("uint32_t not supported in dec_modular"); + } else if (metadata.bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + // TODO(lode): must handle metadata.floating_point_channel? + int maxval = + (fp ? 1 + : (1u << static_cast(metadata.bit_depth.bits_per_sample)) - + 1); + + Image gi(frame_dim.xsize, frame_dim.ysize, maxval, nb_chans + nb_extra); + + for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) { + const ExtraChannelInfo& eci = metadata.extra_channel_info[ec]; + gi.channel[c].resize(eci.Size(frame_dim.xsize), eci.Size(frame_dim.ysize)); + gi.channel[c].hshift = gi.channel[c].vshift = eci.dim_shift; + } + + ModularOptions options; + options.max_chan_size = frame_dim.group_dim; + Status dec_status = ModularGenericDecompress( + reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), + &options, + /*undo_transforms=*/-2, &tree, &code, &context_map, + allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) { + return JXL_FAILURE("Failed to decode global modular info"); + } + + // TODO(eustas): are we sure this can be done after partial decode? + // ensure all the channel buffers are allocated + have_something = false; + for (size_t c = 0; c < gi.channel.size(); c++) { + Channel& gic = gi.channel[c]; + if (c >= gi.nb_meta_channels && gic.w < frame_dim.group_dim && + gic.h < frame_dim.group_dim) + have_something = true; + gic.resize(); + } + full_image = std::move(gi); + return dec_status; +} + +Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader, + int minShift, int maxShift, + const ModularStreamId& stream, + bool zerofill) { + JXL_DASSERT(stream.kind == ModularStreamId::kModularDC || + stream.kind == ModularStreamId::kModularAC); + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + int maxval = full_image.maxval; + Image gi(xsize, ysize, maxval, 0); + // start at the first bigger-than-groupsize non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; + } + size_t beginc = c; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + Channel gc(r.xsize(), r.ysize()); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + gi.channel.emplace_back(std::move(gc)); + } + gi.nb_channels = gi.channel.size(); + gi.real_nb_channels = gi.nb_channels; + if (zerofill) { + int gic = 0; + for (c = beginc; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + for (size_t y = 0; y < r.ysize(); ++y) { + pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); + memset(row_out, 0, r.xsize() * sizeof(*row_out)); + } + gic++; + } + return true; + } + ModularOptions options; + if (!ModularGenericDecompress( + reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, + /*undo_transforms=*/-1, &tree, &code, &context_map)) + return JXL_FAILURE("Failed to decode modular group"); + int gic = 0; + for (c = beginc; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + for (size_t y = 0; y < r.ysize(); ++y) { + pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); + const pixel_type* const JXL_RESTRICT row_in = gi.channel[gic].Row(y); + for (size_t x = 0; x < r.xsize(); ++x) { + row_out[x] = row_in[x]; + } + } + gic++; + } + return true; +} +Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->DCGroupRect(group_id); + // TODO(eustas): investigate if we could reduce the impact of + // EvalRationalPolynomial; generally speaking, the limit is + // 2**(128/(3*magic)), where 128 comes from IEEE 754 exponent, + // 3 comes from XybToRgb that cubes the values, and "magic" is + // the sum of all other contributions. 2**18 is known to lead + // to NaN on input found by fuzzing (see commit message). + constexpr const int kRawDcLimit = 1 << 17; + Image image(r.xsize(), r.ysize(), kRawDcLimit, 3); + image.minval = -kRawDcLimit; + size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim); + reader->Refill(); + size_t extra_precision = reader->ReadFixedBits<2>(); + float mul = 1.0f / (1 << extra_precision); + ModularOptions options; + for (size_t c = 0; c < 3; c++) { + Channel& ch = image.channel[c < 2 ? c ^ 1 : c]; + ch.w >>= dec_state->shared->frame_header.chroma_subsampling.HShift(c); + ch.h >>= dec_state->shared->frame_header.chroma_subsampling.VShift(c); + ch.resize(); + } + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/0, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode modular DC group"); + } + DequantDC(r, &dec_state->shared_storage.dc_storage, + &dec_state->shared_storage.quant_dc, image, + dec_state->shared->quantizer.MulDC(), mul, + dec_state->shared->cmap.DCFactors(), + dec_state->shared->frame_header.chroma_subsampling, + dec_state->shared->block_ctx_map); + return true; +} + +Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->DCGroupRect(group_id); + size_t upper_bound = r.xsize() * r.ysize(); + reader->Refill(); + size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1; + size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim); + // YToX, YToB, ACS + QF, EPF + Image image(r.xsize(), r.ysize(), 255, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(count, 2, 0, 0); + ModularOptions options; + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/-1, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode AC metadata"); + } + ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr, + &dec_state->shared_storage.cmap.ytox_map); + ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr, + &dec_state->shared_storage.cmap.ytob_map); + size_t num = 0; + bool is444 = dec_state->shared->frame_header.chroma_subsampling.Is444(); + auto& ac_strategy = dec_state->shared_storage.ac_strategy; + size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize()); + size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize()); + uint32_t local_used_acs = 0; + for (size_t iy = 0; iy < r.ysize(); iy++) { + size_t y = r.y0() + iy; + int* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); + uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); + int* row_in_1 = image.channel[2].plane.Row(0); + int* row_in_2 = image.channel[2].plane.Row(1); + int* row_in_3 = image.channel[3].plane.Row(iy); + for (size_t ix = 0; ix < r.xsize(); ix++) { + size_t x = r.x0() + ix; + int sharpness = row_in_3[ix]; + if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) { + return JXL_FAILURE("Corrupted sharpness field"); + } + row_epf[ix] = sharpness; + if (ac_strategy.IsValid(x, y)) { + continue; + } + + if (num >= count) return JXL_FAILURE("Corrupted stream"); + + if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) { + return JXL_FAILURE("Invalid AC strategy"); + } + local_used_acs |= 1u << row_in_1[num]; + AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]); + if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) && + !is444) { + return JXL_FAILURE( + "AC strategy not compatible with chroma subsampling"); + } + // Ensure that blocks do not overflow *AC* groups. + size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_x_dct_block = x + acs.covered_blocks_x(); + size_t next_y_dct_block = y + acs.covered_blocks_y(); + if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) { + return JXL_FAILURE("Invalid AC strategy, x overflow"); + } + if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) { + return JXL_FAILURE("Invalid AC strategy, y overflow"); + } + JXL_RETURN_IF_ERROR( + ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num]))); + row_qf[ix] = + 1 + std::max(0, std::min(Quantizer::kQuantMax - 1, row_in_2[num])); + num++; + } + } + dec_state->used_acs |= local_used_acs; + if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(r, dec_state); + } + return true; +} + +Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state, + jxl::ThreadPool* pool, + ImageBundle* output) { + Image& gi = full_image; + size_t xsize = gi.w; + size_t ysize = gi.h; + + const auto& frame_header = dec_state->shared->frame_header; + const auto* metadata = frame_header.nonserialized_metadata; + + // Don't use threads if total image size is smaller than a group + if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; + + // Undo the global transforms + gi.undo_transforms(global_header.wp_header, -1, pool); + if (gi.error) return JXL_FAILURE("Undoing transforms failed"); + + auto& decoded = dec_state->decoded; + + int c = 0; + if (do_color) { + const bool rgb_from_gray = + metadata->m.color_encoding.IsGray() && + frame_header.color_transform == ColorTransform::kNone; + const bool fp = metadata->m.bit_depth.floating_point_sample; + + for (; c < 3; c++) { + float factor = 1.f / (float)full_image.maxval; + int c_in = c; + if (frame_header.color_transform == ColorTransform::kXYB) { + factor = dec_state->shared->matrices.DCQuants()[c]; + // XYB is encoded as YX(B-Y) + if (c < 2) c_in = 1 - c; + } else if (rgb_from_gray) { + c_in = 0; + } + // TODO(eustas): could we detect it on earlier stage? + if (gi.channel[c_in].w == 0 || gi.channel[c_in].h == 0) { + return JXL_FAILURE("Empty image"); + } + if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { + JXL_ASSERT(!fp); + RunOnPool( + pool, 0, ysize, jxl::ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + gi.channel[c_in].Row(y); + const pixel_type* const JXL_RESTRICT row_in_Y = + gi.channel[0].Row(y); + float* const JXL_RESTRICT row_out = decoded.PlaneRow(c, y); + HWY_DYNAMIC_DISPATCH(MultiplySum) + (xsize, row_in, row_in_Y, factor, row_out); + }, + "ModularIntToFloat"); + } else if (fp) { + int bits = metadata->m.bit_depth.bits_per_sample; + int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; + RunOnPool( + pool, 0, ysize, jxl::ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + gi.channel[c_in].Row(y); + float* const JXL_RESTRICT row_out = decoded.PlaneRow(c, y); + int_to_float(row_in, row_out, xsize, bits, exp_bits); + }, + "ModularIntToFloat_losslessfloat"); + } else { + RunOnPool( + pool, 0, ysize, jxl::ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + gi.channel[c_in].Row(y); + if (rgb_from_gray) { + HWY_DYNAMIC_DISPATCH(RgbFromSingle) + (xsize, row_in, factor, &decoded, c, y); + } else { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (xsize, row_in, factor, &decoded, c, y); + } + }, + "ModularIntToFloat"); + } + if (rgb_from_gray) { + break; + } + } + if (rgb_from_gray) { + c = 1; + } + } + for (size_t ec = 0; ec < dec_state->extra_channels.size(); ec++, c++) { + const ExtraChannelInfo& eci = output->metadata()->extra_channel_info[ec]; + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + JXL_ASSERT(fp || bits < 32); + const float mul = fp ? 0 : (1.0f / ((1u << bits) - 1)); + const size_t ec_xsize = eci.Size(xsize); // includes shift + const size_t ec_ysize = eci.Size(ysize); + for (size_t y = 0; y < ec_ysize; ++y) { + float* const JXL_RESTRICT row_out = dec_state->extra_channels[ec].Row(y); + const pixel_type* const JXL_RESTRICT row_in = gi.channel[c].Row(y); + if (fp) { + int_to_float(row_in, row_out, ec_xsize, bits, exp_bits); + } else { + for (size_t x = 0; x < ec_xsize; ++x) { + row_out[x] = row_in[x] * mul; + } + } + } + } + return true; +} + +static constexpr const float kAlmostZero = 1e-8f; + +Status ModularFrameDecoder::DecodeQuantTable( + size_t required_size_x, size_t required_size_y, BitReader* br, + QuantEncoding* encoding, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den)); + if (encoding->qraw.qtable_den < kAlmostZero) { + // qtable[] values are already checked for <= 0 so the denominator may not + // be negative. + return JXL_FAILURE("Invalid qtable_den: value too small"); + } + Image image(required_size_x, required_size_y, 255, 3); + ModularOptions options; + if (modular_frame_decoder) { + JXL_RETURN_IF_ERROR(ModularGenericDecompress( + br, image, /*header=*/nullptr, + ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim), + &options, /*undo_transforms=*/-1, &modular_frame_decoder->tree, + &modular_frame_decoder->code, &modular_frame_decoder->context_map)); + } else { + JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, + 0, &options, + /*undo_transforms=*/-1)); + } + if (!encoding->qraw.qtable) { + encoding->qraw.qtable = new std::vector(); + } + encoding->qraw.qtable->resize(required_size_x * required_size_y * 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < required_size_y; y++) { + int* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < required_size_x; x++) { + (*encoding->qraw.qtable)[c * required_size_x * required_size_y + + y * required_size_x + x] = row[x]; + if (row[x] <= 0) { + return JXL_FAILURE("Invalid raw quantization table"); + } + } + } + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_modular.h b/third_party/jpeg-xl/lib/jxl/dec_modular.h new file mode 100644 index 000000000000..0cf06449651c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_modular.h @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_MODULAR_H_ +#define LIB_JXL_DEC_MODULAR_H_ + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +struct ModularStreamId { + enum Kind { + kGlobalData, + kVarDCTDC, + kModularDC, + kACMetadata, + kQuantTable, + kModularAC + }; + Kind kind; + size_t quant_table_id; + size_t group_id; // DC or AC group id. + size_t pass_id; // Only for kModularAC. + size_t ID(const FrameDimensions& frame_dim) const { + size_t id = 0; + switch (kind) { + case kGlobalData: + id = 0; + break; + case kVarDCTDC: + id = 1 + group_id; + break; + case kModularDC: + id = 1 + frame_dim.num_dc_groups + group_id; + break; + case kACMetadata: + id = 1 + 2 * frame_dim.num_dc_groups + group_id; + break; + case kQuantTable: + id = 1 + 3 * frame_dim.num_dc_groups + quant_table_id; + break; + case kModularAC: + id = 1 + 3 * frame_dim.num_dc_groups + DequantMatrices::kNum + + frame_dim.num_groups * pass_id + group_id; + break; + }; + return id; + } + static ModularStreamId Global() { + return ModularStreamId{kGlobalData, 0, 0, 0}; + } + static ModularStreamId VarDCTDC(size_t group_id) { + return ModularStreamId{kVarDCTDC, 0, group_id, 0}; + } + static ModularStreamId ModularDC(size_t group_id) { + return ModularStreamId{kModularDC, 0, group_id, 0}; + } + static ModularStreamId ACMetadata(size_t group_id) { + return ModularStreamId{kACMetadata, 0, group_id, 0}; + } + static ModularStreamId QuantTable(size_t quant_table_id) { + JXL_ASSERT(quant_table_id < DequantMatrices::kNum); + return ModularStreamId{kQuantTable, quant_table_id, 0, 0}; + } + static ModularStreamId ModularAC(size_t group_id, size_t pass_id) { + return ModularStreamId{kModularAC, 0, group_id, pass_id}; + } + static size_t Num(const FrameDimensions& frame_dim, size_t passes) { + return ModularAC(0, passes).ID(frame_dim); + } +}; + +class ModularFrameDecoder { + public: + void Init(const FrameDimensions& frame_dim) { this->frame_dim = frame_dim; } + Status DecodeGlobalInfo(BitReader* reader, const FrameHeader& frame_header, + bool allow_truncated_group = false); + Status DecodeGroup(const Rect& rect, BitReader* reader, int minShift, + int maxShift, const ModularStreamId& stream, + bool zerofill); + // Decodes a VarDCT DC group (`group_id`) from the given `reader`. + Status DecodeVarDCTDC(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state); + // Decodes a VarDCT AC Metadata group (`group_id`) from the given `reader`. + Status DecodeAcMetadata(size_t group_id, BitReader* reader, + PassesDecoderState* dec_state); + // Decodes a RAW quant table from `br` into the given `encoding`, of size + // `required_size_x x required_size_y`. If `modular_frame_decoder` is passed, + // its global tree is used, otherwise no global tree is used. + static Status DecodeQuantTable(size_t required_size_x, size_t required_size_y, + BitReader* br, QuantEncoding* encoding, + size_t idx, + ModularFrameDecoder* modular_frame_decoder); + Status FinalizeDecoding(PassesDecoderState* dec_state, jxl::ThreadPool* pool, + ImageBundle* output); + bool have_dc() const { return have_something; } + + private: + Image full_image; + FrameDimensions frame_dim; + bool do_color; + bool have_something; + Tree tree; + ANSCode code; + std::vector context_map; + GroupHeader global_header; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_MODULAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_noise.cc b/third_party/jpeg-xl/lib/jxl/dec_noise.cc new file mode 100644 index 000000000000..f010055fd6e3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_noise.cc @@ -0,0 +1,245 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_noise.h" + +#include +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_noise.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/xorshift128plus-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Vec; + +using D = HWY_CAPPED(float, 1); + +// Converts one vector's worth of random bits to floats in [1, 2). +// NOTE: as the convolution kernel sums to 0, it doesn't matter if inputs are in +// [0, 1) or in [1, 2). +void BitsToFloat(const uint32_t* JXL_RESTRICT random_bits, + float* JXL_RESTRICT floats) { + const HWY_FULL(float) df; + const HWY_FULL(uint32_t) du; + + const auto bits = Load(du, random_bits); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = BitCast(df, ShiftRight<9>(bits) | Set(du, 0x3F800000)); + Store(rand12, df, floats); +} + +void RandomImage(Xorshift128Plus* rng, const Rect& rect, + ImageF* JXL_RESTRICT noise) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + + // May exceed the vector size, hence we have two loops over x below. + constexpr size_t kFloatsPerBatch = + Xorshift128Plus::N * sizeof(uint64_t) / sizeof(float); + HWY_ALIGN uint64_t batch[Xorshift128Plus::N]; + + const HWY_FULL(float) df; + const size_t N = Lanes(df); + + for (size_t y = 0; y < ysize; ++y) { + float* JXL_RESTRICT row = rect.Row(noise, y); + + size_t x = 0; + // Only entire batches (avoids exceeding the image padding). + for (; x + kFloatsPerBatch <= xsize; x += kFloatsPerBatch) { + rng->Fill(batch); + for (size_t i = 0; i < kFloatsPerBatch; i += Lanes(df)) { + BitsToFloat(reinterpret_cast(batch) + i, row + x + i); + } + } + + // Any remaining pixels, rounded up to vectors (safe due to padding). + rng->Fill(batch); + size_t batch_pos = 0; // < kFloatsPerBatch + for (; x < xsize; x += N) { + BitsToFloat(reinterpret_cast(batch) + batch_pos, + row + x); + batch_pos += N; + } + } +} + +// [0, max_value] +template +static HWY_INLINE V Clamp0ToMax(D d, const V x, const V max_value) { + const auto clamped = Min(x, max_value); + return ZeroIfNegative(clamped); +} + +// x is in [0+delta, 1+delta], delta ~= 0.06 +template +typename StrengthEval::V NoiseStrength(const StrengthEval& eval, + const typename StrengthEval::V x) { + return Clamp0ToMax(D(), eval(x), Set(D(), 1.0f)); +} + +// TODO(veluca): SIMD-fy. +class StrengthEvalLut { + public: + using V = Vec; + + explicit StrengthEvalLut(const NoiseParams& noise_params) + : noise_params_(noise_params) {} + + V operator()(const V vx) const { + float x; + Store(vx, D(), &x); + std::pair pos = IndexAndFrac(x); + JXL_DASSERT(pos.first >= 0 && static_cast(pos.first) < + NoiseParams::kNumNoisePoints - 1); + float low = noise_params_.lut[pos.first]; + float hi = noise_params_.lut[pos.first + 1]; + return Set(D(), low * (1.0f - pos.second) + hi * pos.second); + } + + private: + const NoiseParams noise_params_; +}; + +template +void AddNoiseToRGB(const D d, const Vec rnd_noise_r, + const Vec rnd_noise_g, const Vec rnd_noise_cor, + const Vec noise_strength_g, const Vec noise_strength_r, + float ytox, float ytob, float* JXL_RESTRICT out_x, + float* JXL_RESTRICT out_y, float* JXL_RESTRICT out_b) { + const auto kRGCorr = Set(d, 0.9921875f); // 127/128 + const auto kRGNCorr = Set(d, 0.0078125f); // 1/128 + + const auto red_noise = kRGNCorr * rnd_noise_r * noise_strength_r + + kRGCorr * rnd_noise_cor * noise_strength_r; + const auto green_noise = kRGNCorr * rnd_noise_g * noise_strength_g + + kRGCorr * rnd_noise_cor * noise_strength_g; + + auto vx = Load(d, out_x); + auto vy = Load(d, out_y); + auto vb = Load(d, out_b); + + vx += red_noise - green_noise + Set(d, ytox) * (red_noise + green_noise); + vy += red_noise + green_noise; + vb += Set(d, ytob) * (red_noise + green_noise); + + Store(vx, d, out_x); + Store(vy, d, out_y); + Store(vb, d, out_b); +} + +void AddNoise(const NoiseParams& noise_params, const Rect& noise_rect, + const Image3F& noise, const Rect& opsin_rect, + const ColorCorrelationMap& cmap, Image3F* opsin) { + if (!noise_params.HasAny()) return; + const StrengthEvalLut noise_model(noise_params); + D d; + const auto half = Set(d, 0.5f); + + const size_t xsize = opsin_rect.xsize(); + const size_t ysize = opsin_rect.ysize(); + + // With the prior subtract-random Laplacian approximation, rnd_* ranges were + // about [-1.5, 1.6]; Laplacian3 about doubles this to [-3.6, 3.6], so the + // normalizer is half of what it was before (0.5). + const auto norm_const = Set(d, 0.22f); + + float ytox = cmap.YtoXRatio(0); + float ytob = cmap.YtoBRatio(0); + + for (size_t y = 0; y < ysize; ++y) { + float* JXL_RESTRICT row_x = opsin_rect.PlaneRow(opsin, 0, y); + float* JXL_RESTRICT row_y = opsin_rect.PlaneRow(opsin, 1, y); + float* JXL_RESTRICT row_b = opsin_rect.PlaneRow(opsin, 2, y); + const float* JXL_RESTRICT row_rnd_r = noise_rect.ConstPlaneRow(noise, 0, y); + const float* JXL_RESTRICT row_rnd_g = noise_rect.ConstPlaneRow(noise, 1, y); + const float* JXL_RESTRICT row_rnd_c = noise_rect.ConstPlaneRow(noise, 2, y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto vx = Load(d, row_x + x); + const auto vy = Load(d, row_y + x); + const auto in_g = vy - vx; + const auto in_r = vy + vx; + const auto noise_strength_g = NoiseStrength(noise_model, in_g * half); + const auto noise_strength_r = NoiseStrength(noise_model, in_r * half); + const auto addit_rnd_noise_red = Load(d, row_rnd_r + x) * norm_const; + const auto addit_rnd_noise_green = Load(d, row_rnd_g + x) * norm_const; + const auto addit_rnd_noise_correlated = + Load(d, row_rnd_c + x) * norm_const; + AddNoiseToRGB(D(), addit_rnd_noise_red, addit_rnd_noise_green, + addit_rnd_noise_correlated, noise_strength_g, + noise_strength_r, ytox, ytob, row_x + x, row_y + x, + row_b + x); + } + } +} + +void RandomImage3(size_t seed, const Rect& rect, Image3F* JXL_RESTRICT noise) { + HWY_ALIGN Xorshift128Plus rng(seed); + RandomImage(&rng, rect, &noise->Plane(0)); + RandomImage(&rng, rect, &noise->Plane(1)); + RandomImage(&rng, rect, &noise->Plane(2)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(AddNoise); +void AddNoise(const NoiseParams& noise_params, const Rect& noise_rect, + const Image3F& noise, const Rect& opsin_rect, + const ColorCorrelationMap& cmap, Image3F* opsin) { + return HWY_DYNAMIC_DISPATCH(AddNoise)(noise_params, noise_rect, noise, + opsin_rect, cmap, opsin); +} + +HWY_EXPORT(RandomImage3); +void RandomImage3(size_t seed, const Rect& rect, Image3F* JXL_RESTRICT noise) { + return HWY_DYNAMIC_DISPATCH(RandomImage3)(seed, rect, noise); +} + +void DecodeFloatParam(float precision, float* val, BitReader* br) { + const int absval_quant = br->ReadFixedBits<10>(); + *val = absval_quant / precision; +} + +Status DecodeNoise(BitReader* br, NoiseParams* noise_params) { + for (float& i : noise_params->lut) { + DecodeFloatParam(kNoisePrecision, &i, br); + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_noise.h b/third_party/jpeg-xl/lib/jxl/dec_noise.h new file mode 100644 index 000000000000..3ea9d5442f88 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_noise.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_NOISE_H_ +#define LIB_JXL_DEC_NOISE_H_ + +// Noise synthesis. Currently disabled. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +// Add a noise to Opsin image, loading generated random noise from `noise_rect` +// in `noise`. +void AddNoise(const NoiseParams& noise_params, const Rect& noise_rect, + const Image3F& noise, const Rect& opsin_rect, + const ColorCorrelationMap& cmap, Image3F* opsin); + +void RandomImage3(size_t seed, const Rect& rect, Image3F* JXL_RESTRICT noise); + +// Must only call if FrameHeader.flags.kNoise. +Status DecodeNoise(BitReader* br, NoiseParams* noise_params); + +} // namespace jxl + +#endif // LIB_JXL_DEC_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_params.h b/third_party/jpeg-xl/lib/jxl/dec_params.h new file mode 100644 index 000000000000..83ef9b385c1e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_params.h @@ -0,0 +1,68 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_PARAMS_H_ +#define LIB_JXL_DEC_PARAMS_H_ + +// Parameters and flags that govern JXL decompression. + +#include +#include + +#include + +#include "lib/jxl/base/override.h" + +namespace jxl { + +struct DecompressParams { + // If true, checks at the end of decoding that all of the compressed data + // was consumed by the decoder. + bool check_decompressed_size = true; + + // If true, skip dequant and iDCT and decode to JPEG (only if possible) + bool keep_dct = false; + + // These cannot be kOn because they need encoder support. + Override preview = Override::kDefault; + + // How many passes to decode at most. By default, decode everything. + uint32_t max_passes = std::numeric_limits::max(); + // Alternatively, one can specify the maximum tolerable downscaling factor + // with respect to the full size of the image. By default, nothing less than + // the full size is requested. + size_t max_downsampling = 1; + + // Try to decode as much as possible of a truncated codestream, but only whole + // sections at a time. + bool allow_partial_files = false; + // Allow even more progression. + bool allow_more_progressive_steps = false; + + bool operator==(const DecompressParams other) const { + return check_decompressed_size == other.check_decompressed_size && + keep_dct == other.keep_dct && preview == other.preview && + max_passes == other.max_passes && + max_downsampling == other.max_downsampling && + allow_partial_files == other.allow_partial_files && + allow_more_progressive_steps == other.allow_more_progressive_steps; + } + bool operator!=(const DecompressParams& other) const { + return !(*this == other); + } +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc new file mode 100644 index 000000000000..ef5997d75852 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc @@ -0,0 +1,187 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_patch_dictionary.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +constexpr int kMaxPatches = 1 << 24; + +Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize) { + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + + auto read_num = [&](size_t context) { + size_t r = decoder.ReadHybridUint(context, br, context_map); + return r; + }; + + size_t num_ref_patch = read_num(kNumRefPatchContext); + // TODO(veluca): does this make sense? + if (num_ref_patch > kMaxPatches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + + for (size_t id = 0; id < num_ref_patch; id++) { + PatchReferencePosition ref_pos; + ref_pos.ref = read_num(kReferenceFrameContext); + if (ref_pos.ref >= kMaxNumReferenceFrames || + shared_->reference_frames[ref_pos.ref].frame->xsize() == 0) { + return JXL_FAILURE("Invalid reference frame ID"); + } + const ImageBundle& ib = *shared_->reference_frames[ref_pos.ref].frame; + ref_pos.x0 = read_num(kPatchReferencePositionContext); + ref_pos.y0 = read_num(kPatchReferencePositionContext); + ref_pos.xsize = read_num(kPatchSizeContext) + 1; + ref_pos.ysize = read_num(kPatchSizeContext) + 1; + if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + size_t id_count = read_num(kPatchCountContext) + 1; + if (id_count > kMaxPatches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + positions_.reserve(positions_.size() + id_count); + for (size_t i = 0; i < id_count; i++) { + PatchPosition pos; + pos.ref_pos = ref_pos; + if (i == 0) { + pos.x = read_num(kPatchPositionContext); + pos.y = read_num(kPatchPositionContext); + } else { + pos.x = + positions_.back().x + UnpackSigned(read_num(kPatchOffsetContext)); + pos.y = + positions_.back().y + UnpackSigned(read_num(kPatchOffsetContext)); + } + if (pos.x + ref_pos.xsize > xsize) { + return JXL_FAILURE("Invalid patch x: at %zu + %zu > %zu", pos.x, + ref_pos.xsize, xsize); + } + if (pos.y + ref_pos.ysize > ysize) { + return JXL_FAILURE("Invalid patch y: at %zu + %zu > %zu", pos.y, + ref_pos.ysize, ysize); + } + for (size_t i = 0; i < shared_->metadata->m.extra_channel_info.size() + 1; + i++) { + uint32_t blend_mode = read_num(kPatchBlendModeContext); + if (blend_mode >= uint32_t(PatchBlendMode::kNumBlendModes)) { + return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode); + } + PatchBlending info; + info.mode = static_cast(blend_mode); + if (i != 0 && info.mode != PatchBlendMode::kNone) { + return JXL_FAILURE( + "Blending of extra channels with patches is not supported yet"); + } + if (info.mode != PatchBlendMode::kAdd && + info.mode != PatchBlendMode::kNone && + info.mode != PatchBlendMode::kReplace) { + return JXL_FAILURE("Blending mode not supported yet: %u", blend_mode); + } + if (UsesAlpha(info.mode) && + shared_->metadata->m.extra_channel_info.size() > 1) { + info.alpha_channel = read_num(kPatchAlphaChannelContext); + if (info.alpha_channel >= + shared_->metadata->m.extra_channel_info.size()) { + return JXL_FAILURE( + "Invalid alpha channel for blending: %u out of %u\n", + info.alpha_channel, + (uint32_t)shared_->metadata->m.extra_channel_info.size()); + } + } + if (UsesClamp(info.mode)) { + info.clamp = read_num(kPatchClampContext); + } + pos.blending.push_back(info); + } + positions_.push_back(std::move(pos)); + } + } + + if (!decoder.CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + if (!HasAny()) { + return JXL_FAILURE("Decoded patch dictionary but got none"); + } + + ComputePatchCache(); + return true; +} + +void PatchDictionary::ComputePatchCache() { + if (positions_.empty()) return; + std::vector> sorted_patches_y; + for (size_t i = 0; i < positions_.size(); i++) { + const PatchPosition& pos = positions_[i]; + for (size_t y = pos.y; y < pos.y + pos.ref_pos.ysize; y++) { + sorted_patches_y.emplace_back(y, i); + } + } + // The relative order of patches that affect the same pixels is preserved. + // This is important for patches that have a blend mode different from kAdd. + std::sort(sorted_patches_y.begin(), sorted_patches_y.end()); + patch_starts_.resize(sorted_patches_y.back().first + 2, + sorted_patches_y.size()); + sorted_patches_.resize(sorted_patches_y.size()); + for (size_t i = 0; i < sorted_patches_y.size(); i++) { + sorted_patches_[i] = sorted_patches_y[i].second; + patch_starts_[sorted_patches_y[i].first] = + std::min(patch_starts_[sorted_patches_y[i].first], i); + } + for (size_t i = patch_starts_.size() - 1; i > 0; i--) { + patch_starts_[i - 1] = std::min(patch_starts_[i], patch_starts_[i - 1]); + } +} + +void PatchDictionary::AddTo(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + Apply(opsin, opsin_rect, image_rect); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h new file mode 100644 index 000000000000..a210fd0f9675 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h @@ -0,0 +1,204 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_PATCH_DICTIONARY_H_ +#define LIB_JXL_DEC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include +#include +#include + +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +constexpr size_t kMaxPatchSize = 32; + +enum class PatchBlendMode : uint8_t { + // The new values are the old ones. Useful to skip some channels. + kNone = 0, + // The new values (in the crop) replace the old ones: sample = new + kReplace = 1, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 2, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // This blend mode is only supported if BlendColorSpace is kEncoded. The + // range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + kMul = 3, + // The new values (in the crop) replace the old ones if alpha>0: + // For first alpha channel: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + // If using kBlendAbove, new is the patch and old is the original image; if + // using kBlendBelow, the meaning is inverted. + kBlendAbove = 4, + kBlendBelow = 5, + // The new values (in the crop) are added to the old ones if alpha>0: + // For first alpha channel: sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAddAbove = 6, + kAlphaWeightedAddBelow = 7, + kNumBlendModes, +}; + +inline bool UsesAlpha(PatchBlendMode mode) { + return mode == PatchBlendMode::kBlendAbove || + mode == PatchBlendMode::kBlendBelow || + mode == PatchBlendMode::kAlphaWeightedAddAbove || + mode == PatchBlendMode::kAlphaWeightedAddBelow; +} +inline bool UsesClamp(PatchBlendMode mode) { + return UsesAlpha(mode) || mode == PatchBlendMode::kMul; +} + +struct PatchBlending { + PatchBlendMode mode; + uint32_t alpha_channel; + bool clamp; +}; + +struct QuantizedPatch { + size_t xsize; + size_t ysize; + QuantizedPatch() { + for (size_t i = 0; i < 3; i++) { + pixels[i].resize(kMaxPatchSize * kMaxPatchSize); + fpixels[i].resize(kMaxPatchSize * kMaxPatchSize); + } + } + std::vector pixels[3] = {}; + // Not compared. Used only to retrieve original pixels to construct the + // reference image. + std::vector fpixels[3] = {}; + bool operator==(const QuantizedPatch& other) const { + if (xsize != other.xsize) return false; + if (ysize != other.ysize) return false; + for (size_t c = 0; c < 3; c++) { + if (memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize) != 0) + return false; + } + return true; + } + + bool operator<(const QuantizedPatch& other) const { + if (xsize != other.xsize) return xsize < other.xsize; + if (ysize != other.ysize) return ysize < other.ysize; + for (size_t c = 0; c < 3; c++) { + int cmp = memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize); + if (cmp > 0) return false; + if (cmp < 0) return true; + } + return false; + } +}; + +// Pair (patch, vector of occurences). +using PatchInfo = + std::pair>>; + +// Position and size of the patch in the reference frame. +struct PatchReferencePosition { + size_t ref, x0, y0, xsize, ysize; + bool operator<(const PatchReferencePosition& oth) const { + return std::make_tuple(ref, x0, y0, xsize, ysize) < + std::make_tuple(oth.ref, oth.x0, oth.y0, oth.xsize, oth.ysize); + } + bool operator==(const PatchReferencePosition& oth) const { + return !(*this < oth) && !(oth < *this); + } +}; + +struct PatchPosition { + // Position of top-left corner of the patch in the image. + size_t x, y; + // Different blend mode for color and extra channels. + std::vector blending; + PatchReferencePosition ref_pos; + bool operator<(const PatchPosition& oth) const { + return std::make_tuple(ref_pos, x, y) < + std::make_tuple(oth.ref_pos, oth.x, oth.y); + } +}; + +struct PassesSharedState; + +// Encoder-side helper class to encode the PatchesDictionary. +class PatchDictionaryEncoder; + +class PatchDictionary { + public: + PatchDictionary() = default; + + void SetPassesSharedState(const PassesSharedState* shared) { + shared_ = shared; + } + + bool HasAny() const { return !positions_.empty(); } + + Status Decode(BitReader* br, size_t xsize, size_t ysize); + + // Only adds patches that belong to the `image_rect` area of the decoded + // image, writing them to the `opsin_rect` area of `opsin`. + void AddTo(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; + + private: + friend class PatchDictionaryEncoder; + + const PassesSharedState* shared_; + std::vector positions_; + + // Patch occurences sorted by y. + std::vector sorted_patches_; + // Index of the first patch for each y value. + std::vector patch_starts_; + + // Patch IDs in position [patch_starts_[y], patch_start_[y+1]) of + // sorted_patches_ are all the patches that intersect the horizontal line at + // y. + // The relative order of patches that affect the same pixels is the same - + // important when applying patches is noncommutative. + + // Compute patches_by_y_ after updating positions_. + void ComputePatchCache(); + + // Implemented in patch_dictionary_internal.h + template + void Apply(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_PATCH_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc new file mode 100644 index 000000000000..0488c98adc00 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.cc @@ -0,0 +1,853 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_reconstruct.h" + +#include +#include + +#include "lib/jxl/filters.h" +#include "lib/jxl/image_ops.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_reconstruct.cc" +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/blending.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/dec_upsample.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/transfer_functions-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +Status UndoXYBInPlace(Image3F* idct, const Rect& rect, + const OutputEncodingInfo& output_encoding_info) { + PROFILER_ZONE("UndoXYB"); + + // The size of `rect` might not be a multiple of Lanes(d), but is guaranteed + // to be a multiple of kBlockDim or at the margin of the image. + for (size_t y = 0; y < rect.ysize(); y++) { + float* JXL_RESTRICT row0 = rect.PlaneRow(idct, 0, y); + float* JXL_RESTRICT row1 = rect.PlaneRow(idct, 1, y); + float* JXL_RESTRICT row2 = rect.PlaneRow(idct, 2, y); + + const HWY_CAPPED(float, GroupBorderAssigner::kPaddingXRound) d; + + if (output_encoding_info.color_encoding.tf.IsLinear()) { + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, + output_encoding_info.opsin_params, &linear_r, &linear_g, + &linear_b); + + Store(linear_r, d, row0 + x); + Store(linear_g, d, row1 + x); + Store(linear_b, d, row2 + x); + } + } else if (output_encoding_info.color_encoding.tf.IsSRGB()) { + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, + output_encoding_info.opsin_params, &linear_r, &linear_g, + &linear_b); + +#if JXL_HIGH_PRECISION + Store(TF_SRGB().EncodedFromDisplay(d, linear_r), d, row0 + x); + Store(TF_SRGB().EncodedFromDisplay(d, linear_g), d, row1 + x); + Store(TF_SRGB().EncodedFromDisplay(d, linear_b), d, row2 + x); +#else + Store(FastLinearToSRGB(d, linear_r), d, row0 + x); + Store(FastLinearToSRGB(d, linear_g), d, row1 + x); + Store(FastLinearToSRGB(d, linear_b), d, row2 + x); +#endif + } + } else if (output_encoding_info.color_encoding.tf.IsPQ()) { + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, + output_encoding_info.opsin_params, &linear_r, &linear_g, + &linear_b); + Store(TF_PQ().EncodedFromDisplay(d, linear_r), d, row0 + x); + Store(TF_PQ().EncodedFromDisplay(d, linear_g), d, row1 + x); + Store(TF_PQ().EncodedFromDisplay(d, linear_b), d, row2 + x); + } + } else if (output_encoding_info.color_encoding.tf.IsGamma()) { + auto gamma_tf = [&](hwy::HWY_NAMESPACE::Vec v) { + return IfThenZeroElse( + v <= Set(d, 1e-5f), + FastPowf(d, v, Set(d, output_encoding_info.inverse_gamma))); + }; + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, + output_encoding_info.opsin_params, &linear_r, &linear_g, + &linear_b); + Store(gamma_tf(linear_r), d, row0 + x); + Store(gamma_tf(linear_g), d, row1 + x); + Store(gamma_tf(linear_b), d, row2 + x); + } + } else { + return JXL_FAILURE("Invalid target encoding"); + } + } + return true; +} + +template +void StoreRGBA(D d, V r, V g, V b, V a, bool alpha, size_t n, size_t extra, + uint8_t* buf) { +#if HWY_TARGET == HWY_SCALAR + buf[0] = r.raw; + buf[1] = g.raw; + buf[2] = b.raw; + if (alpha) { + buf[3] = a.raw; + } +#elif HWY_TARGET == HWY_NEON + if (alpha) { + uint8x8x4_t data = {r.raw, g.raw, b.raw, a.raw}; + if (extra >= 8) { + vst4_u8(buf, data); + } else { + uint8_t tmp[8 * 4]; + vst4_u8(tmp, data); + memcpy(buf, tmp, n * 4); + } + } else { + uint8x8x3_t data = {r.raw, g.raw, b.raw}; + if (extra >= 8) { + vst3_u8(buf, data); + } else { + uint8_t tmp[8 * 3]; + vst3_u8(tmp, data); + memcpy(buf, tmp, n * 3); + } + } +#else + // TODO(veluca): implement this for x86. + size_t mul = alpha ? 4 : 3; + HWY_ALIGN uint8_t bytes[16]; + Store(r, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i] = bytes[i]; + } + Store(g, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i + 1] = bytes[i]; + } + Store(b, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[mul * i + 2] = bytes[i]; + } + if (alpha) { + Store(a, d, bytes); + for (size_t i = 0; i < n; i++) { + buf[4 * i + 3] = bytes[i]; + } + } +#endif +} + +// Outputs floating point image to RGBA 8-bit buffer. Does not support alpha +// channel in the input, but outputs opaque alpha channel for the case where the +// output buffer to write to is in the 4-byte per pixel RGBA format. +void FloatToRGBA8(const Image3F& input, const Rect& input_rect, bool is_rgba, + const ImageF* alpha_in, const Rect& alpha_rect, + const Rect& output_buf_rect, uint8_t* JXL_RESTRICT output_buf, + size_t stride) { + size_t bytes = is_rgba ? 4 : 3; + for (size_t y = 0; y < output_buf_rect.ysize(); y++) { + const float* JXL_RESTRICT row_in_r = input_rect.ConstPlaneRow(input, 0, y); + const float* JXL_RESTRICT row_in_g = input_rect.ConstPlaneRow(input, 1, y); + const float* JXL_RESTRICT row_in_b = input_rect.ConstPlaneRow(input, 2, y); + const float* JXL_RESTRICT row_in_a = + alpha_in ? alpha_rect.ConstRow(*alpha_in, y) : nullptr; + size_t base_ptr = + (y + output_buf_rect.y0()) * stride + bytes * output_buf_rect.x0(); + using D = HWY_CAPPED(float, 4); + const D d; + D::Rebind du; + auto zero = Zero(d); + auto one = Set(d, 1.0f); + auto mul = Set(d, 255.0f); + for (size_t x = 0; x < output_buf_rect.xsize(); x += Lanes(d)) { + auto rf = Clamp(zero, Load(d, row_in_r + x), one) * mul; + auto gf = Clamp(zero, Load(d, row_in_g + x), one) * mul; + auto bf = Clamp(zero, Load(d, row_in_b + x), one) * mul; + auto af = row_in_a ? Clamp(zero, Load(d, row_in_a + x), one) * mul + : Set(d, 255.0f); + auto r8 = U8FromU32(BitCast(du, NearestInt(rf))); + auto g8 = U8FromU32(BitCast(du, NearestInt(gf))); + auto b8 = U8FromU32(BitCast(du, NearestInt(bf))); + auto a8 = U8FromU32(BitCast(du, NearestInt(af))); + size_t n = output_buf_rect.xsize() - x; + if (JXL_LIKELY(n >= Lanes(d))) { + StoreRGBA(D::Rebind(), r8, g8, b8, a8, is_rgba, Lanes(d), n, + output_buf + base_ptr + bytes * x); + } else { + StoreRGBA(D::Rebind(), r8, g8, b8, a8, is_rgba, n, n, + output_buf + base_ptr + bytes * x); + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(UndoXYBInPlace); +HWY_EXPORT(FloatToRGBA8); + +void UndoXYB(const Image3F& src, Image3F* dst, + const OutputEncodingInfo& output_info, ThreadPool* pool) { + CopyImageTo(src, dst); + pool->Run(0, src.ysize(), ThreadPool::SkipInit(), [&](int y, int /*thread*/) { + JXL_CHECK(HWY_DYNAMIC_DISPATCH(UndoXYBInPlace)(dst, Rect(*dst).Line(y), + output_info)); + }); +} + +namespace { +// Implements EnsurePaddingInPlace, but allows processing data one row at a +// time. +class EnsurePaddingInPlaceRowByRow { + public: + void Init(Image3F* img, const Rect& rect, const Rect& image_rect, + size_t image_xsize, size_t image_ysize, size_t xpadding, + size_t ypadding, ssize_t* y0, ssize_t* y1) { + // coordinates relative to rect. + JXL_DASSERT(SameSize(rect, image_rect)); + *y0 = -std::min(image_rect.y0(), ypadding); + *y1 = rect.ysize() + std::min(ypadding, image_ysize - image_rect.ysize() - + image_rect.y0()); + if (image_rect.x0() >= xpadding && + image_rect.x0() + image_rect.xsize() + xpadding <= image_xsize) { + // Nothing to do. + strategy_ = kSkip; + } else if (image_xsize >= 2 * xpadding) { + strategy_ = kFast; + } else { + strategy_ = kSlow; + } + img_ = img; + y0_ = rect.y0(); + JXL_DASSERT(rect.x0() >= xpadding); + x0_ = x1_ = rect.x0() - xpadding; + // If close to the left border - do mirroring. + if (image_rect.x0() < xpadding) x1_ = rect.x0() - image_rect.x0(); + x2_ = x3_ = rect.x0() + rect.xsize() + xpadding; + // If close to the right border - do mirroring. + if (image_rect.x0() + image_rect.xsize() + xpadding > image_xsize) { + x2_ = rect.x0() + image_xsize - image_rect.x0(); + } + JXL_DASSERT(x3_ <= img->xsize()); + JXL_DASSERT(image_xsize == (x2_ - x1_) || + (x1_ - x0_ <= x2_ - x1_ && x3_ - x2_ <= x2_ - x1_)); + } + // To be called when row `y` of the input is available, for all the values in + // [*y0, *y1). + void Process(ssize_t y) { + switch (strategy_) { + case kSkip: + break; + case kFast: + // Image is wide enough that a single Mirror() step is sufficient. + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT row = img_->PlaneRow(c, y + y0_); + for (size_t x = x0_; x < x1_; x++) { + row[x] = row[2 * x1_ - x - 1]; + } + for (size_t x = x2_; x < x3_; x++) { + row[x] = row[2 * x2_ - x - 1]; + } + } + break; + case kSlow: + // Slow case for small images. + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT row = img_->PlaneRow(c, y + y0_) + x1_; + for (ssize_t x = x0_ - x1_; x < 0; x++) { + *(row + x) = row[Mirror(x, x2_ - x1_)]; + } + for (size_t x = x2_ - x1_; x < x3_ - x1_; x++) { + *(row + x) = row[Mirror(x, x2_ - x1_)]; + } + } + break; + } + } + + private: + // Initialized to silence spurious compiler warnings. + Image3F* img_ = nullptr; + // Will fill [x0_, x1_) and [x2_, x3_) on every row. + // The [x1_, x2_) range contains valid image pixels. We guarantee that either + // x1_ - x0_ <= x2_ - x1_, (and similarly for x2_, x3_), or that the [x1_, + // x2_) contains a full horizontal line of the original image. + size_t x0_ = 0, x1_ = 0, x2_ = 0, x3_ = 0; + size_t y0_ = 0; + // kSlow: use calls to Mirror(), for the case where the border might be larger + // than the image. + // kFast: directly use the result of Mirror() when it can be computed in a + // single iteration. + // kSkip: do nothing. + enum Strategy { kFast, kSlow, kSkip }; + Strategy strategy_ = kSkip; +}; +} // namespace + +void EnsurePaddingInPlace(Image3F* img, const Rect& rect, + const Rect& image_rect, size_t image_xsize, + size_t image_ysize, size_t xpadding, + size_t ypadding) { + ssize_t y0, y1; + EnsurePaddingInPlaceRowByRow impl; + impl.Init(img, rect, image_rect, image_xsize, image_ysize, xpadding, ypadding, + &y0, &y1); + for (ssize_t y = y0; y < y1; y++) { + impl.Process(y); + } +} + +Status FinalizeImageRect( + Image3F* input_image, const Rect& input_rect, + const std::vector>& extra_channels, + PassesDecoderState* dec_state, size_t thread, + ImageBundle* JXL_RESTRICT output_image, const Rect& output_rect) { + const ImageFeatures& image_features = dec_state->shared->image_features; + const FrameHeader& frame_header = dec_state->shared->frame_header; + const ImageMetadata& metadata = frame_header.nonserialized_metadata->m; + const LoopFilter& lf = frame_header.loop_filter; + const FrameDimensions& frame_dim = dec_state->shared->frame_dim; + JXL_DASSERT(output_rect.xsize() <= kApplyImageFeaturesTileDim); + JXL_DASSERT(output_rect.ysize() <= kApplyImageFeaturesTileDim); + JXL_DASSERT(input_rect.xsize() == output_rect.xsize()); + JXL_DASSERT(input_rect.ysize() == output_rect.ysize()); + JXL_DASSERT(output_rect.x0() % GroupBorderAssigner::kPaddingXRound == 0); + JXL_DASSERT(output_rect.xsize() % GroupBorderAssigner::kPaddingXRound == 0 || + output_rect.xsize() + output_rect.x0() == frame_dim.xsize || + output_rect.xsize() + output_rect.x0() == frame_dim.xsize_padded); + + // +----------------------------- STEP 1 ------------------------------+ + // | Compute the rects on which patches and splines will be applied. | + // | In case we are applying upsampling, we need to apply patches on a | + // | slightly larger image. | + // +-------------------------------------------------------------------+ + + // If we are applying upsampling, we need 2 more pixels around the actual rect + // for border. Thus, we also need to apply patches and splines to those + // pixels. We compute here + // - The portion of image that corresponds to the area we are applying IF. + // (rect_for_if) + // - The rect where that pixel data is stored in upsampling_input_storage. + // (rect_for_if_storage) + // - The rect where the pixel data that we need to upsample is stored. + // (rect_for_upsampling) + // - The source rect for the pixel data in `input_image`. It is assumed that, + // if `output_rect` is not on an image border, `input_image:input_rect` has + // enough border available. (rect_for_if_input) + + Image3F* output_color = + dec_state->rgb_output == nullptr ? output_image->color() : nullptr; + + Image3F* storage_for_if = output_color; + Rect rect_for_if = output_rect; + Rect rect_for_if_storage = output_rect; + Rect rect_for_upsampling = output_rect; + Rect rect_for_if_input = input_rect; + size_t extra_rows_t = 0; + size_t extra_rows_b = 0; + if (frame_header.upsampling != 1) { + size_t ifbx0 = 0; + size_t ifbx1 = 0; + size_t ifby0 = 0; + size_t ifby1 = 0; + if (output_rect.x0() >= 2) { + JXL_DASSERT(input_rect.x0() >= 2); + ifbx0 = 2; + } + if (output_rect.y0() >= 2) { + JXL_DASSERT(input_rect.y0() >= 2); + extra_rows_t = ifby0 = 2; + } + for (size_t extra : {1, 2}) { + if (output_rect.x0() + output_rect.xsize() + extra <= + dec_state->shared->frame_dim.xsize_padded) { + JXL_DASSERT(input_rect.x0() + input_rect.xsize() + extra <= + input_image->xsize()); + ifbx1 = extra; + } + if (output_rect.y0() + output_rect.ysize() + extra <= + dec_state->shared->frame_dim.ysize_padded) { + JXL_DASSERT(input_rect.y0() + input_rect.ysize() + extra <= + input_image->ysize()); + extra_rows_b = ifby1 = extra; + } + } + rect_for_if = Rect(output_rect.x0() - ifbx0, output_rect.y0() - ifby0, + output_rect.xsize() + ifbx0 + ifbx1, + output_rect.ysize() + ifby0 + ifby1); + // Storage for pixel data does not necessarily start at (0, 0) as we need to + // have the left border of upsampling_rect aligned to a multiple of + // GroupBorderAssigner::kPaddingXRound. + rect_for_if_storage = + Rect(kBlockDim + RoundUpTo(ifbx0, GroupBorderAssigner::kPaddingXRound) - + ifbx0, + kBlockDim, rect_for_if.xsize(), rect_for_if.ysize()); + rect_for_upsampling = + Rect(kBlockDim + RoundUpTo(ifbx0, GroupBorderAssigner::kPaddingXRound), + kBlockDim + ifby0, output_rect.xsize(), output_rect.ysize()); + rect_for_if_input = + Rect(input_rect.x0() - ifbx0, input_rect.y0() - ifby0, + rect_for_if_storage.xsize(), rect_for_if_storage.ysize()); + storage_for_if = &dec_state->upsampling_input_storage[thread]; + } + + // Variables for upsampling and filtering. + Rect upsampled_output_rect(output_rect.x0() * frame_header.upsampling, + output_rect.y0() * frame_header.upsampling, + output_rect.xsize() * frame_header.upsampling, + output_rect.ysize() * frame_header.upsampling); + EnsurePaddingInPlaceRowByRow ensure_padding_upsampling; + ssize_t ensure_padding_upsampling_y0 = 0; + ssize_t ensure_padding_upsampling_y1 = 0; + + EnsurePaddingInPlaceRowByRow ensure_padding_filter; + FilterPipeline* fp = nullptr; + ssize_t ensure_padding_filter_y0 = 0; + ssize_t ensure_padding_filter_y1 = 0; + Rect image_padded_rect; + if (lf.epf_iters != 0 || lf.gab) { + fp = &dec_state->filter_pipelines[thread]; + size_t xextra = + rect_for_if_input.x0() % GroupBorderAssigner::kPaddingXRound; + image_padded_rect = Rect(rect_for_if.x0() - xextra, rect_for_if.y0(), + rect_for_if.xsize() + xextra, rect_for_if.ysize()); + } + + // Also prepare rect for memorizing the pre-color-transform frame. + const Rect pre_color_output_rect = + upsampled_output_rect.Crop(dec_state->pre_color_transform_frame); + Rect pre_color_output_rect_storage = pre_color_output_rect; + + // +----------------------------- STEP 2 ------------------------------+ + // | Change rects and buffer to not use `output_image` if direct | + // | output to rgb8 is requested. | + // +-------------------------------------------------------------------+ + Image3F* output_pixel_data_storage = output_color; + Rect upsampled_output_rect_for_storage = upsampled_output_rect; + if (dec_state->rgb_output) { + size_t log2_upsampling = CeilLog2Nonzero(frame_header.upsampling); + if (storage_for_if == output_color) { + storage_for_if = + &dec_state->output_pixel_data_storage[log2_upsampling][thread]; + rect_for_if_storage = + Rect(0, 0, rect_for_if_storage.xsize(), rect_for_if_storage.ysize()); + } + output_pixel_data_storage = + &dec_state->output_pixel_data_storage[log2_upsampling][thread]; + upsampled_output_rect_for_storage = Rect( + 0, 0, upsampled_output_rect.xsize(), upsampled_output_rect.ysize()); + if (frame_header.upsampling == 1 && fp == nullptr) { + upsampled_output_rect_for_storage = rect_for_if_storage = + rect_for_if_input; + output_pixel_data_storage = storage_for_if = input_image; + } + pre_color_output_rect_storage = + Rect(upsampled_output_rect_for_storage.x0(), + upsampled_output_rect_for_storage.y0(), + pre_color_output_rect.xsize(), pre_color_output_rect.ysize()); + } + // Set up alpha channel. + const size_t ec = + metadata.Find(ExtraChannel::kAlpha) - metadata.extra_channel_info.data(); + const ImageF* alpha = nullptr; + Rect alpha_rect = upsampled_output_rect; + if (ec < metadata.extra_channel_info.size()) { + JXL_ASSERT(ec < extra_channels.size()); + alpha = extra_channels[ec].first; + JXL_ASSERT(upsampled_output_rect.x0() >= extra_channels[ec].second.x0()); + JXL_ASSERT(upsampled_output_rect.y0() >= extra_channels[ec].second.y0()); + alpha_rect = + Rect(upsampled_output_rect.x0() - extra_channels[ec].second.x0(), + upsampled_output_rect.y0() - extra_channels[ec].second.y0(), + upsampled_output_rect.xsize(), upsampled_output_rect.ysize()); + } + + // +----------------------------- STEP 3 ------------------------------+ + // | Set up upsampling. | + // +-------------------------------------------------------------------+ + if (frame_header.upsampling != 1) { + ensure_padding_upsampling.Init( + storage_for_if, rect_for_upsampling, output_rect, + frame_dim.xsize_padded, frame_dim.ysize_padded, 2, 2, + &ensure_padding_upsampling_y0, &ensure_padding_upsampling_y1); + } + + // +----------------------------- STEP 4 ------------------------------+ + // | Set up the filter pipeline. | + // +-------------------------------------------------------------------+ + if (fp) { + // If `rect_for_if_input` does not start at a multiple of + // GroupBorderAssigner::kPaddingXRound, we extend the rect we run EPF on by + // one full padding length to ensure sigma is handled correctly. We also + // extend the output and image rects accordingly. To do this, we need 2x the + // border. + size_t xextra = + rect_for_if_input.x0() % GroupBorderAssigner::kPaddingXRound; + Rect filter_input_padded_rect( + rect_for_if_input.x0() - xextra, rect_for_if_input.y0(), + rect_for_if_input.xsize() + xextra, rect_for_if_input.ysize()); + ensure_padding_filter.Init( + input_image, rect_for_if_input, rect_for_if, frame_dim.xsize_padded, + frame_dim.ysize_padded, lf.Padding(), lf.Padding(), + &ensure_padding_filter_y0, &ensure_padding_filter_y1); + Rect filter_output_padded_rect( + rect_for_if_storage.x0() - xextra, rect_for_if_storage.y0(), + rect_for_if_storage.xsize() + xextra, rect_for_if_storage.ysize()); + fp = PrepareFilterPipeline(dec_state, image_padded_rect, *input_image, + filter_input_padded_rect, frame_dim.ysize_padded, + thread, storage_for_if, + filter_output_padded_rect); + } + + // +----------------------------- STEP 5 ------------------------------+ + // | Run the prepared pipeline of operations. | + // +-------------------------------------------------------------------+ + + // y values are relative to rect_for_if. + // Automatic mirroring in fp->ApplyFiltersRow() implies that we should ensure + // that padding for the first lines of the image is already present before + // calling ApplyFiltersRow() with "virtual" rows. + // Here we rely on the fact that virtual rows at the beginning of the image + // are only present if input_rect.y0() == 0. + ssize_t first_ensure_padding_y = ensure_padding_filter_y0; + if (output_rect.y0() == 0) { + JXL_DASSERT(ensure_padding_filter_y0 == 0); + first_ensure_padding_y = + std::min(lf.Padding(), ensure_padding_filter_y1); + for (ssize_t y = 0; y < first_ensure_padding_y; y++) { + ensure_padding_filter.Process(y); + } + } + + for (ssize_t y = -lf.Padding(); + y < static_cast(lf.Padding() + rect_for_if.ysize()); y++) { + if (fp) { + if (y >= first_ensure_padding_y && y < ensure_padding_filter_y1) { + ensure_padding_filter.Process(y); + } + fp->ApplyFiltersRow(lf, dec_state->filter_weights, image_padded_rect, y); + } else if (output_pixel_data_storage != input_image) { + for (size_t c = 0; c < 3; c++) { + memcpy(rect_for_if_storage.PlaneRow(storage_for_if, c, y), + rect_for_if_input.ConstPlaneRow(*input_image, c, y), + rect_for_if_input.xsize() * sizeof(float)); + } + } + if (y < static_cast(lf.Padding())) continue; + // At this point, row `y - lf.Padding()` of `rect_for_if` has been produced + // by the filters. + ssize_t available_y = y - lf.Padding(); + image_features.patches.AddTo(storage_for_if, + rect_for_if_storage.Line(available_y), + rect_for_if.Line(available_y)); + JXL_RETURN_IF_ERROR(image_features.splines.AddTo( + storage_for_if, rect_for_if_storage.Line(available_y), + rect_for_if.Line(available_y), dec_state->shared->cmap)); + size_t num_ys = 1; + if (frame_header.upsampling != 1) { + // Upsampling `y` values are relative to `rect_for_upsampling`, not to + // `rect_for_if`. + ssize_t shifted_y = available_y - extra_rows_t; + if (shifted_y >= ensure_padding_upsampling_y0 && + shifted_y < ensure_padding_upsampling_y1) { + ensure_padding_upsampling.Process(shifted_y); + } + // Upsampling will access two rows of border, so the first upsampling + // output will be available after shifted_y is at least 2, *unless* image + // height is <= 2. + if (shifted_y < 2 && + shifted_y + 1 != static_cast(frame_dim.ysize_padded)) { + continue; + } + // Value relative to upsampled_output_rect. + size_t input_y = std::max(shifted_y - 2, 0); + size_t upsampled_available_y = frame_header.upsampling * input_y; + size_t num_input_rows = 1; + // If we are going to mirror the last output rows, then we already have 3 + // input lines ready. This happens iff we did not extend rect_for_if on + // the bottom *and* we are at the last `y` value. + if (extra_rows_b != 2 && + static_cast(y) + 1 == lf.Padding() + rect_for_if.ysize()) { + num_input_rows = 3; + } + num_input_rows = std::min(num_input_rows, frame_dim.ysize_padded); + num_ys = num_input_rows * frame_header.upsampling; + Rect upsample_input_rect = + rect_for_upsampling.Lines(input_y, num_input_rows); + dec_state->upsampler.UpsampleRect( + *storage_for_if, upsample_input_rect, output_pixel_data_storage, + upsampled_output_rect_for_storage.Lines(upsampled_available_y, + num_ys), + static_cast(output_rect.y0()) - + static_cast(rect_for_upsampling.y0()), + frame_dim.ysize_padded); + available_y = upsampled_available_y; + } + + // The image data is now unconditionally in + // `output_image_storage:upsampled_output_rect_for_storage`. + if (frame_header.flags & FrameHeader::kNoise) { + PROFILER_ZONE("AddNoise"); + AddNoise(image_features.noise_params, + upsampled_output_rect.Lines(available_y, num_ys), + dec_state->noise, + upsampled_output_rect_for_storage.Lines(available_y, num_ys), + dec_state->shared_storage.cmap, output_pixel_data_storage); + } + + if (dec_state->pre_color_transform_frame.xsize() != 0) { + for (size_t c = 0; c < 3; c++) { + for (size_t y = available_y; + y < num_ys && y < pre_color_output_rect.ysize(); y++) { + float* JXL_RESTRICT row_out = pre_color_output_rect.PlaneRow( + &dec_state->pre_color_transform_frame, c, y); + const float* JXL_RESTRICT row_in = + pre_color_output_rect_storage.ConstPlaneRow( + *output_pixel_data_storage, c, y); + memcpy(row_out, row_in, + pre_color_output_rect.xsize() * sizeof(*row_in)); + } + } + } + + // We skip the color transform entirely if save_before_color_transform and + // the frame is not supposed to be displayed. + + if (dec_state->fast_xyb_srgb8_conversion) { + FastXYBTosRGB8( + *output_pixel_data_storage, + upsampled_output_rect_for_storage.Lines(available_y, num_ys), + upsampled_output_rect.Lines(available_y, num_ys) + .Crop(Rect(0, 0, frame_dim.xsize, frame_dim.ysize)), + dec_state->rgb_output, frame_dim.xsize); + } else { + if (frame_header.needs_color_transform()) { + if (frame_header.color_transform == ColorTransform::kXYB) { + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(UndoXYBInPlace)( + output_pixel_data_storage, + upsampled_output_rect_for_storage.Lines(available_y, num_ys), + dec_state->output_encoding_info)); + } else if (frame_header.color_transform == ColorTransform::kYCbCr) { + YcbcrToRgb( + *output_pixel_data_storage, output_pixel_data_storage, + upsampled_output_rect_for_storage.Lines(available_y, num_ys)); + } + } + + // TODO(veluca): all blending should happen here. + + if (dec_state->rgb_output != nullptr) { + HWY_DYNAMIC_DISPATCH(FloatToRGBA8) + (*output_pixel_data_storage, + upsampled_output_rect_for_storage.Lines(available_y, num_ys), + dec_state->rgb_output_is_rgba, alpha, + alpha_rect.Lines(available_y, num_ys), + upsampled_output_rect.Lines(available_y, num_ys) + .Crop(Rect(0, 0, frame_dim.xsize, frame_dim.ysize)), + dec_state->rgb_output, dec_state->rgb_stride); + } + } + } + + return true; +} + +Status FinalizeFrameDecoding(ImageBundle* decoded, + PassesDecoderState* dec_state, ThreadPool* pool, + bool force_fir, bool skip_blending) { + const LoopFilter& lf = dec_state->shared->frame_header.loop_filter; + const FrameHeader& frame_header = dec_state->shared->frame_header; + const FrameDimensions& frame_dim = dec_state->shared->frame_dim; + + const Image3F* finalize_image_rect_input = &dec_state->decoded; + Image3F chroma_upsampled_image; + // If we used chroma subsampling, we upsample chroma now and run + // ApplyImageFeatures after. + // TODO(veluca): this should part of the FinalizeImageRect() pipeline. + if (!frame_header.chroma_subsampling.Is444()) { + chroma_upsampled_image = CopyImage(dec_state->decoded); + finalize_image_rect_input = &chroma_upsampled_image; + for (size_t c = 0; c < 3; c++) { + ImageF& plane = const_cast(chroma_upsampled_image.Plane(c)); + plane.ShrinkTo( + frame_dim.xsize_padded >> frame_header.chroma_subsampling.HShift(c), + frame_dim.ysize_padded >> frame_header.chroma_subsampling.VShift(c)); + for (size_t i = 0; i < frame_header.chroma_subsampling.HShift(c); i++) { + plane.InitializePaddingForUnalignedAccesses(); + plane = UpsampleH2(plane, pool); + } + for (size_t i = 0; i < frame_header.chroma_subsampling.VShift(c); i++) { + plane.InitializePaddingForUnalignedAccesses(); + plane = UpsampleV2(plane, pool); + } + JXL_DASSERT(SameSize(plane, chroma_upsampled_image)); + } + } + // FinalizeImageRect was not yet run, or we are forcing a run. + if (!dec_state->EagerFinalizeImageRect() || force_fir) { + if (lf.epf_iters > 0 && frame_header.encoding == FrameEncoding::kModular) { + FillImage(kInvSigmaNum / lf.epf_sigma_for_modular, + &dec_state->filter_weights.sigma); + } + std::vector rects_to_process; + for (size_t y = 0; y < frame_dim.ysize_padded; y += kGroupDim) { + for (size_t x = 0; x < frame_dim.xsize_padded; x += kGroupDim) { + Rect rect(x, y, kGroupDim, kGroupDim, frame_dim.xsize_padded, + frame_dim.ysize_padded); + if (rect.xsize() == 0 || rect.ysize() == 0) continue; + rects_to_process.push_back(rect); + } + } + const auto allocate_storage = [&](size_t num_threads) { + dec_state->EnsureStorage(num_threads); + return true; + }; + + decoded->SetExtraChannels(std::move(dec_state->extra_channels)); + + std::atomic apply_features_ok{true}; + auto run_apply_features = [&](size_t rect_id, size_t thread) { + size_t xstart = kBlockDim + dec_state->group_border_assigner.PaddingX( + dec_state->FinalizeRectPadding()); + size_t ystart = dec_state->FinalizeRectPadding(); + Rect group_data_rect(xstart, ystart, rects_to_process[rect_id].xsize(), + rects_to_process[rect_id].ysize()); + CopyImageToWithPadding(rects_to_process[rect_id], + *finalize_image_rect_input, + dec_state->FinalizeRectPadding(), group_data_rect, + &dec_state->group_data[thread]); + std::vector> ec_rects; + ec_rects.reserve(decoded->extra_channels().size()); + Rect r(rects_to_process[rect_id].x0() * frame_header.upsampling, + rects_to_process[rect_id].y0() * frame_header.upsampling, + rects_to_process[rect_id].xsize() * frame_header.upsampling, + rects_to_process[rect_id].ysize() * frame_header.upsampling); + for (size_t i = 0; i < decoded->extra_channels().size(); i++) { + ec_rects.emplace_back(&decoded->extra_channels()[i], r); + } + if (!FinalizeImageRect(&dec_state->group_data[thread], group_data_rect, + ec_rects, dec_state, thread, decoded, + rects_to_process[rect_id])) { + apply_features_ok = false; + } + }; + + RunOnPool(pool, 0, rects_to_process.size(), allocate_storage, + run_apply_features, "ApplyFeatures"); + + if (!apply_features_ok) { + return JXL_FAILURE("FinalizeImageRect failed"); + } + } + + const size_t xsize = frame_dim.xsize_upsampled; + const size_t ysize = frame_dim.ysize_upsampled; + + decoded->ShrinkTo(xsize, ysize); + if (dec_state->pre_color_transform_frame.xsize() != 0) { + dec_state->pre_color_transform_frame.ShrinkTo(xsize, ysize); + } + + if (!skip_blending) { + ImageBlender blender; + ImageBundle foreground = std::move(*decoded); + JXL_RETURN_IF_ERROR( + blender.PrepareBlending(dec_state, &foreground, /*output=*/decoded)); + + std::vector rects_to_process; + for (size_t y = 0; y < frame_dim.ysize; y += kGroupDim) { + for (size_t x = 0; x < frame_dim.xsize; x += kGroupDim) { + Rect rect(x, y, kGroupDim, kGroupDim, frame_dim.xsize, frame_dim.ysize); + if (rect.xsize() == 0 || rect.ysize() == 0) continue; + rects_to_process.push_back(rect); + } + } + + std::atomic blending_ok{true}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, rects_to_process.size(), ThreadPool::SkipInit(), + [&](size_t i, size_t /*thread*/) { + const Rect& rect = rects_to_process[i]; + const auto rect_blender = blender.PrepareRect(rect, foreground); + for (size_t y = 0; y < rect.ysize(); ++y) { + if (!rect_blender.DoBlending(y)) { + blending_ok = false; + return; + } + } + }, + "Blend")); + JXL_RETURN_IF_ERROR(blending_ok.load()); + } + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_reconstruct.h b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.h new file mode 100644 index 000000000000..a5ce18d04360 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_reconstruct.h @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_RECONSTRUCT_H_ +#define LIB_JXL_DEC_RECONSTRUCT_H_ + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Finalizes the decoding of a frame by applying image features if necessary, +// doing color transforms (unless the frame header specifies +// `SaveBeforeColorTransform()`) and applying upsampling. +// +// Writes pixels in the appropriate colorspace to `idct`, shrinking it if +// necessary. +// `skip_blending` is necessary because the encoder butteraugli loop does not +// (yet) handle blending. +// TODO(veluca): remove the "force_fir" parameter, and call EPF directly in +// those use cases where this is needed. +Status FinalizeFrameDecoding(ImageBundle* JXL_RESTRICT decoded, + PassesDecoderState* dec_state, ThreadPool* pool, + bool force_fir, bool skip_blending); + +// Renders the `output_rect` portion of the final image to `output_image` +// (unless the frame is upsampled - in which case, `output_rect` is scaled +// accordingly). `input_rect` should have the same shape. `input_rect` always +// refers to the non-padded pixels. `output_rect.x0()` is guaranteed to be a +// multiple of GroupBorderAssigner::kPaddingRoundX. `output_rect.xsize()` is +// either a multiple of GroupBorderAssigner::kPaddingRoundX, or is such that +// `output_rect.x0() + output_rect.xsize() == frame_dim.xsize`. `input_image` +// may be mutated by adding padding. If `output_rect` is on an image border, the +// input will be padded. Otherwise, appropriate padding must already be present. +Status FinalizeImageRect( + Image3F* input_image, const Rect& input_rect, + const std::vector>& extra_channels, + PassesDecoderState* dec_state, size_t thread, + ImageBundle* JXL_RESTRICT output_image, const Rect& output_rect); + +// Fills padding around `img:rect` in the x direction by mirroring. Padding is +// applied so that a full border of xpadding and ypadding is available, except +// if `image_rect` points to an area of the full image that touches the top or +// the bottom. It is expected that padding is already in place for inputs such +// that the corresponding image_rect is not at an image border. +void EnsurePaddingInPlace(Image3F* img, const Rect& rect, + const Rect& image_rect, size_t image_xsize, + size_t image_ysize, size_t xpadding, size_t ypadding); + +// For DC in the API. +void UndoXYB(const Image3F& src, Image3F* dst, + const OutputEncodingInfo& output_info, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_DEC_RECONSTRUCT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h b/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h new file mode 100644 index 000000000000..40cdbf6a6614 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h @@ -0,0 +1,876 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(LIB_JXL_DEC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#undef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_DEC_TRANSFORMS_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +template +struct DoDCT { + template + void operator()(const From& from, float* JXL_RESTRICT to, + float* JXL_RESTRICT scratch_space) { + ComputeScaledDCT()(from, to, scratch_space); + } +}; + +template +struct DoDCT { + template + void operator()(const From& from, float* JXL_RESTRICT to, + float* JXL_RESTRICT scratch_space) { + ComputeTransposedScaledDCT()(from, to, scratch_space); + } +}; + +// Computes the lowest-frequency LF_ROWSxLF_COLS-sized square in output, which +// is a DCT_ROWS*DCT_COLS-sized DCT block, by doing a ROWS*COLS DCT on the +// input block. +template +JXL_INLINE void ReinterpretingDCT(const float* input, const size_t input_stride, + float* output, const size_t output_stride) { + static_assert(LF_ROWS == ROWS, + "ReinterpretingDCT should only be called with LF == N"); + static_assert(LF_COLS == COLS, + "ReinterpretingDCT should only be called with LF == N"); + HWY_ALIGN float block[ROWS * COLS]; + + // ROWS, COLS <= 8, so we can put scratch space on the stack. + HWY_ALIGN float scratch_space[ROWS * COLS]; + DoDCT()(DCTFrom(input, input_stride), block, scratch_space); + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + output[y * output_stride + x] = + block[y * COLS + x] * DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + output[y * output_stride + x] = + block[y * ROWS + x] * DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } +} + +template +void IDCT2TopBlock(const float* block, size_t stride_out, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * kBlockDim + x]; + float c01 = block[y * kBlockDim + num_2x2 + x]; + float c10 = block[(y + num_2x2) * kBlockDim + x]; + float c11 = block[(y + num_2x2) * kBlockDim + num_2x2 + x]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + temp[y * 2 * kBlockDim + x * 2] = r00; + temp[y * 2 * kBlockDim + x * 2 + 1] = r01; + temp[(y * 2 + 1) * kBlockDim + x * 2] = r10; + temp[(y * 2 + 1) * kBlockDim + x * 2 + 1] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * stride_out + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + HWY_ALIGN static constexpr float k4x4AFVBasis[16][16] = { + { + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + }, + { + 0.876902929799142f, + 0.2206518106944235f, + -0.10140050393753763f, + -0.1014005039375375f, + 0.2206518106944236f, + -0.10140050393753777f, + -0.10140050393753772f, + -0.10140050393753763f, + -0.10140050393753758f, + -0.10140050393753769f, + -0.1014005039375375f, + -0.10140050393753768f, + -0.10140050393753768f, + -0.10140050393753759f, + -0.10140050393753763f, + -0.10140050393753741f, + }, + { + 0.0, + 0.0, + 0.40670075830260755f, + 0.44444816619734445f, + 0.0, + 0.0, + 0.19574399372042936f, + 0.2929100136981264f, + -0.40670075830260716f, + -0.19574399372042872f, + 0.0, + 0.11379074460448091f, + -0.44444816619734384f, + -0.29291001369812636f, + -0.1137907446044814f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21255748058288748f, + 0.3085497062849767f, + 0.0, + 0.4706702258572536f, + -0.1621205195722993f, + 0.0, + -0.21255748058287047f, + -0.16212051957228327f, + -0.47067022585725277f, + -0.1464291867126764f, + 0.3085497062849487f, + 0.0, + -0.14642918671266536f, + 0.4251149611657548f, + }, + { + 0.0, + -0.7071067811865474f, + 0.0, + 0.0, + 0.7071067811865476f, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + }, + { + -0.4105377591765233f, + 0.6235485373547691f, + -0.06435071657946274f, + -0.06435071657946266f, + 0.6235485373547694f, + -0.06435071657946284f, + -0.0643507165794628f, + -0.06435071657946274f, + -0.06435071657946272f, + -0.06435071657946279f, + -0.06435071657946266f, + -0.06435071657946277f, + -0.06435071657946277f, + -0.06435071657946273f, + -0.06435071657946274f, + -0.0643507165794626f, + }, + { + 0.0, + 0.0, + -0.4517556589999482f, + 0.15854503551840063f, + 0.0, + -0.04038515160822202f, + 0.0074182263792423875f, + 0.39351034269210167f, + -0.45175565899994635f, + 0.007418226379244351f, + 0.1107416575309343f, + 0.08298163094882051f, + 0.15854503551839705f, + 0.3935103426921022f, + 0.0829816309488214f, + -0.45175565899994796f, + }, + { + 0.0, + 0.0, + -0.304684750724869f, + 0.5112616136591823f, + 0.0, + 0.0, + -0.290480129728998f, + -0.06578701549142804f, + 0.304684750724884f, + 0.2904801297290076f, + 0.0, + -0.23889773523344604f, + -0.5112616136592012f, + 0.06578701549142545f, + 0.23889773523345467f, + 0.0, + }, + { + 0.0, + 0.0, + 0.3017929516615495f, + 0.25792362796341184f, + 0.0, + 0.16272340142866204f, + 0.09520022653475037f, + 0.0, + 0.3017929516615503f, + 0.09520022653475055f, + -0.16272340142866173f, + -0.35312385449816297f, + 0.25792362796341295f, + 0.0, + -0.3531238544981624f, + -0.6035859033230976f, + }, + { + 0.0, + 0.0, + 0.40824829046386274f, + 0.0, + 0.0, + 0.0, + 0.0, + -0.4082482904638628f, + -0.4082482904638635f, + 0.0, + 0.0, + -0.40824829046386296f, + 0.0, + 0.4082482904638634f, + 0.408248290463863f, + 0.0, + }, + { + 0.0, + 0.0, + 0.1747866975480809f, + 0.0812611176717539f, + 0.0, + 0.0, + -0.3675398009862027f, + -0.307882213957909f, + -0.17478669754808135f, + 0.3675398009862011f, + 0.0, + 0.4826689115059883f, + -0.08126111767175039f, + 0.30788221395790305f, + -0.48266891150598584f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21105601049335784f, + 0.18567180916109802f, + 0.0, + 0.0, + 0.49215859013738733f, + -0.38525013709251915f, + 0.21105601049335806f, + -0.49215859013738905f, + 0.0, + 0.17419412659916217f, + -0.18567180916109904f, + 0.3852501370925211f, + -0.1741941265991621f, + 0.0, + }, + { + 0.0, + 0.0, + -0.14266084808807264f, + -0.3416446842253372f, + 0.0, + 0.7367497537172237f, + 0.24627107722075148f, + -0.08574019035519306f, + -0.14266084808807344f, + 0.24627107722075137f, + 0.14883399227113567f, + -0.04768680350229251f, + -0.3416446842253373f, + -0.08574019035519267f, + -0.047686803502292804f, + -0.14266084808807242f, + }, + { + 0.0, + 0.0, + -0.13813540350758585f, + 0.3302282550303788f, + 0.0, + 0.08755115000587084f, + -0.07946706605909573f, + -0.4613374887461511f, + -0.13813540350758294f, + -0.07946706605910261f, + 0.49724647109535086f, + 0.12538059448563663f, + 0.3302282550303805f, + -0.4613374887461554f, + 0.12538059448564315f, + -0.13813540350758452f, + }, + { + 0.0, + 0.0, + -0.17437602599651067f, + 0.0702790691196284f, + 0.0, + -0.2921026642334881f, + 0.3623817333531167f, + 0.0, + -0.1743760259965108f, + 0.36238173335311646f, + 0.29210266423348785f, + -0.4326608024727445f, + 0.07027906911962818f, + 0.0, + -0.4326608024727457f, + 0.34875205199302267f, + }, + { + 0.0, + 0.0, + 0.11354987314994337f, + -0.07417504595810355f, + 0.0, + 0.19402893032594343f, + -0.435190496523228f, + 0.21918684838857466f, + 0.11354987314994257f, + -0.4351904965232251f, + 0.5550443808910661f, + -0.25468277124066463f, + -0.07417504595810233f, + 0.2191868483885728f, + -0.25468277124066413f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto pixel = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto cf = Set(d, coeffs[j]); + auto basis = Load(d, k4x4AFVBasis[j] + i); + pixel = MulAdd(cf, basis, pixel); + } + Store(pixel, d, pixels + i); + } +} + +template +void AFVTransformToPixels(const float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride) { + HWY_ALIGN float scratch_space[4 * 8]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + float dcs[3] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + dcs[0] = (block00 + block10 + block01) * 4.0f; + dcs[1] = (block00 + block10 - block01); + dcs[2] = block00 - block10; + // IAFV: (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + coeff[0] = dcs[0]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + coeff[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2]; + } + } + HWY_ALIGN float block[4 * 8]; + AFVIDCT4x4(coeff, block); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + pixels[(iy + afv_y * 4) * pixels_stride + afv_x * 4 + ix] = + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)]; + } + } + // IDCT4x4 in (odd, even) positions. + block[0] = dcs[1]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2 + 1]; + } + } + ComputeTransposedScaledIDCT<4>()( + block, + DCTTo(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + scratch_space); + // IDCT4x8. + block[0] = dcs[2]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(1 + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, + DCTTo(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + scratch_space); +} + +HWY_MAYBE_UNUSED void TransformToPixels(const AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + PROFILER_ZONE("IDCT Identity"); + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = dcs[y * 2 + x]; + float residual_sum = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + residual_sum += coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1] = + block_dc - residual_sum * (1.0f / 16); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] = + coefficients[(y + iy * 2) * 8 + x + ix * 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + pixels[y * 4 * pixels_stride + x * 4] = + coefficients[(y + 2) * 8 + x + 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + break; + } + case Type::DCT8X4: { + PROFILER_ZONE("IDCT 8x4"); + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(x + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<8, 4>()(block, DCTTo(pixels + x * 4, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X8: { + PROFILER_ZONE("IDCT 4x8"); + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[y]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(y + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, DCTTo(pixels + y * 4 * pixels_stride, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X4: { + PROFILER_ZONE("IDCT 4"); + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + block[0] = dcs[y * 2 + x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + ComputeTransposedScaledIDCT<4>()( + block, + DCTTo(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + scratch_space); + } + } + break; + } + case Type::DCT2X2: { + PROFILER_ZONE("IDCT 2"); + HWY_ALIGN float coeffs[kDCTBlockSize]; + memcpy(coeffs, coefficients, sizeof(float) * kDCTBlockSize); + IDCT2TopBlock<2>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<4>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<8>(coeffs, kBlockDim, coeffs); + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + pixels[y * pixels_stride + x] = coeffs[y * kBlockDim + x]; + } + } + break; + } + case Type::DCT16X16: { + PROFILER_ZONE("IDCT 16"); + ComputeTransposedScaledIDCT<16>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::DCT16X8: { + PROFILER_ZONE("IDCT 16x8"); + ComputeScaledIDCT<16, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X16: { + PROFILER_ZONE("IDCT 8x16"); + ComputeScaledIDCT<8, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X8: { + PROFILER_ZONE("IDCT 32x8"); + ComputeScaledIDCT<32, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X32: { + PROFILER_ZONE("IDCT 8x32"); + ComputeScaledIDCT<8, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X16: { + PROFILER_ZONE("IDCT 32x16"); + ComputeScaledIDCT<32, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT16X32: { + PROFILER_ZONE("IDCT 16x32"); + ComputeScaledIDCT<16, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X32: { + PROFILER_ZONE("IDCT 32"); + ComputeTransposedScaledIDCT<32>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::DCT: { + PROFILER_ZONE("IDCT 8"); + ComputeTransposedScaledIDCT<8>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::AFV0: { + PROFILER_ZONE("IAFV0"); + AFVTransformToPixels<0>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV1: { + PROFILER_ZONE("IAFV1"); + AFVTransformToPixels<1>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV2: { + PROFILER_ZONE("IAFV2"); + AFVTransformToPixels<2>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV3: { + PROFILER_ZONE("IAFV3"); + AFVTransformToPixels<3>(coefficients, pixels, pixels_stride); + break; + } + case Type::DCT64X32: { + PROFILER_ZONE("IDCT 64x32"); + ComputeScaledIDCT<64, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X64: { + PROFILER_ZONE("IDCT 32x64"); + ComputeScaledIDCT<32, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X64: { + PROFILER_ZONE("IDCT 64"); + ComputeTransposedScaledIDCT<64>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::DCT128X64: { + PROFILER_ZONE("IDCT 128x64"); + ComputeScaledIDCT<128, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X128: { + PROFILER_ZONE("IDCT 64x128"); + ComputeScaledIDCT<64, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X128: { + PROFILER_ZONE("IDCT 128"); + ComputeTransposedScaledIDCT<128>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::DCT256X128: { + PROFILER_ZONE("IDCT 256x128"); + ComputeScaledIDCT<256, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X256: { + PROFILER_ZONE("IDCT 128x256"); + ComputeScaledIDCT<128, 256>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT256X256: { + PROFILER_ZONE("IDCT 256"); + ComputeTransposedScaledIDCT<256>()( + coefficients, DCTTo(pixels, pixels_stride), scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void LowestFrequenciesFromDC(const AcStrategy::Type strategy, + const float* dc, size_t dc_stride, + float* llf) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT8X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT16X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 2 * kBlockDim); + break; + } + case Type::DCT32X8: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT8X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT32X16: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT16X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT32X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 4 * kBlockDim); + break; + } + case Type::DCT64X32: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT32X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT64X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 8 * kBlockDim); + break; + } + case Type::DCT128X64: { + ReinterpretingDCT( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT64X128: { + ReinterpretingDCT( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT128X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + dc, dc_stride, llf, 16 * kBlockDim); + break; + } + case Type::DCT256X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT128X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT256X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + llf[0] = dc[0]; + break; + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + }; +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_TRANSFORMS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc new file mode 100644 index 000000000000..a561026b4ab8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_transforms_testonly.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_transforms_testonly.cc" +#include +#include + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformToPixels); +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformToPixels)(strategy, coefficients, pixels, + pixels_stride, scratch_space); +} + +HWY_EXPORT(LowestFrequenciesFromDC); +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf) { + return HWY_DYNAMIC_DISPATCH(LowestFrequenciesFromDC)(strategy, dc, dc_stride, + llf); +} + +HWY_EXPORT(AFVIDCT4x4); +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + return HWY_DYNAMIC_DISPATCH(AFVIDCT4x4)(coeffs, pixels); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h new file mode 100644 index 000000000000..c8f56398799a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ +#define LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ + +// Facade for (non-inlined) inverse integral transforms. + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf); + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels); + +} // namespace jxl + +#endif // LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_upsample.cc b/third_party/jpeg-xl/lib/jxl/dec_upsample.cc new file mode 100644 index 000000000000..02b6fb674f47 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_upsample.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_upsample.h" + +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +template +void InitKernel(const float* weights, float kernel[4][4][5][5]) { + static_assert(N == 1 || N == 2 || N == 4, + "Upsampling kernel init only implemented for N = 1,2,4"); + for (size_t i = 0; i < 5 * N; i++) { + for (size_t j = 0; j < 5 * N; j++) { + size_t y = std::min(i, j); + size_t x = std::max(i, j); + kernel[j / 5][i / 5][j % 5][i % 5] = + weights[5 * N * y - y * (y - 1) / 2 + x - y]; + } + } +} + +template +float Kernel(size_t x, size_t y, size_t ix, size_t iy, + const float kernel[4][4][5][5]) { + if (N == 2) { + return kernel[0][0][y % 2 ? 4 - iy : iy][x % 2 ? 4 - ix : ix]; + } + if (N == 4) { + return kernel[y % 4 < 2 ? y % 2 : 1 - y % 2][x % 4 < 2 ? x % 2 : 1 - x % 2] + [y % 4 < 2 ? iy : 4 - iy][x % 4 < 2 ? ix : 4 - ix]; + } + if (N == 8) { + return kernel[y % 8 < 4 ? y % 4 : 3 - y % 4][x % 8 < 4 ? x % 4 : 3 - x % 4] + [y % 8 < 4 ? iy : 4 - iy][x % 8 < 4 ? ix : 4 - ix]; + } + JXL_ABORT("Invalid upsample"); +} + +template +void Upsample(const Image3F& src, const Rect& src_rect, Image3F* dst, + const Rect& dst_rect, const float kernel[4][4][5][5], + ssize_t image_y_offset, size_t image_ysize) { + JXL_DASSERT(src_rect.x0() >= 2); + JXL_DASSERT(src_rect.x0() + src_rect.xsize() + 2 <= src.xsize()); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < dst_rect.ysize(); y++) { + float* dst_row = dst_rect.PlaneRow(dst, c, y); + const float* src_rows[5]; + for (int iy = -2; iy <= 2; iy++) { + ssize_t image_y = + static_cast(y / N + src_rect.y0() + iy) + image_y_offset; + src_rows[iy + 2] = + src.PlaneRow(c, Mirror(image_y, image_ysize) - image_y_offset); + } + for (size_t x = 0; x < dst_rect.xsize(); x++) { + size_t xbase = x / N + src_rect.x0() - 2; + float result = 0; + float min = src_rows[0][xbase]; + float max = src_rows[0][xbase]; + for (size_t iy = 0; iy < 5; iy++) { + for (size_t ix = 0; ix < 5; ix++) { + float v = src_rows[iy][xbase + ix]; + result += Kernel(x, y, ix, iy, kernel) * v; + min = std::min(v, min); + max = std::max(v, max); + } + } + // Avoid overshooting. + dst_row[x] = std::min(std::max(result, min), max); + } + } + } +} +} // namespace + +void Upsampler::Init(size_t upsampling, const CustomTransformData& data) { + upsampling_ = upsampling; + if (upsampling_ == 1) return; + if (upsampling_ == 2) { + InitKernel<1>(data.upsampling2_weights, kernel_); + } else if (upsampling_ == 4) { + InitKernel<2>(data.upsampling4_weights, kernel_); + } else if (upsampling_ == 8) { + InitKernel<4>(data.upsampling8_weights, kernel_); + } else { + JXL_ABORT("Invalid upsample"); + } +} + +void Upsampler::UpsampleRect(const Image3F& src, const Rect& src_rect, + Image3F* dst, const Rect& dst_rect, + ssize_t image_y_offset, size_t image_ysize) const { + if (upsampling_ == 1) return; + JXL_ASSERT(dst_rect.xsize() == src_rect.xsize() * upsampling_); + JXL_ASSERT(dst_rect.ysize() == src_rect.ysize() * upsampling_); + if (upsampling_ == 2) { + Upsample<2>(src, src_rect, dst, dst_rect, kernel_, image_y_offset, + image_ysize); + } else if (upsampling_ == 4) { + Upsample<4>(src, src_rect, dst, dst_rect, kernel_, image_y_offset, + image_ysize); + } else if (upsampling_ == 8) { + Upsample<8>(src, src_rect, dst, dst_rect, kernel_, image_y_offset, + image_ysize); + } else { + JXL_ABORT("Not implemented"); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_upsample.h b/third_party/jpeg-xl/lib/jxl/dec_upsample.h new file mode 100644 index 000000000000..3520edc9c9ae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_upsample.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_UPSAMPLE_H_ +#define LIB_JXL_DEC_UPSAMPLE_H_ + +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +struct Upsampler { + void Init(size_t upsampling, const CustomTransformData& data); + + // The caller must guarantee that `src:src_rect` has two pixels of padding + // available on each side of the x dimension. `image_ysize` is the total + // height of the frame that the source area belongs to (not the buffer); + // `image_y_offset` is the difference between `src.y0()` and the corresponding + // y value in the full frame. + void UpsampleRect(const Image3F& src, const Rect& src_rect, Image3F* dst, + const Rect& dst_rect, ssize_t image_y_offset, + size_t image_ysize) const; + + private: + size_t upsampling_ = 1; + float kernel_[4][4][5][5]; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_UPSAMPLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h b/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h new file mode 100644 index 000000000000..a54c521a8ccb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h @@ -0,0 +1,330 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// XYB -> linear sRGB helper function. + +#if defined(LIB_JXL_DEC_XYB_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_XYB_INL_H_ +#undef LIB_JXL_DEC_XYB_INL_H_ +#else +#define LIB_JXL_DEC_XYB_INL_H_ +#endif + +#include + +#include "lib/jxl/dec_xyb.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; + +// Inverts the pixel-wise RGB->XYB conversion in OpsinDynamicsImage() (including +// the gamma mixing and simple gamma). Avoids clamping to [0, 1] - out of (sRGB) +// gamut values may be in-gamut after transforming to a wider space. +// "inverse_matrix" points to 9 broadcasted vectors, which are the 3x3 entries +// of the (row-major) opsin absorbance matrix inverse. Pre-multiplying its +// entries by c is equivalent to multiplying linear_* by c afterwards. +template +HWY_INLINE HWY_MAYBE_UNUSED void XybToRgb(D d, const V opsin_x, const V opsin_y, + const V opsin_b, + const OpsinParams& opsin_params, + V* const HWY_RESTRICT linear_r, + V* const HWY_RESTRICT linear_g, + V* const HWY_RESTRICT linear_b) { +#if HWY_TARGET == HWY_SCALAR + const auto neg_bias_r = Set(d, opsin_params.opsin_biases[0]); + const auto neg_bias_g = Set(d, opsin_params.opsin_biases[1]); + const auto neg_bias_b = Set(d, opsin_params.opsin_biases[2]); +#else + const auto neg_bias_rgb = LoadDup128(d, opsin_params.opsin_biases); + const auto neg_bias_r = Broadcast<0>(neg_bias_rgb); + const auto neg_bias_g = Broadcast<1>(neg_bias_rgb); + const auto neg_bias_b = Broadcast<2>(neg_bias_rgb); +#endif + + // Color space: XYB -> RGB + auto gamma_r = opsin_y + opsin_x; + auto gamma_g = opsin_y - opsin_x; + auto gamma_b = opsin_b; + + gamma_r -= Set(d, opsin_params.opsin_biases_cbrt[0]); + gamma_g -= Set(d, opsin_params.opsin_biases_cbrt[1]); + gamma_b -= Set(d, opsin_params.opsin_biases_cbrt[2]); + + // Undo gamma compression: linear = gamma^3 for efficiency. + const auto gamma_r2 = gamma_r * gamma_r; + const auto gamma_g2 = gamma_g * gamma_g; + const auto gamma_b2 = gamma_b * gamma_b; + const auto mixed_r = MulAdd(gamma_r2, gamma_r, neg_bias_r); + const auto mixed_g = MulAdd(gamma_g2, gamma_g, neg_bias_g); + const auto mixed_b = MulAdd(gamma_b2, gamma_b, neg_bias_b); + + const float* HWY_RESTRICT inverse_matrix = opsin_params.inverse_opsin_matrix; + + // Unmix (multiply by 3x3 inverse_matrix) + *linear_r = LoadDup128(d, &inverse_matrix[0 * 4]) * mixed_r; + *linear_g = LoadDup128(d, &inverse_matrix[3 * 4]) * mixed_r; + *linear_b = LoadDup128(d, &inverse_matrix[6 * 4]) * mixed_r; + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[1 * 4]), mixed_g, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[4 * 4]), mixed_g, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[7 * 4]), mixed_g, *linear_b); + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[2 * 4]), mixed_b, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[5 * 4]), mixed_b, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[8 * 4]), mixed_b, *linear_b); +} + +bool HasFastXYBTosRGB8() { +#if HWY_TARGET == HWY_NEON + return true; +#else + return false; +#endif +} + +void FastXYBTosRGB8(const Image3F& input, const Rect& input_rect, + const Rect& output_buf_rect, + uint8_t* JXL_RESTRICT output_buf, size_t xsize) { + // This function is very NEON-specific. As such, it uses intrinsics directly. +#if HWY_TARGET == HWY_NEON + // WARNING: doing fixed point arithmetic correctly is very complicated. + // Changes to this function should be thoroughly tested. + + // Note that the input is assumed to have 13 bits of mantissa, and the output + // will have 14 bits. + auto srgb_tf = [&](int16x8_t v16) { + int16x8_t clz = vclzq_s16(v16); + // Convert to [0.25, 0.5) range. + int16x8_t v025_05_16 = vqshlq_s16(v16, vqsubq_s16(clz, vdupq_n_s16(2))); + + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) / 32. + // poly ~ ((0.95x-1.75)*x+1.72)*x+0.29 + // We actually compute ~ ((0.47x-0.87)*x+0.86)*(2x)+0.29 as 1.75 and 1.72 + // overflow our fixed point representation. + + int16x8_t twov = vqaddq_s16(v025_05_16, v025_05_16); + + // 0.47 * x + int16x8_t step1 = vqrdmulhq_n_s16(v025_05_16, 15706); + // - 0.87 + int16x8_t step2 = vsubq_s16(step1, vdupq_n_s16(28546)); + // * x + int16x8_t step3 = vqrdmulhq_s16(step2, v025_05_16); + // + 0.86 + int16x8_t step4 = vaddq_s16(step3, vdupq_n_s16(28302)); + // * 2x + int16x8_t step5 = vqrdmulhq_s16(step4, twov); + // + 0.29 + int16x8_t mul16 = vaddq_s16(step5, vdupq_n_s16(9485)); + + int16x8_t exp16 = vsubq_s16(vdupq_n_s16(11), clz); + // Compute 2**(1/2.4*exp16)/32. Values of exp16 that would overflow are + // capped to 1. + // Generated with the following Python script: + // a = [] + // b = [] + // + // for i in range(0, 16): + // v = 2**(5/12.*i) + // v /= 16 + // v *= 256 * 128 + // v = int(v) + // a.append(v // 256) + // b.append(v % 256) + // + // print(", ".join("0x%02x" % x for x in a)) + // + // print(", ".join("0x%02x" % x for x in b)) + + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_high[16] = { + 0x08, 0x0a, 0x0e, 0x13, 0x19, 0x21, 0x2d, 0x3c, + 0x50, 0x6b, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, + }; + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_low[16] = { + 0x00, 0xad, 0x41, 0x06, 0x65, 0xe7, 0x41, 0x68, + 0xa2, 0xa2, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + }; + // Using the highway implementation here since vqtbl1q is aarch64-only. + using hwy::HWY_NAMESPACE::Vec128; + uint8x16_t pow_low = + TableLookupBytes( + Vec128(vld1q_u8(k2to512powersm1div32_low)), + Vec128(vreinterpretq_s16_u8(exp16))) + .raw; + uint8x16_t pow_high = + TableLookupBytes( + Vec128(vld1q_u8(k2to512powersm1div32_high)), + Vec128(vreinterpretq_s16_u8(exp16))) + .raw; + int16x8_t pow16 = vreinterpretq_u16_s16(vsliq_n_u16( + vreinterpretq_u8_s16(pow_low), vreinterpretq_u8_s16(pow_high), 8)); + + // approximation of v * 12.92, divided by 2 + // Note that our input is using 13 mantissa bits instead of 15. + int16x8_t v16_linear = vrshrq_n_s16(vmulq_n_s16(v16, 826), 5); + // 1.055*pow(v, 1/2.4) - 0.055, divided by 2 + auto v16_pow = vsubq_s16(vqrdmulhq_s16(mul16, pow16), vdupq_n_s16(901)); + // > 0.0031308f (note that v16 has 13 mantissa bits) + return vbslq_s16(vcgeq_s16(v16, vdupq_n_s16(26)), v16_pow, v16_linear); + }; + for (size_t y = 0; y < output_buf_rect.ysize(); y++) { + const float* JXL_RESTRICT row_in_x = input_rect.ConstPlaneRow(input, 0, y); + const float* JXL_RESTRICT row_in_y = input_rect.ConstPlaneRow(input, 1, y); + const float* JXL_RESTRICT row_in_b = input_rect.ConstPlaneRow(input, 2, y); + size_t base_ptr = + 3 * (y + output_buf_rect.y0()) * xsize + 3 * output_buf_rect.x0(); + for (size_t x = 0; x < output_buf_rect.xsize(); x += 8) { + // Normal ranges for xyb for in-gamut sRGB colors: + // x: -0.015386 0.028100 + // y: 0.000000 0.845308 + // b: 0.000000 0.845308 + + // We actually want x * 8 to have some extra precision. + // TODO(veluca): consider different approaches here, like vld1q_f32_x2. + float32x4_t opsin_x_left = vld1q_f32(row_in_x + x); + int16x4_t opsin_x16_times8_left = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_left, 18)); + float32x4_t opsin_x_right = + vld1q_f32(row_in_x + x + (x + 4 < output_buf_rect.xsize() ? 4 : 0)); + int16x4_t opsin_x16_times8_right = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_right, 18)); + int16x8_t opsin_x16_times8 = + vcombine_s16(opsin_x16_times8_left, opsin_x16_times8_right); + + float32x4_t opsin_y_left = vld1q_f32(row_in_y + x); + int16x4_t opsin_y16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_y_left, 15)); + float32x4_t opsin_y_right = + vld1q_f32(row_in_y + x + (x + 4 < output_buf_rect.xsize() ? 4 : 0)); + int16x4_t opsin_y16_right = + vqmovn_s32(vcvtq_n_s32_f32(opsin_y_right, 15)); + int16x8_t opsin_y16 = vcombine_s16(opsin_y16_left, opsin_y16_right); + + float32x4_t opsin_b_left = vld1q_f32(row_in_b + x); + int16x4_t opsin_b16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_b_left, 15)); + float32x4_t opsin_b_right = + vld1q_f32(row_in_b + x + (x + 4 < output_buf_rect.xsize() ? 4 : 0)); + int16x4_t opsin_b16_right = + vqmovn_s32(vcvtq_n_s32_f32(opsin_b_right, 15)); + int16x8_t opsin_b16 = vcombine_s16(opsin_b16_left, opsin_b16_right); + + int16x8_t neg_bias16 = vdupq_n_s16(-124); // -0.0037930732552754493 + int16x8_t neg_bias_cbrt16 = vdupq_n_s16(-5110); // -0.155954201 + int16x8_t neg_bias_half16 = vdupq_n_s16(-62); + + // Color space: XYB -> RGB + // Compute ((y+x-bias_cbrt)^3-(y-x-bias_cbrt)^3)/2, + // ((y+x-bias_cbrt)^3+(y-x-bias_cbrt)^3)/2+bias, (b-bias_cbrt)^3+bias. + // Note that ignoring x2 in the formulas below (as x << y) results in + // errors of at least 3 in the final sRGB values. + int16x8_t opsin_yp16 = vqsubq_s16(opsin_y16, neg_bias_cbrt16); + int16x8_t ysq16 = vqrdmulhq_s16(opsin_yp16, opsin_yp16); + int16x8_t twentyfourx16 = vmulq_n_s16(opsin_x16_times8, 3); + int16x8_t twentyfourxy16 = vqrdmulhq_s16(opsin_yp16, twentyfourx16); + int16x8_t threexsq16 = + vrshrq_n_s16(vqrdmulhq_s16(opsin_x16_times8, twentyfourx16), 6); + + // We can ignore x^3 here. Note that this is multiplied by 8. + int16x8_t mixed_rmg16 = vqrdmulhq_s16(twentyfourxy16, opsin_yp16); + + int16x8_t mixed_rpg_sos_half = vhaddq_s16(ysq16, threexsq16); + int16x8_t mixed_rpg16 = vhaddq_s16( + vqrdmulhq_s16(opsin_yp16, mixed_rpg_sos_half), neg_bias_half16); + + int16x8_t gamma_b16 = vqsubq_s16(opsin_b16, neg_bias_cbrt16); + int16x8_t gamma_bsq16 = vqrdmulhq_s16(gamma_b16, gamma_b16); + int16x8_t gamma_bcb16 = vqrdmulhq_s16(gamma_bsq16, gamma_b16); + int16x8_t mixed_b16 = vqaddq_s16(gamma_bcb16, neg_bias16); + // mixed_rpg and mixed_b are in 0-1 range. + // mixed_rmg has a smaller range (-0.035 to 0.035 for valid sRGB). Note + // that at this point it is already multiplied by 8. + + // We multiply all the mixed values by 1/4 (i.e. shift them to 13-bit + // fixed point) to ensure intermediate quantities are in range. Note that + // r-g is not shifted, and was x8 before here; this corresponds to a x32 + // overall multiplicative factor and ensures that all the matrix constants + // are in 0-1 range. + // Similarly, mixed_rpg16 is already multiplied by 1/4 because of the two + // vhadd + using neg_bias_half. + mixed_b16 = vshrq_n_s16(mixed_b16, 2); + + // Unmix (multiply by 3x3 inverse_matrix) + // For increased precision, we use a matrix for converting from + // ((mixed_r - mixed_g)/2, (mixed_r + mixed_g)/2, mixed_b) to rgb. This + // avoids cancellation effects when computing (y+x)^3-(y-x)^3. + // We compute mixed_rpg - mixed_b because the (1+c)*mixed_rpg - c * + // mixed_b pattern is repeated frequently in the code below. This allows + // us to save a multiply per channel, and removes the presence of + // some constants above 1. Moreover, mixed_rmg - mixed_b is in (-1, 1) + // range, so the subtraction is safe. + // All the magic-looking constants here are derived by computing the + // inverse opsin matrix for the transformation modified as described + // above. + + // Precomputation common to multiple color values. + int16x8_t mixed_rpgmb16 = vqsubq_s16(mixed_rpg16, mixed_b16); + int16x8_t mixed_rpgmb_times_016 = vqrdmulhq_n_s16(mixed_rpgmb16, 5394); + int16x8_t mixed_rg16 = vqaddq_s16(mixed_rpgmb_times_016, mixed_rpg16); + + // R + int16x8_t linear_r16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, 21400)); + + // G + int16x8_t linear_g16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, -7857)); + + // B + int16x8_t linear_b16 = vqrdmulhq_n_s16(mixed_rpgmb16, -30996); + linear_b16 = vqaddq_s16(linear_b16, mixed_b16); + linear_b16 = vqaddq_s16(linear_b16, vqrdmulhq_n_s16(mixed_rmg16, -6525)); + + // Apply SRGB transfer function. + uint16x8_t r = srgb_tf(linear_r16); + uint16x8_t g = srgb_tf(linear_g16); + uint16x8_t b = srgb_tf(linear_b16); + + uint8x8_t r8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(r, vshrq_n_s16(r, 8)), 6)); + uint8x8_t g8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(g, vshrq_n_s16(g, 8)), 6)); + uint8x8_t b8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(b, vshrq_n_s16(b, 8)), 6)); + + size_t n = output_buf_rect.xsize() - x; + uint8x8x3_t data = {r8, g8, b8}; + uint8_t* buf = output_buf + base_ptr + 3 * x; + if (n >= 8) { + vst3_u8(buf, data); + } else { + uint8_t tmp[8 * 3]; + vst3_u8(tmp, data); + memcpy(buf, tmp, n * 3); + } + } + } +#else + JXL_ABORT("Unreachable"); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_XYB_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb.cc b/third_party/jpeg-xl/lib/jxl/dec_xyb.cc new file mode 100644 index 000000000000..c719a8ad7d08 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb.cc @@ -0,0 +1,413 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/dec_xyb.h" + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_xyb.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quantizer.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; + +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + PROFILER_FUNC; + + const size_t xsize = inout->xsize(); // not padded + RunOnPool( + pool, 0, inout->ysize(), ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + + // Faster than adding via ByteOffset at end of loop. + float* JXL_RESTRICT row0 = inout->PlaneRow(0, y); + float* JXL_RESTRICT row1 = inout->PlaneRow(1, y); + float* JXL_RESTRICT row2 = inout->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row0 + x); + Store(linear_g, d, row1 + x); + Store(linear_b, d, row2 + x); + } + }, + "OpsinToLinear"); +} + +// Same, but not in-place. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + PROFILER_FUNC; + + JXL_ASSERT(SameSize(rect, *linear)); + + RunOnPool( + pool, 0, static_cast(rect.ysize()), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t y = static_cast(task); + + // Faster than adding via ByteOffset at end of loop. + const float* JXL_RESTRICT row_opsin_0 = rect.ConstPlaneRow(opsin, 0, y); + const float* JXL_RESTRICT row_opsin_1 = rect.ConstPlaneRow(opsin, 1, y); + const float* JXL_RESTRICT row_opsin_2 = rect.ConstPlaneRow(opsin, 2, y); + float* JXL_RESTRICT row_linear_0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear_1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear_2 = linear->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row_opsin_0 + x); + const auto in_opsin_y = Load(d, row_opsin_1 + x); + const auto in_opsin_b = Load(d, row_opsin_2 + x); + JXL_COMPILER_FENCE; + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row_linear_0 + x); + Store(linear_g, d, row_linear_1 + x); + Store(linear_b, d, row_linear_2 + x); + } + }, + "OpsinToLinear(Rect)"); +} + +// Transform YCbCr to RGB. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + const HWY_CAPPED(float, GroupBorderAssigner::kPaddingXRound) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + if ((xsize == 0) || (ysize == 0)) return; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto c128 = Set(df, 128.0f / 255); + const auto crcr = Set(df, 1.402f); + const auto cgcb = Set(df, -0.114f * 1.772f / 0.587f); + const auto cgcr = Set(df, -0.299f * 1.402f / 0.587f); + const auto cbcb = Set(df, 1.772f); + + for (size_t y = 0; y < ysize; y++) { + const float* y_row = rect.ConstPlaneRow(ycbcr, 1, y); + const float* cb_row = rect.ConstPlaneRow(ycbcr, 0, y); + const float* cr_row = rect.ConstPlaneRow(ycbcr, 2, y); + float* r_row = rect.PlaneRow(rgb, 0, y); + float* g_row = rect.PlaneRow(rgb, 1, y); + float* b_row = rect.PlaneRow(rgb, 2, y); + for (size_t x = 0; x < xsize; x += S) { + const auto y_vec = Load(df, y_row + x) + c128; + const auto cb_vec = Load(df, cb_row + x); + const auto cr_vec = Load(df, cr_row + x); + const auto r_vec = crcr * cr_vec + y_vec; + const auto g_vec = cgcr * cr_vec + cgcb * cb_vec + y_vec; + const auto b_vec = cbcb * cb_vec + y_vec; + Store(r_vec, df, r_row + x); + Store(g_vec, df, g_row + x); + Store(b_vec, df, b_row + x); + } + } +} + +/* Vertical upsampling: + * input: + * (a, b, c) := |a1 a2 a3 a4| + * |b1 b2 b3 b4| <- current line + * |c1 c2 c3 c4| + * intermediate: + * u := a + 3 * b + * d := c + 3 * b + * output: + * |u1 u2 u3 u4| =: (u, d) + * |d1 d2 d3 d4| + */ +ImageF UpsampleV2(const ImageF& src, ThreadPool* pool) { + const HWY_FULL(float) df; + const size_t S = Lanes(df); + const auto c14 = Set(df, 0.25f); + const auto c34 = Set(df, 0.75f); + + const size_t xsize = src.xsize(); + const size_t ysize = src.ysize(); + JXL_ASSERT(xsize != 0); + JXL_ASSERT(ysize != 0); + ImageF dst(xsize, ysize * 2); + if (ysize == 1) { + memcpy(dst.Row(0), src.Row(0), xsize * sizeof(*src.Row(0))); + memcpy(dst.Row(1), src.Row(0), xsize * sizeof(*src.Row(0))); + } else { + constexpr size_t kGroupArea = kGroupDim * kGroupDim; + const size_t lines_per_group = DivCeil(kGroupArea, xsize); + const size_t num_stripes = DivCeil(ysize, lines_per_group); + const auto upsample = [&](int idx, int /* thread*/) { + const size_t y0 = idx * lines_per_group; + const size_t y1 = std::min(y0 + lines_per_group, ysize); + for (size_t y = y0; y < y1; ++y) { + const float* JXL_RESTRICT prev_row = src.ConstRow(y == 0 ? 1 : y - 1); + const float* JXL_RESTRICT current_row = src.ConstRow(y); + const float* JXL_RESTRICT next_row = + src.ConstRow(y == ysize - 1 ? ysize - 2 : y + 1); + float* JXL_RESTRICT dst1_row = dst.Row(2 * y); + float* JXL_RESTRICT dst2_row = dst.Row(2 * y + 1); + for (size_t x = 0; x < xsize; x += S) { + const auto current34 = Load(df, current_row + x) * c34; + const auto prev = Load(df, prev_row + x); + const auto next = Load(df, next_row + x); + Store(MulAdd(prev, c14, current34), df, dst1_row + x); + Store(MulAdd(next, c14, current34), df, dst2_row + x); + } + } + }; + RunOnPool(pool, 0, static_cast(num_stripes), ThreadPool::SkipInit(), + upsample, "UpsampleV2"); + } + return dst; +} + +/* Horizontal upsampling: + * input: + * (a, b, c) := |a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4| + * ^^^^^^^^^^^ + * current block + * intermediate: + * l := (a << 3) {0001} (b >> 1) = [a4 b1 b2 b3] + * r := (c >> 3) {1000} (b << 1) = [b2 b3 b4 c1] + * o := 3 * b + l + * e := 3 * b + r + * output: + * |o1 e1 o2 e2 o3 e3 o4 e4| =: (o, e) + */ +ImageF UpsampleH2(const ImageF& src, ThreadPool* pool) { + const size_t xsize = src.xsize(); + const size_t ysize = src.ysize(); + JXL_ASSERT(xsize != 0); + JXL_ASSERT(ysize != 0); + ImageF dst(xsize * 2, ysize); + + constexpr size_t kGroupArea = kGroupDim * kGroupDim; + const size_t lines_per_group = DivCeil(kGroupArea, xsize); + const size_t num_stripes = DivCeil(ysize, lines_per_group); + + HWY_CAPPED(float, 4) d; // necessary for interleaving. + + const auto upsample = [&](int idx, int /* thread*/) { + const size_t y0 = idx * lines_per_group; + const size_t y1 = std::min(y0 + lines_per_group, ysize); + for (size_t y = y0; y < y1; ++y) { + const float* JXL_RESTRICT current_row = src.ConstRow(y); + float* JXL_RESTRICT dst_row = dst.Row(y); + const auto c34 = Set(d, 0.75f); + const auto c14 = Set(d, 0.25f); + for (size_t x = 1; x < xsize - 1; x += Lanes(d)) { + auto current = LoadU(d, current_row + x) * c34; + auto prev = LoadU(d, current_row + x - 1); + auto next = LoadU(d, current_row + x + 1); + auto left = MulAdd(c14, prev, current); + auto right = MulAdd(c14, next, current); +#if HWY_TARGET == HWY_SCALAR + StoreU(left, d, dst_row + x * 2); + StoreU(right, d, dst_row + x * 2 + 1); +#else + StoreU(InterleaveLower(left, right), d, dst_row + x * 2); + StoreU(InterleaveUpper(left, right), d, dst_row + x * 2 + Lanes(d)); +#endif + } + if (xsize == 1) { + dst_row[0] = dst_row[1] = current_row[0]; + } else { + const float leftmost = current_row[0] * 0.75f + current_row[1] * 0.25f; + dst_row[0] = dst_row[1] = leftmost; + const float rightmost = + current_row[xsize - 1] * 0.75f + current_row[xsize - 2] * 0.25f; + dst_row[xsize * 2 - 2] = dst_row[xsize * 2 - 1] = rightmost; + } + } + }; + RunOnPool(pool, 0, static_cast(num_stripes), ThreadPool::SkipInit(), + upsample, "UpsampleH2"); + return dst; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(OpsinToLinearInplace); +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinearInplace)(inout, pool, opsin_params); +} + +HWY_EXPORT(OpsinToLinear); +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinear)(opsin, rect, pool, linear, + opsin_params); +} + +HWY_EXPORT(YcbcrToRgb); +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + return HWY_DYNAMIC_DISPATCH(YcbcrToRgb)(ycbcr, rgb, rect); +} + +HWY_EXPORT(UpsampleV2); +ImageF UpsampleV2(const ImageF& src, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(UpsampleV2)(src, pool); +} + +HWY_EXPORT(UpsampleH2); +ImageF UpsampleH2(const ImageF& src, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(UpsampleH2)(src, pool); +} + +HWY_EXPORT(HasFastXYBTosRGB8); +bool HasFastXYBTosRGB8() { return HWY_DYNAMIC_DISPATCH(HasFastXYBTosRGB8)(); } + +HWY_EXPORT(FastXYBTosRGB8); +void FastXYBTosRGB8(const Image3F& input, const Rect& input_rect, + const Rect& output_buf_rect, + uint8_t* JXL_RESTRICT output_buf, size_t xsize) { + return HWY_DYNAMIC_DISPATCH(FastXYBTosRGB8)( + input, input_rect, output_buf_rect, output_buf, xsize); +} + +void OpsinParams::Init(float intensity_target) { + InitSIMDInverseMatrix(GetOpsinAbsorbanceInverseMatrix(), inverse_opsin_matrix, + intensity_target); + memcpy(opsin_biases, kNegOpsinAbsorbanceBiasRGB, + sizeof(kNegOpsinAbsorbanceBiasRGB)); + memcpy(quant_biases, kDefaultQuantBias, sizeof(kDefaultQuantBias)); + for (size_t c = 0; c < 4; c++) { + opsin_biases_cbrt[c] = cbrtf(opsin_biases[c]); + } +} + +Status OutputEncodingInfo::Set(const ImageMetadata& metadata) { + const auto& im = metadata.transform_data.opsin_inverse_matrix; + float inverse_matrix[9]; + memcpy(inverse_matrix, im.inverse_matrix, sizeof(inverse_matrix)); + float intensity_target = metadata.IntensityTarget(); + if (metadata.xyb_encoded) { + const auto& orig_color_encoding = metadata.color_encoding; + color_encoding = ColorEncoding::LinearSRGB(orig_color_encoding.IsGray()); + // Figure out if we can output to this color encoding. + do { + if (!orig_color_encoding.HaveFields()) break; + // TODO(veluca): keep in sync with dec_reconstruct.cc + if (!orig_color_encoding.tf.IsPQ() && !orig_color_encoding.tf.IsSRGB() && + !orig_color_encoding.tf.IsGamma() && + !orig_color_encoding.tf.IsLinear()) { + break; + } + if (orig_color_encoding.tf.IsGamma()) { + inverse_gamma = 1.0f / orig_color_encoding.tf.GetGamma(); + } + if (orig_color_encoding.IsGray() && + orig_color_encoding.white_point != WhitePoint::kD65) { + // TODO(veluca): figure out what should happen here. + break; + } + + if ((orig_color_encoding.primaries != Primaries::kSRGB || + orig_color_encoding.white_point != WhitePoint::kD65) && + !orig_color_encoding.IsGray()) { + all_default_opsin = false; + float srgb_to_xyzd50[9]; + const auto& srgb = ColorEncoding::SRGB(/*is_gray=*/false); + JXL_CHECK(PrimariesToXYZD50( + srgb.GetPrimaries().r.x, srgb.GetPrimaries().r.y, + srgb.GetPrimaries().g.x, srgb.GetPrimaries().g.y, + srgb.GetPrimaries().b.x, srgb.GetPrimaries().b.y, + srgb.GetWhitePoint().x, srgb.GetWhitePoint().y, srgb_to_xyzd50)); + float xyzd50_to_original[9]; + JXL_RETURN_IF_ERROR(PrimariesToXYZD50( + orig_color_encoding.GetPrimaries().r.x, + orig_color_encoding.GetPrimaries().r.y, + orig_color_encoding.GetPrimaries().g.x, + orig_color_encoding.GetPrimaries().g.y, + orig_color_encoding.GetPrimaries().b.x, + orig_color_encoding.GetPrimaries().b.y, + orig_color_encoding.GetWhitePoint().x, + orig_color_encoding.GetWhitePoint().y, xyzd50_to_original)); + Inv3x3Matrix(xyzd50_to_original); + float srgb_to_original[9]; + MatMul(xyzd50_to_original, srgb_to_xyzd50, 3, 3, 3, srgb_to_original); + MatMul(srgb_to_original, im.inverse_matrix, 3, 3, 3, inverse_matrix); + } + color_encoding = orig_color_encoding; + color_encoding_is_original = true; + if (color_encoding.tf.IsPQ()) { + intensity_target = 10000; + } + } while (false); + } else { + color_encoding = metadata.color_encoding; + } + if (std::abs(intensity_target - 255.0) > 0.1f || !im.all_default) { + all_default_opsin = false; + } + InitSIMDInverseMatrix(inverse_matrix, opsin_params.inverse_opsin_matrix, + intensity_target); + std::copy(std::begin(im.opsin_biases), std::end(im.opsin_biases), + opsin_params.opsin_biases); + for (int i = 0; i < 3; ++i) { + opsin_params.opsin_biases_cbrt[i] = cbrtf(opsin_params.opsin_biases[i]); + } + opsin_params.opsin_biases_cbrt[3] = opsin_params.opsin_biases[3] = 1; + std::copy(std::begin(im.quant_biases), std::end(im.quant_biases), + opsin_params.quant_biases); + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb.h b/third_party/jpeg-xl/lib/jxl/dec_xyb.h new file mode 100644 index 000000000000..b8fc24894e43 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb.h @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_DEC_XYB_H_ +#define LIB_JXL_DEC_XYB_H_ + +// XYB -> linear sRGB. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +// Parameters for XYB->sRGB conversion. +struct OpsinParams { + float inverse_opsin_matrix[9 * 4]; + float opsin_biases[4]; + float opsin_biases_cbrt[4]; + float quant_biases[4]; + void Init(float intensity_target); +}; + +struct OutputEncodingInfo { + ColorEncoding color_encoding; + float inverse_gamma; + // Contains an opsin matrix that converts to the primaries of the output + // encoding. + OpsinParams opsin_params; + Status Set(const ImageMetadata& metadata); + bool all_default_opsin = true; + bool color_encoding_is_original = false; +}; + +// Converts `inout` (not padded) from opsin to linear sRGB in-place. Called from +// per-pass postprocessing, hence parallelized. +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params); + +// Converts `opsin:rect` (opsin may be padded, rect.x0 must be vector-aligned) +// to linear sRGB. Called from whole-frame encoder, hence parallelized. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params); + +// Bt.601 to match JPEG/JFIF. Inputs are _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect); + +ImageF UpsampleV2(const ImageF& src, ThreadPool* pool); + +// WARNING: this uses unaligned accesses, so the caller must first call +// src.InitializePaddingForUnalignedAccesses() to avoid msan crashes. +ImageF UpsampleH2(const ImageF& src, ThreadPool* pool); + +bool HasFastXYBTosRGB8(); +void FastXYBTosRGB8(const Image3F& input, const Rect& input_rect, + const Rect& output_buf_rect, + uint8_t* JXL_RESTRICT output_buf, size_t xsize); + +} // namespace jxl + +#endif // LIB_JXL_DEC_XYB_H_ diff --git a/third_party/jpeg-xl/lib/jxl/decode.cc b/third_party/jpeg-xl/lib/jxl/decode.cc new file mode 100644 index 000000000000..efe2ba7b7c3f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode.cc @@ -0,0 +1,2076 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/decode.h" + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/dec_reconstruct.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/memory_manager_internal.h" +#include "lib/jxl/toc.h" + +namespace { + +// If set (by fuzzer) then some operations will fail, if those would require +// allocating large objects. Actual memory usage might be two orders of +// magnitude bigger. +// TODO(eustas): this is a poor-mans replacement for memory-manager approach; +// remove, once memory-manager actually works. +size_t memory_limit_base_ = 0; + +bool CheckSizeLimit(size_t xsize, size_t ysize) { + if (!memory_limit_base_) return true; + if (xsize == 0 || ysize == 0) return true; + size_t num_pixels = xsize * ysize; + if (num_pixels / xsize != ysize) return false; // overflow + if (num_pixels > memory_limit_base_) return false; + return true; +} + +// Checks if a + b > size, taking possible integer overflow into account. +bool OutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return true; + if (pos < a) return true; // overflow happened + return false; +} + +// Checks if a + b + c > size, taking possible integer overflow into account. +bool OutOfBounds(size_t a, size_t b, size_t c, size_t size) { + size_t pos = a + b; + if (pos < b) return true; // overflow happened + pos += c; + if (pos < c) return true; // overflow happened + if (pos > size) return true; + return false; +} + +bool SumOverflows(size_t a, size_t b, size_t c) { + size_t sum = a + b; + if (sum < b) return true; + sum += c; + if (sum < c) return true; + return false; +} + +JXL_INLINE size_t InitialBasicInfoSizeHint() { + // Amount of bytes before the start of the codestream in the container format, + // assuming that the codestream is the first box after the signature and + // filetype boxes. 12 bytes signature box + 20 bytes filetype box + 16 bytes + // codestream box length + name + optional XLBox length. + const size_t container_header_size = 48; + + // Worst-case amount of bytes for basic info of the JPEG XL codestream header, + // that is all information up to and including extra_channel_bits. Up to + // around 2 bytes signature + 8 bytes SizeHeader + 31 bytes ColorEncoding + 4 + // bytes rest of ImageMetadata + 5 bytes part of ImageMetadata2. + // TODO(lode): recompute and update this value when alpha_bits is moved to + // extra channels info. + const size_t max_codestream_basic_info_size = 50; + + return container_header_size + max_codestream_basic_info_size; +} + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_DEC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_DEC_ERROR) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JXL_DEC_ERROR) +#endif // JXL_CRASH_ON_ERROR + +JxlDecoderStatus ConvertStatus(JxlDecoderStatus status) { return status; } + +JxlDecoderStatus ConvertStatus(jxl::Status status) { + return status ? JXL_DEC_SUCCESS : JXL_DEC_ERROR; +} + +JxlSignature ReadSignature(const uint8_t* buf, size_t len, size_t* pos) { + if (*pos >= len) return JXL_SIG_NOT_ENOUGH_BYTES; + + buf += *pos; + len -= *pos; + + // JPEG XL codestream: 0xff 0x0a + if (len >= 1 && buf[0] == 0xff) { + if (len < 2) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == jxl::kCodestreamMarker) { + *pos += 2; + return JXL_SIG_CODESTREAM; + } else { + return JXL_SIG_INVALID; + } + } + + // JPEG XL container + if (len >= 1 && buf[0] == 0) { + if (len < 12) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == 0 && buf[2] == 0 && buf[3] == 0xC && buf[4] == 'J' && + buf[5] == 'X' && buf[6] == 'L' && buf[7] == ' ' && + buf[8] == 0xD && buf[9] == 0xA && buf[10] == 0x87 && + buf[11] == 0xA) { + *pos += 12; + return JXL_SIG_CONTAINER; + } else { + return JXL_SIG_INVALID; + } + } + + return JXL_SIG_INVALID; +} + +} // namespace + +uint32_t JxlDecoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +JxlSignature JxlSignatureCheck(const uint8_t* buf, size_t len) { + size_t pos = 0; + return ReadSignature(buf, len, &pos); +} + +size_t BitsPerChannel(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_BOOLEAN: + return 1; + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_UINT32: + return 32; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + // No default, give compiler error if new type not handled. + } + return 0; // Indicate invalid data type. +} + +enum class DecoderStage : uint32_t { + kInited, // Decoder created, no JxlDecoderProcessInput called yet + kStarted, // Running JxlDecoderProcessInput calls + kFinished, // Everything done, nothing left to process + kError, // Error occured, decoder object no longer useable +}; + +enum class FrameStage : uint32_t { + kHeader, // Must parse frame header. dec->frame_start must be set up + // correctly already. + kTOC, // Must parse TOC + kDC, // Must parse DC pixels + kDCOutput, // Must output DC pixels + kFull, // Must parse full pixels + kFullOutput, // Must output full pixels +}; + +// Manages the sections for the FrameDecoder based on input bytes received. +struct Sections { + // sections_begin = position in the frame where the sections begin, after + // the frame header and TOC, so sections_begin = sum of frame header size and + // TOC size. + Sections(jxl::FrameDecoder* frame_dec, size_t frame_size, + size_t sections_begin) + : frame_dec_(frame_dec), + frame_size_(frame_size), + sections_begin_(sections_begin) {} + + Sections(const Sections&) = delete; + Sections& operator=(const Sections&) = delete; + Sections(Sections&&) = delete; + Sections& operator=(Sections&&) = delete; + + ~Sections() { + // Avoid memory leaks if the JXL decoder quits early and doesn't end up + // calling CloseInput(). + CloseInput(); + } + + // frame_dec_ must have been Inited already, but not yet done ProcessSections. + JxlDecoderStatus Init() { + section_received.resize(frame_dec_->NumSections(), 0); + + const auto& offsets = frame_dec_->SectionOffsets(); + const auto& sizes = frame_dec_->SectionSizes(); + + // Ensure none of the sums of section offset and size overflow. + for (size_t i = 0; i < frame_dec_->NumSections(); i++) { + if (OutOfBounds(sections_begin_, offsets[i], sizes[i], frame_size_)) { + return JXL_API_ERROR("section out of bounds"); + } + } + + return JXL_DEC_SUCCESS; + } + + // Sets the input data for the frame. The frame pointer must point to the + // beginning of the frame, size is the amount of bytes gotten so far and + // should increase with next calls until the full frame is loaded. + // TODO(lode): allow caller to provide only later chunks of memory when + // earlier sections are fully processed already. + void SetInput(const uint8_t* frame, size_t size) { + const auto& offsets = frame_dec_->SectionOffsets(); + const auto& sizes = frame_dec_->SectionSizes(); + + for (size_t i = 0; i < frame_dec_->NumSections(); i++) { + if (section_received[i]) continue; + if (!OutOfBounds(sections_begin_, offsets[i], sizes[i], size)) { + section_received[i] = 1; + section_info.emplace_back(jxl::FrameDecoder::SectionInfo{nullptr, i}); + section_status.emplace_back(); + } + } + // Reset all the bitreaders, because the address of the frame pointer may + // change, even if it always represents the same frame start. + for (size_t i = 0; i < section_info.size(); i++) { + size_t id = section_info[i].id; + JXL_ASSERT(section_info[i].br == nullptr); + section_info[i].br = new jxl::BitReader(jxl::Span( + frame + sections_begin_ + offsets[id], sizes[id])); + } + } + + JxlDecoderStatus CloseInput() { + bool out_of_bounds = false; + for (size_t i = 0; i < section_info.size(); i++) { + if (!section_info[i].br) continue; + if (!section_info[i].br->AllReadsWithinBounds()) { + // Mark out of bounds section, but keep closing and deleting the next + // ones as well. + out_of_bounds = true; + } + JXL_ASSERT(section_info[i].br->Close()); + delete section_info[i].br; + section_info[i].br = nullptr; + } + if (out_of_bounds) { + // If any bit reader indicates out of bounds, it's an error, not just + // needing more input, since we ensure only bit readers containing + // a complete section are provided to the FrameDecoder. + return JXL_API_ERROR("frame out of bounds"); + } + return JXL_DEC_SUCCESS; + } + + // Not managed by us. + jxl::FrameDecoder* frame_dec_; + + size_t frame_size_; + size_t sections_begin_; + + std::vector section_info; + std::vector section_status; + std::vector section_received; +}; + +struct JxlDecoderStruct { + JxlDecoderStruct() = default; + + JxlMemoryManager memory_manager; + std::unique_ptr thread_pool; + + DecoderStage stage; + + // Status of progression, internal. + bool got_signature; + bool first_codestream_seen; + // Indicates we know that we've seen the last codestream, however this is not + // guaranteed to be true for the last box because a jxl file may have multiple + // "jxlp" boxes and it is possible (and permitted) that the last one is not a + // final box that uses size 0 to indicate the end. + bool last_codestream_seen; + bool got_basic_info; + bool got_all_headers; // Codestream metadata headers + + // This means either we actually got the preview image, or determined we + // cannot get it or there is none. + bool got_preview_image; + + // Processed the last frame, so got_dc_image, and so on false no + // longer mean there is more work to do. + bool last_frame_reached; + + // Position of next_in in the original file including box format if present + // (as opposed to position in the codestream) + size_t file_pos; + // Begin and end of the content of the current codestream box. This could be + // a partial codestream box. + // codestream_begin 0 is used to indicate the begin is not yet known. + // codestream_end 0 is used to indicate uncapped (until end of file, for the + // last box if this box doesn't indicate its actual size). + // Not used if the file is a direct codestream. + size_t codestream_begin; + size_t codestream_end; + + // Settings + bool keep_orientation; + + // Bitfield, for which informative events (JXL_DEC_BASIC_INFO, etc...) the + // decoder returns a status. By default, do not return for any of the events, + // only return when the decoder cannot continue becasue it needs mor input or + // output data. + int events_wanted; + int orig_events_wanted; + + // Fields for reading the basic info from the header. + size_t basic_info_size_hint; + bool have_container; + + // Whether the DC out buffer was set. It is possible for dc_out_buffer to + // be nullptr and dc_out_buffer_set be true, indicating it was deliberately + // set to nullptr. + bool preview_out_buffer_set; + bool dc_out_buffer_set; + // Idem for the image buffer. + bool image_out_buffer_set; + + // Owned by the caller, buffers for DC image and full resolution images + void* preview_out_buffer; + void* dc_out_buffer; + void* image_out_buffer; + + size_t preview_out_size; + size_t dc_out_size; + size_t image_out_size; + + // TODO(lode): merge these? + JxlPixelFormat preview_out_format; + JxlPixelFormat dc_out_format; + JxlPixelFormat image_out_format; + + jxl::CodecMetadata metadata; + std::unique_ptr ib; + + std::unique_ptr passes_state; + std::unique_ptr frame_dec; + std::unique_ptr sections; + // The FrameDecoder is initialized, and not yet finalized + bool frame_dec_in_progress; + + // headers and TOC for the current frame. When got_toc is true, this is + // always the frame header of the last frame of the current still series, + // that is, the displayed frame. + std::unique_ptr frame_header; + jxl::FrameDimensions frame_dim; + + // Start of the current frame being processed, as offset from the beginning of + // the codestream. + size_t frame_start; + size_t frame_size; + size_t dc_size; + FrameStage frame_stage; + // The currently processed frame is the last of the current composite still, + // and so must be returned as pixels + bool is_last_of_still; + // The currently processed frame is the last of the codestream + bool is_last_total; + + // Codestream input data is stored here, when the decoder takes in and stores + // the user input bytes. If the decoder does not do that (e.g. in one-shot + // case), this field is unused. + // TODO(lode): avoid needing this field once the C++ decoder doesn't need + // all bytes at once, to save memory. Find alternative to std::vector doubling + // strategy to prevent some memory usage. + std::vector codestream; + + // Content of the most recently parsed JPEG reconstruction box is stored here. + std::vector jpeg_reconstruction_buffer; + // Decoded content of the most recently parsed JPEG reconstruction box is + // stored here. + std::unique_ptr jpeg_reconstruction_data; + // True if the decoder is currently reading bytes inside a JPEG reconstruction + // box. + bool inside_jpeg_reconstruction_box = false; + // True if the JPEG reconstruction box had undefined size (all remaining + // bytes). + bool jpeg_reconstruction_box_until_eof = false; + // Size of most recently parsed JPEG reconstruction box contents. + size_t jpeg_reconstruction_size = 0; + // Next bytes to write JPEG reconstruction to. + uint8_t* next_jpeg_reconstruction_out = nullptr; + // Available bytes to write JPEG reconstruction to. + size_t avail_jpeg_reconstruction_size = 0; + + // Position in the actual codestream, which codestream.begin() points to. + // Non-zero once earlier parts of the codestream vector have been erased. + size_t codestream_pos; + + // Statistics which CodecInOut can keep + uint64_t dec_pixels; + + const uint8_t* next_in; + size_t avail_in; +}; + +// TODO(zond): Make this depend on the data loaded into the decoder. +JxlDecoderStatus JxlDecoderDefaultPixelFormat(const JxlDecoder* dec, + JxlPixelFormat* format) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + *format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + return JXL_DEC_SUCCESS; +} + +void JxlDecoderReset(JxlDecoder* dec) { + dec->thread_pool.reset(); + dec->stage = DecoderStage::kInited; + dec->got_signature = false; + dec->first_codestream_seen = false; + dec->last_codestream_seen = false; + dec->got_basic_info = false; + dec->got_all_headers = false; + dec->got_preview_image = false; + dec->last_frame_reached = false; + dec->file_pos = 0; + dec->codestream_pos = 0; + dec->codestream_begin = 0; + dec->codestream_end = 0; + dec->keep_orientation = false; + dec->events_wanted = 0; + dec->orig_events_wanted = 0; + dec->basic_info_size_hint = InitialBasicInfoSizeHint(); + dec->have_container = 0; + dec->preview_out_buffer_set = false; + dec->dc_out_buffer_set = false; + dec->image_out_buffer_set = false; + dec->preview_out_buffer = nullptr; + dec->dc_out_buffer = nullptr; + dec->image_out_buffer = nullptr; + dec->preview_out_size = 0; + dec->dc_out_size = 0; + dec->image_out_size = 0; + dec->dec_pixels = 0; + dec->next_in = 0; + dec->avail_in = 0; + + dec->passes_state.reset(nullptr); + dec->frame_dec.reset(nullptr); + dec->sections.reset(nullptr); + dec->frame_dec_in_progress = false; + + dec->ib.reset(); + dec->metadata = jxl::CodecMetadata(); + dec->frame_header.reset(new jxl::FrameHeader(&dec->metadata)); + dec->frame_dim = jxl::FrameDimensions(); + dec->codestream.clear(); + + dec->frame_stage = FrameStage::kHeader; + dec->frame_start = 0; + dec->frame_size = 0; + dec->dc_size = 0; + dec->is_last_of_still = false; + dec->is_last_total = false; +} + +JxlDecoder* JxlDecoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlDecoder)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlDecoder* dec = new (alloc) JxlDecoder(); + dec->memory_manager = local_memory_manager; + + JxlDecoderReset(dec); + + return dec; +} + +void JxlDecoderDestroy(JxlDecoder* dec) { + if (dec) { + // Call destructor directly since custom free function is used. + dec->~JxlDecoder(); + jxl::MemoryManagerFree(&dec->memory_manager, dec); + } +} + +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetParallelRunner(JxlDecoder* dec, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (dec->thread_pool) return JXL_API_ERROR("parallel runner already set"); + dec->thread_pool.reset( + new jxl::ThreadPool(parallel_runner, parallel_runner_opaque)); + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderSizeHintBasicInfo(const JxlDecoder* dec) { + if (dec->got_basic_info) return 0; + return dec->basic_info_size_hint; +} + +JxlDecoderStatus JxlDecoderSubscribeEvents(JxlDecoder* dec, int events_wanted) { + if (dec->stage != DecoderStage::kInited) { + return JXL_DEC_ERROR; // Cannot subscribe to events after having started. + } + if (events_wanted & 63) { + return JXL_DEC_ERROR; // Can only subscribe to informative events. + } + dec->events_wanted = events_wanted; + dec->orig_events_wanted = events_wanted; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetKeepOrientation(JxlDecoder* dec, + JXL_BOOL keep_orientation) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set keep_orientation option before starting"); + } + dec->keep_orientation = !!keep_orientation; + return JXL_DEC_SUCCESS; +} + +namespace jxl { +namespace { + +template +bool CanRead(Span data, BitReader* reader, T* JXL_RESTRICT t) { + // Use a copy of the bit reader because CanRead advances bits. + BitReader reader2(data); + reader2.SkipBits(reader->TotalBitsConsumed()); + bool result = Bundle::CanRead(&reader2, t); + JXL_ASSERT(reader2.Close()); + return result; +} + +// Returns JXL_DEC_SUCCESS if the full bundle was successfully read, status +// indicating either error or need more input otherwise. +template +JxlDecoderStatus ReadBundle(Span data, BitReader* reader, + T* JXL_RESTRICT t) { + if (!CanRead(data, reader, t)) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (!Bundle::Read(reader, t)) { + return JXL_DEC_ERROR; + } + return JXL_DEC_SUCCESS; +} + +#define JXL_API_RETURN_IF_ERROR(expr) \ + { \ + JxlDecoderStatus status_ = ConvertStatus(expr); \ + if (status_ != JXL_DEC_SUCCESS) return status_; \ + } + +std::unique_ptr> GetBitReader( + Span span) { + BitReader* reader = new BitReader(span); + return std::unique_ptr>( + reader, [](BitReader* reader) { + // We can't allow Close to abort the program if the reader is out of + // bounds, or all return paths in the code, even those that already + // return failure, would have to manually call AllReadsWithinBounds(). + // Invalid JXL codestream should not cause program to quit. + (void)reader->AllReadsWithinBounds(); + (void)reader->Close(); + delete reader; + }); +} + +JxlDecoderStatus JxlDecoderReadBasicInfo(JxlDecoder* dec, const uint8_t* in, + size_t size) { + size_t pos = 0; + + // Check and skip the codestream signature + JxlSignature signature = ReadSignature(in, size, &pos); + if (signature == JXL_SIG_NOT_ENOUGH_BYTES) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (signature == JXL_SIG_CONTAINER) { + // There is a container signature where we expect a codestream, container + // is handled at a higher level already. + return JXL_API_ERROR("invalid: nested container"); + } + if (signature != JXL_SIG_CODESTREAM) { + return JXL_API_ERROR("invalid signature"); + } + + Span span(in + pos, size - pos); + auto reader = GetBitReader(span); + JXL_API_RETURN_IF_ERROR(ReadBundle(span, reader.get(), &dec->metadata.size)); + + dec->metadata.m.nonserialized_only_parse_basic_info = true; + JXL_API_RETURN_IF_ERROR(ReadBundle(span, reader.get(), &dec->metadata.m)); + dec->metadata.m.nonserialized_only_parse_basic_info = false; + dec->got_basic_info = true; + dec->basic_info_size_hint = 0; + + if (!CheckSizeLimit(dec->metadata.size.xsize(), dec->metadata.size.ysize())) { + return JXL_API_ERROR("image is too large"); + } + + return JXL_DEC_SUCCESS; +} + +// Reads all codestream headers (but not frame headers) +JxlDecoderStatus JxlDecoderReadAllHeaders(JxlDecoder* dec, const uint8_t* in, + size_t size) { + size_t pos = 0; + + // Check and skip the codestream signature + JxlSignature signature = ReadSignature(in, size, &pos); + if (signature == JXL_SIG_CONTAINER) { + return JXL_API_ERROR("invalid: nested container"); + } + if (signature != JXL_SIG_CODESTREAM) { + return JXL_API_ERROR("invalid signature"); + } + + Span span(in + pos, size - pos); + auto reader = GetBitReader(span); + SizeHeader dummy_size_header; + JXL_API_RETURN_IF_ERROR(ReadBundle(span, reader.get(), &dummy_size_header)); + + // We already decoded the metadata to dec->metadata.m, no reason to + // overwrite it, use a dummy metadata instead. + ImageMetadata dummy_metadata; + JXL_API_RETURN_IF_ERROR(ReadBundle(span, reader.get(), &dummy_metadata)); + + JXL_API_RETURN_IF_ERROR( + ReadBundle(span, reader.get(), &dec->metadata.transform_data)); + + if (dec->metadata.m.color_encoding.WantICC()) { + PaddedBytes icc; + jxl::Status status = ReadICC(reader.get(), &icc, memory_limit_base_); + // Always check AllReadsWithinBounds, not all the C++ decoder implementation + // handles reader out of bounds correctly yet (e.g. context map). Not + // checking AllReadsWithinBounds can cause reader->Close() to trigger an + // assert, but we don't want library to quit program for invalid codestream. + if (!reader->AllReadsWithinBounds()) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (!status) { + if (status.code() == StatusCode::kNotEnoughBytes) { + return JXL_DEC_NEED_MORE_INPUT; + } + // Other non-successful status is an error + return JXL_DEC_ERROR; + } + if (!dec->metadata.m.color_encoding.SetICCRaw(std::move(icc))) { + return JXL_DEC_ERROR; + } + } + + dec->got_all_headers = true; + JXL_API_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + + dec->frame_start = pos + reader->TotalBitsConsumed() / jxl::kBitsPerByte; + + if (!dec->passes_state) { + dec->passes_state.reset(new jxl::PassesDecoderState()); + } + JXL_API_RETURN_IF_ERROR( + dec->passes_state->output_encoding_info.Set(dec->metadata.m)); + + return JXL_DEC_SUCCESS; +} + +static size_t GetStride(const JxlDecoder* dec, const JxlPixelFormat& format, + const jxl::ImageBundle* frame = nullptr) { + size_t xsize = dec->metadata.xsize(); + if (!dec->keep_orientation && dec->metadata.m.orientation > 4) { + xsize = dec->metadata.ysize(); + } + if (frame) { + xsize = dec->keep_orientation ? frame->xsize() : frame->oriented_xsize(); + } + size_t stride = xsize * (BitsPerChannel(format.data_type) * + format.num_channels / jxl::kBitsPerByte); + if (format.align > 1) { + stride = jxl::DivCeil(stride, format.align) * format.align; + } + return stride; +} + +static JxlDecoderStatus ConvertImageInternal(const JxlDecoder* dec, + const jxl::ImageBundle& frame, + const JxlPixelFormat& format, + void* out_image, size_t out_size) { + // TODO(lode): handle mismatch of RGB/grayscale color profiles and pixel data + // color/grayscale format + const auto& metadata = dec->metadata.m; + + const size_t stride = GetStride(dec, format, &frame); + + bool float_format = format.data_type == JXL_TYPE_FLOAT || + format.data_type == JXL_TYPE_FLOAT16; + + jxl::Orientation undo_orientation = dec->keep_orientation + ? jxl::Orientation::kIdentity + : metadata.GetOrientation(); + JXL_DASSERT(!dec->frame_dec || !dec->frame_dec->HasRGBBuffer()); + jxl::Status status = jxl::ConvertToExternal( + frame, BitsPerChannel(format.data_type), float_format, + format.num_channels, format.endianness, stride, dec->thread_pool.get(), + out_image, out_size, undo_orientation); + + return status ? JXL_DEC_SUCCESS : JXL_DEC_ERROR; +} + +// Reads all frame headers and computes the total size in bytes of the frame. +// Stores information in dec->frame_header and dec->frame_dim. +// Outputs optional variables, unless set to nullptr: +// frame_size: total frame size +// header_size: size of the frame header and TOC within the frame +// dc_size: size of DC groups within the frame, or 0 if there's no DC or we're +// unable to compute its size. +// Can finish successfully if reader has headers and TOC available, does not +// read groups themselves. +// TODO(lode): merge this with FrameDecoder +JxlDecoderStatus ParseFrameHeader(JxlDecoder* dec, + jxl::FrameHeader* frame_header, + const uint8_t* in, size_t size, size_t pos, + bool is_preview, size_t* frame_size, + size_t* dc_size) { + Span span(in + pos, size - pos); + auto reader = GetBitReader(span); + + frame_header->nonserialized_is_preview = is_preview; + jxl::Status status = DecodeFrameHeader(reader.get(), frame_header); + dec->frame_dim = frame_header->ToFrameDimensions(); + if (!CheckSizeLimit(dec->frame_dim.xsize_upsampled_padded, + dec->frame_dim.ysize_upsampled_padded)) { + return JXL_API_ERROR("frame is too large"); + } + + if (status.code() == StatusCode::kNotEnoughBytes) { + // TODO(lode): prevent asking for way too much input bytes in case of + // invalid header that the decoder thinks is a very long user extension + // instead. Example: fields can currently print something like this: + // "../lib/jxl/fields.cc:416: Skipping 71467322-bit extension(s)" + // Maybe fields.cc should return error in the above case rather than + // print a message. + return JXL_DEC_NEED_MORE_INPUT; + } else if (!status) { + return JXL_API_ERROR("invalid frame header"); + } + + // Read TOC. + uint64_t groups_total_size; + const bool has_ac_global = true; + const size_t toc_entries = + NumTocEntries(dec->frame_dim.num_groups, dec->frame_dim.num_dc_groups, + frame_header->passes.num_passes, has_ac_global); + + std::vector group_offsets; + std::vector group_sizes; + status = ReadGroupOffsets(toc_entries, reader.get(), &group_offsets, + &group_sizes, &groups_total_size); + + // TODO(lode): we're actually relying on AllReadsWithinBounds() here + // instead of on status.code(), change the internal TOC C++ code to + // correctly set the status.code() instead so we can rely on that one. + if (!reader->AllReadsWithinBounds() || + status.code() == StatusCode::kNotEnoughBytes) { + return JXL_DEC_NEED_MORE_INPUT; + } else if (!status) { + return JXL_API_ERROR("invalid toc entries"); + } + + if (dc_size) { + bool can_get_dc = true; + if (frame_header->passes.num_passes == 1 && + dec->frame_dim.num_groups == 1) { + // If there is one pass and one group, the TOC only has one entry and + // doesn't allow to distinguish the DC size, so it's not easy to tell + // whether we got all DC bytes or not. This will happen for very small + // images only. + can_get_dc = false; + } + + *dc_size = 0; + if (can_get_dc) { + // one DcGlobal entry, N dc group entries. + size_t num_dc_toc_entries = 1 + dec->frame_dim.num_dc_groups; + if (group_sizes.size() < num_dc_toc_entries) { + JXL_ABORT("too small TOC"); + } + for (size_t i = 0; i < num_dc_toc_entries; i++) { + *dc_size = + std::max(*dc_size, group_sizes[i] + group_offsets[i]); + } + } + } + + JXL_DASSERT((reader->TotalBitsConsumed() % kBitsPerByte) == 0); + JXL_API_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + size_t header_size = (reader->TotalBitsConsumed() >> 3); + *frame_size = header_size + groups_total_size; + + return JXL_DEC_SUCCESS; +} + +// TODO(eustas): no CodecInOut -> no image size reinforcement -> possible OOM. +JxlDecoderStatus JxlDecoderProcessInternal(JxlDecoder* dec, const uint8_t* in, + size_t size) { + // If no parallel runner is set, use the default + // TODO(lode): move this initialization to an appropriate location once the + // runner is used to decode pixels. + if (!dec->thread_pool) { + dec->thread_pool.reset(new jxl::ThreadPool(nullptr, nullptr)); + } + + // No matter what events are wanted, the basic info is always required. + if (!dec->got_basic_info) { + JxlDecoderStatus status = JxlDecoderReadBasicInfo(dec, in, size); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_BASIC_INFO) { + dec->events_wanted &= ~JXL_DEC_BASIC_INFO; + return JXL_DEC_BASIC_INFO; + } + + if (!dec->got_all_headers) { + JxlDecoderStatus status = JxlDecoderReadAllHeaders(dec, in, size); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_EXTENSIONS) { + dec->events_wanted &= ~JXL_DEC_EXTENSIONS; + if (dec->metadata.m.extensions != 0) { + return JXL_DEC_EXTENSIONS; + } + } + + if (dec->events_wanted & JXL_DEC_COLOR_ENCODING) { + dec->events_wanted &= ~JXL_DEC_COLOR_ENCODING; + return JXL_DEC_COLOR_ENCODING; + } + + // Decode to pixels, only if required for the events the user wants. + if (!dec->got_preview_image) { + // Parse the preview, or at least its TOC to be able to skip the frame, if + // any frame or image decoding is desired. + bool parse_preview = + (dec->events_wanted & (JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FRAME | + JXL_DEC_DC_IMAGE | JXL_DEC_FULL_IMAGE)); + + if (!dec->metadata.m.have_preview) { + // There is no preview, mark this as done and go to next step + dec->got_preview_image = true; + } else if (!parse_preview) { + // No preview parsing needed, mark this step as done + dec->got_preview_image = true; + } else { + // Want to decode the preview, not just skip the frame + bool want_preview = (dec->events_wanted & JXL_DEC_PREVIEW_IMAGE); + size_t frame_size; + size_t pos = dec->frame_start; + dec->frame_header.reset(new FrameHeader(&dec->metadata)); + JxlDecoderStatus status = + ParseFrameHeader(dec, dec->frame_header.get(), in, size, pos, true, + &frame_size, /*dc_size=*/nullptr); + if (status != JXL_DEC_SUCCESS) return status; + if (OutOfBounds(pos, frame_size, size)) { + return JXL_DEC_NEED_MORE_INPUT; + } + + if (want_preview && !dec->preview_out_buffer_set) { + return JXL_DEC_NEED_PREVIEW_OUT_BUFFER; + } + + jxl::Span compressed(in + dec->frame_start, + size - dec->frame_start); + auto reader = GetBitReader(compressed); + jxl::DecompressParams dparams; + dparams.preview = want_preview ? jxl::Override::kOn : jxl::Override::kOff; + jxl::ImageBundle ib(&dec->metadata.m); + PassesDecoderState preview_dec_state; + JXL_API_RETURN_IF_ERROR( + preview_dec_state.output_encoding_info.Set(dec->metadata.m)); + if (!DecodeFrame(dparams, &preview_dec_state, dec->thread_pool.get(), + reader.get(), &ib, dec->metadata, + /*constraints=*/nullptr, + /*is_preview=*/true)) { + return JXL_API_ERROR("decoding preview failed"); + } + + // Set frame_start to the first non-preview frame. + dec->frame_start += DivCeil(reader->TotalBitsConsumed(), kBitsPerByte); + dec->got_preview_image = true; + + if (want_preview) { + if (dec->preview_out_buffer) { + JxlDecoderStatus status = ConvertImageInternal( + dec, ib, dec->preview_out_format, dec->preview_out_buffer, + dec->preview_out_size); + if (status != JXL_DEC_SUCCESS) return status; + } + return JXL_DEC_PREVIEW_IMAGE; + } + } + } + + // Handle frames + for (;;) { + if (!(dec->events_wanted & + (JXL_DEC_FULL_IMAGE | JXL_DEC_DC_IMAGE | JXL_DEC_FRAME))) { + break; + } + if (dec->frame_stage == FrameStage::kHeader && dec->is_last_total) { + break; + } + + if (dec->frame_stage == FrameStage::kHeader) { + size_t pos = dec->frame_start - dec->codestream_pos; + if (pos >= size) { + return JXL_DEC_NEED_MORE_INPUT; + } + dec->frame_header.reset(new FrameHeader(&dec->metadata)); + JxlDecoderStatus status = ParseFrameHeader( + dec, dec->frame_header.get(), in, size, pos, + /*is_preview=*/false, &dec->frame_size, &dec->dc_size); + if (status != JXL_DEC_SUCCESS) return status; + + // is last in entire codestream + dec->is_last_total = dec->frame_header->is_last; + // is last of current still + dec->is_last_of_still = + dec->is_last_total || dec->frame_header->animation_frame.duration > 0; + + dec->frame_stage = FrameStage::kTOC; + + if ((dec->events_wanted & JXL_DEC_FRAME) && dec->is_last_of_still) { + // Only return this for the last of a series of stills: patches frames + // etc... before this one do not contain the correct information such + // as animation timing, ... + return JXL_DEC_FRAME; + } + } + + if (dec->frame_stage == FrameStage::kTOC) { + size_t pos = dec->frame_start - dec->codestream_pos; + Span span(in + pos, size - pos); + auto reader = GetBitReader(span); + + if (!dec->passes_state) { + dec->passes_state.reset(new jxl::PassesDecoderState()); + } + if (!dec->ib) { + dec->ib.reset(new jxl::ImageBundle(&dec->metadata.m)); + } + + dec->frame_dec.reset(new FrameDecoder( + dec->passes_state.get(), dec->metadata, dec->thread_pool.get())); + + if (dec->next_jpeg_reconstruction_out != nullptr && + dec->jpeg_reconstruction_data != nullptr) { + // If JPEG reconstruction is wanted and possible, set the jpeg_data of + // the ImageBundle. + if (!jxl::jpeg::SetJPEGDataFromICC( + dec->ib->metadata()->color_encoding.ICC(), + dec->jpeg_reconstruction_data.get())) { + return JXL_DEC_ERROR; + } + dec->ib->jpeg_data = std::move(dec->jpeg_reconstruction_data); + } + + jxl::Status status = dec->frame_dec->InitFrame( + reader.get(), dec->ib.get(), /*is_preview=*/false, + /*allow_partial_frames=*/false, /*allow_partial_dc_global=*/false); + if (!status) JXL_API_RETURN_IF_ERROR(status); + + if (dec->image_out_format.data_type == JXL_TYPE_UINT8 && + dec->image_out_format.num_channels >= 3) { + bool is_rgba = dec->image_out_format.num_channels == 4; + dec->frame_dec->MaybeSetRGB8OutputBuffer( + reinterpret_cast(dec->image_out_buffer), + GetStride(dec, dec->image_out_format), is_rgba); + } + size_t sections_begin = + DivCeil(reader->TotalBitsConsumed(), kBitsPerByte); + + dec->sections.reset( + new Sections(dec->frame_dec.get(), dec->frame_size, sections_begin)); + JXL_API_RETURN_IF_ERROR(dec->sections->Init()); + + dec->frame_dec_in_progress = true; + dec->frame_stage = FrameStage::kDC; + } + + if (dec->frame_stage == FrameStage::kDC) { + if (!(dec->events_wanted & JXL_DEC_DC_IMAGE)) { + dec->frame_stage = FrameStage::kFull; + } + } + + bool return_full_image = false; + + if (dec->frame_stage == FrameStage::kFull || + dec->frame_stage == FrameStage::kDC) { + if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + if (!dec->image_out_buffer_set && + (dec->next_jpeg_reconstruction_out == nullptr || + dec->ib->jpeg_data == nullptr) && + dec->is_last_of_still) { + return JXL_DEC_NEED_IMAGE_OUT_BUFFER; + } + } + size_t pos = dec->frame_start - dec->codestream_pos; + + bool get_dc = dec->is_last_of_still && + (dec->frame_stage == FrameStage::kDC) && dec->dc_size != 0; + dec->sections->SetInput(in + pos, size - pos); + jxl::Status status = + dec->frame_dec->ProcessSections(dec->sections->section_info.data(), + dec->sections->section_info.size(), + dec->sections->section_status.data()); + JXL_API_RETURN_IF_ERROR(dec->sections->CloseInput()); + if (status.IsFatalError()) { + return JXL_API_ERROR("decoding frame failed"); + } + + // TODO(lode): allow next_in to move forward if sections from the + // beginning of the stream have been processed + + if (get_dc) { + // Not all DC sections have been processed yet + if (OutOfBounds(pos, dec->dc_size, size)) { + return JXL_DEC_NEED_MORE_INPUT; + } + + if (!dec->frame_dec->HasDecodedDC()) { + // DC not available, e.g. if the frame was not encoded as VarDCT. + get_dc = false; + } + if (dec->frame_header->custom_size_or_origin || + dec->frame_header->dc_level != 0 || + dec->frame_header->upsampling != 1) { + // We don't support JXL_DEC_DC_IMAGE if the frame size doesn't match + // the image size. + get_dc = false; + } + if (get_dc) { + dec->frame_stage = FrameStage::kDCOutput; + } else { + dec->frame_stage = FrameStage::kFull; + } + } + + if (!get_dc) { + if (status.code() == StatusCode::kNotEnoughBytes || + dec->sections->section_info.size() < + dec->frame_dec->NumSections()) { + // Not all sections have been processed yet + return JXL_DEC_NEED_MORE_INPUT; + } + + if (!dec->frame_dec->FinalizeFrame()) { + return JXL_API_ERROR("decoding frame failed"); + } + dec->frame_dec_in_progress = false; + + dec->frame_stage = FrameStage::kFullOutput; + } + } + + if (dec->frame_stage == FrameStage::kDCOutput) { + if (!dec->dc_out_buffer_set) { + return JXL_DEC_NEED_DC_OUT_BUFFER; + } + + PassesDecoderState& passes = *dec->passes_state.get(); + PassesSharedState& shared = passes.shared_storage; + Image3F dc(shared.dc_storage.xsize(), shared.dc_storage.ysize()); + UndoXYB(shared.dc_storage, &dc, dec->passes_state->output_encoding_info, + dec->thread_pool.get()); + // TODO(lode): use the real metadata instead, this requires matching + // all the extra channels. Support DC with alpha too. + jxl::ImageMetadata dummy; + ImageBundle dc_bundle(&dummy); + // TODO(lode): check whether LinearSRGB is always the correct color + // space to set here. + dc_bundle.SetFromImage( + std::move(dc), + dec->passes_state->output_encoding_info.color_encoding); + JXL_API_RETURN_IF_ERROR( + ConvertImageInternal(dec, dc_bundle, dec->dc_out_format, + dec->dc_out_buffer, dec->dc_out_size)); + dec->frame_stage = FrameStage::kFull; + return JXL_DEC_DC_IMAGE; + } + + if (dec->frame_stage == FrameStage::kFullOutput) { + if (dec->is_last_of_still) { + if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + dec->events_wanted &= ~JXL_DEC_FULL_IMAGE; + return_full_image = true; + } + + if (!dec->last_frame_reached) { + dec->events_wanted = + dec->orig_events_wanted & + (JXL_DEC_FULL_IMAGE | JXL_DEC_DC_IMAGE | JXL_DEC_FRAME); + } + + // If no output buffer was set, we merely return the JXL_DEC_FULL_IMAGE + // status without outputting pixels. + if (dec->next_jpeg_reconstruction_out != nullptr && + dec->ib->jpeg_data != nullptr) { + // Copy JPEG bytestream if desired. + uint8_t* tmp_next_out = dec->next_jpeg_reconstruction_out; + size_t tmp_avail_size = dec->avail_jpeg_reconstruction_size; + auto write = [&tmp_next_out, &tmp_avail_size](const uint8_t* buf, + size_t len) { + size_t to_write = std::min(tmp_avail_size, len); + memcpy(tmp_next_out, buf, to_write); + tmp_next_out += to_write; + tmp_avail_size -= to_write; + return to_write; + }; + Status write_result = + jxl::jpeg::WriteJpeg(*(dec->ib->jpeg_data.get()), write); + if (!write_result) { + if (tmp_avail_size == 0) { + return JXL_DEC_JPEG_NEED_MORE_OUTPUT; + } + return JXL_DEC_ERROR; + } + dec->next_jpeg_reconstruction_out = tmp_next_out; + dec->avail_jpeg_reconstruction_size = tmp_avail_size; + } else if (return_full_image && dec->image_out_buffer_set) { + if (!dec->frame_dec->HasRGBBuffer()) { + // Copy pixels if desired. + JxlDecoderStatus status = ConvertImageInternal( + dec, *dec->ib, dec->image_out_format, dec->image_out_buffer, + dec->image_out_size); + if (status != JXL_DEC_SUCCESS) return status; + } + dec->image_out_buffer_set = false; + } + } + } + + // The pixels have been output or are not needed, do not keep them in + // memory here. + dec->ib.reset(); + dec->frame_stage = FrameStage::kHeader; + dec->frame_start += dec->frame_size; + if (return_full_image) { + return JXL_DEC_FULL_IMAGE; + } + } + + dec->stage = DecoderStage::kFinished; + // Return success, this means there is nothing more to do. + return JXL_DEC_SUCCESS; +} + +} // namespace +} // namespace jxl + +JxlDecoderStatus JxlDecoderSetInput(JxlDecoder* dec, const uint8_t* data, + size_t size) { + if (dec->next_in) return JXL_DEC_ERROR; + + dec->next_in = data; + dec->avail_in = size; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseInput(JxlDecoder* dec) { + size_t result = dec->avail_in; + dec->next_in = nullptr; + dec->avail_in = 0; + return result; +} + +JxlDecoderStatus JxlDecoderSetJPEGBuffer(JxlDecoder* dec, uint8_t* data, + size_t size) { + if (dec->next_jpeg_reconstruction_out) return JXL_DEC_ERROR; + + dec->next_jpeg_reconstruction_out = data; + dec->avail_jpeg_reconstruction_size = size; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseJPEGBuffer(JxlDecoder* dec) { + size_t result = dec->avail_jpeg_reconstruction_size; + dec->next_jpeg_reconstruction_out = nullptr; + dec->avail_jpeg_reconstruction_size = 0; + return result; +} + +// Consumes data from next_in/avail_in to reconstruct JPEG data. +// Uses dec->jpeg_reconstruction_size, dec->inside_jpeg_reconstruction_box, and +// dec->jpeg_reconstruction_box_until_eof to calculate how much to consume. +// Potentially stores unparsed data in dec->jpeg_reconstruction_buffer. +// Potentially populates dec->jpeg_reconstruction_data. +// Potentially updates dec->inside_reconstruction_box. +JxlDecoderStatus JxlDecoderProcessJPEGReconstruction(JxlDecoder* dec, + const uint8_t** next_in, + size_t* avail_in) { + if (!dec->inside_jpeg_reconstruction_box) { + JXL_ABORT( + "processing of JPEG reconstruction data outside JPEG reconstruction " + "box"); + } + jxl::Span to_decode; + if (dec->jpeg_reconstruction_box_until_eof) { + // Until EOF means consume all data. + to_decode = jxl::Span(*next_in, *avail_in); + *next_in += *avail_in; + *avail_in = 0; + } else { + // Defined size means consume min(available, needed). + size_t avail_recon_in = + std::min(*avail_in, dec->jpeg_reconstruction_size - + dec->jpeg_reconstruction_buffer.size()); + to_decode = jxl::Span(*next_in, avail_recon_in); + *next_in += avail_recon_in; + *avail_in -= avail_recon_in; + } + bool old_data_exists = !dec->jpeg_reconstruction_buffer.empty(); + if (old_data_exists) { + // Append incoming data to buffer if we already had data in the buffer. + dec->jpeg_reconstruction_buffer.insert( + dec->jpeg_reconstruction_buffer.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + to_decode = + jxl::Span(dec->jpeg_reconstruction_buffer.data(), + dec->jpeg_reconstruction_buffer.size()); + } + if (!dec->jpeg_reconstruction_box_until_eof && + to_decode.size() > dec->jpeg_reconstruction_size) { + JXL_ABORT("JPEG reconstruction data to decode larger than expected"); + } + if (dec->jpeg_reconstruction_box_until_eof || + to_decode.size() == dec->jpeg_reconstruction_size) { + // If undefined size, or the right size, try to decode. + dec->jpeg_reconstruction_data = jxl::make_unique(); + const auto status = jxl::jpeg::DecodeJPEGData( + to_decode, dec->jpeg_reconstruction_data.get()); + if (status.IsFatalError()) return JXL_DEC_ERROR; + if (status) { + // Successful decoding, emit event after updating state to track that we + // are no longer parsing JPEG reconstruction data. + dec->inside_jpeg_reconstruction_box = false; + return JXL_DEC_JPEG_RECONSTRUCTION; + } + if (dec->jpeg_reconstruction_box_until_eof) { + // Unsuccessful decoding and undefined size, assume incomplete data. Copy + // the data if we haven't already. + if (!old_data_exists) { + dec->jpeg_reconstruction_buffer.insert( + dec->jpeg_reconstruction_buffer.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } else { + // Unsuccessful decoding of correct amount of data, assume error. + return JXL_DEC_ERROR; + } + } else { + // Not enough data, copy the data if we haven't already. + if (!old_data_exists) { + dec->jpeg_reconstruction_buffer.insert( + dec->jpeg_reconstruction_buffer.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } + return JXL_DEC_NEED_MORE_INPUT; +} + +JxlDecoderStatus JxlDecoderProcessInput(JxlDecoder* dec) { + const uint8_t** next_in = &dec->next_in; + size_t* avail_in = &dec->avail_in; + if (dec->stage == DecoderStage::kInited) { + dec->stage = DecoderStage::kStarted; + } + if (dec->stage == DecoderStage::kError) { + return JXL_API_ERROR( + "Cannot keep using decoder after it encountered an error, use " + "JxlDecoderReset to reset it"); + } + if (dec->stage == DecoderStage::kFinished) { + return JXL_API_ERROR( + "Cannot keep using decoder after it finished, use JxlDecoderReset to " + "reset it"); + } + + if (!dec->got_signature) { + JxlSignature sig = JxlSignatureCheck(*next_in, *avail_in); + if (sig == JXL_SIG_INVALID) return JXL_API_ERROR("invalid signature"); + if (sig == JXL_SIG_NOT_ENOUGH_BYTES) return JXL_DEC_NEED_MORE_INPUT; + + dec->got_signature = true; + + if (sig == JXL_SIG_CONTAINER) { + dec->have_container = 1; + } + } + + // Available codestream bytes, may differ from *avail_in if there is another + // box behind the current position, in the dec->have_container case. + size_t csize = *avail_in; + + if (dec->have_container) { + /* + Process bytes as follows: + *) find the box(es) containing the codestream + *) support codestream split over multiple partial boxes + *) avoid copying bytes to the codestream vector if the decoding will be + one-shot, when the user already provided everything contiguously in + memory + *) copy to codestream vector, and update next_in so user can delete the data + on their side, once we know it's not oneshot. This relieves the user from + continuing to store the data. + *) also copy to codestream if one-shot but the codestream is split across + multiple boxes: this copying can be avoided in the future if the C++ + decoder is updated for streaming, but for now it requires all consecutive + data at once. + */ + + if (dec->first_codestream_seen && !dec->last_codestream_seen && + dec->codestream_end != 0 && dec->file_pos < dec->codestream_end && + dec->file_pos + *avail_in >= dec->codestream_end && + !dec->codestream.empty()) { + // dec->file_pos in a codestream, not in surrounding box format bytes, but + // the end of the current codestream part is in the current input, and + // boxes that can contain a next part of the codestream could be present. + // Therefore, store the known codestream part, and ensure processing of + // boxes below will trigger. This is only done if + // !dec->codestream.empty(), that is, we're already streaming. + + // Size of the codestream, excluding potential boxes that come after it. + csize = *avail_in; + if (dec->codestream_end && csize > dec->codestream_end - dec->file_pos) { + csize = dec->codestream_end - dec->file_pos; + } + dec->codestream.insert(dec->codestream.end(), *next_in, *next_in + csize); + dec->file_pos += csize; + *next_in += csize; + *avail_in -= csize; + } + + if (dec->inside_jpeg_reconstruction_box) { + // We are inside a JPEG reconstruction box. + JxlDecoderStatus recon_result = + JxlDecoderProcessJPEGReconstruction(dec, next_in, avail_in); + if (recon_result == JXL_DEC_JPEG_RECONSTRUCTION) { + // If successful JPEG reconstruction, return the success if the user + // cares about it, otherwise continue. + if (dec->events_wanted & recon_result) { + dec->events_wanted &= ~recon_result; + return recon_result; + } + } else { + // If anything else, return the result. + return recon_result; + } + } + + if (!dec->last_codestream_seen && + (dec->codestream_begin == 0 || + (dec->codestream_end != 0 && dec->file_pos >= dec->codestream_end))) { + size_t pos = 0; + // after this for loop, either we should be in a part of the data that is + // codestream (not boxes), or have returned that we need more input. + for (;;) { + const uint8_t* in = *next_in; + size_t size = *avail_in; + if (size == pos) { + // If the remaining size is 0, we are exactly after a full box. We + // can't know for sure if this is the last box or not since more bytes + // can follow, but do not return NEED_MORE_INPUT, instead break and + // let the codestream-handling code determine if we need more. + break; + } + if (OutOfBounds(pos, 8, size)) { + dec->basic_info_size_hint = + InitialBasicInfoSizeHint() + pos + 8 - dec->file_pos; + return JXL_DEC_NEED_MORE_INPUT; + } + size_t box_start = pos; + uint64_t box_size = LoadBE32(in + pos); + char type[5] = {0}; + memcpy(type, in + pos + 4, 4); + pos += 8; + if (box_size == 1) { + if (OutOfBounds(pos, 8, size)) return JXL_DEC_NEED_MORE_INPUT; + box_size = LoadBE64(in + pos); + pos += 8; + } + size_t header_size = pos - box_start; + if (box_size > 0 && box_size < header_size) { + return JXL_API_ERROR("invalid box size"); + } + if (SumOverflows(dec->file_pos, pos, box_size)) { + return JXL_API_ERROR("Box size overflow"); + } + size_t contents_size = + (box_size == 0) ? 0 : (box_size - pos + box_start); + if (strcmp(type, "jxlc") == 0 || strcmp(type, "jxlp") == 0) { + size_t codestream_size = contents_size; + // Whether this is the last codestream box, either when it is a jxlc + // box, or when it is a jxlp box that has the final bit set. + // The codestream is either contained within a single jxlc box, or + // within one or more jxlp boxes. The final jxlp box is marked as last + // by setting the high bit of its 4-byte box-index value. + bool last_codestream = false; + if (strcmp(type, "jxlp") == 0) { + if (OutOfBounds(pos, 4, size)) return JXL_DEC_NEED_MORE_INPUT; + if (box_size != 0 && contents_size < 4) { + return JXL_API_ERROR("jxlp box too small to contain index"); + } + codestream_size -= 4; + size_t jxlp_index = LoadBE32(in + pos); + pos += 4; + // The high bit of jxlp_index indicates whether this is the last + // jxlp box. + if (jxlp_index & 0x80000000) last_codestream = true; + } else if (strcmp(type, "jxlc") == 0) { + last_codestream = true; + } + if (!last_codestream && box_size == 0) { + return JXL_API_ERROR( + "final box has unbounded size, but is a non-final codestream " + "box"); + } + dec->first_codestream_seen = true; + if (last_codestream) dec->last_codestream_seen = true; + if (dec->codestream_begin != 0 && dec->codestream.empty()) { + // We've already seen a codestream part, so it's a stream spanning + // multiple boxes. + // We have no choice but to copy contents to the codestream + // vector to make it a contiguous stream for the C++ decoder. + // This appends the previous codestream box that we had seen to + // dec->codestream. + if (dec->codestream_begin < dec->file_pos) { + return JXL_API_ERROR("earlier codestream box out of range"); + } + size_t begin = dec->codestream_begin - dec->file_pos; + size_t end = dec->codestream_end - dec->file_pos; + JXL_ASSERT(end <= *avail_in); + dec->codestream.insert(dec->codestream.end(), *next_in + begin, + *next_in + end); + } + dec->codestream_begin = dec->file_pos + pos; + dec->codestream_end = + (box_size == 0) ? 0 : (dec->codestream_begin + codestream_size); + size_t avail_codestream_size = + (box_size == 0) + ? (size - pos) + : std::min(size - pos, box_size - pos + box_start); + // If already appending codestream, append what we have here too + if (!dec->codestream.empty()) { + size_t begin = pos; + size_t end = + std::min(*avail_in, begin + avail_codestream_size); + dec->codestream.insert(dec->codestream.end(), *next_in + begin, + *next_in + end); + pos += (end - begin); + dec->file_pos += pos; + *next_in += pos; + *avail_in -= pos; + pos = 0; + // TODO(lode): check if this should break always instead, and + // process what we have of the codestream so far, to support + // progressive decoding, and get events such as basic info faster. + // The user could have given 1.5 boxes here, and the first one could + // contain useful parts of codestream that can already be processed. + // Similar to several other exact avail_size checks. This may not + // need to be changed here, but instead at the point in this for + // loop where it returns "NEED_MORE_INPUT", it could instead break + // and allow decoding what we have of the codestream so far. + if (*avail_in == 0) break; + } else { + // skip only the header, so next_in points to the start of this new + // codestream part, for the one-shot case where user data is not + // (yet) copied to dec->codestream. + dec->file_pos += pos; + *next_in += pos; + *avail_in -= pos; + pos = 0; + // Update pos to be after the box contents with codestream + if (avail_codestream_size == *avail_in) { + break; // the rest is codestream, this loop is done + } + pos += avail_codestream_size; + } + } else if (strcmp(type, "jbrd") == 0) { + // This is a JPEG reconstruction metadata box. + // A new box implies that we clear the buffer. + dec->jpeg_reconstruction_buffer.clear(); + dec->inside_jpeg_reconstruction_box = true; + if (box_size == 0) { + dec->jpeg_reconstruction_box_until_eof = true; + } else { + dec->jpeg_reconstruction_size = contents_size; + } + dec->file_pos += pos; + *next_in += pos; + *avail_in -= pos; + JxlDecoderStatus recon_result = + JxlDecoderProcessJPEGReconstruction(dec, next_in, avail_in); + pos = 0; + if (recon_result == JXL_DEC_JPEG_RECONSTRUCTION) { + // If successful JPEG reconstruction, return the success if the user + // cares about it, otherwise continue. + if (dec->events_wanted & recon_result) { + dec->events_wanted &= ~recon_result; + return recon_result; + } + } else { + // If anything else, return the result. + return recon_result; + } + } else { + if (box_size == 0) { + // Final box with unknown size, but it's not a codestream box, so + // nothing more to do. + if (!dec->first_codestream_seen) { + return JXL_API_ERROR("didn't find any codestream box"); + } + break; + } + if (OutOfBounds(pos, contents_size, size)) { + // Indicate how many more bytes needed starting from *next_in. + dec->basic_info_size_hint = InitialBasicInfoSizeHint() + pos + + contents_size - dec->file_pos; + return JXL_DEC_NEED_MORE_INPUT; + } + pos += contents_size; + if (!(dec->codestream.empty() && dec->first_codestream_seen)) { + // Last box no longer needed since we have copied the codestream + // buffer, remove from input so user can release memory. + dec->file_pos += pos; + *next_in += pos; + *avail_in -= pos; + pos = 0; + } + } + } + } + + // Size of the codestream, excluding potential boxes that come after it. + csize = *avail_in; + if (dec->codestream_end && csize > dec->codestream_end - dec->file_pos) { + csize = dec->codestream_end - dec->file_pos; + } + } + + // Whether we are taking the input directly from the user (oneshot case, + // without copying bytes), or appending parts of input to dec->codestream + // (streaming) + bool detected_streaming = !dec->codestream.empty(); + JxlDecoderStatus result; + JXL_DASSERT(csize <= *avail_in); + + if (detected_streaming) { + dec->codestream.insert(dec->codestream.end(), *next_in, *next_in + csize); + dec->file_pos += csize; + *next_in += csize; + *avail_in -= csize; + result = jxl::JxlDecoderProcessInternal(dec, dec->codestream.data(), + dec->codestream.size()); + } else { + // No data copied to codestream buffer yet, the user input may contain the + // full codestream. + result = jxl::JxlDecoderProcessInternal(dec, *next_in, csize); + // Copy the user's input bytes to the codestream once we are able to and + // it is needed. Before we got the basic info, we're still parsing the box + // format instead. If the result is not JXL_DEC_NEED_MORE_INPUT, then + // there is no reason yet to copy since the user may have a full buffer + // allowing one-shot. Once JXL_DEC_NEED_MORE_INPUT occured at least once, + // start copying over the codestream bytes and allow user to free them + // instead. Next call, detected_streaming will be true. + if (dec->got_basic_info && result == JXL_DEC_NEED_MORE_INPUT) { + dec->codestream.insert(dec->codestream.end(), *next_in, *next_in + csize); + dec->file_pos += csize; + *next_in += csize; + *avail_in -= csize; + } + } + + return result; +} + +JxlDecoderStatus JxlDecoderGetBasicInfo(const JxlDecoder* dec, + JxlBasicInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + if (info) { + const jxl::ImageMetadata& meta = dec->metadata.m; + + info->have_container = dec->have_container; + info->xsize = dec->metadata.size.xsize(); + info->ysize = dec->metadata.size.ysize(); + info->uses_original_profile = !meta.xyb_encoded; + + info->bits_per_sample = meta.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = meta.bit_depth.exponent_bits_per_sample; + + info->have_preview = meta.have_preview; + info->have_animation = meta.have_animation; + // TODO(janwas): intrinsic_size + info->orientation = static_cast(meta.orientation); + + if (!dec->keep_orientation) { + if (info->orientation >= JXL_ORIENT_TRANSPOSE) { + std::swap(info->xsize, info->ysize); + } + info->orientation = JXL_ORIENT_IDENTITY; + } + + info->intensity_target = meta.IntensityTarget(); + info->min_nits = meta.tone_mapping.min_nits; + info->relative_to_max_display = meta.tone_mapping.relative_to_max_display; + info->linear_below = meta.tone_mapping.linear_below; + + const jxl::ExtraChannelInfo* alpha = meta.Find(jxl::ExtraChannel::kAlpha); + if (alpha != nullptr) { + info->alpha_bits = alpha->bit_depth.bits_per_sample; + info->alpha_exponent_bits = alpha->bit_depth.exponent_bits_per_sample; + info->alpha_premultiplied = alpha->alpha_associated; + } else { + info->alpha_bits = 0; + info->alpha_exponent_bits = 0; + info->alpha_premultiplied = 0; + } + + info->num_color_channels = + meta.color_encoding.GetColorSpace() == jxl::ColorSpace::kGray ? 1 : 3; + + info->num_extra_channels = meta.num_extra_channels; + + if (info->have_preview) { + info->preview.xsize = dec->metadata.m.preview_size.xsize(); + info->preview.ysize = dec->metadata.m.preview_size.ysize(); + } + + if (info->have_animation) { + info->animation.tps_numerator = dec->metadata.m.animation.tps_numerator; + info->animation.tps_denominator = + dec->metadata.m.animation.tps_denominator; + info->animation.num_loops = dec->metadata.m.animation.num_loops; + info->animation.have_timecodes = dec->metadata.m.animation.have_timecodes; + } + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelInfo(const JxlDecoder* dec, + size_t index, + JxlExtraChannelInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + info->type = static_cast(channel.type); + info->bits_per_sample = channel.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = + channel.bit_depth.floating_point_sample + ? channel.bit_depth.exponent_bits_per_sample + : 0; + info->dim_shift = channel.dim_shift; + info->name_length = channel.name.size(); + info->alpha_associated = channel.alpha_associated; + info->spot_color[0] = channel.spot_color[0]; + info->spot_color[1] = channel.spot_color[1]; + info->spot_color[2] = channel.spot_color[2]; + info->spot_color[3] = channel.spot_color[3]; + info->cfa_channel = channel.cfa_channel; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelName(const JxlDecoder* dec, + size_t index, char* name, + size_t size) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + // Also need null-termination character + if (channel.name.size() + 1 > size) return JXL_DEC_ERROR; + + memcpy(name, channel.name.c_str(), channel.name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +namespace { + +// Gets the jxl::ColorEncoding for the desired target, and checks errors. +// Returns the object regardless of whether the actual color space is in ICC, +// but ensures that if the color encoding is not the encoding from the +// codestream header metadata, it cannot require ICC profile. +JxlDecoderStatus GetColorEncodingForTarget( + const JxlDecoder* dec, const JxlPixelFormat* format, + JxlColorProfileTarget target, const jxl::ColorEncoding** encoding) { + if (!dec->got_all_headers) return JXL_DEC_NEED_MORE_INPUT; + + *encoding = nullptr; + if (target == JXL_COLOR_PROFILE_TARGET_DATA && dec->metadata.m.xyb_encoded) { + *encoding = &dec->passes_state->output_encoding_info.color_encoding; + } else { + *encoding = &dec->metadata.m.color_encoding; + } + return JXL_DEC_SUCCESS; +} +} // namespace + +JxlDecoderStatus JxlDecoderGetColorAsEncodedProfile( + const JxlDecoder* dec, const JxlPixelFormat* format, + JxlColorProfileTarget target, JxlColorEncoding* color_encoding) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, format, target, &jxl_color_encoding); + if (status) return status; + + if (jxl_color_encoding->WantICC()) + return JXL_DEC_ERROR; // Indicate no encoded profile available. + + if (color_encoding) { + ConvertInternalToExternalColorEncoding(*jxl_color_encoding, color_encoding); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetICCProfileSize(const JxlDecoder* dec, + const JxlPixelFormat* format, + JxlColorProfileTarget target, + size_t* size) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, format, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + if (jxl_color_encoding->WantICC()) { + jxl::ColorSpace color_space = + dec->metadata.m.color_encoding.GetColorSpace(); + if (color_space == jxl::ColorSpace::kUnknown || + color_space == jxl::ColorSpace::kXYB) { + // This indicates there's no ICC profile available + // TODO(lode): for the XYB case, do we want to craft an ICC profile that + // represents XYB as an RGB profile? It may be possible, but not with + // only 1D transfer functions. + return JXL_DEC_ERROR; + } + } + + if (size) { + *size = jxl_color_encoding->ICC().size(); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetColorAsICCProfile(const JxlDecoder* dec, + const JxlPixelFormat* format, + JxlColorProfileTarget target, + uint8_t* icc_profile, + size_t size) { + size_t wanted_size; + // This also checks the NEED_MORE_INPUT and the unknown/xyb cases + JxlDecoderStatus status = + JxlDecoderGetICCProfileSize(dec, format, target, &wanted_size); + if (status != JXL_DEC_SUCCESS) return status; + if (size < wanted_size) return JXL_API_ERROR("ICC profile output too small"); + + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + status = GetColorEncodingForTarget(dec, format, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + memcpy(icc_profile, jxl_color_encoding->ICC().data(), + jxl_color_encoding->ICC().size()); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetInverseOpsinMatrix( + const JxlDecoder* dec, JxlInverseOpsinMatrix* matrix) { + memcpy(matrix->opsin_inv_matrix, + dec->metadata.transform_data.opsin_inverse_matrix.inverse_matrix, + sizeof(matrix->opsin_inv_matrix)); + memcpy(matrix->opsin_biases, + dec->metadata.transform_data.opsin_inverse_matrix.opsin_biases, + sizeof(matrix->opsin_biases)); + memcpy(matrix->quant_biases, + dec->metadata.transform_data.opsin_inverse_matrix.quant_biases, + sizeof(matrix->quant_biases)); + + return JXL_DEC_SUCCESS; +} + +namespace { +// Returns the amount of bits needed for getting memory buffer size, and does +// all error checking required for size checking and format validity. +JxlDecoderStatus PrepareSizeCheck(const JxlDecoder* dec, + const JxlPixelFormat* format, size_t* bits) { + if (!dec->got_basic_info) { + // Don't know image dimensions yet, cannot check for valid size. + return JXL_DEC_NEED_MORE_INPUT; + } + if (format->num_channels > 4) { + return JXL_API_ERROR("More than 4 channels not supported"); + } + if (format->num_channels < 3 && !dec->metadata.m.color_encoding.IsGray()) { + return JXL_API_ERROR("Grayscale output not possible for color image"); + } + if (format->data_type == JXL_TYPE_BOOLEAN) { + return JXL_API_ERROR("Boolean data type not yet supported"); + } + if (format->data_type == JXL_TYPE_UINT32) { + return JXL_API_ERROR("uint32 data type not yet supported"); + } + + *bits = BitsPerChannel(format->data_type); + + if (*bits == 0) { + return JXL_API_ERROR("Invalid data type"); + } + + return JXL_DEC_SUCCESS; +} +} // namespace + +JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec) { + if (!dec->image_out_buffer) return JXL_DEC_ERROR; + if (!dec->sections || dec->sections->section_info.empty()) { + return JXL_DEC_ERROR; + } + if (!dec->frame_dec || !dec->frame_dec_in_progress) { + return JXL_DEC_ERROR; + } + if (!dec->frame_dec->HasDecodedDC()) { + // FrameDecoder::Fush currently requires DC to have been decoded already + // to work correctly. + return JXL_DEC_ERROR; + } + if (dec->frame_header->encoding != jxl::FrameEncoding::kVarDCT) { + // Flushing does not yet work corretly if the frame uses modular encoding. + return JXL_DEC_ERROR; + } + if (dec->metadata.m.num_extra_channels > 0) { + // Flushing does not yet work corretly if there are extra channels, which + // use modular + return JXL_DEC_ERROR; + } + + if (!dec->frame_dec->Flush()) { + return JXL_DEC_ERROR; + } + + if (dec->frame_dec->HasRGBBuffer()) { + return JXL_DEC_SUCCESS; + } + + JxlDecoderStatus status = + jxl::ConvertImageInternal(dec, *dec->ib, dec->image_out_format, + dec->image_out_buffer, dec->image_out_size); + if (status != JXL_DEC_SUCCESS) return status; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderPreviewOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + const auto& metadata = dec->metadata.m; + size_t xsize = metadata.preview_size.xsize(); + size_t ysize = metadata.preview_size.ysize(); + + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreviewOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size) { + if (!dec->got_basic_info || !dec->metadata.m.have_preview || + !(dec->orig_events_wanted & JXL_DEC_PREVIEW_IMAGE)) { + return JXL_API_ERROR("No preview out buffer needed at this time"); + } + + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderPreviewOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->preview_out_buffer_set = true; + dec->preview_out_buffer = buffer; + dec->preview_out_size = size; + dec->preview_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderDCOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + size_t xsize = jxl::DivCeil(dec->metadata.size.xsize(), jxl::kBlockDim); + size_t ysize = jxl::DivCeil(dec->metadata.size.ysize(), jxl::kBlockDim); + + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetDCOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_DC_IMAGE)) { + return JXL_API_ERROR("No dc out buffer needed at this time"); + } + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = JxlDecoderDCOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->dc_out_buffer_set = true; + dec->dc_out_buffer = buffer; + dec->dc_out_size = size; + dec->dc_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderImageOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + size_t row_size = + jxl::DivCeil(dec->metadata.size.xsize() * format->num_channels * bits, + jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * dec->metadata.size.ysize(); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetImageOutBuffer(JxlDecoder* dec, + const JxlPixelFormat* format, + void* buffer, size_t size) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_FULL_IMAGE)) { + return JXL_API_ERROR("No image out buffer needed at this time"); + } + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderImageOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->image_out_buffer_set = true; + dec->image_out_buffer = buffer; + dec->image_out_size = size; + dec->image_out_format = *format; + + if (format->data_type == JXL_TYPE_UINT8 && format->num_channels >= 3 && + dec->frame_dec_in_progress) { + bool is_rgba = format->num_channels == 4; + dec->frame_dec->MaybeSetRGB8OutputBuffer(reinterpret_cast(buffer), + jxl::GetStride(dec, *format), + is_rgba); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameHeader(const JxlDecoder* dec, + JxlFrameHeader* header) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + const auto& metadata = dec->metadata.m; + if (metadata.have_animation) { + header->duration = dec->frame_header->animation_frame.duration; + if (metadata.animation.have_timecodes) { + header->timecode = dec->frame_header->animation_frame.timecode; + } + } + header->name_length = dec->frame_header->name.size(); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameName(const JxlDecoder* dec, char* name, + size_t size) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + if (size < dec->frame_header->name.size() + 1) { + return JXL_API_ERROR("too small frame name output buffer"); + } + memcpy(name, dec->frame_header->name.c_str(), + dec->frame_header->name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +#if JXL_IS_DEBUG_BUILD || defined(JXL_ENABLE_FUZZERS) +void SetDecoderMemoryLimitBase_(size_t memory_limit_base) { + memory_limit_base_ = memory_limit_base; +} +#endif diff --git a/third_party/jpeg-xl/lib/jxl/decode_test.cc b/third_party/jpeg-xl/lib/jxl/decode_test.cc new file mode 100644 index 000000000000..e8ad58d4205b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode_test.cc @@ -0,0 +1,2546 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/decode.h" + +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "jxl/decode_cxx.h" +#include "jxl/thread_parallel_runner.h" +#include "lib/extras/codec.h" +#include "lib/extras/codec_jpg.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_gamma_correct.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" +#include "tools/box/box.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +void AppendU32BE(uint32_t u32, jxl::PaddedBytes* bytes) { + bytes->push_back(u32 >> 24); + bytes->push_back(u32 >> 16); + bytes->push_back(u32 >> 8); + bytes->push_back(u32 >> 0); +} + +// What type of codestream format in the boxes to use for testing +enum CodeStreamBoxFormat { + // Do not use box format at all, only pure codestream + kCSBF_None, + // Have a single codestream box, with its actual size given in the box + kCSBF_Single, + // Have a single codestream box, with box size 0 (final box running to end) + kCSBF_Single_Zero_Terminated, + // Single codestream box, with another unknown box behind it + kCSBF_Single_other, + // Have multiple partial codestream boxes + kCSBF_Multi, + // Have multiple partial codestream boxes, with final box size 0 (running + // to end) + kCSBF_Multi_Zero_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + kCSBF_Multi_Other_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + // that has its size set to 0 (running to end) + kCSBF_Multi_Other_Zero_Terminated, + // Have multiple partial codestream boxes, and the first one has a content + // of zero length + kCSBF_Multi_First_Empty, + // Not a value but used for counting amount of enum entries + kCSBF_NUM_ENTRIES, +}; + +// Returns an ICC profile output by the JPEG XL decoder for RGB_D65_SRG_Rel_Lin, +// but with, on purpose, rXYZ, bXYZ and gXYZ (the RGB primaries) switched to a +// different order to ensure the profile does not match any known profile, so +// the encoder cannot encode it in a compact struct instead. +jxl::PaddedBytes GetIccTestProfile() { + const uint8_t* profile = reinterpret_cast( + "\0\0\3\200lcms\0040\0\0mntrRGB XYZ " + "\a\344\0\a\0\27\0\21\0$" + "\0\37acspAPPL\0\0\0\1\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1\0\0\366" + "\326\0\1\0\0\0\0\323-lcms\372c\207\36\227\200{" + "\2\232s\255\327\340\0\n\26\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0\rdesc\0\0\1 " + "\0\0\0Bcprt\0\0\1d\0\0\1\0wtpt\0\0\2d\0\0\0\24chad\0\0\2x\0\0\0," + "bXYZ\0\0\2\244\0\0\0\24gXYZ\0\0\2\270\0\0\0\24rXYZ\0\0\2\314\0\0\0\24rTR" + "C\0\0\2\340\0\0\0 gTRC\0\0\2\340\0\0\0 bTRC\0\0\2\340\0\0\0 " + "chrm\0\0\3\0\0\0\0$dmnd\0\0\3$\0\0\0(" + "dmdd\0\0\3L\0\0\0002mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0&" + "\0\0\0\34\0R\0G\0B\0_\0D\0006\0005\0_\0S\0R\0G\0_\0R\0e\0l\0_" + "\0L\0i\0n\0\0mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\344\0\0\0\34\0C\0o\0" + "p\0y\0r\0i\0g\0h\0t\0 \0002\0000\0001\08\0 \0G\0o\0o\0g\0l\0e\0 " + "\0L\0L\0C\0,\0 \0C\0C\0-\0B\0Y\0-\0S\0A\0 \0003\0.\0000\0 " + "\0U\0n\0p\0o\0r\0t\0e\0d\0 " + "\0l\0i\0c\0e\0n\0s\0e\0(\0h\0t\0t\0p\0s\0:\0/\0/" + "\0c\0r\0e\0a\0t\0i\0v\0e\0c\0o\0m\0m\0o\0n\0s\0.\0o\0r\0g\0/" + "\0l\0i\0c\0e\0n\0s\0e\0s\0/\0b\0y\0-\0s\0a\0/\0003\0.\0000\0/" + "\0l\0e\0g\0a\0l\0c\0o\0d\0e\0)XYZ " + "\0\0\0\0\0\0\366\326\0\1\0\0\0\0\323-" + "sf32\0\0\0\0\0\1\fB\0\0\5\336\377\377\363%" + "\0\0\a\223\0\0\375\220\377\377\373\241\377\377\375\242\0\0\3\334\0\0\300" + "nXYZ \0\0\0\0\0\0o\240\0\08\365\0\0\3\220XYZ " + "\0\0\0\0\0\0$\237\0\0\17\204\0\0\266\304XYZ " + "\0\0\0\0\0\0b\227\0\0\267\207\0\0\30\331para\0\0\0\0\0\3\0\0\0\1\0\0\0\1" + "\0\0\0\0\0\0\0\1\0\0\0\0\0\0chrm\0\0\0\0\0\3\0\0\0\0\243\327\0\0T|" + "\0\0L\315\0\0\231\232\0\0&" + "g\0\0\17\\mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\f\0\0\0\34\0G\0o\0o\0g" + "\0l\0emluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\26\0\0\0\34\0I\0m\0a\0g\0e" + "\0 \0c\0o\0d\0e\0c\0\0"); + size_t profile_size = 896; + jxl::PaddedBytes icc_profile; + icc_profile.assign(profile, profile + profile_size); + return icc_profile; +} + +} // namespace + +namespace jxl { +namespace { + +// Input pixels always given as 16-bit RGBA, 8 bytes per pixel. +// include_alpha determines if the encoded image should contain the alpha +// channel. +// add_icc_profile: if false, encodes the image as sRGB using the JXL fields, +// for grayscale or RGB images. If true, encodes the image using the ICC profile +// returned by GetIccTestProfile, without the JXL fields, this requires the +// image is RGB, not grayscale. +// Providing jpeg_codestream will populate the jpeg_codestream with compressed +// JPEG bytes, and make it possible to reconstruct those exact JPEG bytes using +// the return value _if_ add_container indicates a box format. +PaddedBytes CreateTestJXLCodestream(Span pixels, size_t xsize, + size_t ysize, size_t num_channels, + const CompressParams& cparams, + CodeStreamBoxFormat add_container, + bool add_preview, + bool add_icc_profile = false, + PaddedBytes* jpeg_codestream = nullptr) { + // Compress the pixels with JPEG XL. + bool grayscale = (num_channels <= 2); + bool include_alpha = !(num_channels & 1) && jpeg_codestream == nullptr; + size_t bitdepth = jpeg_codestream == nullptr ? 16 : 8; + CodecInOut io; + io.SetSize(xsize, ysize); + ColorEncoding color_encoding = + jxl::ColorEncoding::SRGB(/*is_gray=*/grayscale); + if (add_icc_profile) { + // the hardcoded ICC profile we attach requires RGB. + EXPECT_EQ(false, grayscale); + EXPECT_TRUE(color_encoding.SetICC(GetIccTestProfile())); + } + ThreadPool pool(nullptr, nullptr); + io.metadata.m.SetUintSamples(bitdepth); + if (include_alpha) { + io.metadata.m.SetAlphaBits(bitdepth); + } + // Make the grayscale-ness of the io metadata color_encoding and the packed + // image match. + io.metadata.m.color_encoding = color_encoding; + EXPECT_TRUE(ConvertFromExternal( + pixels, xsize, ysize, color_encoding, /*has_alpha=*/include_alpha, + /*alpha_is_premultiplied=*/false, bitdepth, JXL_BIG_ENDIAN, + /*flipped_y=*/false, &pool, &io.Main())); + jxl::PaddedBytes jpeg_data; + if (jpeg_codestream != nullptr) { +#if JPEGXL_ENABLE_JPEG + jxl::PaddedBytes jpeg_bytes; + EXPECT_TRUE(EncodeImageJPG(&io, jxl::JpegEncoder::kLibJpeg, /*quality=*/70, + jxl::YCbCrChromaSubsampling(), &pool, + &jpeg_bytes, jxl::DecodeTarget::kPixels)); + jpeg_codestream->append(jpeg_bytes.data(), + jpeg_bytes.data() + jpeg_bytes.size()); + EXPECT_TRUE(jxl::jpeg::DecodeImageJPG( + jxl::Span(jpeg_bytes.data(), jpeg_bytes.size()), &io)); + EXPECT_TRUE(EncodeJPEGData(*io.Main().jpeg_data, &jpeg_data)); + io.metadata.m.xyb_encoded = false; +#else // JPEGXL_ENABLE_JPEG + JXL_ABORT( + "unable to create reconstructible JPEG without JPEG support enabled"); +#endif // JPEGXL_ENABLE_JPEG + } + if (add_preview) { + io.preview_frame = io.Main().Copy(); + io.preview_frame.ShrinkTo(xsize / 7, ysize / 7); + io.metadata.m.have_preview = true; + EXPECT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + } + AuxOut aux_out; + PaddedBytes compressed; + PassesEncoderState enc_state; + EXPECT_TRUE( + EncodeFile(cparams, &io, &enc_state, &compressed, &aux_out, &pool)); + if (add_container != kCSBF_None) { + // Header with signature box and ftyp box. + const uint8_t header[] = {0, 0, 0, 0xc, 0x4a, 0x58, 0x4c, 0x20, + 0xd, 0xa, 0x87, 0xa, 0, 0, 0, 0x14, + 0x66, 0x74, 0x79, 0x70, 0x6a, 0x78, 0x6c, 0x20, + 0, 0, 0, 0, 0x6a, 0x78, 0x6c, 0x20}; + // Unknown box, could be a box added by user, decoder must be able to skip + // over it. Type is set to 'unkn', size to 24, contents to 16 0's. + const uint8_t unknown[] = {0, 0, 0, 0x18, 0x75, 0x6e, 0x6b, 0x6e, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + // same as the unknown box, but with size set to 0, this can only be a final + // box + const uint8_t unknown_end[] = {0, 0, 0, 0, 0x75, 0x6e, 0x6b, 0x6e, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + + bool is_multi = add_container == kCSBF_Multi || + add_container == kCSBF_Multi_Zero_Terminated || + add_container == kCSBF_Multi_Other_Terminated || + add_container == kCSBF_Multi_Other_Zero_Terminated || + add_container == kCSBF_Multi_First_Empty; + + if (is_multi) { + size_t third = compressed.size() / 3; + std::vector compressed0(compressed.data(), + compressed.data() + third); + std::vector compressed1(compressed.data() + third, + compressed.data() + 2 * third); + std::vector compressed2(compressed.data() + 2 * third, + compressed.data() + compressed.size()); + + PaddedBytes c; + c.append(header, header + sizeof(header)); + if (jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + c.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + uint32_t jxlp_index = 0; + if (add_container == kCSBF_Multi_First_Empty) { + // Dummy (empty) codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + } + // First codestream part + AppendU32BE(compressed0.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + c.append(compressed0.data(), compressed0.data() + compressed0.size()); + // A few non-codestream boxes in between + c.append(unknown, unknown + sizeof(unknown)); + c.append(unknown, unknown + sizeof(unknown)); + // Dummy (empty) codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + // Second codestream part + AppendU32BE(compressed1.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + c.append(compressed1.data(), compressed1.data() + compressed1.size()); + // Third codestream part + AppendU32BE(add_container == kCSBF_Multi ? (compressed2.size() + 12) : 0, + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++ | 0x80000000, &c); + c.append(compressed2.data(), compressed2.data() + compressed2.size()); + if (add_container == kCSBF_Multi_Other_Terminated) { + c.append(unknown, unknown + sizeof(unknown)); + } + if (add_container == kCSBF_Multi_Other_Zero_Terminated) { + c.append(unknown_end, unknown_end + sizeof(unknown_end)); + } + compressed.swap(c); + } else { + PaddedBytes c; + c.append(header, header + sizeof(header)); + if (jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + c.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + AppendU32BE(add_container == kCSBF_Single_Zero_Terminated + ? 0 + : (compressed.size() + 8), + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('c'); + c.append(compressed.data(), compressed.data() + compressed.size()); + if (add_container == kCSBF_Single_other) { + c.append(unknown, unknown + sizeof(unknown)); + } + compressed.swap(c); + } + } + + return compressed; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector DecodeWithAPI(JxlDecoder* dec, + Span compressed, + const JxlPixelFormat& format) { + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + // After the full image is gotten, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlThreadParallelRunnerDestroy(runner); + + return pixels; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector DecodeWithAPI(Span compressed, + const JxlPixelFormat& format) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + std::vector pixels = DecodeWithAPI(dec, compressed, format); + JxlDecoderDestroy(dec); + return pixels; +} + +} // namespace +} // namespace jxl + +namespace { +bool Near(double expected, double value, double max_dist) { + double dist = expected > value ? expected - value : value - expected; + return dist <= max_dist; +} + +// Loads a Big-Endian float +float LoadBEFloat(const uint8_t* p) { + uint32_t u = LoadBE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Loads a Little-Endian float +float LoadLEFloat(const uint8_t* p) { + uint32_t u = LoadLE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Based on highway scalar implementation, for testing +float LoadFloat16(uint16_t bits16) { + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + float result; + memcpy(&result, &bits32, 4); + return result; +} + +float LoadLEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadLE16(p); + return LoadFloat16(bits16); +} + +float LoadBEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadBE16(p); + return LoadFloat16(bits16); +} + +size_t GetPrecision(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_BOOLEAN: + return 1; + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_UINT32: + return 32; + case JXL_TYPE_FLOAT: + // Floating point mantissa precision + return 24; + case JXL_TYPE_FLOAT16: + return 11; + } + JXL_ASSERT(false); // unknown type +} + +size_t GetDataBits(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_BOOLEAN: + return 1; + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_UINT32: + return 32; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + } + JXL_ASSERT(false); // unknown type +} + +// Procedure to convert pixels to double precision, not efficient, but +// well-controlled for testing. It uses double, to be able to represent all +// precisions needed for the maximum data types the API supports: uint32_t +// integers, and, single precision float. The values are in range 0-1 for SDR. +std::vector ConvertToRGBA32(const uint8_t* pixels, size_t xsize, + size_t ysize, + const JxlPixelFormat& format) { + std::vector result(xsize * ysize * 4); + size_t num_channels = format.num_channels; + bool gray = num_channels == 1 || num_channels == 2; + bool alpha = num_channels == 2 || num_channels == 4; + + size_t stride = + xsize * jxl::DivCeil(GetDataBits(format.data_type) * num_channels, + jxl::kBitsPerByte); + if (format.align > 1) stride = jxl::RoundUpTo(stride, format.align); + + if (format.data_type == JXL_TYPE_BOOLEAN) { + for (size_t y = 0; y < ysize; ++y) { + jxl::BitReader br(jxl::Span(pixels + stride * y, stride)); + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + double r = br.ReadBits(1); + double g = gray ? r : br.ReadBits(1); + double b = gray ? r : br.ReadBits(1); + double a = alpha ? br.ReadBits(1) : 1; + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + JXL_CHECK(br.Close()); + } + } else if (format.data_type == JXL_TYPE_UINT8) { + double mul = 1.0 / 255.0; // Multiplier to bring to 0-1.0 range + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels; + double r = pixels[i]; + double g = gray ? r : pixels[i + 1]; + double b = gray ? r : pixels[i + 2]; + double a = alpha ? pixels[i + num_channels - 1] : 255; + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_UINT16) { + double mul = 1.0 / 65535.0; // Multiplier to bring to 0-1.0 range + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = (pixels[i + 0] << 8) + pixels[i + 1]; + g = gray ? r : (pixels[i + 2] << 8) + pixels[i + 3]; + b = gray ? r : (pixels[i + 4] << 8) + pixels[i + 5]; + a = alpha ? (pixels[i + num_channels * 2 - 2] << 8) + + pixels[i + num_channels * 2 - 1] + : 65535; + } else { + r = (pixels[i + 1] << 8) + pixels[i + 0]; + g = gray ? r : (pixels[i + 3] << 8) + pixels[i + 2]; + b = gray ? r : (pixels[i + 5] << 8) + pixels[i + 4]; + a = alpha ? (pixels[i + num_channels * 2 - 1] << 8) + + pixels[i + num_channels * 2 - 2] + : 65535; + } + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_UINT32) { + double mul = 1.0 / 4294967295.0; // Multiplier to bring to 0-1.0 range + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 4; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = LoadBE32(pixels + i); + g = gray ? r : LoadBE32(pixels + i + 4); + b = gray ? r : LoadBE32(pixels + i + 8); + a = alpha ? LoadBE32(pixels + i + num_channels * 2 - 4) : 4294967295; + + } else { + r = LoadLE32(pixels + i); + g = gray ? r : LoadLE32(pixels + i + 4); + b = gray ? r : LoadLE32(pixels + i + 8); + a = alpha ? LoadLE32(pixels + i + num_channels * 2 - 4) : 4294967295; + } + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT) { + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 4; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat(pixels + i); + g = gray ? r : LoadBEFloat(pixels + i + 4); + b = gray ? r : LoadBEFloat(pixels + i + 8); + a = alpha ? LoadBEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } else { + r = LoadLEFloat(pixels + i); + g = gray ? r : LoadLEFloat(pixels + i + 4); + b = gray ? r : LoadLEFloat(pixels + i + 8); + a = alpha ? LoadLEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT16) { + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (format.endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat16(pixels + i); + g = gray ? r : LoadBEFloat16(pixels + i + 2); + b = gray ? r : LoadBEFloat16(pixels + i + 4); + a = alpha ? LoadBEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } else { + r = LoadLEFloat16(pixels + i); + g = gray ? r : LoadLEFloat16(pixels + i + 2); + b = gray ? r : LoadLEFloat16(pixels + i + 4); + a = alpha ? LoadLEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else { + JXL_ASSERT(false); // Unsupported type + } + return result; +} + +// Returns amount of pixels which differ between the two pictures. Image b is +// the image after roundtrip after roundtrip, image a before roundtrip. There +// are more strict requirements for the alpha channel and grayscale values of +// the output image. +size_t ComparePixels(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format_a, + const JxlPixelFormat& format_b) { + // Convert both images to equal full precision for comparison. + std::vector a_full = ConvertToRGBA32(a, xsize, ysize, format_a); + std::vector b_full = ConvertToRGBA32(b, xsize, ysize, format_b); + bool gray_a = format_a.num_channels < 3; + bool gray_b = format_b.num_channels < 3; + bool alpha_a = !(format_a.num_channels & 1); + bool alpha_b = !(format_b.num_channels & 1); + size_t bits_a = GetPrecision(format_a.data_type); + size_t bits_b = GetPrecision(format_b.data_type); + size_t bits = std::min(bits_a, bits_b); + // How much distance is allowed in case of pixels with lower bit depths, given + // that the double precision float images use range 0-1.0. + // E.g. in case of 1-bit this is 0.5 since 0.499 must map to 0 and 0.501 must + // map to 1. + double precision = 0.5 / ((1ull << bits) - 1ull); + if (format_a.data_type == JXL_TYPE_FLOAT16 || + format_b.data_type == JXL_TYPE_FLOAT16) { + // Lower the precision for float16, because it currently looks like the + // scalar and wasm implementations of hwy have 1 less bit of precision + // than the x86 implementations. + // TODO(lode): Set the required precision back to 11 bits when possible. + precision = 0.5 / ((1ull << (bits - 1)) - 1ull); + } + size_t numdiff = 0; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + size_t i = (y * xsize + x) * 4; + bool ok = true; + if (gray_a || gray_b) { + if (!Near(a_full[i + 0], b_full[i + 0], precision)) ok = false; + // If the input was grayscale and the output not, then the output must + // have all channels equal. + if (gray_a && b_full[i + 0] != b_full[i + 1] && + b_full[i + 2] != b_full[i + 2]) { + ok = false; + } + } else { + if (!Near(a_full[i + 0], b_full[i + 0], precision) || + !Near(a_full[i + 1], b_full[i + 1], precision) || + !Near(a_full[i + 2], b_full[i + 2], precision)) { + ok = false; + } + } + if (alpha_a && alpha_b) { + if (!Near(a_full[i + 3], b_full[i + 3], precision)) ok = false; + } else { + // If the input had no alpha channel, the output should be opaque + // after roundtrip. + if (alpha_b && !Near(1.0, b_full[i + 3], precision)) ok = false; + } + if (!ok) numdiff++; + } + } + return numdiff; +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +TEST(DecodeTest, JxlSignatureCheckTest) { + std::vector>> tests = { + // No JPEGXL header starts with 'a'. + {JXL_SIG_INVALID, {'a'}}, + {JXL_SIG_INVALID, {'a', 'b', 'c', 'd', 'e', 'f'}}, + + // Empty file is not enough bytes. + {JXL_SIG_NOT_ENOUGH_BYTES, {}}, + + // JPEGXL headers. + {JXL_SIG_NOT_ENOUGH_BYTES, {0xff}}, // Part of a signature. + {JXL_SIG_INVALID, {0xff, 0xD8}}, // JPEG-1 + {JXL_SIG_CODESTREAM, {0xff, 0x0a}}, + + // JPEGXL container file. + {JXL_SIG_CONTAINER, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0xA}}, + // Ending with invalid byte. + {JXL_SIG_INVALID, {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0}}, + // Part of signature. + {JXL_SIG_NOT_ENOUGH_BYTES, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87}}, + {JXL_SIG_NOT_ENOUGH_BYTES, {0}}, + }; + for (const auto& test : tests) { + EXPECT_EQ(test.first, + JxlSignatureCheck(test.second.data(), test.second.size())) + << "Where test data is " << ::testing::PrintToString(test.second); + } +} + +TEST(DecodeTest, DefaultAllocTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast(opaque)->frees++; + free(address); + }; + + JxlDecoder* dec = JxlDecoderCreate(&mm); + EXPECT_NE(nullptr, dec); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + JxlDecoderDestroy(dec); + EXPECT_LE(1, counters.frees); +} + +// TODO(lode): add multi-threaded test when multithreaded pixel decoding from +// API is implemented. +TEST(DecodeTest, DefaultParallelRunnerTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, nullptr, nullptr)); + JxlDecoderDestroy(dec); +} + +// Creates the header of a JPEG XL file with various custom parameters for +// testing. +// xsize, ysize: image dimentions to store in the SizeHeader, max 512. +// bits_per_sample, orientation: a selection of header parameters to test with. +// orientation: image orientation to set in the metadata +// alpha_bits: if non-0, alpha extra channel bits to set in the metadata. Also +// gives the alpha channel the name "alpha_test" +// have_container: add box container format around the codestream. +// metadata_default: if true, ImageMetadata is set to default and +// bits_per_sample, orientation and alpha_bits are ignored. +// insert_box: insert an extra box before the codestream box, making the header +// farther away from the front than is ideal. Only used if have_container. +std::vector GetTestHeader(size_t xsize, size_t ysize, + size_t bits_per_sample, size_t orientation, + size_t alpha_bits, bool xyb_encoded, + bool have_container, bool metadata_default, + bool insert_extra_box, + const jxl::PaddedBytes& icc_profile) { + jxl::BitWriter writer; + jxl::BitWriter::Allotment allotment(&writer, 65536); // Large enough + + if (have_container) { + const std::vector signature_box = {0, 0, 0, 0xc, 'J', 'X', + 'L', ' ', 0xd, 0xa, 0x87, 0xa}; + const std::vector filetype_box = { + 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + const std::vector extra_box_header = {0, 0, 0, 0xff, + 't', 'e', 's', 't'}; + // Beginning of codestream box, with an arbitrary size certainly large + // enough to contain the header + const std::vector codestream_box_header = {0, 0, 0, 0xff, + 'j', 'x', 'l', 'c'}; + + for (size_t i = 0; i < signature_box.size(); i++) { + writer.Write(8, signature_box[i]); + } + for (size_t i = 0; i < filetype_box.size(); i++) { + writer.Write(8, filetype_box[i]); + } + if (insert_extra_box) { + for (size_t i = 0; i < extra_box_header.size(); i++) { + writer.Write(8, extra_box_header[i]); + } + for (size_t i = 0; i < 255 - 8; i++) { + writer.Write(8, 0); + } + } + for (size_t i = 0; i < codestream_box_header.size(); i++) { + writer.Write(8, codestream_box_header[i]); + } + } + + // JXL signature + writer.Write(8, 0xff); + writer.Write(8, 0x0a); + + // SizeHeader + jxl::CodecMetadata metadata; + EXPECT_TRUE(metadata.size.Set(xsize, ysize)); + EXPECT_TRUE(WriteSizeHeader(metadata.size, &writer, 0, nullptr)); + + if (!metadata_default) { + metadata.m.SetUintSamples(bits_per_sample); + metadata.m.orientation = orientation; + metadata.m.SetAlphaBits(alpha_bits); + metadata.m.xyb_encoded = xyb_encoded; + if (alpha_bits != 0) { + metadata.m.extra_channel_info[0].name = "alpha_test"; + } + } + + if (!icc_profile.empty()) { + jxl::PaddedBytes copy = icc_profile; + EXPECT_TRUE(metadata.m.color_encoding.SetICC(std::move(copy))); + } + + EXPECT_TRUE(jxl::Bundle::Write(metadata.m, &writer, 0, nullptr)); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(jxl::Bundle::Write(metadata.transform_data, &writer, 0, nullptr)); + + if (!icc_profile.empty()) { + EXPECT_TRUE(metadata.m.color_encoding.WantICC()); + EXPECT_TRUE(jxl::WriteICC(icc_profile, &writer, 0, nullptr)); + } + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + return std::vector( + writer.GetSpan().data(), + writer.GetSpan().data() + writer.GetSpan().size()); +} + +TEST(DecodeTest, BasicInfoTest) { + size_t xsize[2] = {50, 33}; + size_t ysize[2] = {50, 77}; + size_t bits_per_sample[2] = {8, 23}; + size_t orientation[2] = {3, 5}; + size_t alpha_bits[2] = {0, 8}; + size_t have_container[2] = {0, 1}; + bool xyb_encoded = false; + + std::vector> test_samples; + // Test with direct codestream + test_samples.push_back(GetTestHeader( + xsize[0], ysize[0], bits_per_sample[0], orientation[0], alpha_bits[0], + xyb_encoded, have_container[0], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + // Test with container and different parameters + test_samples.push_back(GetTestHeader( + xsize[1], ysize[1], bits_per_sample[1], orientation[1], alpha_bits[1], + xyb_encoded, have_container[1], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + + for (size_t i = 0; i < test_samples.size(); ++i) { + const std::vector& data = test_samples[i]; + // Test decoding too small header first, until we reach the final byte. + for (size_t size = 0; size <= data.size(); ++size) { + // Test with a new decoder for each tested byte size. + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = data.data(); + size_t avail_in = size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + + JxlBasicInfo info; + bool have_basic_info = !JxlDecoderGetBasicInfo(dec, &info); + + if (size == data.size()) { + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + // All header bytes given so the decoder must have the basic info. + EXPECT_EQ(true, have_basic_info); + EXPECT_EQ(have_container[i], info.have_container); + EXPECT_EQ(alpha_bits[i], info.alpha_bits); + // Orientations 5..8 swap the dimensions + if (orientation[i] >= 5) { + EXPECT_EQ(xsize[i], info.ysize); + EXPECT_EQ(ysize[i], info.xsize); + } else { + EXPECT_EQ(xsize[i], info.xsize); + EXPECT_EQ(ysize[i], info.ysize); + } + // The API should set the orientation to identity by default since it + // already applies the transformation internally by default. + EXPECT_EQ(1, info.orientation); + + EXPECT_EQ(3, info.num_color_channels); + + if (alpha_bits[i] != 0) { + // Expect an extra channel + EXPECT_EQ(1, info.num_extra_channels); + JxlExtraChannelInfo extra; + EXPECT_EQ(0, JxlDecoderGetExtraChannelInfo(dec, 0, &extra)); + EXPECT_EQ(alpha_bits[i], extra.bits_per_sample); + EXPECT_EQ(JXL_CHANNEL_ALPHA, extra.type); + EXPECT_EQ(0, extra.alpha_associated); + // Verify the name "alpha_test" given to the alpha channel + EXPECT_EQ(10, extra.name_length); + char name[11]; + EXPECT_EQ(0, + JxlDecoderGetExtraChannelName(dec, 0, name, sizeof(name))); + EXPECT_EQ(std::string("alpha_test"), std::string(name)); + } else { + EXPECT_EQ(0, info.num_extra_channels); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + } else { + // If we did not give the full header, the basic info should not be + // available. Allow a few bytes of slack due to some bits for default + // opsinmatrix/extension bits. + if (size + 2 < data.size()) { + EXPECT_EQ(false, have_basic_info); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + } + } + + // Test that decoder doesn't allow setting a setting required at beginning + // unless it's reset + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + JxlDecoderReset(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, BufferSizeTest) { + size_t xsize = 33; + size_t ysize = 77; + size_t bits_per_sample = 8; + size_t orientation = 1; + size_t alpha_bits = 8; + bool have_container = false; + bool xyb_encoded = false; + + std::vector header = + GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, have_container, /*metadata_default=*/false, + /*insert_extra_box=*/false, {}); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = header.data(); + size_t avail_in = header.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + size_t image_out_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &image_out_size)); + EXPECT_EQ(xsize * ysize * 4, image_out_size); + + size_t dc_out_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderDCOutBufferSize(dec, &format, &dc_out_size)); + // expected dc size: ceil(33 / 8) * ceil(77 / 8) * 4 channels + EXPECT_EQ(5 * 10 * 4, dc_out_size); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, BasicInfoSizeHintTest) { + // Test on a file where the size hint is too small initially due to inserting + // a box before the codestream (something that is normally not recommended) + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + bool xyb_encoded = false; + std::vector data = GetTestHeader( + xsize, ysize, bits_per_sample, orientation, alpha_bits, xyb_encoded, + /*have_container=*/true, /*metadata_default=*/false, + /*insert_extra_box=*/true, {}); + + JxlDecoderStatus status; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + size_t hint0 = JxlDecoderSizeHintBasicInfo(dec); + // Test that the test works as intended: we construct a file on purpose to + // be larger than the first hint by having that extra box. + EXPECT_LT(hint0, data.size()); + const uint8_t* next_in = data.data(); + // Do as if we have only as many bytes as indicated by the hint available + size_t avail_in = std::min(hint0, data.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + // Basic info cannot be available yet due to the extra inserted box. + EXPECT_EQ(false, !JxlDecoderGetBasicInfo(dec, nullptr)); + + size_t num_read = avail_in - JxlDecoderReleaseInput(dec); + EXPECT_LT(num_read, data.size()); + + size_t hint1 = JxlDecoderSizeHintBasicInfo(dec); + // The hint must be larger than the previous hint (taking already processed + // bytes into account, the hint is a hint for the next avail_in) since the + // decoder now knows there is a box in between. + EXPECT_GT(hint1 + num_read, hint0); + avail_in = std::min(hint1, data.size() - num_read); + next_in += num_read; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + JxlBasicInfo info; + // We should have the basic info now, since we only added one box in-between, + // and the decoder should have known its size, its implementation can return + // a correct hint. + EXPECT_EQ(true, !JxlDecoderGetBasicInfo(dec, &info)); + + // Also test if the basic info is correct. + EXPECT_EQ(1, info.have_container); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_EQ(orientation, info.orientation); + EXPECT_EQ(bits_per_sample, info.bits_per_sample); + + JxlDecoderDestroy(dec); +} + +std::vector GetIccTestHeader(const jxl::PaddedBytes& icc_profile, + bool xyb_encoded) { + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + return GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, + /*have_container=*/false, /*metadata_default=*/false, + /*insert_extra_box=*/false, icc_profile); +} + +// Tests the case where pixels and metadata ICC profile are the same +TEST(DecodeTest, IccProfileTestOriginal) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = false; + std::vector data = GetIccTestHeader(icc_profile, xyb_encoded); + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_TRUE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + size_t dec_profile_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_profile_size)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::PaddedBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // the data is not xyb_encoded, so same result expected for the pixel data + // color profile + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + JxlDecoderDestroy(dec); +} + +// Tests the case where pixels and metadata ICC profile are different +TEST(DecodeTest, IccProfileTestXybEncoded) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = true; + std::vector data = GetIccTestHeader(icc_profile, xyb_encoded); + JxlPixelFormat format = {4, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + JxlPixelFormat format_int = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_FALSE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + size_t dec_profile_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_profile_size)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::PaddedBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // Data is xyb_encoded, so the data profile is a different profile, encoded + // as structured profile. + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + JxlColorEncoding pixel_encoding; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + // The API returns LINEAR because the colorspace cannot be represented by enum + // values. + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // Test the same but with integer format. + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, &format_int, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // The decoder can also output this as a generated ICC profile anyway, and + // we're certain that it will differ from the above defined profile since + // the sRGB data should not have swapped R/G/B primaries. + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + // We don't need to dictate exactly what size the generated ICC profile + // must be (since there are many ways to represent the same color space), + // but it should not be zero. + EXPECT_NE(0, dec_profile_size); + if (0 != dec_profile_size) { + jxl::PaddedBytes icc_profile2(dec_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, &format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile2.data(), icc_profile2.size())); + // expected not equal + EXPECT_NE(icc_profile, icc_profile2); + } + + JxlDecoderDestroy(dec); +} + +// Test decoding ICC from partial files byte for byte. +// This test must pass also if JXL_CRASH_ON_ERROR is enabled, that is, the +// decoding of the ANS histogram and stream of the encoded ICC profile must also +// handle the case of not enough input bytes with StatusCode::kNotEnoughBytes +// rather than fatal error status codes. +TEST(DecodeTest, ICCPartialTest) { + jxl::PaddedBytes icc_profile = GetIccTestProfile(); + std::vector data = GetIccTestHeader(icc_profile, false); + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + bool seen_basic_info = false; + bool seen_color_encoding = false; + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of partial codestream with codestrema headers and ICC profile + // reached, it should not require more input since full image is not + // requested + FAIL(); + break; + } + size_t increment = 1; + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_FALSE(seen_basic_info); + seen_basic_info = true; + } else if (status == JXL_DEC_COLOR_ENCODING) { + EXPECT_TRUE(seen_basic_info); + EXPECT_FALSE(seen_color_encoding); + seen_color_encoding = true; + + // Sanity check that the ICC profile was decoded correctly + size_t dec_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_color_encoding); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_color_encoding); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PixelTest) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + for (int include_alpha = 0; include_alpha <= 1; include_alpha++) { + uint32_t orig_channels = include_alpha ? 4 : 3; + for (size_t box = 0; box < kCSBF_NUM_ENTRIES; ++box) { + CodeStreamBoxFormat add_container = (CodeStreamBoxFormat)box; + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, orig_channels, 0); + JxlPixelFormat format_orig = {orig_channels, JXL_TYPE_UINT16, + JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + // Lossless to verify pixels exactly after roundtrip. + cparams.SetLossless(); + // For variation: some have container and no preview, others have preview + // and no container. + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + orig_channels, cparams, add_container, false); + jxl::PaddedBytes compressed_with_preview = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + orig_channels, cparams, add_container, true); + + const JxlEndianness endiannesses[] = {JXL_NATIVE_ENDIAN, + JXL_LITTLE_ENDIAN, JXL_BIG_ENDIAN}; + for (JxlEndianness endianness : endiannesses) { + for (uint32_t channels = 3; channels <= orig_channels; ++channels) { + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, + jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT16, endianness, 0}; + + // Test with the container for one of the pixel formats. + std::vector pixels2 = jxl::DecodeWithAPI( + dec, + jxl::Span(compressed_with_preview.data(), + compressed_with_preview.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + +#if 0 // Disabled since external_image doesn't currently support uint32_t + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT32, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI(dec, + jxl::Span(compressed.data(), + compressed.size()), format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), + xsize, ysize, format_orig, format)); + } +#endif + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, + jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT16, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, + jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + } + } + } + } + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PixelTestWithICCProfileLossless) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + // Lossless to verify pixels exactly after roundtrip. + cparams.SetLossless(); + // For variation: some have container and no preview, others have preview + // and no container. + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + cparams, kCSBF_None, false, true); + + for (uint32_t channels = 3; channels <= 4; ++channels) { + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}; + + // Test with the container for one of the pixel formats. + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + } + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PixelTestWithICCProfileLossy) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + cparams, kCSBF_None, /*add_preview=*/false, /*add_icc_profile=*/true); + uint32_t channels = 3; + + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + + // The input pixels use the profile matching GetIccTestProfile, since we set + // add_icc_profile for CreateTestJXLCodestream to true. + jxl::ColorEncoding color_encoding0; + EXPECT_TRUE(color_encoding0.SetICC(GetIccTestProfile())); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span0, xsize, ysize, color_encoding0, + /*has_alpha=*/false, false, 16, format_orig.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io0.Main())); + + // The output pixels are expected to be in the same colorspace as the input + // profile, as the profile can be represented by enum values. + jxl::ColorEncoding color_encoding1 = color_encoding0; + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + io1.SetSize(xsize, ysize); + EXPECT_TRUE( + ConvertFromExternal(span1, xsize, ysize, color_encoding1, + /*has_alpha=*/false, false, 32, format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io1.Main())); + + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(io0, io1, ba, /*distmap=*/nullptr, nullptr), + 2.0f); + + JxlDecoderDestroy(dec); +} + +// Tests the case of lossy sRGB image without alpha channel, decoded to RGB8 +// and to RGBA8 +TEST(DecodeTest, PixelTestOpaqueSrgbLossy) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + cparams, kCSBF_None, /*add_preview=*/false, /*add_icc_profile=*/false); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + // The input pixels use the profile matching GetIccTestProfile, since we set + // add_icc_profile for CreateTestJXLCodestream to true. + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span0, xsize, ysize, color_encoding0, + /*has_alpha=*/false, false, 16, format_orig.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io0.Main())); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + if (channels == 4) { + io1.metadata.m.SetAlphaBits(8); + io1.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span1, xsize, ysize, color_encoding1, + /*has_alpha=*/true, false, 8, format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io1.Main())); + io1.metadata.m.SetAlphaBits(0); + io1.Main().ClearExtraChannels(); + } else { + EXPECT_TRUE(ConvertFromExternal( + span1, xsize, ysize, color_encoding1, + /*has_alpha=*/false, false, 8, format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io1.Main())); + } + + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(io0, io1, ba, /*distmap=*/nullptr, nullptr), + 2.0f); + + JxlDecoderDestroy(dec); + } +} + +// Opaque image with noise enabled, decoded to RGB8 and RGBA8. +TEST(DecodeTest, PixelTestOpaqueSrgbLossyNoise) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 512, ysize = 300; + size_t num_pixels = xsize * ysize; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + cparams.noise = jxl::Override::kOn; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + cparams, kCSBF_None, /*add_preview=*/false, /*add_icc_profile=*/false); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + dec, jxl::Span(compressed.data(), compressed.size()), + format); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + // The input pixels use the profile matching GetIccTestProfile, since we set + // add_icc_profile for CreateTestJXLCodestream to true. + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span0, xsize, ysize, color_encoding0, + /*has_alpha=*/false, false, 16, format_orig.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io0.Main())); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + if (channels == 4) { + io1.metadata.m.SetAlphaBits(8); + io1.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal( + span1, xsize, ysize, color_encoding1, + /*has_alpha=*/true, false, 8, format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io1.Main())); + io1.metadata.m.SetAlphaBits(0); + io1.Main().ClearExtraChannels(); + } else { + EXPECT_TRUE(ConvertFromExternal( + span1, xsize, ysize, color_encoding1, + /*has_alpha=*/false, false, 8, format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io1.Main())); + } + + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(io0, io1, ba, /*distmap=*/nullptr, nullptr), + 2.0f); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, GrayscaleTest) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 2, 0); + JxlPixelFormat format_orig = {2, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 2, + cparams, kCSBF_None, true); + + const JxlEndianness endiannesses[] = {JXL_NATIVE_ENDIAN, JXL_LITTLE_ENDIAN, + JXL_BIG_ENDIAN}; + for (JxlEndianness endianness : endiannesses) { + // The compressed image is grayscale, but the output can be tested with + // up to 4 channels (RGBA) + for (uint32_t channels = 1; channels <= 4; ++channels) { + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), + format); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT16, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), + format); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + +#if 0 // Disabled since external_image doesn't currently support uint32_t + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT32, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), + compressed.size()), format); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } +#endif + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), + format); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT16, endianness, 0}; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), + format); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); + } + } + } +} + +void TestPartialStream(bool reconstructible_jpeg) { + size_t xsize = 123, ysize = 77; + uint32_t channels = 4; + if (reconstructible_jpeg) { + channels = 3; + } + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, 0); + JxlPixelFormat format_orig = {channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::CompressParams cparams; + if (reconstructible_jpeg) { + cparams.color_transform = jxl::ColorTransform::kNone; + } else { + cparams + .SetLossless(); // Lossless to verify pixels exactly after roundtrip. + } + + std::vector pixels2; + pixels2.resize(pixels.size()); + + jxl::PaddedBytes jpeg_output(64); + size_t used_jpeg_output = 0; + + std::vector codestreams(kCSBF_NUM_ENTRIES); + std::vector jpeg_codestreams(kCSBF_NUM_ENTRIES); + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat add_container = (CodeStreamBoxFormat)i; + + codestreams[i] = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + channels, cparams, add_container, /*add_preview=*/true, + /*add_icc_profile=*/false, + reconstructible_jpeg ? &jpeg_codestreams[i] : nullptr); + } + + // Test multiple step sizes, to test different combinations of the streaming + // box parsing. + std::vector increments = {1, 3, 17, 23, 120, 700, 1050}; + + for (size_t index = 0; index < increments.size(); index++) { + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + if (reconstructible_jpeg && + (CodeStreamBoxFormat)i == CodeStreamBoxFormat::kCSBF_None) { + continue; + } + const jxl::PaddedBytes& data = codestreams[i]; + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | + JXL_DEC_JPEG_RECONSTRUCTION)); + + bool seen_basic_info = false; + bool seen_full_image = false; + bool seen_jpeg_recon = false; + + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of test data reached, it should have successfully decoded the + // image now. + FAIL(); + break; + } + + size_t increment = increments[index]; + // End of the file reached, should be the final test. + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + // This event should happen exactly once + EXPECT_FALSE(seen_basic_info); + if (seen_basic_info) break; + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_JPEG_RECONSTRUCTION) { + EXPECT_FALSE(seen_basic_info); + EXPECT_FALSE(seen_full_image); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec, jpeg_output.data(), + jpeg_output.size())); + seen_jpeg_recon = true; + } else if (status == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + EXPECT_TRUE(seen_jpeg_recon); + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + jpeg_output.resize(jpeg_output.size() * 2); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer( + dec, jpeg_output.data() + used_jpeg_output, + jpeg_output.size() - used_jpeg_output)); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &format_orig, pixels2.data(), pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + // This event should happen exactly once + EXPECT_FALSE(seen_full_image); + if (seen_full_image) break; + // This event should happen after basic info + EXPECT_TRUE(seen_basic_info); + seen_full_image = true; + if (reconstructible_jpeg) { + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + EXPECT_EQ(used_jpeg_output, jpeg_codestreams[i].size()); + EXPECT_EQ(0, memcmp(jpeg_output.data(), jpeg_codestreams[i].data(), + used_jpeg_output)); + } else { + EXPECT_EQ(pixels, pixels2); + } + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_full_image); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + // Ensure the decoder emitted the basic info and full image events + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_full_image); + + JxlDecoderDestroy(dec); + } + } +} + +// Tests the return status when trying to decode pixels on incomplete file: it +// should return JXL_DEC_NEED_MORE_INPUT, not error. +TEST(DecodeTest, PixelPartialTest) { TestPartialStream(false); } + +#if JPEGXL_ENABLE_JPEG +// Tests the return status when trying to decode JPEG bytes on incomplete file. +TEST(DecodeTest, JPEGPartialTest) { TestPartialStream(true); } +#endif // JPEGXL_ENABLE_JPEG + +TEST(DecodeTest, DCTest) { + using jxl::kBlockDim; + + // TODO(lode): test with a completely black image, with alpha channel + // 65536, since that gave an error during debuging for getting DC + // image (namely: "Failed to decode AC metadata") + + // Ensure a dimension is larger than 256 so that there are multiple groups, + // otherwise getting DC does not work due to how TOC is then laid out. + size_t xsize = 260, ysize = 77; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + + // Set the params to lossy, since getting DC with API is only supported for + // lossy at this time. + jxl::CompressParams cparams; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + cparams, kCSBF_Multi, true); + + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + // Binary search for the DC size, the first byte where trying to get the DC + // returns JXL_DEC_NEED_DC_OUT_BUFFER rather than JXL_DEC_NEED_MORE_INPUT. + // This search is a test on its own, verifying the decoder succeeds after + // this point and needs more input before it, without errors. It also allows + // the main test below to work on a partial file with only DC. + size_t start = 0; + size_t end = compressed.size(); + size_t dc_size; + for (;;) { + dc_size = (start + end) / 2; + JxlDecoderStatus status; + JxlDecoder* dec = JxlDecoderCreate(NULL); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_DC_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), dc_size)); + status = JxlDecoderProcessInput(dec); + EXPECT_TRUE(status == JXL_DEC_BASIC_INFO || + status == JXL_DEC_NEED_MORE_INPUT); + if (status != JXL_DEC_NEED_MORE_INPUT) { + status = JxlDecoderProcessInput(dec); + EXPECT_TRUE(status == JXL_DEC_NEED_DC_OUT_BUFFER || + status == JXL_DEC_NEED_MORE_INPUT); + } + JxlDecoderDestroy(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + start = dc_size; + if (start == end || start + 1 == end) { + dc_size++; + break; + } + } else { + end = dc_size; + if (start == end || start + 1 == end) { + break; + } + } + } + + // Test that the dc_size is within expected limits: it should be larger than + // 0, and smaller than the entire file, taking 90% here, 50% is too + // optimistic. + EXPECT_LE(dc_size, compressed.size() * 9 / 10); + EXPECT_GT(dc_size, 0); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_DC_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), dc_size)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderDCOutBufferSize(dec, &format, &buffer_size)); + + size_t xsize_dc = (xsize + kBlockDim - 1) / kBlockDim; + size_t ysize_dc = (ysize + kBlockDim - 1) / kBlockDim; + EXPECT_EQ(xsize_dc * ysize_dc * 3, buffer_size); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + EXPECT_EQ(JXL_DEC_NEED_DC_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + std::vector dc(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetDCOutBuffer(dec, &format, dc.data(), dc.size())); + + EXPECT_EQ(JXL_DEC_DC_IMAGE, JxlDecoderProcessInput(dec)); + + jxl::Image3F dc0(xsize_dc, ysize_dc); + jxl::Image3F dc1(xsize_dc, ysize_dc); + + // Downscale the original image 8x8 to allow comparing with the DC. + for (size_t y = 0; y < ysize_dc; y++) { + for (size_t x = 0; x < xsize_dc; x++) { + double r = 0, g = 0, b = 0; + size_t num = 0; + for (size_t by = 0; by < kBlockDim; by++) { + size_t y2 = y * kBlockDim + by; + if (y2 >= ysize) break; + for (size_t bx = 0; bx < kBlockDim; bx++) { + size_t x2 = x * kBlockDim + bx; + if (x2 >= xsize) break; + // Use linear RGB for correct downscaling. + r += jxl::Srgb8ToLinearDirect((1.f / 255) * + pixels[(y2 * xsize + x2) * 6 + 0]); + g += jxl::Srgb8ToLinearDirect((1.f / 255) * + pixels[(y2 * xsize + x2) * 6 + 2]); + b += jxl::Srgb8ToLinearDirect((1.f / 255) * + pixels[(y2 * xsize + x2) * 6 + 4]); + num++; + } + } + // Take average per block. + double mul = 1.0 / num; + r *= mul; + g *= mul; + b *= mul; + dc0.PlaneRow(0, y)[x] = r; + dc0.PlaneRow(1, y)[x] = g; + dc0.PlaneRow(2, y)[x] = b; + dc1.PlaneRow(0, y)[x] = (1.f / 255) * (dc[(y * xsize_dc + x) * 3 + 0]); + dc1.PlaneRow(1, y)[x] = (1.f / 255) * (dc[(y * xsize_dc + x) * 3 + 1]); + dc1.PlaneRow(2, y)[x] = (1.f / 255) * (dc[(y * xsize_dc + x) * 3 + 2]); + } + } + + // dc0 is in linear sRGB because we converted it to linear in the downscaling + // above. + jxl::CodecInOut dc0_io; + dc0_io.SetFromImage(std::move(dc0), jxl::ColorEncoding::LinearSRGB(false)); + // dc1 is in non-linear sRGB because the C decoding API outputs non-linear + // sRGB for VarDCT to integer output types + jxl::CodecInOut dc1_io; + dc1_io.SetFromImage(std::move(dc1), jxl::ColorEncoding::SRGB(false)); + + // Check with butteraugli that the DC is close to the 8x8 downscaled original + // image. We don't expect a score of 0, since the downscaling done may not + // 100% match what is stored for the DC, and the lossy codec is used. + // A reasonable butteraugli distance shows that the DC works, the color + // encoding (transfer function) is correct and geometry (shifts, ...) is + // correct. + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(dc0_io, dc1_io, ba, + /*distmap=*/nullptr, nullptr), + 3.0f); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, DCNotGettableTest) { + // 1x1 pixel JXL image + std::string compressed( + "\377\n\0\20\260\23\0H\200(" + "\0\334\0U\17\0\0\250P\31e\334\340\345\\\317\227\37:," + "\246m\\gh\253m\vK\22E\306\261I\252C&pH\22\353 " + "\363\6\22\bp\0\200\237\34\231W2d\255$\1", + 68); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_DC_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput( + dec, reinterpret_cast(compressed.data()), + compressed.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Since the image is only 1x1 pixel, there is only 1 group, the decoder is + // unable to get DC size from this, and will not return the DC at all. Since + // no full image is requested either, it is expected to return success. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PreviewTest) { + size_t xsize = 77, ysize = 120; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + + jxl::CompressParams cparams; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 3, + cparams, kCSBF_Multi, /*add_preview=*/true); + + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_PREVIEW_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + + // GetSomeTestImage is hardcoded to use a top-left cropped preview with + // floor of 1/7th of the size + size_t xsize_preview = (xsize / 7); + size_t ysize_preview = (ysize / 7); + EXPECT_EQ(xsize_preview, info.preview.xsize); + EXPECT_EQ(ysize_preview, info.preview.ysize); + EXPECT_EQ(xsize_preview * ysize_preview * 3, buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + std::vector preview(xsize_preview * ysize_preview * 3); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetPreviewOutBuffer( + dec, &format, preview.data(), preview.size())); + + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + + jxl::Image3F preview0(xsize_preview, ysize_preview); + jxl::Image3F preview1(xsize_preview, ysize_preview); + + // For preview0, the original: top-left crop the preview image the way + // GetSomeTestImage does. + for (size_t y = 0; y < ysize_preview; y++) { + for (size_t x = 0; x < xsize_preview; x++) { + preview0.PlaneRow(0, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 0]); + preview0.PlaneRow(1, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 2]); + preview0.PlaneRow(2, y)[x] = + (1.f / 255) * (pixels[(y * xsize + x) * 6 + 4]); + preview1.PlaneRow(0, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 0]); + preview1.PlaneRow(1, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 1]); + preview1.PlaneRow(2, y)[x] = + (1.f / 255) * (preview[(y * xsize_preview + x) * 3 + 2]); + } + } + + jxl::CodecInOut io0; + io0.SetFromImage(std::move(preview0), jxl::ColorEncoding::SRGB(false)); + jxl::CodecInOut io1; + io1.SetFromImage(std::move(preview1), jxl::ColorEncoding::SRGB(false)); + + jxl::ButteraugliParams ba; + // TODO(lode): this ButteraugliDistance silently returns 0 (dangerous for + // tests) if xsize or ysize is < 8, no matter how different the images, a tiny + // size that could happen for a preview. ButteraugliDiffmap does support + // smaller than 8x8, but jxl's ButteraugliDistance does not. Perhaps move + // butteraugli's <8x8 handling from ButteraugliDiffmap to + // ButteraugliComparator::Diffmap in butteraugli.cc. + EXPECT_LE(ButteraugliDistance(io0, io1, ba, + /*distmap=*/nullptr, nullptr), + 0.9f); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, AlignTest) { + size_t xsize = 123, ysize = 77; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, 4, + cparams, kCSBF_None, false); + + size_t align = 17; + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, align}; + // On purpose not using jxl::RoundUpTo to test it independently. + size_t expected_line_bytes = (1 * 3 * xsize + align - 1) / align * align; + + std::vector pixels2 = jxl::DecodeWithAPI( + jxl::Span(compressed.data(), compressed.size()), format); + EXPECT_EQ(expected_line_bytes * ysize, pixels2.size()); + EXPECT_EQ(0, ComparePixels(pixels.data(), pixels2.data(), xsize, ysize, + format_orig, format)); +} + +TEST(DecodeTest, AnimationTest) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*has_alpha=*/false, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*flipped_y=*/false, /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, &aux_out, + nullptr)); + + // Decode and test the animation frames + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + std::vector pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + EXPECT_EQ(0, frame_header.name_length); + // For now, test with empty name, there's currently no easy way to encode + // a jxl file with a frame name because ImageBundle doesn't have a + // jxl::FrameHeader to set the name in. We can test the null termination + // character though. + char name; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameName(dec, &name, 1)); + EXPECT_EQ(0, name); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0, ComparePixels(frames[i].data(), pixels.data(), xsize, ysize, + format, format)); + } + + // After all frames gotten, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, AnimationTestStreaming) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(frames[i].data(), frames[i].size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), /*has_alpha=*/false, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, + JXL_BIG_ENDIAN, /*flipped_y=*/false, /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + jxl::AuxOut aux_out; + jxl::PaddedBytes compressed; + jxl::PassesEncoderState enc_state; + EXPECT_TRUE(jxl::EncodeFile(cparams, &io, &enc_state, &compressed, &aux_out, + nullptr)); + + // Decode and test the animation frames + + const size_t step_size = 16; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t frame_headers_seen = 0; + size_t frames_seen = 0; + bool seen_basic_info = false; + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + std::vector frames2[2]; + for (size_t i = 0; i < num_frames; ++i) { + frames2[i].resize(frames[i].size()); + } + + size_t total_in = 0; + size_t loop_count = 0; + + for (;;) { + if (loop_count++ > compressed.size()) { + fprintf(stderr, "Too many loops\n"); + FAIL(); + break; + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + auto status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + + if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_in >= compressed.size()) { + fprintf(stderr, "Already gave all input data\n"); + FAIL(); + break; + } + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, frames2[frames_seen].data(), + frames2[frames_seen].size())); + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_EQ(false, seen_basic_info); + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + } else if (status == JXL_DEC_FRAME) { + EXPECT_EQ(true, seen_basic_info); + frame_headers_seen++; + } else if (status == JXL_DEC_FULL_IMAGE) { + frames_seen++; + EXPECT_EQ(frame_headers_seen, frames_seen); + } else { + fprintf(stderr, "Unexpected status: %d\n", (int)status); + FAIL(); + } + } + + EXPECT_EQ(true, seen_basic_info); + EXPECT_EQ(num_frames, frames_seen); + EXPECT_EQ(num_frames, frame_headers_seen); + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(frames[i], frames2[i]); + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, FlushTest) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::CompressParams cparams; + jxl::PaddedBytes data = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + num_channels, cparams, kCSBF_None, true); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. The DC takes up more than 50% of the + // image generated here. + size_t first_part = data.size() * 3 / 4; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + // Note: actual pixel data not tested here, it should look similar to the + // input image, but with less fine detail. Instead the expected events are + // tested here. + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +void VerifyJPEGReconstruction(const jxl::PaddedBytes& container, + const jxl::PaddedBytes& jpeg_bytes) { + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), container.data(), container.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, jpeg_bytes.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), jpeg_bytes.data(), used)); +} + +#if JPEGXL_ENABLE_JPEG +TEST(DecodeTest, JPEGReconstructTestCodestream) { + size_t xsize = 123; + size_t ysize = 77; + size_t channels = 3; + std::vector pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, /*seed=*/0); + jxl::CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kNone; + jxl::PaddedBytes jpeg_codestream; + jxl::PaddedBytes compressed = jxl::CreateTestJXLCodestream( + jxl::Span(pixels.data(), pixels.size()), xsize, ysize, + channels, cparams, kCSBF_Single, /*add_preview=*/true, + /*add_icc_profile=*/false, &jpeg_codestream); + VerifyJPEGReconstruction(compressed, jpeg_codestream); +} +#endif // JPEGXL_ENABLE_JPEG + +TEST(DecodeTest, JPEGReconstructionTest) { + const std::string jpeg_path = + "imagecompression.info/flower_foveon.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + jxl::jpeg::DecodeImageJPG(jxl::Span(orig), &orig_io)); + orig_io.metadata.m.xyb_encoded = false; + jxl::BitWriter writer; + ASSERT_TRUE(WriteHeaders(&orig_io.metadata, &writer, nullptr)); + writer.ZeroPadToByte(); + jxl::PassesEncoderState enc_state; + jxl::CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kNone; + ASSERT_TRUE(jxl::EncodeFrame(cparams, jxl::FrameInfo{}, &orig_io.metadata, + orig_io.Main(), &enc_state, + /*pool=*/nullptr, &writer, + /*aux_out=*/nullptr)); + + jxl::PaddedBytes jpeg_data; + ASSERT_TRUE(EncodeJPEGData(*orig_io.Main().jpeg_data.get(), &jpeg_data)); + jxl::PaddedBytes container; + container.append(jxl::kContainerHeader, + jxl::kContainerHeader + sizeof(jxl::kContainerHeader)); + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &container); + container.append(jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), 0, true, &container); + jxl::PaddedBytes codestream = std::move(writer).TakeBytes(); + container.append(codestream.data(), codestream.data() + codestream.size()); + VerifyJPEGReconstruction(container, orig); +} diff --git a/third_party/jpeg-xl/lib/jxl/descriptive_statistics_test.cc b/third_party/jpeg-xl/lib/jxl/descriptive_statistics_test.cc new file mode 100644 index 000000000000..d0633f7a5d9f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/descriptive_statistics_test.cc @@ -0,0 +1,161 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/descriptive_statistics.h" + +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/noise_distributions.h" + +namespace jxl { +namespace { + +// Assigns x to one of two streams so we can later test Assimilate. +template +void NotifyEither(float x, Random* rng, Stats* JXL_RESTRICT stats1, + Stats* JXL_RESTRICT stats2) { + if ((*rng)() & 128) { + stats1->Notify(x); + } else { + stats2->Notify(x); + } +} + +TEST(StatsTest, TestGaussian) { + Stats stats; + Stats stats1, stats2; + const float mean = 5.0f; + const float stddev = 4.0f; + NoiseGaussian noise(stddev); + std::mt19937 rng(129); + for (size_t i = 0; i < 1000 * 1000; ++i) { + const float x = noise(mean, &rng); + stats.Notify(x); + NotifyEither(x, &rng, &stats1, &stats2); + } + EXPECT_NEAR(mean, stats.Mean(), 0.01); + EXPECT_NEAR(stddev, stats.StandardDeviation(), 0.02); + EXPECT_NEAR(0.0, stats.Skewness(), 0.02); + EXPECT_NEAR(0.0, stats.Kurtosis() - 3, 0.02); + printf("%s\n", stats.ToString().c_str()); + + // Same results after merging both accumulators. + stats1.Assimilate(stats2); + EXPECT_NEAR(mean, stats1.Mean(), 0.01); + EXPECT_NEAR(stddev, stats1.StandardDeviation(), 0.02); + EXPECT_NEAR(0.0, stats1.Skewness(), 0.02); + EXPECT_NEAR(0.0, stats1.Kurtosis() - 3, 0.02); +} + +TEST(StatsTest, TestUniform) { + Stats stats; + Stats stats1, stats2; + NoiseUniform noise(0, 256); + std::mt19937 rng(129), rng_split(65537); + for (size_t i = 0; i < 1000 * 1000; ++i) { + const float x = noise(0.0f, &rng); + stats.Notify(x); + NotifyEither(x, &rng_split, &stats1, &stats2); + } + EXPECT_NEAR(128.0, stats.Mean(), 0.05); + EXPECT_NEAR(0.0, stats.Min(), 0.01); + EXPECT_NEAR(256.0, stats.Max(), 0.01); + EXPECT_NEAR(70, stats.StandardDeviation(), 10); + // No outliers. + EXPECT_NEAR(-1.2, stats.Kurtosis() - 3, 0.1); + printf("%s\n", stats.ToString().c_str()); + + // Same results after merging both accumulators. + stats1.Assimilate(stats2); + EXPECT_NEAR(128.0, stats1.Mean(), 0.05); + EXPECT_NEAR(0.0, stats1.Min(), 0.01); + EXPECT_NEAR(256.0, stats1.Max(), 0.01); + EXPECT_NEAR(70, stats1.StandardDeviation(), 10); +} + +TEST(StatsTest, CompareCentralMomentsAgainstTwoPass) { + // Vary seed so the thresholds are not specific to one distribution. + for (int rep = 0; rep < 200; ++rep) { + // Uniform avoids outliers. + NoiseUniform noise(0, 256); + std::mt19937 rng(129 + 13 * rep), rng_split(65537); + + // Small count so bias (population vs sample) is visible. + const size_t kSamples = 20; + + // First pass: compute mean + std::vector samples; + samples.reserve(kSamples); + double sum = 0.0; + for (size_t i = 0; i < kSamples; ++i) { + const float x = noise(0.0f, &rng); + samples.push_back(x); + sum += x; + } + const double mean = sum / kSamples; + + // Second pass: compute stats and moments + Stats stats; + Stats stats1, stats2; + double sum2 = 0.0; + double sum3 = 0.0; + double sum4 = 0.0; + for (const double x : samples) { + const double d = x - mean; + sum2 += d * d; + sum3 += d * d * d; + sum4 += d * d * d * d; + + stats.Notify(x); + NotifyEither(x, &rng_split, &stats1, &stats2); + } + const double mu1 = mean; + const double mu2 = sum2 / kSamples; + const double mu3 = sum3 / kSamples; + const double mu4 = sum4 / kSamples; + + // Raw central moments (note: Mu1 is zero by definition) + EXPECT_NEAR(mu1, stats.Mu1(), 1E-13); + EXPECT_NEAR(mu2, stats.Mu2(), 1E-11); + EXPECT_NEAR(mu3, stats.Mu3(), 1E-9); + EXPECT_NEAR(mu4, stats.Mu4(), 1E-6); + + // Same results after merging both accumulators. + stats1.Assimilate(stats2); + EXPECT_NEAR(mu1, stats1.Mu1(), 1E-13); + EXPECT_NEAR(mu2, stats1.Mu2(), 1E-11); + EXPECT_NEAR(mu3, stats1.Mu3(), 1E-9); + EXPECT_NEAR(mu4, stats1.Mu4(), 1E-6); + + const double sample_variance = mu2; + // Scaling factor for sampling bias + const double r = (kSamples - 1.0) / kSamples; + const double skewness = mu3 * pow(r / mu2, 1.5); + const double kurtosis = mu4 * pow(r / mu2, 2.0); + + EXPECT_NEAR(sample_variance, stats.SampleVariance(), + sample_variance * 1E-12); + EXPECT_NEAR(skewness, stats.Skewness(), std::abs(skewness * 1E-11)); + EXPECT_NEAR(kurtosis, stats.Kurtosis(), kurtosis * 1E-12); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/docs/color_management.md b/third_party/jpeg-xl/lib/jxl/docs/color_management.md new file mode 100644 index 000000000000..56f4a2856c88 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/docs/color_management.md @@ -0,0 +1,68 @@ +# Color Management + +[TOC] + + + +## Why + +The vast majority of web images are still sRGB. However, wide-gamut material is +increasingly being produced (photography, cinema, 4K). Screens covering most of +the Adobe RGB gamut are readily available and some also cover most of DCI P3 +(iPhone, Pixel2) or even BT.2020. + +Currently, after a camera records a very saturated red pixel, most raw +processors would clip it to the rather small sRGB gamut before saving as JPEG. +In keeping with our high-quality goal, we prevent such loss by allowing wider +input color spaces. + +## Which color space + +Even wide gamuts could be expressed relative to the sRGB primaries, but the +resulting coordinates may be outside the valid 0..1 range. Surprisingly, such +'unbounded' coordinates can be passed through color transforms provided the +transfer functions are expressed as parametric functions (not lookup tables). +However, most image file formats (including PNG and PNM) lack min/max metadata +and thus do not support unbounded coordinates. + +Instead, we need a larger working gamut to ensure most pixel coordinates are +within bounds and thus not clipped. However, larger gamuts result in lower +precision/resolution when using <= 16 bit encodings (as opposed to 32-bit float +in PFM). BT.2100 or P3 DCI appear to be good compromises. + +## CMS library + +Transforms with unbounded pixels are desirable because they reduce round-trip +error in tests. This requires parametric curves, which are only supported for +the common sRGB case in ICC v4 profiles. ArgyllCMS does not support v4. The +other popular open-source CMS is LittleCMS. It is also used by color-managed +editors (Krita/darktable), which increases the chances of interoperability. +However, LCMS has race conditions and overflow issues that prevent fuzzing. We +will later switch to the newer skcms. Note that this library does not intend to +support multiProcessElements, so HDR transfer functions cannot be represented +accurately. Thus in the long term, we will probably migrate away from ICC +profiles entirely. + +## Which viewer + +On Linux, Krita and darktable support loading our PNG output images and their +ICC profile. + +## How to compress/decompress + +### Embedded ICC profile + +- Create an 8-bit or 16-bit PNG with an iCCP chunk, e.g. using darktable. +- Pass it to `cjxl`, then `djxl` with no special arguments. The decoded output + will have the same bit depth (can override with `--output_bit_depth`) and + color space. + +### Images without metadata (e.g. HDR) + +- Create a PGM/PPM/PFM file in a known color space. +- Invoke `cjxl` with `-x color_space=RGB_D65_202_Rel_Lin` (linear 2020). For + details/possible values, see color_encoding.cc `Description`. +- Invoke `djxl` as above with no special arguments. diff --git a/third_party/jpeg-xl/lib/jxl/docs/dc_predictor.md b/third_party/jpeg-xl/lib/jxl/docs/dc_predictor.md new file mode 100644 index 000000000000..478b048f6551 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/docs/dc_predictor.md @@ -0,0 +1,129 @@ +# DC prediction + +[TOC] + + + +## Background + +After the DCT integral transform, we wish to reduce the encoded size of DC +coefficients (or original pixels if the transform is omitted). +pik/dc_predictor.h provides `Shrink` and `Expand` functions to reduce the +magnitude of pixels by subtracting predicted values computed from their +neighbors. In the ideal case of flat or slowly varying regions, the resulting +"residuals" are zero and thus very efficiently encodable. + +## Design goals + +1. Benefit from SIMD. +1. Use prior horizontal neighbor for prediction because it has the highest + correlation. +1. Use previously decoded planes (i.e. luminance) for improving prediction of + chrominance. +1. Lossless (no overflow/rounding errors) for signed integers in [-32768, + 32768). + +## Prior approaches + +CALIC GAP and LOCO-I/JPEG-LS MED are somewhat adaptive: they choose from one of +several simple predictors based on properties of the pixel neighborhood (let N +denote the pixel above the current one, W = left, and NW above-left). History +Based Blending introduced more freedom by adding adaptive weights and several +more predictors including averaging, gradients and neighbors. We use a similar +approach, but with selecting rather than blending predictors to avoid +multiplications, and a more efficient cost estimator. + +## Basic algorithm + +``` +Compute all predictors for the already decoded N and W pixel. +See which single predictor performs best (lowest total absolute residual). +Residual := current pixel - best predictor's prediction for current pixel. +To expand again, add this same prediction to the stored residual. +``` + +## Details + +### Dependencies between SIMD lanes + +The SIMD and prior horizontal neighbor requirements seem to be contradictory. If +all steps are computed in units of SIMD vectors, then the prior neighbor may be +part of the same vector and thus not computed beforehand. Instead, we compute +multiple _predictors_ at a time using SIMD for the same pixel. This is feasible +because several predictors use the same computation (e.g. Average) on +independent inputs in the SIMD vector lanes. + +### Predictors + +We use several kinds of predictors. + +1. **Average**: Noisy areas have outliers; computing the average is a simple + way to reduce their impact without expensive non-linear computations. We use + a (saturated) add and shift rather than the average+round SIMD instruction. +1. **Gradient**: Smooth areas benefit from a gradient; we find it helpful to + clamp the JPEG-LS-style N+W-NW to the min/max of all neighbor pixels. +1. **Edge**: Finally, unchanged N/W neighbor values are also helpful for + regions with edges because they violate the assumptions underlying the + average and gradient predictors (namely a shared central tendency or + smoothly varying neighborhood). + +We chose predictors subject to the constraint that 128-bit SIMD instruction sets +can compute eight of them at a time, with a simple optimization algorithm that +tried various permutations of these basic predictors. Note that order matters: +the first predictor with min cost is chosen. The test data set was small, so it +is possible that the choice of predictors could be improved, but that involves +rewriting some intricate SIMD code. + +### Finding best predictor + +Given a SIMD vector with all prediction results, we compute each of their +residuals by subtracting from the (broadcasted) current pixel. X86 SSE4 provides +special support for finding the minimum element in u16x8 vectors. This lane can +efficiently moved to the first lane via shuffling. On other architectures, we +scan through the 8 lanes manually. + +### Threading + +No separate parallelization is needed because this is part of the JXL decoder, +which ensures all pixel decoding steps are independent (in 512x512 groups) and +computed in parallel. + +### Cross-channel correlations + +The adaptive predictor exploits correlations between neighboring pixels. To also +reduce correlation between channels, we require the Y channel to be decoded +first (green is in the middle of the frequency band and thus has higher +correlation with red/blue). Then, rather than computing the cost (prediction +residual) at N and W neighbors, we compute the cost at the same position in the +previously decoded luminance image. +Note: this assumes there is some remaining correlation (despite chroma from +luma already being carried out), which is plausible because the current CfL +model for DC is global, i.e. unable to remove all correlations. Once that +changes, it may be better to no longer use Y to influence the chroma channels. + +### Bounds checking + +Predictors only use immediately adjacent pixels, and the cost estimator also +applies predictors at immediately adjacent neighbors, so we need a border of two +pixels on each side. For the first pixel, we have no predictor; subsequent +pixels in the same row use their left neighbor. In the second row, we can apply +the adaptive predictor while supplying an imaginary extra top row that is +identical to the first row. + +### Side information + +Only causal (previously decoded in left to right, then top to bottom line scan +order) neighbors are used, so the decoder-side `Expand` can run without any side +information. + +### Related work + +lossless16 by Alex Rhatushnyak implements a more powerful predictor which uses +the error history as feedback for the prediction weights. lossless16 is indeed +better than dc_predictor for lossless photo compression. However, using +lossless16 for encoding DCs has not yet proven to be advantageous - even after +entropy-coding of the channel compact transforms, and encoding the entire DC +image in a single call. diff --git a/third_party/jpeg-xl/lib/jxl/docs/entropy_coding_basic.md b/third_party/jpeg-xl/lib/jxl/docs/entropy_coding_basic.md new file mode 100644 index 000000000000..973e62f59003 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/docs/entropy_coding_basic.md @@ -0,0 +1,111 @@ +# Basic entropy encoders + +[TOC] + + + +This document describes low level encodings used in JXL. + +## Uint32 field + +Field is a variable length encoding for values with predefined distribution. + +Distribution is described with \\(\text{distribution}\\) value. + +If \\(\text{distribution} > \texttt{0xFFFFFFDF} = 2^{32} - 32\\) then value is +always encoded as \\(\text{distribution} - \texttt{0xFFFFFFDF}\\) bits. +This mode is called "raw" encoding. Raw encoding is typically used for +edge cases, e.g. when exactly 1 or 32 bit value needs to be encoded. + +In "regular" (non-raw) mode \\(\text{distribution}\\) is interpreted as array +\\(\text{L}\\) containing 4 uint8 values. + +To decode regular field, first 2 bits are read; +those bits represent "selector" value \\(\text{S}\\). + +Depending on mode \\(\text{M} = \text{L}[\text{S}]\\): + +- if \\(\text{M} \ge \texttt{0x80}\\), + then \\(\text{M} - \texttt{0x80}\\) is the "direct" encoded value +- else if \\(\text{M} \ge \texttt{0x40}\\), + let \\(\text{V}\\) be the \\((\text{M} \& \texttt{0x7}) + 1\\) + following bits, "shifted" encoded value is + \\(\text{V} + ((\text{M} \gg 3) \& \texttt{0x7}) + 1\\) +- otherwise \\(\text{M}\\) following bits represent the encoded value + +Source code: cs/jxl/fields.h + +## Uint64 field + +This field supports bigger values than [Uint32](#uint32-field), but has single +fixed distribution. + +Value is decoded as following: + +- "selector" \\(\text{S}\\) 2 bits +- if \\(\text{S} = 0\\) then value is 0 +- if \\(\text{S} = 1\\) then next 4 bits represent \\(\text{value} - 1\\) +- if \\(\text{S} = 2\\) then next 8 bits represent \\(\text{value} - 17\\) +- if \\(\text{S} = 3\\) then: + - 12 bits represent the lowest 12 bits of value + - while next bit is 1: + - if less than 60 value bits are already read, + then 8 bits represent higher 8 bits of value + - otherwise 4 bits represent the highest 4 bits of value and 'while' + loop is finished + +Source code: cs/jxl/fields.h + +## Byte array field + +Byte array field holds (optionally compressed) array of bytes. Byte array is +encoded as: + +- \\(\text{type}\\) [field](#uint32-field) / + L = [direct 0, direct 1, direct 2, 3 bits + 3] +- if \\(\text{type}\\) is \\(\text{None} = 0\\), then byte array is empty +- if \\(\text{type}\\) is \\(\text{Raw} = 1\\) or \\(\text{Brotli} = 2\\), + then: + - \\(\text{size}\\) [field](#uint32-field) / + L = [8 bits, 16 bits, 24 bits, 32 bits] + - uint8 repeated \\(\text{size}\\) times; + - payload is compressed if \\(\text{type}\\) is \\(\text{Brotli}\\) + +Source code: cs/jxl/fields.h + +## VarSignedMantissa + +VarMantissa is a mapping from \\(\text{L}\\) bits to signed values. +The principle is that \\(\text{L} + 1\\) bits may only represent values whose +absolute values are bigger than absolute values that could be represented with +\\(\text{L}\\) bits. + +- 0-bit values represent \\(\{0\}\\) +- 1-bit values represent \\(\{-1, 1\}\\) +- 2-bit values represent \\(\{-3, -2, 2, 3\}\\) +- L-bit values represent \\(\{\pm 2^{L - 1} .. \pm(2^L - 1)\}\\) + +## VarUnsignedMantissa + +Analogous to [VarSignedMantissa](#varsignedmantissa), but for unsigned values. + +TODO: provide examples + +## VarLenUint8 + +- \\(\text{zero}\\) 1 bit +- if \\(\text{zero}\\) is \\(0\\), then value is \\(0\\), otherwise: + - \\(\text{L}\\) 3 bits + - \\(\text{value} - 1\\) encoded as + \\(\text{L}\\)-bit [VarUnsignedMantissa](#vaunrsignedmantissa) + +## VarLenUint16 + +- \\(\text{zero}\\) 1 bit +- if \\(\text{zero}\\) is \\(0\\), then value is \\(0\\), otherwise: + - \\(\text{L}\\) 4 bits + - \\(\text{value} - 1\\) encoded as + \\(\text{L}\\)-bit [VarUnsignedMantissa](#varunsignedmantissa) diff --git a/third_party/jpeg-xl/lib/jxl/docs/file_format.md b/third_party/jpeg-xl/lib/jxl/docs/file_format.md new file mode 100644 index 000000000000..429ee669d14b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/docs/file_format.md @@ -0,0 +1,129 @@ +# File format overview + +[TOC] + + + +This document describes high level PIK format and metadata. + +## FileHeader + +Topmost structure that contains most general image related metadata and a file +signature used to distinguish PIK files. + +Structure: + +- "signature" 32 bits; must be equal to little-endian representation of + \\(\texttt{0x0A4CD74A}\\) value; \\(\texttt{0x0A}\\) causes files opened in + text mode to be rejected, and \\(\texttt{0xD7}\\) detects 7-bit transfers +- "xsize_minus_1" [field](entropy_coding_basic.md#uint32-field) / + L = [9 bits, 11 bits, 13 bits, 32 bits]; 13-bit values support most existing + cameras (up to 8K x 8K images), 32-bit values cover provide support for up + to 4G x 4G images; "xsize" is derived from this field by adding \\(1\\), + thus zero-width images are not supported +- "ysize_minus_1" [field](entropy_coding_basic.md#uint32-field) / + L = [9 bits, 11 bits, 13 bits, 32 bits]; similar to "xsize_minus_1" +- orientation indicates 8 possible orientations, as defined in EXIF. +- nested ["metadata"](#metadata) structure +- nested ["preview"](#preview) structure +- nested ["animation"](#animation) structure +- ["extensions"](#extensions) stub that allows future format extension + +## Metadata + +[Optional structure](#optional-structures) that contains meta-information about +image and other supportive information. + +Structure: +- "all_default" 1 bit; see ["optional structures"](#optional-structures) +- nested ["transcoded"](#transcoded) structure +- "target_nits_div50" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 2, direct 5, direct 80, 8 bits]; + most common values (100, 250, 4000) are encoded directly; maximal + expressible intensity is 12750 +- "exif" compressed [byte array](entropy_coding_basic.md#byte-array-field); + original image metadata +- "iptc" compressed [byte array](entropy_coding_basic.md#byte-array-field); + original image metadata +- "xmp" compressed [byte array](entropy_coding_basic.md#byte-array-field); + original image metadata + +## Transcoded + +[Optional structure](#optional-structures) that contains information original +image. + +Structure: + +- "all_default" 1 bit; see ["optional structures"](#optional-structures) +- "original_bit_depth" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 8, direct 16, direct 32, 5 bits] +- "original_color_encoding" nested [ColorEncoding](#colorencoding) structure +- "original_bytes_per_alpha" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 0, direct 1, direct 2, direct 4] + + +## Preview + +[Optional structure](#optional-structures) that contains information about +"discardable preview" image. Unlike early "progressive" passes, this image is +completely independent and optimized for better "preview" experience, i.e. +appropriately preprocessed and non-power-of-two scaled. + +Preview images have different size constraints than main image and currently +limited to 8K x 8K; size limit is set to be 128MiB. + +Structure: + +- "all_default" 1 bit; see ["optional structures"](#optional-structures) +- "size_bits" [field](entropy_coding_basic.md#uint32-field) / + L = [12 bits, 16 bits, 20 bits, 28 bits] +- "xsize" [field](entropy_coding_basic.md#uint32-field) / + L = [7 bits, 9 bits, 11 bits, 13 bits] +- "ysize" [field](entropy_coding_basic.md#uint32-field) / + L = [7 bits, 9 bits, 11 bits, 13 bits] + +## Animation + +[Optional structure](#optional-structures) that contains meta-information about +animation (image sequence). + +Structure: + +- "all_default" 1 bit; see ["optional structures"](#optional-structures) +- "num_loops" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 0, 3 bits, 16 bits, 32 bits]; \\(0\\) means to repeat infinitely +- "ticks_numerator" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 1, 9 bits, 20 bits, 32 bits] +- "ticks_denominator" [field](entropy_coding_basic.md#uint32-field) / + L = [direct 1, 9 bits, 20 bits, 32 bits] + +## Extensions + +This "structure" is usually put at the end of other structures which would be +extended in future. It allows earlier versions of decoders to skip the newer +fields. + +Structure: + +- "extensions" [field](entropy_coding_basic.md#uint64-field) +- if "extensions" is \\(0\\), then no extra information follows +- otherwise: + - "extension_bits" [field](entropy_coding_basic.md#uint64-field) - number + of bits to be skipped by decoder that does not expect extensions here + +## Optional structures + +Some structures are "optional". In case all the field values are equal to their +defaults, encoder is eligible to represent the structure with a single bit. + +Structures that contain non-empty [extension](#extensions) tail are ineligible +for 1-bit encoding. + +Technically, this bit is represented as "all_default" field that comes first; if +the value of this field is \\(1\\), then the rest of structure is not decoded. diff --git a/third_party/jpeg-xl/lib/jxl/docs/upsample.md b/third_party/jpeg-xl/lib/jxl/docs/upsample.md new file mode 100644 index 000000000000..bcbc54c580d4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/docs/upsample.md @@ -0,0 +1,151 @@ +# Upsampling + +[TOC] + + + +jxl/resample.h provides `Upsampler8` for fast and high-quality 8x8 upsampling by +4x4 (separable/non-separable) or 6x6 (non-separable) floating-point kernels. +This was previously used for the "smooth predictor", which has been removed. It +would still be useful for a progressive mode that upsamples DC for a preview, +though this code is not yet used for that purpose. + +See 'Separability' section below for the surprising result that non-separable +can be faster than separable and possibly better. + +## Performance evaluation + +### 4x4 separable/non-separable + +__Single-core__: 5.1 GB/s (single-channel, floating-point, 160x96 input) + +* 40x speedup vs. an unoptimized Nehab/Hoppe bicubic upsampler +* 25x or 6x speedup vs. [Pillow 4.3](http://python-pillow.org/pillow-perf/) + bicubic (AVX2: 66M RGB 8-bit per second) in terms of bytes or samples +* 2.6x AVX2 speedup vs. our SSE4 version (similar algorithm) + +__12 independent instances on 12 cores__: same speed for each instance (linear +scalability: not limited by memory bandwidth). + +__Multicore__: 15-18 GB/s (single-channel, floating-point, 320x192 input) + +* 9-11x or 2.3-2.8x speedup vs. a parallel Halide bicubic (1.6G single-channel + 8-bit per second, 320x192 input) in terms of bytes or samples. + +### 6x6 non-separable + +Note that a separable (outer-product) kernel only requires 6+6 (12) +multiplications per output pixel/vector. However, we do not assume anything +about the kernel rank (separability), and thus require 6x6 (36) multiplications. + +__Single-core__: 2.8 GB/s (single-channel, floating-point, 320x192 input) + +* 5x speedup vs. optimized (outer-product of two 1D) Lanczos without SIMD + +__Multicore__: 9-10 GB/s (single-channel, floating-point, 320x192 input) + +* 7-8x or 2x speedup vs. a parallel Halide 6-tap tensor-product Lanczos (1.3G + single-channel 8-bit per second, 320x192 input) in terms of bytes or + samples. + +## Implementation details + +### Data type + +Our input pixels are 16-bit, so 8-bit integer multiplications are insufficient. +16-bit fixed-point arithmetic is fast but risks overflow or loss of precision +unless the value intervals are carefully managed. 32-bit integers reduce this +concern, but 32-bit multiplies are slow on Haswell (one mullo32 per cycle with +10+1 cycle latency). We instead use single-precision float, which enables 2 FMA +per cycle with 5 cycle latency. The extra bandwidth vs. 16-bit types types may +be a concern, but we can run 12 instances (on a 12-core socket) with zero +slowdown, indicating that memory bandwidth is not the bottleneck at the moment. + +### Pixel layout + +We require planar inputs, e.g. all red channel samples in a 2D matrix. This +allows better utilization of 8-lane SIMD compared to dedicating four SIMD lanes +to R,G,B,A samples. If upsampling is the only operation, this requires an extra +deinterleaving step, but our application involves a larger processing pipeline. + +### No prefilter + +Nehab and Hoppe (http://hhoppe.com/filtering.pdf) advocate generalized sampling +with an additional digital filter step. They claim "significantly higher +quality" for upsampling operations, which is contradicted by minimal MSSIM +differences between cardinal and ordinary cubic interpolators in their +experiment on page 65. We also see only minor Butteraugli differences between +Catmull-Rom and B-spline3 or OMOMS3 when upsampling 8x. As they note on page 29, +a prefilter is often omitted when upsampling because the reconstruction filter's +frequency cutoff is lower than that of the prefilter. The prefilter is claimed +to be efficient (relative to normal separable convolutions), but still involves +separate horizontal and vertical passes. To avoid cache thrashing would require +a separate ring buffer of rows, which may be less efficient than our single-pass +algorithm which only writes final outputs. + +### 8x upsampling + +Our code is currently specific to 8x upsampling because that is what is required +for our (DCT prediction) application. This happens to be a particularly good fit +for both 4-lane and 8-lane (AVX2) SIMD AVX2. It would be relatively easy to +adapt the code to 4x upsampling. + +### Kernel support + +Even (asymmetric) kernels are often used to reduce computation. Many upsampling +applications use 4-tap cubic interpolation kernels but at such extreme +magnifications (8x) we find 6x6 to be better. + +### Separability + +Separable kernels can be expressed as the (outer) product of two 1D kernels. For +n x n kernels, this requires n + n multiplications per pixel rather than n x n. +However, the Fourier transform of such kernels (except the Gaussian) has a +square rather than disc shape. + +We instead allow arbitrary (possibly non-separable) kernels. This can avoid +structural dependencies between output pixels within the same 8x8 block. + +Surprisingly, 4x4 non-separable is actually faster than the separable version +due to better utilization of FMA units. For 6x6, we only implement the +non-separable version because our application benefits from such kernels. + +### Dual grid + +A primal grid with `n` grid points at coordinates `k/n` would be convenient for +index computations, but is asymmetric at the borders (0 and `(n-1)/n` != 1) and +does not compensate for the asymmetric (even) kernel support. We instead use a +dual grid with coordinates offset by half the sample spacing, which only adds a +integer shift term to the index computations. Note that the offset still leads +to four identical input pixels in 8x upsampling, which is convenient for 4 and +even 8-lane SIMD. + +### Kernel + +We use Catmull-Rom splines out of convenience. Computational cost does not +matter because the weights are precomputed. Also, Catmull-Rom splines pass +through the control points, but that does not matter in this application. +B-splines or other kernels (arbitrary coefficients, even non-separable) can also +be used. + +### Single pixel loads + +For SIMD convolution, we have the choice whether to broadcast inputs or weights. +To ensure we write a unit-stride vector of output pixels, we need to broadcast +the input pixels and vary the weights. This has the additional benefit of +avoiding complexity at the borders, where it is not safe to load an entire +vector. Instead, we only load and broadcast single pixels, with bounds checking +at the borders but not in the interior. + +### Single pass + +Separable 2D convolutions are often implemented as two separate 1D passes. +However, this leads to cache thrashing in the vertical pass, assuming the image +does not fit into L2 caches. We instead use a single-pass algorithm that +generates four (see "Kernel support" above) horizontal convolution results and +immediately convolves them in the vertical direction to produce a final output. +This is simpler than sliding windows and has a smaller working set, thus +enabling the major speedups reported above. diff --git a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc new file mode 100644 index 000000000000..af6773595e16 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc @@ -0,0 +1,907 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_ac_strategy.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fast_math-inl.h" + +// This must come before the begin/end_target, but HWY_ONCE is only true +// after that, so use an "include guard". +#ifndef LIB_JXL_ENC_AC_STRATEGY_ +#define LIB_JXL_ENC_AC_STRATEGY_ +// Parameters of the heuristic are marked with a OPTIMIZE comment. +namespace jxl { + +// Debugging utilities. + +// Returns a linear sRGB color (as bytes) for each AC strategy. +const uint8_t* TypeColor(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Change colors"); + static constexpr uint8_t kColors[][3] = { + {0xFF, 0xFF, 0x00}, // DCT8 + {0xFF, 0x80, 0x80}, // HORNUSS + {0xFF, 0x80, 0x80}, // DCT2x2 + {0xFF, 0x80, 0x80}, // DCT4x4 + {0x80, 0xFF, 0x00}, // DCT16x16 + {0x00, 0xC0, 0x00}, // DCT32x32 + {0xC0, 0xFF, 0x00}, // DCT16x8 + {0xC0, 0xFF, 0x00}, // DCT8x16 + {0x00, 0xFF, 0x00}, // DCT32x8 + {0x00, 0xFF, 0x00}, // DCT8x32 + {0x00, 0xFF, 0x00}, // DCT32x16 + {0x00, 0xFF, 0x00}, // DCT16x32 + {0xFF, 0x80, 0x00}, // DCT4x8 + {0xFF, 0x80, 0x00}, // DCT8x4 + {0xFF, 0xFF, 0x80}, // AFV0 + {0xFF, 0xFF, 0x80}, // AFV1 + {0xFF, 0xFF, 0x80}, // AFV2 + {0xFF, 0xFF, 0x80}, // AFV3 + {0x00, 0xC0, 0xFF}, // DCT64x64 + {0x00, 0xFF, 0xFF}, // DCT64x32 + {0x00, 0xFF, 0xFF}, // DCT32x64 + {0x00, 0x40, 0xFF}, // DCT128x128 + {0x00, 0x80, 0xFF}, // DCT128x64 + {0x00, 0x80, 0xFF}, // DCT64x128 + {0x00, 0x00, 0xC0}, // DCT256x256 + {0x00, 0x00, 0xFF}, // DCT256x128 + {0x00, 0x00, 0xFF}, // DCT128x256 + }; + return kColors[raw_strategy]; +} + +const uint8_t* TypeMask(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Add masks"); + // implicitly, first row and column is made dark + static constexpr uint8_t kMask[][64] = { + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT8 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 1, 1, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // HORNUSS + { + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + }, // 2x2 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // 4x4 + {}, // DCT16x16 (unused) + {}, // DCT32x32 (unused) + {}, // DCT16x8 (unused) + {}, // DCT8x16 (unused) + {}, // DCT32x8 (unused) + {}, // DCT8x32 (unused) + {}, // DCT32x16 (unused) + {}, // DCT16x32 (unused) + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT4x8 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // DCT8x4 + { + 1, 1, 1, 1, 1, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV0 + { + 0, 0, 0, 0, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV1 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + }, // AFV2 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + }, // AFV3 + }; + return kMask[raw_strategy]; +} + +void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize, + size_t ysize, const char* tag, AuxOut* aux_out) { + Image3F color_acs(xsize, ysize); + for (size_t y = 0; y < ysize; y++) { + float* JXL_RESTRICT rows[3] = { + color_acs.PlaneRow(0, y), + color_acs.PlaneRow(1, y), + color_acs.PlaneRow(2, y), + }; + const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim); + for (size_t x = 0; x < xsize; x++) { + AcStrategy acs = acs_row[x / kBlockDim]; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + for (size_t c = 0; c < 3; c++) { + rows[c][x] = color[c] / 255.f; + } + } + } + size_t stride = color_acs.PixelsPerRow(); + for (size_t c = 0; c < 3; c++) { + for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) { + float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim); + const AcStrategyRow acs_row = ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy()); + if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) { + for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize; + iy++) { + for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize; + ix++) { + if (mask[iy * kBlockDim + ix]) { + row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f; + } + } + } + } + // draw block edges + for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() && + bx * kBlockDim + ix < xsize; + ix++) { + row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f; + } + for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() && + by * kBlockDim + iy < ysize; + iy++) { + row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f; + } + } + } + } + aux_out->DumpImage(tag, color_acs); +} + +// AC strategy selection: recursive block splitting. + +namespace { +template +size_t ACSCandidates(const AcStrategy::Type (&in)[N], + AcStrategy::Type* JXL_RESTRICT out) { + memcpy(out, in, N * sizeof(AcStrategy::Type)); + return N; +} + +// Order in which transforms are tested for max delta: the first +// acceptable one is chosen as initial guess. +constexpr AcStrategy::Type kACSOrder[] = { + AcStrategy::Type::DCT64X64, + AcStrategy::Type::DCT64X32, + AcStrategy::Type::DCT32X64, + AcStrategy::Type::DCT32X32, + AcStrategy::Type::DCT32X16, + AcStrategy::Type::DCT16X32, + AcStrategy::Type::DCT16X16, + // TODO(Jyrki): Restore these when we have better heuristics. + // AcStrategy::Type::DCT8X32, + // AcStrategy::Type::DCT32X8, + AcStrategy::Type::DCT16X8, + AcStrategy::Type::DCT8X16, + // DCT8x8 is the "fallback" option if no bigger transform can be used. + AcStrategy::Type::DCT, +}; + +size_t ACSPossibleReplacements(AcStrategy::Type current, + AcStrategy::Type* JXL_RESTRICT out) { + // TODO(veluca): is this decision tree optimal? + if (current == AcStrategy::Type::DCT64X64) { + return ACSCandidates( + {AcStrategy::Type::DCT64X32, AcStrategy::Type::DCT32X64, + AcStrategy::Type::DCT32X32, AcStrategy::Type::DCT16X16, + AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT64X32 || + current == AcStrategy::Type::DCT32X64) { + return ACSCandidates({AcStrategy::Type::DCT32X32, + AcStrategy::Type::DCT16X16, AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT32X32) { + return ACSCandidates( + {AcStrategy::Type::DCT32X16, AcStrategy::Type::DCT16X32, + AcStrategy::Type::DCT16X16, AcStrategy::Type::DCT16X8, + AcStrategy::Type::DCT8X16, AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT32X16) { + return ACSCandidates({AcStrategy::Type::DCT32X8, AcStrategy::Type::DCT16X16, + AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT16X32) { + return ACSCandidates({AcStrategy::Type::DCT8X32, AcStrategy::Type::DCT16X16, + AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT32X8) { + return ACSCandidates({AcStrategy::Type::DCT16X8, AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT8X32) { + return ACSCandidates({AcStrategy::Type::DCT8X16, AcStrategy::Type::DCT}, + out); + } + if (current == AcStrategy::Type::DCT16X16) { + return ACSCandidates({AcStrategy::Type::DCT8X16, AcStrategy::Type::DCT16X8}, + out); + } + if (current == AcStrategy::Type::DCT16X8 || + current == AcStrategy::Type::DCT8X16) { + return ACSCandidates({AcStrategy::Type::DCT}, out); + } + if (current == AcStrategy::Type::DCT) { + return ACSCandidates({AcStrategy::Type::DCT4X8, AcStrategy::Type::DCT8X4, + AcStrategy::Type::DCT4X4, AcStrategy::Type::DCT2X2, + AcStrategy::Type::IDENTITY, AcStrategy::Type::AFV0, + AcStrategy::Type::AFV1, AcStrategy::Type::AFV2, + AcStrategy::Type::AFV3}, + out); + } + // Other 8x8 have no replacements - they already were chosen as the best + // between all the 8x8s. + return 0; +} + +void InitEntropyAdjustTable(float* entropy_adjust) { + // Precomputed FMA: premultiply `add` by `mul` so that the previous + // entropy *= add; entropy *= mul becomes entropy = MulAdd(entropy, mul, add). + const auto set = [entropy_adjust](size_t raw_strategy, float add, float mul) { + entropy_adjust[2 * raw_strategy + 0] = add * mul; + entropy_adjust[2 * raw_strategy + 1] = mul; + }; + set(AcStrategy::Type::DCT, 0.0f, 0.80f); + set(AcStrategy::Type::DCT4X4, 4.0f, 0.79f); + set(AcStrategy::Type::DCT2X2, 4.0f, 1.1f); + set(AcStrategy::Type::DCT16X16, 0.0f, 0.83f); + set(AcStrategy::Type::DCT64X64, 0.0f, 1.3f); + set(AcStrategy::Type::DCT64X32, 0.0f, 1.15f); + set(AcStrategy::Type::DCT32X64, 0.0f, 1.15f); + set(AcStrategy::Type::DCT32X32, 0.0f, 0.97f); + set(AcStrategy::Type::DCT16X32, 0.0f, 0.94f); + set(AcStrategy::Type::DCT32X16, 0.0f, 0.94f); + set(AcStrategy::Type::DCT32X8, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT8X32, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT16X8, 0.0f, 0.86f); + set(AcStrategy::Type::DCT8X16, 0.0f, 0.86f); + set(AcStrategy::Type::DCT4X8, 3.0f, 0.81f); + set(AcStrategy::Type::DCT8X4, 3.0f, 0.81f); + set(AcStrategy::Type::IDENTITY, 8.0f, 1.2f); + set(AcStrategy::Type::AFV0, 3.0f, 0.77f); + set(AcStrategy::Type::AFV1, 3.0f, 0.77f); + set(AcStrategy::Type::AFV2, 3.0f, 0.77f); + set(AcStrategy::Type::AFV3, 3.0f, 0.77f); + set(AcStrategy::Type::DCT128X128, 0.0f, 1.0f); + set(AcStrategy::Type::DCT128X64, 0.0f, 0.73f); + set(AcStrategy::Type::DCT64X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT256X256, 0.0f, 1.0f); + set(AcStrategy::Type::DCT256X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT128X256, 0.0f, 0.73f); + static_assert(AcStrategy::kNumValidStrategies == 27, "Keep in sync"); +} +} // namespace + +} // namespace jxl +#endif // LIB_JXL_ENC_AC_STRATEGY_ + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; + +float EstimateEntropy(const AcStrategy& acs, size_t x, size_t y, + const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, float* block, + float* scratch_space, uint32_t* quantized) { + const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize; + + // Apply transform. + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT block_c = block + size * c; + TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y), + config.src_stride, block_c, scratch_space); + } + + HWY_FULL(float) df; + HWY_FULL(int) di; + + const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y(); + float quant = 0; + float quant_norm8 = 0; + float masking = 0; + { + float masking_norm2 = 0; + float masking_max = 0; + // Load QF value, calculate empirical heuristic on masking field + // for weighting the information loss. Information loss manifests + // itself as ringing, and masking could hide it. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float qval = config.Quant(x / 8 + ix, y / 8 + iy); + quant = std::max(quant, qval); + qval *= qval; + qval *= qval; + quant_norm8 += qval * qval; + float maskval = config.Masking(x / 8 + ix, y / 8 + iy); + masking_max = std::max(masking_max, maskval); + masking_norm2 += maskval * maskval; + } + } + quant_norm8 /= num_blocks; + quant_norm8 = FastPowf(quant_norm8, 1.0f / 8.0f); + masking_norm2 = sqrt(masking_norm2 / num_blocks); + // This is a highly empirical formula. + masking = (masking_norm2 + masking_max); + } + const auto q = Set(df, quant_norm8); + + // Compute entropy. + float entropy = config.base_entropy; + auto info_loss = Zero(df); + + for (size_t c = 0; c < 3; c++) { + const float* inv_matrix = config.dequant->InvMatrix(acs.RawStrategy(), c); + const auto cmap_factor = Set(df, cmap_factors[c]); + + auto entropy_v = Zero(df); + auto nzeros_v = Zero(di); + auto cost1 = Set(df, config.cost1); + auto cost2 = Set(df, config.cost2); + auto cost_delta = Set(df, config.cost_delta); + for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) { + const auto in = Load(df, block + c * size + i); + const auto in_y = Load(df, block + size + i) * cmap_factor; + const auto im = Load(df, inv_matrix + i); + const auto val = (in - in_y) * im * q; + const auto rval = Round(val); + info_loss += AbsDiff(val, rval); + const auto q = Abs(rval); + const auto q_is_zero = q == Zero(df); + entropy_v += IfThenElseZero(q >= Set(df, 0.5f), cost1); + entropy_v += IfThenElseZero(q >= Set(df, 1.5f), cost2); + // We used to have q * C here, but that cost model seems to + // be punishing large values more than necessary. Sqrt tries + // to avoid large values less aggressively. Having high accuracy + // around zero is most important at low qualities, and there + // we have directly specified costs for 0, 1, and 2. + entropy_v += Sqrt(q) * cost_delta; + nzeros_v += + BitCast(di, IfThenZeroElse(q_is_zero, BitCast(df, Set(di, 1)))); + } + entropy += GetLane(SumOfLanes(entropy_v)); + size_t num_nzeros = GetLane(SumOfLanes(nzeros_v)); + // Add #bit of num_nonzeros, as an estimate of the cost for encoding the + // number of non-zeros of the block. + size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1; + // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a + // bias. + entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits); + } + float ret = entropy + masking * config.info_loss_multiplier * + GetLane(SumOfLanes(info_loss)); + return ret; +} + +void MaybeReplaceACS(size_t bx, size_t by, const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, + const float* JXL_RESTRICT entropy_adjust, + float* JXL_RESTRICT entropy_estimate, float* block, + float* scratch_space, uint32_t* quantized) { + AcStrategy::Type current = + AcStrategy::Type(ac_strategy->ConstRow(by)[bx].RawStrategy()); + AcStrategy::Type candidates[AcStrategy::kNumValidStrategies]; + size_t num_candidates = ACSPossibleReplacements(current, candidates); + if (num_candidates == 0) return; + size_t best = num_candidates; + float best_ee = entropy_estimate[0]; + // For each candidate replacement strategy, keep track of its entropy + // estimate. + constexpr size_t kFit64X64DctInBlocks = 64 * 64 / (8 * 8); + float ee_val[AcStrategy::kNumValidStrategies][kFit64X64DctInBlocks]; + AcStrategy current_acs = AcStrategy::FromRawStrategy(current); + for (size_t cand = 0; cand < num_candidates; cand++) { + AcStrategy acs = AcStrategy::FromRawStrategy(candidates[cand]); + size_t idx = 0; + float total_entropy = 0; + for (size_t iy = 0; iy < current_acs.covered_blocks_y(); + iy += acs.covered_blocks_y()) { + for (size_t ix = 0; ix < current_acs.covered_blocks_x(); + ix += acs.covered_blocks_x()) { + const HWY_CAPPED(float, 1) df1; + auto entropy1 = + Set(df1, + EstimateEntropy(acs, (bx + ix) * 8, (by + iy) * 8, config, + cmap_factors, block, scratch_space, quantized)); + entropy1 = MulAdd(entropy1, + Set(df1, entropy_adjust[2 * acs.RawStrategy() + 1]), + Set(df1, entropy_adjust[2 * acs.RawStrategy() + 0])); + const float entropy = GetLane(entropy1); + ee_val[cand][idx] = entropy; + total_entropy += entropy; + idx++; + } + } + if (total_entropy < best_ee) { + best_ee = total_entropy; + best = cand; + } + } + // Nothing changed. + if (best == num_candidates) return; + AcStrategy acs = AcStrategy::FromRawStrategy(candidates[best]); + size_t idx = 0; + for (size_t y = 0; y < current_acs.covered_blocks_y(); + y += acs.covered_blocks_y()) { + for (size_t x = 0; x < current_acs.covered_blocks_x(); + x += acs.covered_blocks_x()) { + ac_strategy->Set(bx + x, by + y, candidates[best]); + for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) { + for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) { + entropy_estimate[iy * 8 + ix] = ee_val[best][idx]; + } + } + idx++; + } + } +} + +void ProcessRectACS(PassesEncoderState* JXL_RESTRICT enc_state, + const ACSConfig& config, float* entropy_adjust, + const Rect& rect) { + const CompressParams& cparams = enc_state->cparams; + const float butteraugli_target = cparams.butteraugli_distance; + AcStrategyImage* ac_strategy = &enc_state->shared.ac_strategy; + + const size_t xsize_blocks = enc_state->shared.frame_dim.xsize_blocks; + const size_t ysize_blocks = enc_state->shared.frame_dim.ysize_blocks; + + // Maximum delta that every strategy type is allowed to have in the area + // it covers. Ignored for 8x8 transforms. This heuristic is now mostly + // disabled. + const float kMaxDelta = + 0.5f * std::sqrt(butteraugli_target + 0.5); // OPTIMIZE + + // TODO(veluca): reuse allocations + auto mem = hwy::AllocateAligned(5 * AcStrategy::kMaxCoeffArea); + auto qmem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + uint32_t* JXL_RESTRICT quantized = qmem.get(); + float* JXL_RESTRICT block = mem.get(); + float* JXL_RESTRICT scratch_space = mem.get() + 3 * AcStrategy::kMaxCoeffArea; + size_t bx = rect.x0(); + size_t by = rect.y0(); + JXL_ASSERT(rect.xsize() <= 8); + JXL_ASSERT(rect.ysize() <= 8); + size_t tx = bx / kColorTileDimInBlocks; + size_t ty = by / kColorTileDimInBlocks; + const float cmap_factors[3] = { + enc_state->shared.cmap.YtoXRatio( + enc_state->shared.cmap.ytox_map.ConstRow(ty)[tx]), + 0.0f, + enc_state->shared.cmap.YtoBRatio( + enc_state->shared.cmap.ytob_map.ConstRow(ty)[tx]), + }; + HWY_CAPPED(float, kBlockDim) d; + HWY_CAPPED(uint32_t, kBlockDim) di; + + // Padded, see UpdateMaxFlatness. + HWY_ALIGN float pixels[3][8 + 64 + 8]; + for (size_t c = 0; c < 3; ++c) { + pixels[c][8 - 2] = pixels[c][8 - 1] = 0.0f; // value does not matter + pixels[c][64] = pixels[c][64 + 1] = 0.0f; // value does not matter + } + + // Scale of channels when computing delta. + const float kDeltaScale[3] = {3.0f, 1.0f, 0.2f}; + + // Pre-compute maximum delta in each 8x8 block. + // Find a minimum delta of three options: + // 1) all, 2) not accounting vertical, 3) not accounting horizontal + float max_delta[3][64] = {}; + float entropy_estimate[64] = {}; + for (size_t c = 0; c < 3; c++) { + for (size_t iy = 0; iy < rect.ysize(); iy++) { + size_t dy = by + iy; + for (size_t ix = 0; ix < rect.xsize(); ix++) { + size_t dx = bx + ix; + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x += Lanes(d)) { + const auto v = Load(d, &config.Pixel(c, dx * 8 + x, dy * 8 + y)); + Store(v, d, &pixels[c][y * 8 + x + 8]); + } + } + + auto delta = Zero(d); + for (size_t x = 0; x < 8; x += Lanes(d)) { + HWY_ALIGN const uint32_t kMask[] = {0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, 0u}; + auto mask = BitCast(d, Load(di, kMask + x)); + for (size_t y = 1; y < 7; y++) { + float* pix = &pixels[c][y * 8 + x + 8]; + const auto p = Load(d, pix); + const auto n = Load(d, pix + 8); + const auto s = Load(d, pix - 8); + const auto w = LoadU(d, pix - 1); + const auto e = LoadU(d, pix + 1); + // Compute amount of per-pixel variation. + const auto m1 = Max(AbsDiff(n, p), AbsDiff(s, p)); + const auto m2 = Max(AbsDiff(w, p), AbsDiff(e, p)); + const auto m3 = Max(AbsDiff(e, w), AbsDiff(s, n)); + const auto m4 = Max(m1, m2); + const auto m5 = Max(m3, m4); + delta = Max(delta, m5); + } + const float mdelta = GetLane(MaxOfLanes(And(mask, delta))); + max_delta[c][iy * 8 + ix] = + std::max(max_delta[c][iy * 8 + ix], mdelta * kDeltaScale[c]); + } + } + } + } + + // Choose the first transform that can be used to cover each block. + uint8_t chosen_mask[64] = {0}; + for (size_t iy = 0; iy < rect.ysize(); iy++) { + for (size_t ix = 0; ix < rect.xsize(); ix++) { + if (chosen_mask[iy * 8 + ix]) continue; + for (auto i : kACSOrder) { + AcStrategy acs = AcStrategy::FromRawStrategy(i); + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + // Only blocks up to a certain size if targeting faster decoding. + if (cparams.decoding_speed_tier >= 1) { + if (cx * cy > 16) continue; + } + if (cparams.decoding_speed_tier >= 2) { + if (cx * cy > 8) continue; + } + float max_delta_v[3] = {max_delta[0][iy * 8 + ix], + max_delta[1][iy * 8 + ix], + max_delta[2][iy * 8 + ix]}; + float max2_delta_v[3] = {0, 0, 0}; + float max_delta_acs = + std::max(std::max(max_delta_v[0], max_delta_v[1]), max_delta_v[2]); + float min_delta_v[3] = {1e30f, 1e30f, 1e30f}; + float ave_delta_v[3] = {}; + // Check if strategy is usable + if (cx != 1 || cy != 1) { + // Alignment + if ((iy & (cy - 1)) != 0) continue; + if ((ix & (cx - 1)) != 0) continue; + // Out of block64 bounds + if (iy + cy > 8) continue; + if (ix + cx > 8) continue; + // Out of image bounds + if (by + iy + cy > ysize_blocks) continue; + if (bx + ix + cx > xsize_blocks) continue; + // Block would overwrite an already-chosen block + bool overwrites_covered = false; + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + if (chosen_mask[(y + iy) * 8 + x + ix]) overwrites_covered = true; + } + } + if (overwrites_covered) continue; + for (size_t c = 0; c < 3; ++c) { + max_delta_v[c] = 0; + max2_delta_v[c] = 0; + min_delta_v[c] = 1e30f; + ave_delta_v[c] = 0; + // Max delta in covered area + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + int pix = (iy + y) * 8 + ix + x; + if (max_delta_v[c] < max_delta[c][pix]) { + max2_delta_v[c] = max_delta_v[c]; + max_delta_v[c] = max_delta[c][pix]; + } else if (max2_delta_v[c] < max_delta[c][pix]) { + max2_delta_v[c] = max_delta[c][pix]; + } + min_delta_v[c] = std::min(min_delta_v[c], max_delta[c][pix]); + ave_delta_v[c] += max_delta[c][pix]; + } + } + ave_delta_v[c] -= max_delta_v[c]; + if (cy * cx >= 5) { + ave_delta_v[c] -= max2_delta_v[c]; + ave_delta_v[c] /= (cy * cx - 2); + } else { + ave_delta_v[c] /= (cy * cx - 1); + } + max_delta_v[c] -= 0.03f * max2_delta_v[c]; + max_delta_v[c] -= 0.25f * min_delta_v[c]; + max_delta_v[c] -= 0.25f * ave_delta_v[c]; + } + max_delta_acs = max_delta_v[0] + max_delta_v[1] + max_delta_v[2]; + max_delta_acs *= std::pow(1.044f, cx * cy); + if (max_delta_acs > kMaxDelta) continue; + } + // Estimate entropy and qf value + float entropy = 0.0f; + // In modes faster than Wombat mode, AC strategy replacement is not + // attempted: no need to estimate entropy. + if (cparams.speed_tier <= SpeedTier::kWombat) { + entropy = + EstimateEntropy(acs, (bx + ix) * 8, (by + iy) * 8, config, + cmap_factors, block, scratch_space, quantized); + entropy *= entropy_adjust[i * 2 + 1]; + } + // In modes faster than Hare mode, we don't use InitialQuantField - + // hence, we need to come up with quant field values. + if (cparams.speed_tier > SpeedTier::kHare && + cparams.uniform_quant <= 0) { + // OPTIMIZE + float quant = 1.1f / (1.0f + max_delta_acs) / butteraugli_target; + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + config.SetQuant(bx + ix + x, by + iy + y, quant); + } + } + } + // Mark blocks as chosen and write to acs image. + ac_strategy->Set(bx + ix, by + iy, i); + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + chosen_mask[(y + iy) * 8 + x + ix] = 1; + entropy_estimate[(iy + y) * 8 + ix + x] = entropy; + } + } + break; + } + } + } + // Do not try to replace ACS in modes faster than wombat mode. + if (cparams.speed_tier > SpeedTier::kWombat) return; + // Iterate through the 32-block attempting to replace the current strategy. + // If replaced, repeat for the top-left new block and let the other ones be + // taken care of by future iterations. + uint8_t computed_mask[64] = {}; + for (size_t iy = 0; iy < rect.ysize(); iy++) { + for (size_t ix = 0; ix < rect.xsize(); ix++) { + if (computed_mask[iy * 8 + ix]) continue; + uint8_t prev = AcStrategy::kNumValidStrategies; + while (prev != ac_strategy->ConstRow(by + iy)[bx + ix].RawStrategy()) { + prev = ac_strategy->ConstRow(by + iy)[bx + ix].RawStrategy(); + MaybeReplaceACS(bx + ix, by + iy, config, cmap_factors, ac_strategy, + entropy_adjust, entropy_estimate + (iy * 8 + ix), block, + scratch_space, quantized); + } + AcStrategy acs = ac_strategy->ConstRow(by + iy)[bx + ix]; + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + computed_mask[(iy + y) * 8 + ix + x] = 1; + } + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessRectACS); + +void AcStrategyHeuristics::Init(const Image3F& src, + PassesEncoderState* enc_state) { + this->enc_state = enc_state; + const CompressParams& cparams = enc_state->cparams; + const float butteraugli_target = cparams.butteraugli_distance; + + config.dequant = &enc_state->shared.matrices; + + // Image row pointers and strides. + config.quant_field_row = enc_state->initial_quant_field.Row(0); + config.quant_field_stride = enc_state->initial_quant_field.PixelsPerRow(); + auto& mask = enc_state->initial_quant_masking; + if (mask.xsize() > 0 && mask.ysize() > 0) { + config.masking_field_row = mask.Row(0); + config.masking_field_stride = mask.PixelsPerRow(); + } + + config.src_rows[0] = src.ConstPlaneRow(0, 0); + config.src_rows[1] = src.ConstPlaneRow(1, 0); + config.src_rows[2] = src.ConstPlaneRow(2, 0); + config.src_stride = src.PixelsPerRow(); + + InitEntropyAdjustTable(entropy_adjust); + + // Entropy estimate is composed of two factors: + // - estimate of the number of bits that will be used by the block + // - information loss due to quantization + // The following constant controls the relative weights of these components. + // TODO(jyrki): better choice of constants/parameterization. + config.info_loss_multiplier = 39.2; + config.base_entropy = 30.0; + config.zeros_mul = 0.3; // Possibly a bigger value would work better. + if (butteraugli_target < 2) { + config.cost1 = 2.1467536133280064f; + config.cost2 = 4.5233239814548617f; + config.cost_delta = 2.7192877948074784f; + } else if (butteraugli_target < 4) { + config.cost1 = 3.3478899662356103f; + config.cost2 = 3.2493410394508086f; + config.cost_delta = 2.9192251887428096f; + } else if (butteraugli_target < 8) { + config.cost1 = 3.9758237938237959f; + config.cost2 = 1.2423859153559777f; + config.cost_delta = 3.1181324266623122f; + } else if (butteraugli_target < 16) { + config.cost1 = 2.5; + config.cost2 = 2.2630019747782897f; + config.cost_delta = 3.8409539247825222f; + } else { + config.cost1 = 1.5; + config.cost2 = 2.6952503610099059f; + config.cost_delta = 4.316274170126156f; + } + + JXL_ASSERT(enc_state->shared.ac_strategy.xsize() == + enc_state->shared.frame_dim.xsize_blocks); + JXL_ASSERT(enc_state->shared.ac_strategy.ysize() == + enc_state->shared.frame_dim.ysize_blocks); +} + +void AcStrategyHeuristics::ProcessRect(const Rect& rect) { + PROFILER_FUNC; + const CompressParams& cparams = enc_state->cparams; + // In Falcon mode, use DCT8 everywhere and uniform quantization. + if (cparams.speed_tier == SpeedTier::kFalcon) { + enc_state->shared.ac_strategy.FillDCT8(rect); + return; + } + HWY_DYNAMIC_DISPATCH(ProcessRectACS) + (enc_state, config, entropy_adjust, rect); +} + +void AcStrategyHeuristics::Finalize(AuxOut* aux_out) { + const auto& ac_strategy = enc_state->shared.ac_strategy; + // Accounting and debug output. + if (aux_out != nullptr) { + aux_out->num_dct2_blocks = + 32 * (ac_strategy.CountBlocks(AcStrategy::Type::DCT32X64) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT64X32)); + aux_out->num_dct4_blocks = + 64 * ac_strategy.CountBlocks(AcStrategy::Type::DCT64X64); + aux_out->num_dct4x8_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT4X8) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X4); + aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategy::Type::AFV0) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV1) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV2) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV3); + aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategy::Type::DCT); + aux_out->num_dct8x16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X16) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X8); + aux_out->num_dct8x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X8); + aux_out->num_dct16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X16); + aux_out->num_dct16x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X16); + aux_out->num_dct32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X32); + } + + if (WantDebugOutput(aux_out)) { + DumpAcStrategy(ac_strategy, enc_state->shared.frame_dim.xsize, + enc_state->shared.frame_dim.ysize, "ac_strategy", aux_out); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h new file mode 100644 index 000000000000..fe4f55562478 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h @@ -0,0 +1,88 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_AC_STRATEGY_H_ +#define LIB_JXL_ENC_AC_STRATEGY_H_ + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +// `FindBestAcStrategy` uses heuristics to choose which AC strategy should be +// used in each block, as well as the initial quantization field. + +namespace jxl { + +// AC strategy selection: utility struct. + +struct ACSConfig { + const DequantMatrices* JXL_RESTRICT dequant; + float info_loss_multiplier; + float* JXL_RESTRICT quant_field_row; + size_t quant_field_stride; + float* JXL_RESTRICT masking_field_row; + size_t masking_field_stride; + const float* JXL_RESTRICT src_rows[3]; + size_t src_stride; + // Cost for 1 (-1), 2 (-2) explicitly, cost for others computed with cost1 + + // cost2 + sqrt(q) * cost_delta. + float cost1; + float cost2; + float cost_delta; + float base_entropy; + float zeros_mul; + const float& Pixel(size_t c, size_t x, size_t y) const { + return src_rows[c][y * src_stride + x]; + } + float Masking(size_t bx, size_t by) const { + JXL_DASSERT(masking_field_row[by * masking_field_stride + bx] > 0); + return masking_field_row[by * masking_field_stride + bx]; + } + float Quant(size_t bx, size_t by) const { + JXL_DASSERT(quant_field_row[by * quant_field_stride + bx] > 0); + return quant_field_row[by * quant_field_stride + bx]; + } + void SetQuant(size_t bx, size_t by, float value) const { + JXL_DASSERT(value > 0); + quant_field_row[by * quant_field_stride + bx] = value; + } +}; + +struct AcStrategyHeuristics { + void Init(const Image3F& src, PassesEncoderState* enc_state); + void ProcessRect(const Rect& rect); + void Finalize(AuxOut* aux_out); + ACSConfig config; + PassesEncoderState* enc_state; + float entropy_adjust[2 * AcStrategy::kNumValidStrategies]; +}; + +// Debug. +void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize, + size_t ysize, const char* tag, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_AC_STRATEGY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc new file mode 100644 index 000000000000..59c42798a6d9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc @@ -0,0 +1,1090 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_adaptive_quantization.h" + +#include +#include +#include + +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/dec_reconstruct.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; + +// The following functions modulate an exponent (out_val) and return the updated +// value. Their descriptor is limited to 8 lanes for 8x8 blocks. + +// Hack for mask estimation. Eventually replace this code with butteraugli's +// masking. +float ComputeMaskForAcStrategyUse(const float out_val) { + const float kMul = 1.0f; + const float kOffset = 0.4f; + return kMul / (out_val + kOffset); +} + +template +V ComputeMask(const D d, const V out_val) { + const auto kBase = Set(d, -0.74174993f); + const auto kMul4 = Set(d, 3.2353257320940401f); + const auto kMul2 = Set(d, 12.906028311180409f); + const auto kOffset2 = Set(d, 305.04035728311436f); + const auto kMul3 = Set(d, 5.0220313103171232f); + const auto kOffset3 = Set(d, 2.1925739705298404f); + const auto kOffset4 = Set(d, 0.25f) * kOffset3; + const auto kMul0 = Set(d, 0.74760422233706747f); + const auto k1 = Set(d, 1.0f); + + // Avoid division by zero. + const auto v1 = Max(out_val * kMul0, Set(d, 1e-3f)); + const auto v2 = k1 / (v1 + kOffset2); + const auto v3 = k1 / MulAdd(v1, v1, kOffset3); + const auto v4 = k1 / MulAdd(v1, v1, kOffset4); + // TODO(jyrki): + // A log or two here could make sense. In butteraugli we have effectively + // log(log(x + C)) for this kind of use, as a single log is used in + // saturating visual masking and here the modulation values are exponential, + // another log would counter that. + return kBase + MulAdd(kMul4, v4, MulAdd(kMul2, v2, kMul3 * v3)); +} + +// For converting full vectors to a subset. Assumes `vfull` lanes are identical. +template +Vec CapTo(const D d, VFull vfull) { + using T = typename D::T; + const HWY_FULL(T) dfull; + HWY_ALIGN T lanes[MaxLanes(dfull)]; + Store(vfull, dfull, lanes); + return Load(d, lanes); +} + +// mul and mul2 represent a scaling difference between jxl and butteraugli. +static const float kSGmul = 226.0480446705883f; +static const float kSGmul2 = 1.0f / 73.377132366608819f; +static const float kLog2 = 0.693147181f; +// Includes correction factor for std::log -> log2. +static const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; +static const float kSGRetAdd = kSGmul2 * -20.2789020414f; +static const float kSGVOffset = 7.14672470003f; + +template +V SimpleGamma(const D d, V v) { + // A simple HDR compatible gamma function. + const auto mul = Set(d, kSGmul); + const auto kRetMul = Set(d, kSGRetMul); + const auto kRetAdd = Set(d, kSGRetAdd); + const auto kVOffset = Set(d, kSGVOffset); + + v *= mul; + + // This should happen rarely, but may lead to a NaN, which is rather + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + // TODO(veluca): with FastLog2f, this no longer leads to NaNs. + v = ZeroIfNegative(v); + return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; +} + +template +V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { + // The opsin space in jxl is the cubic root of photons, i.e., v * v * v + // is related to the number of photons. + // + // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. + // This ratio allows quantization to move from jxl's opsin space to + // butteraugli's log-gamma space. + v = ZeroIfNegative(v); + const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); + const auto kVOffset = Set(d, kSGVOffset * kLog2); + const auto kDenMul = Set(d, kLog2 * kSGmul); + + const auto v2 = v * v; + + const auto num = kNumMul * v2; + const auto den = MulAdd(kDenMul * v, v2, kVOffset); + return invert ? num / den : den / num; +} + +template +static float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane( + RatioOfDerivativesOfCubicRootToSimpleGamma(DScalar(), vscalar)); +} + +// TODO(veluca): this function computes an approximation of the derivative of +// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or +// exact derivatives. +template +V GammaModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const V out_val) { + const float kBias = 0.16f; + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[0]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[1]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[2]); + auto overall_ratio = Zero(d); + auto bias = Set(d, kBias); + auto half = Set(d, 0.5f); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = xyb_x.Row(y + dy); + const float* const JXL_RESTRICT row_in_y = xyb_y.Row(y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto iny = Load(d, row_in_y + x + dx) + bias; + const auto inx = Load(d, row_in_x + x + dx); + const auto r = iny - inx; + const auto g = iny + inx; + const auto ratio_r = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, r); + const auto ratio_g = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, g); + const auto avg_ratio = half * (ratio_r + ratio_g); + + overall_ratio += avg_ratio; + } + } + overall_ratio = SumOfLanes(overall_ratio); + overall_ratio *= Set(d, 1.0f / 64); + // ideally -1.0, but likely optimal correction adds some entropy, so slightly + // less than that. + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kGam = Set(d, -0.15526878023684174f * 0.693147180559945f); + return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val); +} + +// Change precision in 8x8 blocks that have high frequency content. +template +V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb, + const V out_val) { + // Zero out the invalid differences for the rightmost value per row. + const Rebind du; + HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, 0}; + + auto sum = Zero(d); // sum of absolute differences with right and below + + for (size_t dy = 0; dy < 8; ++dy) { + const float* JXL_RESTRICT row_in = xyb.Row(y + dy) + x; + const float* JXL_RESTRICT row_in_next = + dy == 7 ? row_in : xyb.Row(y + dy + 1) + x; + + // In SCALAR, there is no guarantee of having extra row padding. + // Hence, we need to ensure we don't access pixels outside the row itself. + // In SIMD modes, however, rows are padded, so it's safe to access one + // garbage value after the row. The vector then gets masked with kMaskRight + // to remove the influence of that value. +#if HWY_TARGET != HWY_SCALAR + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { +#else + for (size_t dx = 0; dx < 7; dx += Lanes(d)) { +#endif + const auto p = Load(d, row_in + dx); + const auto pr = LoadU(d, row_in + dx + 1); + const auto mask = BitCast(d, Load(du, kMaskRight + dx)); + sum += And(mask, AbsDiff(p, pr)); + + const auto pd = Load(d, row_in_next + dx); + sum += AbsDiff(p, pd); + } + } + + sum = SumOfLanes(sum); + return MulAdd(sum, Set(d, -2.0052193233688884f / 112), out_val); +} + +void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, + const ImageF& xyb_y, const float scale, + const Rect& rect, ImageF* out) { + JXL_ASSERT(SameSize(xyb_x, xyb_y)); + JXL_ASSERT(DivCeil(xyb_x.xsize(), kBlockDim) == out->xsize()); + JXL_ASSERT(DivCeil(xyb_x.ysize(), kBlockDim) == out->ysize()); + + float base_level = 0.5f * scale; + float kDampenRampStart = 7.0f; + float kDampenRampEnd = 14.0f; + float dampen = 1.0f; + if (butteraugli_target >= kDampenRampStart) { + dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / + (kDampenRampEnd - kDampenRampStart)); + if (dampen < 0) { + dampen = 0; + } + } + const float mul = scale * dampen; + const float add = (1.0f - dampen) * base_level; + for (size_t iy = rect.y0(); iy < rect.y0() + rect.ysize(); iy++) { + const size_t y = iy * 8; + float* const JXL_RESTRICT row_out = out->Row(iy); + const HWY_CAPPED(float, kBlockDim) df; + for (size_t ix = rect.x0(); ix < rect.x0() + rect.xsize(); ix++) { + size_t x = ix * 8; + auto out_val = Set(df, row_out[ix]); + out_val = ComputeMask(df, out_val); + out_val = HfModulation(df, x, y, xyb_y, out_val); + out_val = GammaModulation(df, x, y, xyb_x, xyb_y, out_val); + // We want multiplicative quantization field, so everything + // until this point has been modulating the exponent. + row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; + } + } +} + +template +V MaskingSqrt(const D d, V v) { + static const float kLogOffset = 26.481471032459346f; + static const float kMul = 211.50759899638012f; + const auto mul_v = Set(d, kMul * 1e8); + const auto offset_v = Set(d, kLogOffset); + return Set(d, 0.25f) * Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)); +} + +float MaskingSqrt(const float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane(MaskingSqrt(DScalar(), vscalar)); +} + +void StoreMin3(const float v, float& min0, float& min1, float& min2) { + if (v < min2) { + if (v < min0) { + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas are generally smooth, don't do masking. +// Output is downsampled 2x. +void FuzzyErosion(const Rect& from_rect, const ImageF& from, + const Rect& to_rect, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + constexpr int kStep = 1; + static_assert(kStep == 1, "Step must be 1"); + JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize()); + JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize()); + for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { + size_t y = fy + from_rect.y0(); + size_t ym1 = y >= kStep ? y - kStep : y; + size_t yp1 = y + kStep < ysize ? y + kStep : y; + const float* rowt = from.Row(ym1); + const float* row = from.Row(y); + const float* rowb = from.Row(yp1); + float* row_out = to_rect.Row(to, fy / 2); + for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { + size_t x = fx + from_rect.x0(); + size_t xm1 = x >= kStep ? x - kStep : x; + size_t xp1 = x + kStep < xsize ? x + kStep : x; + float min0 = row[x]; + float min1 = min0; + float min2 = min1; + StoreMin3(row[xm1], min0, min1, min2); + StoreMin3(row[xp1], min0, min1, min2); + StoreMin3(rowt[xm1], min0, min1, min2); + StoreMin3(rowt[x], min0, min1, min2); + StoreMin3(rowt[xp1], min0, min1, min2); + StoreMin3(rowb[xm1], min0, min1, min2); + StoreMin3(rowb[x], min0, min1, min2); + StoreMin3(rowb[xp1], min0, min1, min2); + static const float kMulC = 0.029598804634393225 * 0.25f; + static const float kMul0 = 0.561331076516815 * 0.25f; + static const float kMul1 = 0.16504828561110252 * 0.25f; + static const float kMul2 = 0.2440218332376892 * 0.25f; + float v = kMulC * row[x] + kMul0 * min0 + kMul1 * min1 + kMul2 * min2; + if (fx % 2 == 0 && fy % 2 == 0) { + row_out[fx / 2] = v; + } else { + row_out[fx / 2] += v; + } + } + } +} + +struct AdaptiveQuantizationImpl { + void Init(const Image3F& xyb) { + JXL_DASSERT(xyb.xsize() % kBlockDim == 0); + JXL_DASSERT(xyb.ysize() % kBlockDim == 0); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + aq_map = ImageF(xsize / kBlockDim, ysize / kBlockDim); + } + void PrepareBuffers(size_t num_threads) { + diff_buffer = ImageF(kEncTileDim + 8, num_threads); + for (size_t i = pre_erosion.size(); i < num_threads; i++) { + pre_erosion.emplace_back(kEncTileDimInBlocks * 2 + 2, + kEncTileDimInBlocks * 2 + 2); + } + } + + void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, + const Rect& rect, const int thread, ImageF* mask) { + PROFILER_ZONE("aq DiffPrecompute"); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + + // The XYB gamma is 3.0 to be able to decode faster with two muls. + // Butteraugli's gamma is matching the gamma of human eye, around 2.6. + // We approximate the gamma difference by adding one cubic root into + // the adaptive quantization. This gives us a total gamma of 2.6666 + // for quantization uses. + const float match_gamma_offset = 0.019; + + const HWY_FULL(float) df; + const float kXMul = 30.49302140275616f; + const auto kXMulv = Set(df, kXMul); + + size_t y_start = rect.y0() * 8; + size_t y_end = y_start + rect.ysize() * 8; + + size_t x0 = rect.x0() * 8; + size_t x1 = x0 + rect.xsize() * 8; + if (x0 != 0) x0 -= 4; + if (x1 != xyb.xsize()) x1 += 4; + if (y_start != 0) y_start -= 4; + if (y_end != xyb.ysize()) y_end += 4; + pre_erosion[thread].ShrinkTo((x1 - x0) / 4, (y_end - y_start) / 4); + + // Computes image (padded to multiple of 8x8) of local pixel differences. + // Subsample both directions by 4. + for (size_t y = y_start; y < y_end; ++y) { + size_t y2 = y + 1 < ysize ? y + 1 : y; + size_t y1 = y > 0 ? y - 1 : y; + + const float* row_in = xyb.PlaneRow(1, y); + const float* row_in1 = xyb.PlaneRow(1, y1); + const float* row_in2 = xyb.PlaneRow(1, y2); + const float* row_x_in = xyb.PlaneRow(0, y); + const float* row_x_in1 = xyb.PlaneRow(0, y1); + const float* row_x_in2 = xyb.PlaneRow(0, y2); + float* JXL_RESTRICT row_out = diff_buffer.Row(thread); + + auto scalar_pixel = [&](size_t x) { + const size_t x2 = x + 1 < xsize ? x + 1 : x; + const size_t x1 = x > 0 ? x - 1 : x; + const float base = + 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); + const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( + row_in[x] + match_gamma_offset); + float diff = gammac * (row_in[x] - base); + diff *= diff; + const float base_x = + 0.25f * (row_x_in2[x] + row_x_in1[x] + row_x_in[x1] + row_x_in[x2]); + float diff_x = gammac * (row_x_in[x] - base_x); + diff_x *= diff_x; + diff += kXMul * diff_x; + diff = MaskingSqrt(diff); + if ((y % 4) != 0) { + row_out[x - x0] += diff; + } else { + row_out[x - x0] = diff; + } + }; + + size_t x = x0; + // First pixel of the row. + if (x0 == 0) { + scalar_pixel(x0); + ++x; + } + // SIMD + const auto match_gamma_offset_v = Set(df, match_gamma_offset); + const auto quarter = Set(df, 0.25f); + for (; x + 1 + Lanes(df) < x1; x += Lanes(df)) { + const auto in = LoadU(df, row_in + x); + const auto in_r = LoadU(df, row_in + x + 1); + const auto in_l = LoadU(df, row_in + x - 1); + const auto in_t = LoadU(df, row_in2 + x); + const auto in_b = LoadU(df, row_in1 + x); + auto base = quarter * (in_r + in_l + in_t + in_b); + auto gammacv = + RatioOfDerivativesOfCubicRootToSimpleGamma( + df, in + match_gamma_offset_v); + auto diff = gammacv * (in - base); + diff *= diff; + + const auto in_x = LoadU(df, row_x_in + x); + const auto in_x_r = LoadU(df, row_x_in + x + 1); + const auto in_x_l = LoadU(df, row_x_in + x - 1); + const auto in_x_t = LoadU(df, row_x_in2 + x); + const auto in_x_b = LoadU(df, row_x_in1 + x); + auto base_x = quarter * (in_x_r + in_x_l + in_x_t + in_x_b); + auto diff_x = gammacv * (in_x - base_x); + diff_x *= diff_x; + diff += kXMulv * diff_x; + diff = MaskingSqrt(df, diff); + if ((y & 3) != 0) { + diff += LoadU(df, row_out + x - x0); + } + StoreU(diff, df, row_out + x - x0); + } + // Scalar + for (; x < x1; ++x) { + scalar_pixel(x); + } + if (y % 4 == 3) { + float* row_dout = pre_erosion[thread].Row((y - y_start) / 4); + for (size_t x = 0; x < (x1 - x0) / 4; x++) { + row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] + + row_out[x * 4 + 2] + row_out[x * 4 + 3]) * + 0.25f; + } + } + } + Rect from_rect(x0 % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, + rect.xsize() * 2, rect.ysize() * 2); + FuzzyErosion(from_rect, pre_erosion[thread], rect, &aq_map); + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* aq_map_row = rect.ConstRow(aq_map, y); + float* mask_row = rect.Row(mask, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); + } + } + PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), scale, + rect, &aq_map); + } + std::vector pre_erosion; + ImageF aq_map; + ImageF diff_buffer; +}; + +ImageF AdaptiveQuantizationMap(const float butteraugli_target, + const Image3F& xyb, + const FrameDimensions& frame_dim, float scale, + ThreadPool* pool, ImageF* mask) { + PROFILER_ZONE("aq AdaptiveQuantMap"); + + AdaptiveQuantizationImpl impl; + impl.Init(xyb); + *mask = ImageF(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + RunOnPool( + pool, 0, + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks), + [&](size_t num_threads) { + impl.PrepareBuffers(num_threads); + return true; + }, + [&](const int tid, int thread) { + size_t n_enc_tiles = + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = + std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = + std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + impl.ComputeTile(butteraugli_target, scale, xyb, r, thread, mask); + }, + "AQ DiffPrecompute"); + + return std::move(impl).aq_map; +} + +} // namespace + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(AdaptiveQuantizationMap); + +namespace { +bool FLAGS_log_search_state = false; +// If true, prints the quantization maps at each iteration. +bool FLAGS_dump_quant_state = false; + +bool AdjustQuantVal(float* const JXL_RESTRICT q, const float d, + const float factor, const float quant_max) { + if (*q >= 0.999f * quant_max) return false; + const float inv_q = 1.0f / *q; + const float adj_inv_q = inv_q - factor / (d + 1.0f); + *q = 1.0f / std::max(1.0f / quant_max, adj_inv_q); + return true; +} + +void DumpHeatmap(const AuxOut* aux_out, const std::string& label, + const ImageF& image, float good_threshold, + float bad_threshold) { + Image3F heatmap = CreateHeatMapImage(image, good_threshold, bad_threshold); + char filename[200]; + snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), + aux_out->num_butteraugli_iters); + aux_out->DumpImage(filename, heatmap); +} + +void DumpHeatmaps(const AuxOut* aux_out, float ba_target, + const ImageF& quant_field, const ImageF& tile_heatmap, + const ImageF& bt_diffmap) { + if (!WantDebugOutput(aux_out)) return; + ImageF inv_qmap(quant_field.xsize(), quant_field.ysize()); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); + float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + row_inv_q[x] = 1.0f / row_q[x]; // never zero + } + } + DumpHeatmap(aux_out, "quant_heatmap", inv_qmap, 4.0f * ba_target, + 6.0f * ba_target); + DumpHeatmap(aux_out, "tile_heatmap", tile_heatmap, ba_target, + 1.5f * ba_target); + // matches heat maps produced by the command line tool. + DumpHeatmap(aux_out, "bt_diffmap", bt_diffmap, ButteraugliFuzzyInverse(1.5), + ButteraugliFuzzyInverse(0.5)); +} + +ImageF TileDistMap(const ImageF& distmap, int tile_size, int margin, + const AcStrategyImage& ac_strategy) { + PROFILER_FUNC; + const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; + const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; + ImageF tile_distmap(tile_xsize, tile_ysize); + size_t distmap_stride = tile_distmap.PixelsPerRow(); + for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); + float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); + for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { + AcStrategy acs = ac_strategy_row[tile_x]; + if (!acs.IsFirstBlock()) continue; + int this_tile_xsize = acs.covered_blocks_x() * tile_size; + int this_tile_ysize = acs.covered_blocks_y() * tile_size; + int y_begin = std::max(0, tile_size * tile_y - margin); + int y_end = std::min(distmap.ysize(), + tile_size * tile_y + this_tile_ysize + margin); + int x_begin = std::max(0, tile_size * tile_x - margin); + int x_end = std::min(distmap.xsize(), + tile_size * tile_x + this_tile_xsize + margin); + float dist_norm = 0.0; + double pixels = 0; + for (int y = y_begin; y < y_end; ++y) { + float ymul = 1.0; + constexpr float kBorderMul = 0.98f; + constexpr float kCornerMul = 0.7f; + if (margin != 0 && (y == y_begin || y == y_end - 1)) { + ymul = kBorderMul; + } + const float* const JXL_RESTRICT row = distmap.Row(y); + for (int x = x_begin; x < x_end; ++x) { + float xmul = ymul; + if (margin != 0 && (x == x_begin || x == x_end - 1)) { + if (xmul == 1.0) { + xmul = kBorderMul; + } else { + xmul = kCornerMul; + } + } + float v = row[x]; + v *= v; + v *= v; + v *= v; + v *= v; + dist_norm += xmul * v; + pixels += xmul; + } + } + if (pixels == 0) pixels = 1; + // 16th norm is less than the max norm, we reduce the difference + // with this normalization factor. + constexpr float kTileNorm = 1.2f; + const float tile_dist = + kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); + dist_row[tile_x] = tile_dist; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; + } + } + } + } + return tile_distmap; +} + +ImageF DistToPeakMap(const ImageF& field, float peak_min, int local_radius, + float peak_weight) { + ImageF result(field.xsize(), field.ysize()); + FillImage(-1.0f, &result); + for (size_t y0 = 0; y0 < field.ysize(); ++y0) { + for (size_t x0 = 0; x0 < field.xsize(); ++x0) { + int x_min = std::max(0, static_cast(x0) - local_radius); + int y_min = std::max(0, static_cast(y0) - local_radius); + int x_max = std::min(field.xsize(), x0 + 1 + local_radius); + int y_max = std::min(field.ysize(), y0 + 1 + local_radius); + float local_max = peak_min; + for (int y = y_min; y < y_max; ++y) { + for (int x = x_min; x < x_max; ++x) { + local_max = std::max(local_max, field.Row(y)[x]); + } + } + if (field.Row(y0)[x0] > + (1.0f - peak_weight) * peak_min + peak_weight * local_max) { + for (int y = y_min; y < y_max; ++y) { + for (int x = x_min; x < x_max; ++x) { + float dist = std::max(std::abs(y - static_cast(y0)), + std::abs(x - static_cast(x0))); + float cur_dist = result.Row(y)[x]; + if (cur_dist < 0.0 || cur_dist > dist) { + result.Row(y)[x] = dist; + } + } + } + } + } + } + return result; +} + +constexpr float kDcQuantPow = 0.57f; +static const float kDcQuant = 1.12f; +static const float kAcQuant = 0.79f; + +void FindBestQuantization(const ImageBundle& linear, const Image3F& opsin, + PassesEncoderState* enc_state, ThreadPool* pool, + AuxOut* aux_out) { + const CompressParams& cparams = enc_state->cparams; + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + const float butteraugli_target = cparams.butteraugli_distance; + ButteraugliParams params = cparams.ba_params; + params.intensity_target = linear.metadata()->IntensityTarget(); + // Hack the default intensity target value to be 80.0, the intensity + // target of sRGB images and a more reasonable viewing default than + // JPEG XL file format's default. + if (fabs(params.intensity_target - 255.0f) < 1e-3) { + params.intensity_target = 80.0f; + } + JxlButteraugliComparator comparator(params); + JXL_CHECK(comparator.SetReferenceImage(linear)); + bool lower_is_better = + (comparator.GoodQualityScore() < comparator.BadQualityScore()); + const float initial_quant_dc = InitialQuantDC(butteraugli_target); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + ImageF tile_distmap; + ImageF initial_quant_field = CopyImage(quant_field); + + float initial_qf_min, initial_qf_max; + ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); + float initial_qf_ratio = initial_qf_max / initial_qf_min; + float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); + float asymmetry = 2; + if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; + float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); + float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); + + JXL_ASSERT(qf_higher / qf_lower < 253); + + constexpr int kOriginalComparisonRound = 1; + int iters = cparams.max_butteraugli_iters; + if (iters > 7) { + iters = 7; + } + if (cparams.speed_tier != SpeedTier::kTortoise) { + iters = 2; + } + for (int i = 0; i < iters + 1; ++i) { + if (FLAGS_dump_quant_state) { + printf("\nQuantization field:\n"); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + for (size_t x = 0; x < quant_field.xsize(); ++x) { + printf(" %.5f", quant_field.Row(y)[x]); + } + printf("\n"); + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + ImageBundle linear = RoundtripImage(opsin, enc_state, pool); + PROFILER_ZONE("enc Butteraugli"); + float score; + ImageF diffmap; + JXL_CHECK(comparator.CompareWith(linear, &diffmap, &score)); + if (!lower_is_better) { + score = -score; + diffmap = ScaleImage(-1.0f, diffmap); + } + tile_distmap = TileDistMap(diffmap, 8, 0, enc_state->shared.ac_strategy); + if (WantDebugOutput(aux_out)) { + aux_out->DumpImage(("dec" + ToString(i)).c_str(), *linear.color()); + DumpHeatmaps(aux_out, butteraugli_target, quant_field, tile_distmap, + diffmap); + } + if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; + if (FLAGS_log_search_state) { + float minval, maxval; + ImageMinMax(quant_field, &minval, &maxval); + printf("\nButteraugli iter: %d/%d\n", i, cparams.max_butteraugli_iters); + printf("Butteraugli distance: %f\n", score); + printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, + initial_quant_dc); + if (FLAGS_dump_quant_state) { + quantizer.DumpQuantizationMap(raw_quant_field); + } + } + + if (i == iters) break; + + double kPow[8] = { + 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + double kPowMod[8] = { + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + if (i == kOriginalComparisonRound) { + // Don't allow optimization to make the quant field a lot worse than + // what the initial guess was. This allows the AC field to have enough + // precision to reduce the oscillations due to the dc reconstruction. + double kInitMul = 0.6; + const double kOneMinusInitMul = 1.0 - kInitMul; + for (size_t y = 0; y < quant_field.ysize(); ++y) { + float* const JXL_RESTRICT row_q = quant_field.Row(y); + const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; + if (row_q[x] < clamp) { + row_q[x] = clamp; + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + + double cur_pow = 0.0; + if (i < 7) { + cur_pow = kPow[i] + (butteraugli_target - 1.0) * kPowMod[i]; + if (cur_pow < 0) { + cur_pow = 0; + } + } + if (cur_pow == 0.0) { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / butteraugli_target; + if (diff > 1.0f) { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } else { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / butteraugli_target; + if (diff <= 1.0f) { + row_q[x] *= std::pow(diff, cur_pow); + } else { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +void FindBestQuantizationMaxError(const Image3F& opsin, + PassesEncoderState* enc_state, + ThreadPool* pool, AuxOut* aux_out) { + // TODO(veluca): this only works if opsin is in XYB. The current encoder does + // not have code paths that produce non-XYB opsin here. + JXL_CHECK(enc_state->shared.frame_header.color_transform == + ColorTransform::kXYB); + const CompressParams& cparams = enc_state->cparams; + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + // TODO(veluca): better choice of this value. + const float initial_quant_dc = + 16 * std::sqrt(0.1f / cparams.butteraugli_distance); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + + const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], + 1.0f / enc_state->cparams.max_error[1], + 1.0f / enc_state->cparams.max_error[2]}; + + for (int i = 0; i < cparams.max_butteraugli_iters + 1; ++i) { + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + if (aux_out) { + aux_out->DumpXybImage(("ops" + ToString(i)).c_str(), opsin); + } + ImageBundle decoded = RoundtripImage(opsin, enc_state, pool); + if (aux_out) { + aux_out->DumpXybImage(("dec" + ToString(i)).c_str(), *decoded.color()); + } + + for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { + AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + float max_error = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = by * kBlockDim; + y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { + if (y >= decoded.ysize()) continue; + const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); + const float* JXL_RESTRICT dec_row = + decoded.color()->ConstPlaneRow(c, y); + for (size_t x = bx * kBlockDim; + x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { + if (x >= decoded.xsize()) continue; + max_error = std::max( + std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); + } + } + } + // Target an error between max_error/2 and max_error. + // If the error in the varblock is above the target, increase the qf to + // compensate. If the error is below the target, decrease the qf. + // However, to avoid an excessive increase of the qf, only do so if the + // error is less than half the maximum allowed error. + const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f + : (max_error > 1.0f) ? max_error + : 1.0f; + for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { + float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); + for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { + quant_field_row[qx] *= qf_mul; + } + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +} // namespace + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + ImageF* quant_field) { + // Replace the whole quant_field in non-8x8 blocks with the maximum of each + // 8x8 block. + size_t stride = quant_field->PixelsPerRow(); + for (size_t y = 0; y < rect.ysize(); ++y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); + float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + AcStrategy acs = ac_strategy_row[x]; + if (!acs.IsFirstBlock()) continue; + JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize()); + JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize()); + float max = quant_row[x]; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + max = std::max(quant_row[x + ix + iy * stride], max); + } + } + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + quant_row[x + ix + iy * stride] = max; + } + } + } + } +} + +float InitialQuantDC(float butteraugli_target) { + const float kDcMul = 2.9; // Butteraugli target where non-linearity kicks in. + const float butteraugli_target_dc = std::max( + 0.5f * butteraugli_target, + std::min(butteraugli_target, + kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, + kDcQuantPow))); + // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. + // The maximum DC value might not be in the kXybRange because of inverse + // gaborish, so we add some slack to the maximum theoretical quant obtained + // this way (64). + return std::min(kDcQuant / butteraugli_target_dc, 50.f); +} + +ImageF InitialQuantField(const float butteraugli_target, const Image3F& opsin, + const FrameDimensions& frame_dim, ThreadPool* pool, + float rescale, ImageF* mask) { + PROFILER_FUNC; + const float quant_ac = kAcQuant / butteraugli_target; + return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( + butteraugli_target, opsin, frame_dim, quant_ac * rescale, pool, mask); +} + +void FindBestQuantizer(const ImageBundle* linear, const Image3F& opsin, + PassesEncoderState* enc_state, ThreadPool* pool, + AuxOut* aux_out, double rescale) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.max_error_mode) { + PROFILER_ZONE("enc find best maxerr"); + FindBestQuantizationMaxError(opsin, enc_state, pool, aux_out); + } else if (cparams.speed_tier <= SpeedTier::kKitten) { + // Normal encoding to a butteraugli score. + PROFILER_ZONE("enc find best2"); + FindBestQuantization(*linear, opsin, enc_state, pool, aux_out); + } +} + +ImageBundle RoundtripImage(const Image3F& opsin, PassesEncoderState* enc_state, + ThreadPool* pool) { + PROFILER_ZONE("enc roundtrip"); + std::unique_ptr dec_state = + jxl::make_unique(); + JXL_CHECK(dec_state->output_encoding_info.Set(enc_state->shared.metadata->m)); + dec_state->shared = &enc_state->shared; + JXL_ASSERT(opsin.ysize() % kBlockDim == 0); + + const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); + const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); + const size_t num_groups = xsize_groups * ysize_groups; + + size_t num_special_frames = enc_state->special_frames.size(); + + std::unique_ptr modular_frame_encoder = + jxl::make_unique(enc_state->shared.frame_header, + enc_state->cparams); + InitializePassesEncoder(opsin, pool, enc_state, modular_frame_encoder.get(), + nullptr); + dec_state->Init(); + dec_state->InitForAC(pool); + + ImageBundle decoded(&enc_state->shared.metadata->m); + decoded.origin = enc_state->shared.frame_header.frame_origin; + decoded.SetFromImage(Image3F(opsin.xsize(), opsin.ysize()), + dec_state->output_encoding_info.color_encoding); + + // Same as dec_state->shared->frame_header.nonserialized_metadata->m + const ImageMetadata& metadata = *decoded.metadata(); + if (!metadata.extra_channel_info.empty()) { + // Add dummy extra channels to the dec_state: FinalizeFrameDecoding moves + // these extra channels to the ImageBundle, and is required that the amount + // of extra chanels matches its metadata()->extra_channel_info.size(). + // Normally we'd place these extra channels in the ImageBundle, but in this + // case FinalizeFrameDecoding is the one that does this. + std::vector extra_channels; + extra_channels.reserve(metadata.extra_channel_info.size()); + for (size_t i = 0; i < metadata.extra_channel_info.size(); i++) { + const auto& eci = metadata.extra_channel_info[i]; + extra_channels.emplace_back(eci.Size(decoded.xsize()), + eci.Size(decoded.ysize())); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + ZeroFillImage(&extra_channels.back()); + } + dec_state->extra_channels = std::move(extra_channels); + } + + hwy::AlignedUniquePtr group_dec_caches; + const auto allocate_storage = [&](size_t num_threads) { + dec_state->EnsureStorage(num_threads); + group_dec_caches = hwy::MakeUniqueAlignedArray(num_threads); + return true; + }; + const auto process_group = [&](const int group_index, const int thread) { + if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(dec_state->shared->BlockGroupRect(group_index), + dec_state.get()); + } + JXL_CHECK(DecodeGroupForRoundtrip( + enc_state->coeffs, group_index, dec_state.get(), + &group_dec_caches[thread], thread, &decoded, nullptr)); + }; + RunOnPool(pool, 0, num_groups, allocate_storage, process_group, "AQ loop"); + + // Fine to do a JXL_ASSERT instead of error handling, since this only happens + // on the encoder side where we can't be fed with invalid data. + JXL_CHECK(FinalizeFrameDecoding(&decoded, dec_state.get(), pool, + /*force_fir=*/false, /*skip_blending=*/true)); + // Ensure we don't create any new special frames. + enc_state->special_frames.resize(num_special_frames); + + return decoded; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h new file mode 100644 index 000000000000..f04936f70b7d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ +#define LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +// Heuristics to find a good quantizer for a given image. InitialQuantField +// produces a quantization field (i.e. relative quantization amounts for each +// block) out of an opsin-space image. `InitialQuantField` uses heuristics, +// `FindBestQuantizer` (in non-fast mode) will run multiple encoding-decoding +// steps and try to improve the given quant field. + +namespace jxl { + +// Computes the decoded image for a given set of compression parameters. Mainly +// used in the FindBestQuantization loops and in some tests. +// TODO(veluca): this doesn't seem the best possible file for this function. +ImageBundle RoundtripImage(const Image3F& opsin, PassesEncoderState* enc_state, + ThreadPool* pool); + +// Returns an image subsampled by kBlockDim in each direction. If the value +// at pixel (x,y) in the returned image is greater than 1.0, it means that +// more fine-grained quantization should be used in the corresponding block +// of the input image, while a value less than 1.0 indicates that less +// fine-grained quantization should be enough. Returns a mask, too, which +// can later be used to make better decisions about ac strategy. +ImageF InitialQuantField(float butteraugli_target, const Image3F& opsin, + const FrameDimensions& frame_dim, ThreadPool* pool, + float rescale, ImageF* initial_quant_mask); + +float InitialQuantDC(float butteraugli_target); + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + ImageF* quant_field); + +// Returns a quantizer that uses an adjusted version of the provided +// quant_field. Also computes the dequant_map corresponding to the given +// dequant_float_map and chosen quantization levels. +// `linear` is only used in Kitten mode or slower. +void FindBestQuantizer(const ImageBundle* linear, const Image3F& opsin, + PassesEncoderState* enc_state, ThreadPool* pool, + AuxOut* aux_out, double rescale = 1.0); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.cc b/third_party/jpeg-xl/lib/jxl/enc_ans.cc new file mode 100644 index 000000000000..746ece90d809 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans.cc @@ -0,0 +1,1635 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_ans.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_huffman.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +namespace { + +bool ans_fuzzer_friendly_ = false; + +static const int kMaxNumSymbolsForSmallCode = 4; + +void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, + size_t alphabet_size, size_t log_alpha_size, + ANSEncSymbolInfo* info) { + size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size; + size_t entry_size_minus_1 = (1 << log_entry_size) - 1; + // create valid alias table for empty streams. + for (size_t s = 0; s < std::max(1, alphabet_size); ++s) { + const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s]; + info[s].freq_ = static_cast(freq); +#ifdef USE_MULT_BY_RECIPROCAL + if (freq != 0) { + info[s].ifreq_ = + ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_; + } else { + info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but... + } +#endif + info[s].reverse_map_.resize(freq); + } + for (int i = 0; i < ANS_TAB_SIZE; i++) { + AliasTable::Symbol s = + AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1); + info[s.value].reverse_map_[s.offset] = i; + } +} + +float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts, + size_t len) { + float sum = 0.0f; + int total_histogram = 0; + int total_counts = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + total_counts += counts[i]; + if (histogram[i] > 0) { + JXL_ASSERT(counts[i] > 0); + // += histogram[i] * -log(counts[i]/total_counts) + sum += histogram[i] * + std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i])); + } + } + if (total_histogram > 0) { + JXL_ASSERT(total_counts == ANS_TAB_SIZE); + } + return sum; +} + +float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { + const float flat_bits = std::max(FastLog2f(len), 0.0f); + int total_histogram = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + } + return total_histogram * flat_bits; +} + +// Static Huffman code for encoding logcounts. The last symbol is used as RLE +// sequence. +static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { + 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, +}; +static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { + 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, +}; + +// Returns the difference between largest count that can be represented and is +// smaller than "count" and smallest representable count larger than "count". +static int SmallestIncrement(uint32_t count, uint32_t shift) { + int bits = count == 0 ? -1 : FloorLog2Nonzero(count); + int drop_bits = bits - GetPopulationCountPrecision(bits, shift); + return drop_bits < 0 ? 1 : (1 << drop_bits); +} + +template +bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, + uint32_t shift, int* omit_pos, ANSHistBin* counts) { + int sum = 0; + float sum_nonrounded = 0.0; + int remainder_pos = 0; // if all of them are handled in first loop + int remainder_log = -1; + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] > 0 && targets[n] < 1.0f) { + counts[n] = 1; + sum_nonrounded += targets[n]; + sum += counts[n]; + } + } + const float discount_ratio = + (table_size - sum) / (table_size - sum_nonrounded); + JXL_ASSERT(discount_ratio > 0); + JXL_ASSERT(discount_ratio <= 1.0f); + // Invariant for minimize_error_of_sum == true: + // abs(sum - sum_nonrounded) + // <= SmallestIncrement(max(targets[])) + max_symbol + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] >= 1.0f) { + sum_nonrounded += targets[n]; + counts[n] = + static_cast(targets[n] * discount_ratio); // truncate + if (counts[n] == 0) counts[n] = 1; + if (counts[n] == table_size) counts[n] = table_size - 1; + // Round the count to the closest nonzero multiple of SmallestIncrement + // (when minimize_error_of_sum is false) or one of two closest so as to + // keep the sum as close as possible to sum_nonrounded. + int inc = SmallestIncrement(counts[n], shift); + counts[n] -= counts[n] & (inc - 1); + // TODO(robryk): Should we rescale targets[n]? + const float target = + minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n]; + if (counts[n] == 0 || + (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) { + counts[n] += inc; + } + sum += counts[n]; + const int count_log = FloorLog2Nonzero(static_cast(counts[n])); + if (count_log > remainder_log) { + remainder_pos = n; + remainder_log = count_log; + } + } + } + JXL_ASSERT(remainder_pos != -1); + // NOTE: This is the only place where counts could go negative. We could + // detect that, return false and make ANSHistBin uint32_t. + counts[remainder_pos] -= sum - table_size; + *omit_pos = remainder_pos; + return counts[remainder_pos] > 0; +} + +Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, + const int precision_bits, uint32_t shift, + int* num_symbols, int* symbols) { + const int32_t table_size = 1 << precision_bits; // target sum / table size + uint64_t total = 0; + int max_symbol = 0; + int symbol_count = 0; + for (int n = 0; n < length; ++n) { + total += counts[n]; + if (counts[n] > 0) { + if (symbol_count < kMaxNumSymbolsForSmallCode) { + symbols[symbol_count] = n; + } + ++symbol_count; + max_symbol = n + 1; + } + } + *num_symbols = symbol_count; + if (symbol_count == 0) { + return true; + } + if (symbol_count == 1) { + counts[symbols[0]] = table_size; + return true; + } + if (symbol_count > table_size) + return JXL_FAILURE("Too many entries in an ANS histogram"); + + const float norm = 1.f * table_size / total; + std::vector targets(max_symbol); + for (size_t n = 0; n < targets.size(); ++n) { + targets[n] = norm * counts[n]; + } + if (!RebalanceHistogram(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + // Use an alternative rebalancing mechanism if the one above failed + // to create a histogram that is positive wherever the original one was. + if (!RebalanceHistogram(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); + } + } + return true; +} + +struct SizeWriter { + size_t size = 0; + void Write(size_t num, size_t bits) { size += num; } +}; + +template +void StoreVarLenUint8(size_t n, Writer* writer) { + JXL_DASSERT(n <= 255); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(3, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template +void StoreVarLenUint16(size_t n, Writer* writer) { + JXL_DASSERT(n <= 65535); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(4, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template +bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size, + const int omit_pos, const int num_symbols, uint32_t shift, + const int* symbols, Writer* writer) { + bool ok = true; + if (num_symbols <= 2) { + // Small tree marker to encode 1-2 symbols. + writer->Write(1, 1); + if (num_symbols == 0) { + writer->Write(1, 0); + StoreVarLenUint8(0, writer); + } else { + writer->Write(1, num_symbols - 1); + for (int i = 0; i < num_symbols; ++i) { + StoreVarLenUint8(symbols[i], writer); + } + } + if (num_symbols == 2) { + writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]); + } + } else { + // Mark non-small tree. + writer->Write(1, 0); + // Mark non-flat histogram. + writer->Write(1, 0); + + // Precompute sequences for RLE encoding. Contains the number of identical + // values starting at a given index. Only contains the value at the first + // element of the series. + std::vector same(alphabet_size, 0); + int last = 0; + for (int i = 1; i < alphabet_size; i++) { + // Store the sequence length once different symbol reached, or we're at + // the end, or the length is longer than we can encode, or we are at + // the omit_pos. We don't support including the omit_pos in an RLE + // sequence because this value may use a different amount of log2 bits + // than standard, it is too complex to handle in the decoder. + if (counts[i] != counts[last] || i + 1 == alphabet_size || + (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) { + same[last] = (i - last); + last = i + 1; + } + } + + int length = 0; + std::vector logcounts(alphabet_size); + int omit_log = 0; + for (int i = 0; i < alphabet_size; ++i) { + JXL_ASSERT(counts[i] <= ANS_TAB_SIZE); + JXL_ASSERT(counts[i] >= 0); + if (i == omit_pos) { + length = i + 1; + } else if (counts[i] > 0) { + logcounts[i] = FloorLog2Nonzero(static_cast(counts[i])) + 1; + length = i + 1; + if (i < omit_pos) { + omit_log = std::max(omit_log, logcounts[i] + 1); + } else { + omit_log = std::max(omit_log, logcounts[i]); + } + } + } + logcounts[omit_pos] = omit_log; + + // Elias gamma-like code for shift. Only difference is that if the number + // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip + // the terminating 0 in unary coding. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = FloorLog2Nonzero(shift + 1); + writer->Write(log, (1 << log) - 1); + if (log != upper_bound_log) writer->Write(1, 0); + writer->Write(log, ((1 << log) - 1) & (shift + 1)); + + // Since num_symbols >= 3, we know that length >= 3, therefore we encode + // length - 3. + if (length - 3 > 255) { + // Pretend that everything is OK, but complain about correctness later. + StoreVarLenUint8(255, writer); + ok = false; + } else { + StoreVarLenUint8(length - 3, writer); + } + + // The logcount values are encoded with a static Huffman code. + static const size_t kMinReps = 4; + size_t rep = ANS_LOG_TAB_SIZE + 1; + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Encode the RLE symbol and skip the repeated ones. + writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]); + StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer); + i += same[i - 1] - 2; + continue; + } + writer->Write(kLogCountBitLengths[logcounts[i]], + kLogCountSymbols[logcounts[i]]); + } + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Skip symbols encoded by RLE. + i += same[i - 1] - 2; + continue; + } + if (logcounts[i] > 1 && i != omit_pos) { + int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift); + int drop_bits = logcounts[i] - 1 - bitcount; + JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0); + writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount)); + } + } + } + return ok; +} + +void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) { + // Mark non-small tree. + writer->Write(1, 0); + // Mark uniform histogram. + writer->Write(1, 1); + JXL_ASSERT(alphabet_size > 0); + // Encode alphabet size. + StoreVarLenUint8(alphabet_size - 1, writer); +} + +float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size, + uint32_t method) { + if (method == 0) { // Flat code + return ANS_LOG_TAB_SIZE + 2 + + EstimateDataBitsFlat(histogram, alphabet_size); + } + // Non-flat: shift = method-1. + uint32_t shift = method - 1; + std::vector counts(histogram, histogram + alphabet_size); + int omit_pos = 0; + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + SizeWriter writer; + // Ignore the correctness, no real encoding happens at this stage. + (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift, + symbols, &writer); + return writer.size + + EstimateDataBits(histogram, counts.data(), alphabet_size); +} + +uint32_t ComputeBestMethod( + const ANSHistBin* histogram, size_t alphabet_size, float* cost, + HistogramParams::ANSHistogramStrategy ans_histogram_strategy) { + size_t method = 0; + float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0); + for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; + ans_histogram_strategy != HistogramParams::ANSHistogramStrategy::kPrecise + ? shift += 2 + : shift++) { + float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1); + if (c < fcost) { + method = shift + 1; + fcost = c; + } else if (ans_histogram_strategy == + HistogramParams::ANSHistogramStrategy::kFast) { + // do not be as precise if estimating cost. + break; + } + } + *cost = fcost; + return method; +} + +} // namespace + +// Returns an estimate of the cost of encoding this histogram and the +// corresponding data. +size_t BuildAndStoreANSEncodingData( + HistogramParams::ANSHistogramStrategy ans_histogram_strategy, + const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, + bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { + if (use_prefix_code) { + if (alphabet_size <= 1) return 0; + std::vector histo(alphabet_size); + size_t total = 0; + for (size_t i = 0; i < alphabet_size; i++) { + histo[i] = histogram[i]; + JXL_CHECK(histogram[i] >= 0); + total += histo[i]; + } + size_t cost = 0; + { + std::vector depths(alphabet_size); + std::vector bits(alphabet_size); + BitWriter tmp_writer; + BitWriter* w = writer ? writer : &tmp_writer; + size_t start = w->BitsWritten(); + BitWriter::Allotment allotment( + w, 8 * alphabet_size + 8); // safe upper bound + BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), + bits.data(), w); + ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr); + + for (size_t i = 0; i < alphabet_size; i++) { + info[i].bits = depths[i] == 0 ? 0 : bits[i]; + info[i].depth = depths[i]; + } + cost = w->BitsWritten() - start; + } + // Estimate data cost. + for (size_t i = 0; i < alphabet_size; i++) { + cost += histogram[i] * info[i].depth; + } + return cost; + } + JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE); + // Ensure we ignore trailing zeros in the histogram. + if (alphabet_size != 0) { + size_t largest_symbol = 0; + for (size_t i = 0; i < alphabet_size; i++) { + if (histogram[i] != 0) largest_symbol = i; + } + alphabet_size = largest_symbol + 1; + } + float cost; + uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost, + ans_histogram_strategy); + JXL_ASSERT(cost >= 0); + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + std::vector counts(histogram, histogram + alphabet_size); + if (!counts.empty()) { + size_t sum = 0; + for (size_t i = 0; i < counts.size(); i++) { + sum += counts[i]; + } + if (sum == 0) { + counts[0] = ANS_TAB_SIZE; + } + } + if (method == 0) { + counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); + AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; + InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); + ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); + if (writer != nullptr) { + EncodeFlatHistogram(alphabet_size, writer); + } + return cost; + } + int omit_pos = 0; + uint32_t shift = method - 1; + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; + InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); + ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); + if (writer != nullptr) { + bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, + shift, symbols, writer); + (void)ok; + JXL_DASSERT(ok); + } + return cost; +} + +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) { + float c; + ComputeBestMethod(data, alphabet_size, &c, + HistogramParams::ANSHistogramStrategy::kFast); + return c; +} + +template +void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer, + size_t log_alpha_size) { + writer->Write(CeilLog2Nonzero(log_alpha_size + 1), + uint_config.split_exponent); + if (uint_config.split_exponent == log_alpha_size) { + return; // msb/lsb don't matter. + } + size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1); + writer->Write(nbits, uint_config.msb_in_token); + nbits = CeilLog2Nonzero(uint_config.split_exponent - + uint_config.msb_in_token + 1); + writer->Write(nbits, uint_config.lsb_in_token); +} +template +void EncodeUintConfigs(const std::vector& uint_config, + Writer* writer, size_t log_alpha_size) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config.size(); i++) { + EncodeUintConfig(uint_config[i], writer, log_alpha_size); + } +} +template void EncodeUintConfigs(const std::vector&, + BitWriter*, size_t); + +namespace { + +void ChooseUintConfigs(const HistogramParams& params, + const std::vector>& tokens, + const std::vector& context_map, + std::vector* clustered_histograms, + EntropyEncodingData* codes, size_t* log_alpha_size) { + codes->uint_config.resize(clustered_histograms->size()); + if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return; + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + codes->uint_config.clear(); + codes->uint_config.resize(clustered_histograms->size(), + HybridUintConfig(2, 0, 1)); + return; + } + + // Brute-force method that tries a few options. + std::vector configs; + if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 0), // less precise + HybridUintConfig(4, 2, 1), // add sign + HybridUintConfig(4, 2, 2), // add sign+parity + HybridUintConfig(4, 1, 2), // add parity but less msb + // Same as above, but more direct coding. + HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0), + HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2), + HybridUintConfig(5, 1, 2), + // Same as above, but less direct coding. + HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0), + HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2), + // For near-lossless. + HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4), + HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5), + HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0), + // Other + HybridUintConfig(0, 0, 0), // varlenuint + HybridUintConfig(2, 0, 1), // works well for ctx map + HybridUintConfig(7, 0, 0), // direct coding + HybridUintConfig(8, 0, 0), // direct coding + HybridUintConfig(9, 0, 0), // direct coding + HybridUintConfig(10, 0, 0), // direct coding + HybridUintConfig(11, 0, 0), // direct coding + HybridUintConfig(12, 0, 0), // direct coding + }; + } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 2), // add parity but less msb + HybridUintConfig(0, 0, 0), // smallest histograms + HybridUintConfig(2, 0, 1), // works well for ctx map + }; + } + + std::vector costs(clustered_histograms->size(), + std::numeric_limits::max()); + std::vector extra_bits(clustered_histograms->size()); + std::vector is_valid(clustered_histograms->size()); + size_t max_alpha = + codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE; + for (HybridUintConfig cfg : configs) { + std::fill(is_valid.begin(), is_valid.end(), true); + std::fill(extra_bits.begin(), extra_bits.end(), 0); + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + // TODO(veluca): do not ignore lz77 commands. + if (token.is_lz77_length) continue; + size_t histo = context_map[token.context]; + uint32_t tok, nbits, bits; + cfg.Encode(token.value, &tok, &nbits, &bits); + if (tok >= max_alpha || + (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { + is_valid[histo] = false; + continue; + } + extra_bits[histo] += nbits; + (*clustered_histograms)[histo].Add(tok); + } + } + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + if (!is_valid[i]) continue; + float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i]; + if (cost < costs[i]) { + codes->uint_config[i] = cfg; + costs[i] = cost; + } + } + } + + // Rebuild histograms. + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + *log_alpha_size = 4; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + size_t histo = context_map[token.context]; + (token.is_lz77_length ? codes->lz77.length_uint_config + : codes->uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + (*clustered_histograms)[histo].Add(tok); + while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++; + } + } +#if JXL_ENABLE_ASSERT + size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8; + JXL_ASSERT(*log_alpha_size <= max_log_alpha_size); +#endif +} + +class HistogramBuilder { + public: + explicit HistogramBuilder(const size_t num_contexts) + : histograms_(num_contexts) {} + + void VisitSymbol(int symbol, size_t histo_idx) { + JXL_DASSERT(histo_idx < histograms_.size()); + histograms_[histo_idx].Add(symbol); + } + + // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. + size_t BuildAndStoreEntropyCodes( + const HistogramParams& params, + const std::vector>& tokens, EntropyEncodingData* codes, + std::vector* context_map, bool use_prefix_code, + BitWriter* writer, size_t layer, AuxOut* aux_out) const { + size_t cost = 0; + codes->encoding_info.clear(); + std::vector clustered_histograms(histograms_); + context_map->resize(histograms_.size()); + if (histograms_.size() > 1) { + if (!ans_fuzzer_friendly_) { + std::vector histogram_symbols; + ClusterHistograms(params, histograms_, histograms_.size(), + kClustersLimit, &clustered_histograms, + &histogram_symbols); + for (size_t c = 0; c < histograms_.size(); ++c) { + (*context_map)[c] = static_cast(histogram_symbols[c]); + } + } else { + fill(context_map->begin(), context_map->end(), 0); + size_t max_symbol = 0; + for (const Histogram& h : histograms_) { + max_symbol = std::max(h.data_.size(), max_symbol); + } + size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1); + clustered_histograms.resize(1); + clustered_histograms[0].Clear(); + for (size_t i = 0; i < num_symbols; i++) { + clustered_histograms[0].Add(i); + } + } + if (writer != nullptr) { + EncodeContextMap(*context_map, clustered_histograms.size(), writer); + } + } + if (aux_out != nullptr) { + for (size_t i = 0; i < clustered_histograms.size(); ++i) { + aux_out->layers[layer].clustered_entropy += + clustered_histograms[i].ShannonEntropy(); + } + } + codes->use_prefix_code = use_prefix_code; + size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default. + if (ans_fuzzer_friendly_) { + codes->uint_config.clear(); + codes->uint_config.resize(1, HybridUintConfig(7, 0, 0)); + } else { + ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms, + codes, &log_alpha_size); + } + if (log_alpha_size < 5) log_alpha_size = 5; + SizeWriter size_writer; // Used if writer == nullptr to estimate costs. + cost += 1; + if (writer) writer->Write(1, use_prefix_code); + + if (use_prefix_code) { + log_alpha_size = PREFIX_MAX_BITS; + } else { + cost += 2; + } + if (writer == nullptr) { + EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size); + } else { + if (!use_prefix_code) writer->Write(2, log_alpha_size - 5); + EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); + } + if (use_prefix_code) { + for (size_t c = 0; c < clustered_histograms.size(); ++c) { + size_t num_symbol = 1; + for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) { + if (clustered_histograms[c].data_[i]) num_symbol = i + 1; + } + if (writer) { + StoreVarLenUint16(num_symbol - 1, writer); + } else { + StoreVarLenUint16(num_symbol - 1, &size_writer); + } + } + } + cost += size_writer.size; + for (size_t c = 0; c < clustered_histograms.size(); ++c) { + size_t num_symbol = 1; + for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) { + if (clustered_histograms[c].data_[i]) num_symbol = i + 1; + } + codes->encoding_info.emplace_back(); + codes->encoding_info.back().resize(std::max(1, num_symbol)); + + BitWriter::Allotment allotment(writer, 256 + num_symbol * 24); + cost += BuildAndStoreANSEncodingData( + params.ans_histogram_strategy, clustered_histograms[c].data_.data(), + num_symbol, log_alpha_size, use_prefix_code, + codes->encoding_info.back().data(), writer); + allotment.FinishedHistogram(writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + } + return cost; + } + + const Histogram& Histo(size_t i) const { return histograms_[i]; } + + private: + std::vector histograms_; +}; + +class SymbolCostEstimator { + public: + SymbolCostEstimator(size_t num_contexts, bool force_huffman, + const std::vector>& tokens, + const LZ77Params& lz77) { + HistogramBuilder builder(num_contexts); + // Build histograms for estimating lz77 savings. + HybridUintConfig uint_config; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } + max_alphabet_size_ = 0; + for (size_t i = 0; i < num_contexts; i++) { + max_alphabet_size_ = + std::max(max_alphabet_size_, builder.Histo(i).data_.size()); + } + bits_.resize(num_contexts * max_alphabet_size_); + // TODO(veluca): SIMD? + add_symbol_cost_.resize(num_contexts); + for (size_t i = 0; i < num_contexts; i++) { + float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f); + float total_cost = 0; + for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) { + size_t cnt = builder.Histo(i).data_[j]; + float cost = 0; + if (cnt != 0 && cnt != builder.Histo(i).total_count_) { + cost = -FastLog2f(cnt * inv_total); + if (force_huffman) cost = std::ceil(cost); + } else if (cnt == 0) { + cost = ANS_LOG_TAB_SIZE; // Highest possible cost. + } + bits_[i * max_alphabet_size_ + j] = cost; + total_cost += cost * builder.Histo(i).data_[j]; + } + // Penalty for adding a lz77 symbol to this contest (only used for static + // cost model). Higher penalty for contexts that have a very low + // per-symbol entropy. + add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); + } + } + float Bits(size_t ctx, size_t sym) const { + return bits_[ctx * max_alphabet_size_ + sym]; + } + float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); + tok += lz77.min_symbol; + return nbits + Bits(ctx, tok); + } + float DistCost(size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + HybridUintConfig().Encode(len, &tok, &nbits, &bits); + return nbits + Bits(lz77.nonserialized_distance_context, tok); + } + float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } + + private: + size_t max_alphabet_size_; + std::vector bits_; + std::vector add_symbol_cost_; +}; + +void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + std::vector sym_cost; + HybridUintConfig uint_config; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + out.reserve(in.size()); + for (size_t i = 0; i < in.size(); i++) { + size_t num_to_copy = 0; + size_t distance_symbol = 0; // 1 for RLE. + if (distance_multiplier != 0) { + distance_symbol = 1; // Special distance 1 if enabled. + JXL_DASSERT(kSpecialDistances[1][0] == 1); + JXL_DASSERT(kSpecialDistances[1][1] == 0); + } + if (i > 0) { + for (; i + num_to_copy < in.size(); num_to_copy++) { + if (in[i + num_to_copy].value != in[i - 1].value) { + break; + } + } + } + if (num_to_copy == 0) { + out.push_back(in[i]); + continue; + } + float cost = sym_cost[i + num_to_copy] - sym_cost[i]; + // This subtraction might overflow, but that's OK. + size_t lz77_len = num_to_copy - lz77.min_length; + float lz77_cost = num_to_copy >= lz77.min_length + ? CeilLog2Nonzero(lz77_len + 1) + 1 + : 0; + if (num_to_copy < lz77.min_length || cost <= lz77_cost) { + for (size_t j = 0; j < num_to_copy; j++) { + out.push_back(in[i + j]); + } + i += num_to_copy - 1; + continue; + } + // Output the LZ77 length + out.emplace_back(in[i].context, lz77_len); + out.back().is_lz77_length = true; + i += num_to_copy - 1; + bit_decrease += cost - lz77_cost; + // Output the LZ77 copy distance. + out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +// Hash chain for LZ77 matching +struct HashChain { + size_t size_; + std::vector data_; + + unsigned hash_num_values_ = 32768; + unsigned hash_mask_ = hash_num_values_ - 1; + unsigned hash_shift_ = 5; + + std::vector head; + std::vector chain; + std::vector val; + + // Speed up repetitions of zero + std::vector headz; + std::vector chainz; + std::vector zeros; + uint32_t numzeros = 0; + + size_t window_size_; + size_t window_mask_; + size_t min_length_; + size_t max_length_; + + // Map of special distance codes. + std::unordered_map special_dist_table_; + size_t num_special_distances_ = 0; + + uint32_t maxchainlength = 256; // window_size_ to allow all + + HashChain(const Token* data, size_t size, size_t window_size, + size_t min_length, size_t max_length, size_t distance_multiplier) + : size_(size), + window_size_(window_size), + window_mask_(window_size - 1), + min_length_(min_length), + max_length_(max_length) { + data_.resize(size); + for (size_t i = 0; i < size; i++) { + data_[i] = data[i].value; + } + + head.resize(hash_num_values_, -1); + val.resize(window_size_, -1); + chain.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chain[i] = i; // same value as index indicates uninitialized + } + + zeros.resize(window_size_); + headz.resize(window_size_ + 1, -1); + chainz.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chainz[i] = i; + } + // Translate distance to special distance code. + if (distance_multiplier) { + // Count down, so if due to small distance multiplier multiple distances + // map to the same code, the smallest code will be used in the end. + for (int i = kNumSpecialDistances - 1; i >= 0; --i) { + int xi = kSpecialDistances[i][0]; + int yi = kSpecialDistances[i][1]; + int distance = yi * distance_multiplier + xi; + // Ensure that we map distance 1 to the lowest symbols. + if (distance < 1) distance = 1; + special_dist_table_[distance] = i; + } + num_special_distances_ = kNumSpecialDistances; + } + } + + uint32_t GetHash(size_t pos) const { + uint32_t result = 0; + if (pos + 2 < size_) { + // TODO(lode): take the MSB's of the uint32_t values into account as well, + // given that the hash code itself is less than 32 bits. + result ^= (uint32_t)(data_[pos + 0] << 0u); + result ^= (uint32_t)(data_[pos + 1] << hash_shift_); + result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2)); + } else { + // No need to compute hash of last 2 bytes, the length 2 is too short. + return 0; + } + return result & hash_mask_; + } + + uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { + size_t end = pos + window_size_; + if (end > size_) end = size_; + if (prevzeros > 0) { + if (prevzeros >= window_mask_ && data_[end - 1] == 0 && + end == pos + window_size_) { + return prevzeros; + } else { + return prevzeros - 1; + } + } + uint32_t num = 0; + while (pos + num < end && data_[pos + num] == 0) num++; + return num; + } + + void Update(size_t pos) { + uint32_t hashval = GetHash(pos); + uint32_t wpos = pos & window_mask_; + + val[wpos] = (int)hashval; + if (head[hashval] != -1) chain[wpos] = head[hashval]; + head[hashval] = wpos; + + if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; + numzeros = CountZeros(pos, numzeros); + + zeros[wpos] = numzeros; + if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; + headz[numzeros] = wpos; + } + + void Update(size_t pos, size_t len) { + for (size_t i = 0; i < len; i++) { + Update(pos + i); + } + } + + template + void FindMatches(size_t pos, int max_dist, const CB& found_match) const { + uint32_t wpos = pos & window_mask_; + uint32_t hashval = GetHash(pos); + uint32_t hashpos = chain[wpos]; + + int prev_dist = 0; + int end = std::min(pos + max_length_, size_); + uint32_t chainlength = 0; + uint32_t best_len = 0; + for (;;) { + int dist = (hashpos <= wpos) ? (wpos - hashpos) + : (wpos - hashpos + window_mask_ + 1); + if (dist < prev_dist) break; + prev_dist = dist; + uint32_t len = 0; + if (dist > 0) { + int i = pos; + int j = pos - dist; + if (numzeros > 3) { + int r = std::min(numzeros - 1, zeros[hashpos]); + if (i + r >= end) r = end - i - 1; + i += r; + j += r; + } + while (i < end && data_[i] == data_[j]) { + i++; + j++; + } + len = i - pos; + // This can trigger even if the new length is slightly smaller than the + // best length, because it is possible for a slightly cheaper distance + // symbol to occur. + if (len >= min_length_ && len + 2 >= best_len) { + auto it = special_dist_table_.find(dist); + int dist_symbol = (it == special_dist_table_.end()) + ? (num_special_distances_ + dist - 1) + : it->second; + found_match(len, dist_symbol); + if (len > best_len) best_len = len; + } + } + + chainlength++; + if (chainlength >= maxchainlength) break; + + if (numzeros >= 3 && len > numzeros) { + if (hashpos == chainz[hashpos]) break; + hashpos = chainz[hashpos]; + if (zeros[hashpos] != numzeros) break; + } else { + if (hashpos == chain[hashpos]) break; + hashpos = chain[hashpos]; + if (val[hashpos] != (int)hashval) break; // outdated hash value + } + } + } + void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, + size_t* result_len) const { + *result_dist_symbol = 0; + *result_len = 1; + FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { + if (len > *result_len || + (len == *result_len && *result_dist_symbol > dist_symbol)) { + *result_len = len; + *result_dist_symbol = dist_symbol; + } + }); + } +}; + +float LenCost(size_t len) { + uint32_t nbits, bits, tok; + HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 2.797667318563126, 3.213177690381199, 2.5706009246743737, + 2.408392498667534, 2.829649191872326, 3.3923087753324577, + 4.029267451554331, 4.415576699706408, 4.509357574741465, + 9.21481543803004, 10.020590190114898, 11.858671627804766, + 12.45853300490526, 11.713105831990857, 12.561996324849314, + 13.775477692278367, 13.174027068768641, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +// TODO(veluca): this does not take into account usage or non-usage of distance +// multipliers. +float DistCost(size_t dist) { + uint32_t nbits, bits, tok; + HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 6.368282626312716, 5.680793277090298, 8.347404197105247, + 7.641619201599141, 6.914328374119438, 7.959808291537444, + 8.70023120759855, 8.71378518934703, 9.379132523982769, + 9.110472749092708, 9.159029569270908, 9.430936766731973, + 7.278284055315169, 7.8278514904267755, 10.026641158289236, + 9.976049229827066, 9.64351607048908, 9.563403863480442, + 10.171474111762747, 10.45950155077234, 9.994813912104219, + 10.322524683741156, 8.465808729388186, 8.756254166066853, + 10.160930174662234, 10.247329273413435, 10.04090403724809, + 10.129398517544082, 9.342311691539546, 9.07608009102374, + 10.104799540677513, 10.378079384990906, 10.165828974075072, + 10.337595322341553, 7.940557464567944, 10.575665823319431, + 11.023344321751955, 10.736144698831827, 11.118277044595054, + 7.468468230648442, 10.738305230932939, 10.906980780216568, + 10.163468216353817, 10.17805759656433, 11.167283670483565, + 11.147050200274544, 10.517921919244333, 10.651764778156886, + 10.17074446448919, 11.217636876224745, 11.261630721139484, + 11.403140815247259, 10.892472096873417, 11.1859607804481, + 8.017346947551262, 7.895143720278828, 11.036577113822025, + 11.170562110315794, 10.326988722591086, 10.40872184751056, + 11.213498225466386, 11.30580635516863, 10.672272515665442, + 10.768069466228063, 11.145257364153565, 11.64668307145549, + 10.593156194627339, 11.207499484844943, 10.767517766396908, + 10.826629811407042, 10.737764794499988, 10.6200448518045, + 10.191315385198092, 8.468384171390085, 11.731295299170432, + 11.824619886654398, 10.41518844301179, 10.16310536548649, + 10.539423685097576, 10.495136599328031, 10.469112847728267, + 11.72057686174922, 10.910326337834674, 11.378921834673758, + 11.847759036098536, 11.92071647623854, 10.810628276345282, + 11.008601085273893, 11.910326337834674, 11.949212023423133, + 11.298614839104337, 11.611603659010392, 10.472930394619985, + 11.835564720850282, 11.523267392285337, 12.01055816679611, + 8.413029688994023, 11.895784139536406, 11.984679534970505, + 11.220654278717394, 11.716311684833672, 10.61036646226114, + 10.89849965960364, 10.203762898863669, 10.997560826267238, + 11.484217379438984, 11.792836176993665, 12.24310468755171, + 11.464858097919262, 12.212747017409377, 11.425595666074955, + 11.572048533398757, 12.742093965163013, 11.381874288645637, + 12.191870445817015, 11.683156920035426, 11.152442115262197, + 11.90303691580457, 11.653292787169159, 11.938615382266098, + 16.970641701570223, 16.853602280380002, 17.26240782594733, + 16.644655390108507, 17.14310889757499, 16.910935455445955, + 17.505678976959697, 17.213498225466388, 2.4162310293553024, + 3.494587244462329, 3.5258600986408344, 3.4959806589517095, + 3.098390886949687, 3.343454654302911, 3.588847442290287, + 4.14614790111827, 5.152948641990529, 7.433696808092598, + 9.716311684833672, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector sym_cost; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + size_t len, dist_symbol; + + const size_t max_lazy_match_len = 256; // 0 to disable lazy matching + + // Whether the next symbol was already updated (to test lazy matching) + bool already_updated = false; + for (size_t i = 0; i < in.size(); i++) { + out.push_back(in[i]); + if (!already_updated) chain.Update(i); + already_updated = false; + chain.FindMatch(i, max_distance, &dist_symbol, &len); + if (len >= min_length) { + if (len < max_lazy_match_len && i + 1 < in.size()) { + // Try length at next symbol lazy matching + chain.Update(i + 1); + already_updated = true; + size_t len2, dist_symbol2; + chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); + if (len2 > len) { + // Use the lazy match. Add literal, and use the next length starting + // from the next byte. + ++i; + already_updated = false; + len = len2; + dist_symbol = dist_symbol2; + out.push_back(in[i]); + } + } + + float cost = sym_cost[i + len] - sym_cost[i]; + size_t lz77_len = len - lz77.min_length; + float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + + sce.AddSymbolCost(out.back().context); + + if (lz77_cost <= cost) { + out.back().value = len - min_length; + out.back().is_lz77_length = true; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + bit_decrease += cost - lz77_cost; + } else { + // LZ77 match ignored, and symbol already pushed. Push all other + // symbols and skip. + for (size_t j = 1; j < len; j++) { + out.push_back(in[i + j]); + } + } + + if (already_updated) { + chain.Update(i + 2, len - 2); + already_updated = false; + } else { + chain.Update(i + 1, len - 1); + } + i += len - 1; + } else { + // Literal, already pushed + } + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, + LZ77Params& lz77, + std::vector>& tokens_lz77) { + std::vector> tokens_for_cost_estimate; + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate); + // If greedy-LZ77 does not give better compression than no-lz77, no reason to + // run the optimal matching. + if (!lz77.enabled) return; + SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, + tokens_for_cost_estimate, lz77); + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector sym_cost; + std::vector dist_symbols; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + + struct MatchInfo { + uint32_t len; + uint32_t dist_symbol; + uint32_t ctx; + float total_cost = std::numeric_limits::max(); + }; + // Total cost to encode the first N symbols. + std::vector prefix_costs(in.size() + 1); + prefix_costs[0].total_cost = 0; + + size_t rle_length = 0; + size_t skip_lz77 = 0; + for (size_t i = 0; i < in.size(); i++) { + chain.Update(i); + float lit_cost = + prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; + if (prefix_costs[i + 1].total_cost > lit_cost) { + prefix_costs[i + 1].dist_symbol = 0; + prefix_costs[i + 1].len = 1; + prefix_costs[i + 1].ctx = in[i].context; + prefix_costs[i + 1].total_cost = lit_cost; + } + if (skip_lz77 > 0) { + skip_lz77--; + continue; + } + dist_symbols.clear(); + chain.FindMatches(i, max_distance, + [&dist_symbols](size_t len, size_t dist_symbol) { + if (dist_symbols.size() <= len) { + dist_symbols.resize(len + 1, dist_symbol); + } + if (dist_symbol < dist_symbols[len]) { + dist_symbols[len] = dist_symbol; + } + }); + if (dist_symbols.size() <= min_length) continue; + { + size_t best_cost = dist_symbols.back(); + for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { + if (dist_symbols[j] < best_cost) { + best_cost = dist_symbols[j]; + } + dist_symbols[j] = best_cost; + } + } + for (size_t j = min_length; j < dist_symbols.size(); j++) { + // Cost model that uses results from lazy LZ77. + float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + + sce.DistCost(dist_symbols[j], lz77); + float cost = prefix_costs[i].total_cost + lz77_cost; + if (prefix_costs[i + j].total_cost > cost) { + prefix_costs[i + j].len = j; + prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; + prefix_costs[i + j].ctx = in[i].context; + prefix_costs[i + j].total_cost = cost; + } + } + // We are in a RLE sequence: skip all the symbols except the first 8 and + // the last 8. This avoid quadratic costs for sequences with long runs of + // the same symbol. + if ((dist_symbols.back() == 0 && distance_multiplier == 0) || + (dist_symbols.back() == 1 && distance_multiplier != 0)) { + rle_length++; + } else { + rle_length = 0; + } + if (rle_length >= 8 && dist_symbols.size() > 9) { + skip_lz77 = dist_symbols.size() - 10; + rle_length = 0; + } + } + size_t pos = in.size(); + while (pos > 0) { + bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; + if (is_lz77_length) { + size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + } + size_t val = is_lz77_length ? prefix_costs[pos].len - min_length + : in[pos - 1].value; + out.emplace_back(prefix_costs[pos].ctx, val); + out.back().is_lz77_length = is_lz77_length; + pos -= prefix_costs[pos].len; + } + std::reverse(out.begin(), out.end()); + } +} + +void ApplyLZ77(const HistogramParams& params, size_t num_contexts, + const std::vector>& tokens, LZ77Params& lz77, + std::vector>& tokens_lz77) { + lz77.enabled = false; + if (params.force_huffman) { + lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512); + } else { + lz77.min_symbol = 224; + } + if (params.lz77_method == HistogramParams::LZ77Method::kNone) { + return; + } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) { + ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) { + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) { + ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77); + } else { + JXL_ABORT("Not implemented"); + } +} +} // namespace + +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector>& tokens, + EntropyEncodingData* codes, + std::vector* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t total_bits = 0; + codes->lz77.nonserialized_distance_context = num_contexts; + std::vector> tokens_lz77; + ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77); + if (ans_fuzzer_friendly_) { + codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0); + codes->lz77.min_symbol = 2048; + } + + const size_t max_contexts = std::min(num_contexts, kClustersLimit); + BitWriter::Allotment allotment(writer, + 128 + num_contexts * 40 + max_contexts * 96); + if (writer) { + JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out)); + } else { + size_t ebits, bits; + JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits)); + total_bits += bits; + } + if (codes->lz77.enabled) { + if (writer) { + size_t b = writer->BitsWritten(); + EncodeUintConfig(codes->lz77.length_uint_config, writer, + /*log_alpha_size=*/8); + total_bits += writer->BitsWritten() - b; + } else { + SizeWriter size_writer; + EncodeUintConfig(codes->lz77.length_uint_config, &size_writer, + /*log_alpha_size=*/8); + total_bits += size_writer.size; + } + num_contexts += 1; + tokens = std::move(tokens_lz77); + } + size_t total_tokens = 0; + // Build histograms. + HistogramBuilder builder(num_contexts); + HybridUintConfig uint_config; // Default config for clustering. + // Unless we are using the kContextMap histogram option. + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + uint_config = HybridUintConfig(2, 0, 1); + } + if (ans_fuzzer_friendly_) { + uint_config = HybridUintConfig(10, 0, 0); + } + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } + + bool use_prefix_code = + params.force_huffman || total_tokens < 100 || + params.clustering == HistogramParams::ClusteringType::kFastest || + ans_fuzzer_friendly_; + if (!use_prefix_code) { + bool all_singleton = true; + for (size_t i = 0; i < num_contexts; i++) { + if (builder.Histo(i).ShannonEntropy() >= 1e-5) { + all_singleton = false; + } + } + if (all_singleton) { + use_prefix_code = true; + } + } + + // Encode histograms. + total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes, + context_map, use_prefix_code, + writer, layer, aux_out); + allotment.FinishedHistogram(writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + + if (aux_out != nullptr) { + aux_out->layers[layer].num_clustered_histograms += + codes->encoding_info.size(); + } + return total_bits; +} + +size_t WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer) { + size_t num_extra_bits = 0; + if (codes.use_prefix_code) { + for (size_t i = 0; i < tokens.size(); i++) { + uint32_t tok, nbits, bits; + const Token& token = tokens[i]; + size_t histo = context_map[token.context]; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + // Combine two calls to the BitWriter. Equivalent to: + // writer->Write(codes.encoding_info[histo][tok].depth, + // codes.encoding_info[histo][tok].bits); + // writer->Write(nbits, bits); + uint64_t data = codes.encoding_info[histo][tok].bits; + data |= bits << codes.encoding_info[histo][tok].depth; + writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); + num_extra_bits += nbits; + } + return num_extra_bits; + } + std::vector out; + std::vector out_nbits; + out.reserve(tokens.size()); + out_nbits.reserve(tokens.size()); + uint64_t allbits = 0; + size_t numallbits = 0; + // Writes in *reversed* order. + auto addbits = [&](size_t bits, size_t nbits) { + JXL_DASSERT(bits >> nbits == 0); + if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) { + out.push_back(allbits); + out_nbits.push_back(numallbits); + numallbits = allbits = 0; + } + allbits <<= nbits; + allbits |= bits; + numallbits += nbits; + }; + const int end = tokens.size(); + ANSCoder ans; + for (int i = end - 1; i >= 0; --i) { + const Token token = tokens[i]; + const uint8_t histo = context_map[token.context]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(tokens[i].value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok]; + // Extra bits first as this is reversed. + addbits(bits, nbits); + num_extra_bits += nbits; + uint8_t ans_nbits = 0; + uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); + addbits(ans_bits, ans_nbits); + } + const uint32_t state = ans.GetState(); + writer->Write(32, state); + writer->Write(numallbits, allbits); + for (int i = out.size(); i > 0; --i) { + writer->Write(out_nbits[i - 1], out[i - 1]); + } + return num_extra_bits; +} + +void WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4); + size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + if (aux_out != nullptr) { + aux_out->layers[layer].extra_bits += num_extra_bits; + } +} + +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { +#if JXL_IS_DEBUG_BUILD // Guard against accidential / malicious changes. + ans_fuzzer_friendly_ = ans_fuzzer_friendly; +#endif +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.h b/third_party/jpeg-xl/lib/jxl/enc_ans.h new file mode 100644 index 000000000000..a120ad145b00 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans.h @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_ANS_H_ +#define LIB_JXL_ENC_ANS_H_ + +// Library to encode the ANS population counts to the bit-stream and encode +// symbols based on the respective distributions. + +#include +#include +#include +#include +#include + +#include +#include + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans_params.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +#define USE_MULT_BY_RECIPROCAL + +// precision must be equal to: #bits(state_) + #bits(freq) +#define RECIPROCAL_PRECISION (32 + ANS_LOG_TAB_SIZE) + +// Data structure representing one element of the encoding table built +// from a distribution. +// TODO(veluca): split this up, or use an union. +struct ANSEncSymbolInfo { + // ANS + uint16_t freq_; + std::vector reverse_map_; +#ifdef USE_MULT_BY_RECIPROCAL + uint64_t ifreq_; +#endif + // Prefix coding. + uint8_t depth; + uint16_t bits; +}; + +class ANSCoder { + public: + ANSCoder() : state_(ANS_SIGNATURE << 16) {} + + uint32_t PutSymbol(const ANSEncSymbolInfo& t, uint8_t* nbits) { + uint32_t bits = 0; + *nbits = 0; + if ((state_ >> (32 - ANS_LOG_TAB_SIZE)) >= t.freq_) { + bits = state_ & 0xffff; + state_ >>= 16; + *nbits = 16; + } +#ifdef USE_MULT_BY_RECIPROCAL + // We use mult-by-reciprocal trick, but that requires 64b calc. + const uint32_t v = (state_ * t.ifreq_) >> RECIPROCAL_PRECISION; + const uint32_t offset = t.reverse_map_[state_ - v * t.freq_]; + state_ = (v << ANS_LOG_TAB_SIZE) + offset; +#else + state_ = ((state_ / t.freq_) << ANS_LOG_TAB_SIZE) + + t.reverse_map_[state_ % t.freq_]; +#endif + return bits; + } + + uint32_t GetState() const { return state_; } + + private: + uint32_t state_; +}; + +// RebalanceHistogram requires a signed type. +using ANSHistBin = int32_t; + +struct EntropyEncodingData { + std::vector> encoding_info; + bool use_prefix_code; + std::vector uint_config; + LZ77Params lz77; +}; + +// Integer to be encoded by an entropy coder, either ANS or Huffman. +struct Token { + Token(uint32_t c, uint32_t value) + : is_lz77_length(false), context(c), value(value) {} + uint32_t is_lz77_length : 1; + uint32_t context : 31; + uint32_t value; +}; + +// Returns an estimate of the number of bits required to encode the given +// histogram (header bits plus data bits). +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size); + +// Apply context clustering, compute histograms and encode them. Returns an +// estimate of the total bits used for encoding the stream. If `writer` == +// nullptr, the bit estimate will not take into account the context map (which +// does not get written if `num_contexts` == 1). +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector>& tokens, + EntropyEncodingData* codes, + std::vector* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out); + +// Write the tokens to a string. +void WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +// Same as above, but assumes allotment created by caller. +size_t WriteTokens(const std::vector& tokens, + const EntropyEncodingData& codes, + const std::vector& context_map, BitWriter* writer); + +// Exposed for tests; to be used with Writer=BitWriter only. +template +void EncodeUintConfigs(const std::vector& uint_config, + Writer* writer, size_t log_alpha_size); +extern template void EncodeUintConfigs(const std::vector&, + BitWriter*, size_t); + +// Globally set the option to create fuzzer-friendly ANS streams. Negatively +// impacts compression. Not thread-safe. +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly); +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans_params.h b/third_party/jpeg-xl/lib/jxl/enc_ans_params.h new file mode 100644 index 000000000000..c88c84088c12 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans_params.h @@ -0,0 +1,84 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_ANS_PARAMS_H_ +#define LIB_JXL_ENC_ANS_PARAMS_H_ + +// Encoder-only parameter needed for ANS entropy encoding methods. + +#include +#include + +#include "lib/jxl/enc_params.h" + +namespace jxl { + +struct HistogramParams { + enum class ClusteringType { + kFastest, // Only 4 clusters. + kFast, + kBest, + }; + + enum class HybridUintMethod { + kNone, // just use kHybridUint420Config. + kFast, // just try a couple of options. + kContextMap, // fast choice for ctx map. + kBest, + }; + + enum class LZ77Method { + kNone, // do not try lz77. + kRLE, // only try doing RLE. + kLZ77, // try lz77 with backward references. + kOptimal, // optimal-matching LZ77 parsing. + }; + + enum class ANSHistogramStrategy { + kFast, // Only try some methods, early exit. + kApproximate, // Only try some methods. + kPrecise, // Try all methods. + }; + + HistogramParams() = default; + + HistogramParams(SpeedTier tier, size_t num_ctx) { + if (tier == SpeedTier::kFalcon) { + clustering = ClusteringType::kFastest; + lz77_method = LZ77Method::kNone; + } else if (tier > SpeedTier::kTortoise) { + clustering = ClusteringType::kFast; + } else { + clustering = ClusteringType::kBest; + } + if (tier > SpeedTier::kTortoise) { + uint_method = HybridUintMethod::kNone; + } + if (tier >= SpeedTier::kSquirrel) { + ans_histogram_strategy = ANSHistogramStrategy::kApproximate; + } + } + + ClusteringType clustering = ClusteringType::kBest; + HybridUintMethod uint_method = HybridUintMethod::kBest; + LZ77Method lz77_method = LZ77Method::kRLE; + ANSHistogramStrategy ans_histogram_strategy = ANSHistogramStrategy::kPrecise; + std::vector image_widths; + size_t max_histograms = ~0; + bool force_huffman = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc new file mode 100644 index 000000000000..f11d06c6e8a8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc @@ -0,0 +1,327 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_ar_control_field.h" + +#include +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +void ProcessTile(const Image3F& opsin, PassesEncoderState* enc_state, + const Rect& rect, + ArControlFieldHeuristics::TempImages* temp_image) { + constexpr size_t N = kBlockDim; + ImageB* JXL_RESTRICT epf_sharpness = &enc_state->shared.epf_sharpness; + ImageF* JXL_RESTRICT quant = &enc_state->initial_quant_field; + JXL_ASSERT( + epf_sharpness->xsize() == enc_state->shared.frame_dim.xsize_blocks && + epf_sharpness->ysize() == enc_state->shared.frame_dim.ysize_blocks); + + if (enc_state->cparams.butteraugli_distance < kMinButteraugliForDynamicAR || + enc_state->cparams.speed_tier > SpeedTier::kWombat || + enc_state->shared.frame_header.loop_filter.epf_iters == 0) { + FillPlane(static_cast(4), epf_sharpness, rect); + return; + } + + // Likely better to have a higher X weight, like: + // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f}; + const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f}; + const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0], + -0.125f * kChannelWeights[1], + -0.125f * kChannelWeights[2]}; + const size_t sharpness_stride = + static_cast(epf_sharpness->PixelsPerRow()); + + size_t by0 = rect.y0(); + size_t by1 = rect.y0() + rect.ysize(); + size_t bx0 = rect.x0(); + size_t bx1 = rect.x0() + rect.xsize(); + temp_image->InitOnce(); + ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum; + // Calculate the L2 of the 3x3 Laplacian in an integral transform + // (for example 32x32 dct). This relates to transforms ability + // to propagate artefacts. + size_t y0 = by0 == 0 ? 2 : 0; + size_t y1 = by1 * N + 4 <= opsin.ysize() + 2 ? (by1 - by0) * N + 4 + : opsin.ysize() + 2 - by0 * N; + size_t x0 = bx0 == 0 ? 2 : 0; + size_t x1 = bx1 * N + 4 <= opsin.xsize() + 2 ? (bx1 - bx0) * N + 4 + : opsin.xsize() + 2 - bx0 * N; + HWY_FULL(float) df; + for (size_t y = y0; y < y1; y++) { + float* JXL_RESTRICT laplacian_sqrsum_row = laplacian_sqrsum.Row(y); + size_t cy = y + by0 * N - 2; + const float* JXL_RESTRICT in_row_t[3]; + const float* JXL_RESTRICT in_row[3]; + const float* JXL_RESTRICT in_row_b[3]; + for (size_t c = 0; c < 3; c++) { + in_row_t[c] = opsin.PlaneRow(c, cy > 0 ? cy - 1 : cy); + in_row[c] = opsin.PlaneRow(c, cy); + in_row_b[c] = opsin.PlaneRow(c, cy + 1 < opsin.ysize() ? cy + 1 : cy); + } + auto compute_laplacian_scalar = [&](size_t x) { + size_t cx = x + bx0 * N - 2; + const size_t prevX = cx >= 1 ? cx - 1 : cx; + const size_t nextX = cx + 1 < opsin.xsize() ? cx + 1 : cx; + float sumsqr = 0; + for (size_t c = 0; c < 3; c++) { + float laplacian = + kChannelWeights[c] * in_row[c][cx] + + kChannelWeightsLapNeg[c] * + (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] + + in_row_b[c][cx] + in_row_b[c][nextX] + in_row_t[c][prevX] + + in_row_t[c][cx] + in_row_t[c][nextX]); + sumsqr += laplacian * laplacian; + } + laplacian_sqrsum_row[x] = sumsqr; + }; + size_t x = x0; + for (; x + bx0 * N < 3; x++) { + compute_laplacian_scalar(x); + } + // Interior. One extra pixel of border as the last pixel is special. + for (; x + Lanes(df) <= x1 && x + Lanes(df) + bx0 * N - 1 <= opsin.xsize(); + x += Lanes(df)) { + size_t cx = x + bx0 * N - 2; + auto sumsqr = Zero(df); + for (size_t c = 0; c < 3; c++) { + auto laplacian = + LoadU(df, in_row[c] + cx) * Set(df, kChannelWeights[c]); + auto sum_oth0 = LoadU(df, in_row[c] + cx - 1); + auto sum_oth1 = LoadU(df, in_row[c] + cx + 1); + auto sum_oth2 = LoadU(df, in_row_t[c] + cx - 1); + auto sum_oth3 = LoadU(df, in_row_t[c] + cx); + sum_oth0 += LoadU(df, in_row_t[c] + cx + 1); + sum_oth1 += LoadU(df, in_row_b[c] + cx - 1); + sum_oth2 += LoadU(df, in_row_b[c] + cx); + sum_oth3 += LoadU(df, in_row_b[c] + cx + 1); + sum_oth0 += sum_oth1; + sum_oth2 += sum_oth3; + sum_oth0 += sum_oth2; + laplacian = + MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian); + sumsqr = MulAdd(laplacian, laplacian, sumsqr); + } + StoreU(sumsqr, df, laplacian_sqrsum_row + x); + } + for (; x < x1; x++) { + compute_laplacian_scalar(x); + } + } + HWY_CAPPED(float, 4) df4; + // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area + // of the integral transform. Sample them within the integral transform + // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22, + // respectively). + ImageF& sqrsum_00 = temp_image->sqrsum_00; + size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0); + for (size_t y = 0; y < (by1 - by0) * 2; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2); + } + float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride; + for (size_t x = 0; x < (bx1 - bx0) * 2; x++) { + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + sum += LoadU(df4, rows_in[iy] + x * 4 + ix + 2); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f); + } + } + // Indexing iy and ix is a bit tricky as we include a 2 pixel border + // around the block for evenness calculations. This is similar to what + // we did in guetzli for the observability of artefacts, except there + // the element is a sliding 5x5, not sparsely sampled 4x4 box like here. + ImageF& sqrsum_22 = temp_image->sqrsum_22; + size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0); + for (size_t y = 0; y < (by1 - by0) * 2 + 1; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy); + } + float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride; + // ignore pixels outside the image. + // Y coordinates are relative to by0*8+y*4. + size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2; + size_t ey = y * 4 + by0 * 8 + 4 <= opsin.ysize() + 2 + ? 4 + : opsin.ysize() - y * 4 - by0 * 8 + 2; + for (size_t x = 0; x < (bx1 - bx0) * 2 + 1; x++) { + // ignore pixels outside the image. + // X coordinates are relative to bx0*8. + size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2; + size_t ex = x * 4 + bx0 * 8 + 4 <= opsin.xsize() + 2 + ? x * 4 + 4 + : opsin.xsize() - bx0 * 8 + 2; + if (ex - sx == 4 && ey - sy == 4) { + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + sum += Load(df4, rows_in[iy] + sx + ix); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f); + } else { + float sum = 0; + for (size_t iy = sy; iy < ey; iy++) { + for (size_t ix = sx; ix < ex; ix++) { + sum += rows_in[iy][ix]; + } + } + row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy))); + } + } + } + for (size_t by = by0; by < by1; by++) { + AcStrategyRow acs_row = enc_state->shared.ac_strategy.ConstRow(by); + uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by); + float* JXL_RESTRICT quant_row = quant->Row(by); + for (size_t bx = bx0; bx < bx1; bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + // The errors are going to be linear to the quantization value in this + // locality. We only have access to the initial quant field here. + float quant_val = 1.0f / quant_row[bx]; + + const auto sq00 = [&](size_t y, size_t x) { + return sqrsum_00_row[((by - by0) * 2 + y) * sqrsum_00_stride + + (bx - bx0) * 2 + x]; + }; + const auto sq22 = [&](size_t y, size_t x) { + return sqrsum_22_row[((by - by0) * 2 + y) * sqrsum_22_stride + + (bx - bx0) * 2 + x]; + }; + float sqrsum_integral_transform = 0; + for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) { + sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix); + } + } + sqrsum_integral_transform /= + 4 * acs.covered_blocks_x() * acs.covered_blocks_y(); + sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform); + // If masking is high or amplitude of the artefacts is low, then no + // smoothing is needed. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + // Five 4x4 blocks for masking estimation, all within the + // 8x8 area. + float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0), + sq00(2 * iy + 0, 2 * ix + 1)); + float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0), + sq00(2 * iy + 1, 2 * ix + 1)); + float minval = std::min(minval_1, minval_2); + minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1)); + // Nine more 4x4 blocks for masking estimation, includes + // the 2 pixel area around the 8x8 block being controlled. + float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0), + sq22(2 * iy + 0, 2 * ix + 1)); + float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2), + sq22(2 * iy + 1, 2 * ix + 0)); + float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1), + sq22(2 * iy + 1, 2 * ix + 2)); + float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0), + sq22(2 * iy + 2, 2 * ix + 1)); + float minval2_5 = std::min(minval2_1, minval2_2); + float minval2_6 = std::min(minval2_3, minval2_4); + float minval2 = std::min(minval2_5, minval2_6); + minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2)); + float minval3 = std::min(minval, minval2); + minval *= 0.125f; + minval += 0.625f * minval3; + minval += + 0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1)); + minval += 0.125f * minval2; + // Larger kBias, less smoothing for low intensity changes. + float kDeltaLimit = 3.2; + float bias = 0.0625f * quant_val; + float delta = + (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) / + (minval + bias); + int out = 4; + if (delta > kDeltaLimit) { + out = 4; // smooth + } else { + out = 0; + } + // 'threshold' is separate from 'bias' for easier tuning of these + // heuristics. + float threshold = 0.0625f * quant_val; + const float kSmoothLimit = 0.085f; + float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) + + sq00(2 * iy + 0, 2 * ix + 1) + + sq00(2 * iy + 1, 2 * ix + 0) + + sq00(2 * iy + 1, 2 * ix + 1) + minval); + if (smooth < kSmoothLimit * threshold) { + out = 4; + } + out_row[bx + sharpness_stride * iy + ix] = out; + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessTile); + +void ArControlFieldHeuristics::RunRect(const Rect& block_rect, + const Image3F& opsin, + PassesEncoderState* enc_state, + size_t thread) { + HWY_DYNAMIC_DISPATCH(ProcessTile) + (opsin, enc_state, block_rect, &temp_images[thread]); +} + +} // namespace jxl + +#endif diff --git a/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h new file mode 100644 index 000000000000..35520f216040 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_AR_CONTROL_FIELD_H_ +#define LIB_JXL_ENC_AR_CONTROL_FIELD_H_ + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +struct ArControlFieldHeuristics { + struct TempImages { + void InitOnce() { + if (laplacian_sqrsum.xsize() != 0) return; + laplacian_sqrsum = ImageF(kEncTileDim + 4, kEncTileDim + 4); + sqrsum_00 = ImageF(kEncTileDim / 4, kEncTileDim / 4); + sqrsum_22 = ImageF(kEncTileDim / 4 + 1, kEncTileDim / 4 + 1); + } + + ImageF laplacian_sqrsum; + ImageF sqrsum_00; + ImageF sqrsum_22; + }; + + void PrepareForThreads(size_t num_threads) { + temp_images.resize(num_threads); + } + + void RunRect(const Rect& block_rect, const Image3F& opsin, + PassesEncoderState* enc_state, size_t thread); + + std::vector temp_images; + ImageB* epf_sharpness; + ImageF* quant; + bool all_default; +}; + +} // namespace jxl + +#endif // LIB_JXL_AR_ENC_CONTROL_FIELD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc new file mode 100644 index 000000000000..acc3173a564d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc @@ -0,0 +1,259 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_bit_writer.h" + +#include // memcpy + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +BitWriter::Allotment::Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits) + : max_bits_(max_bits) { + if (writer == nullptr) return; + prev_bits_written_ = writer->BitsWritten(); + const size_t prev_bytes = writer->storage_.size(); + const size_t next_bytes = DivCeil(max_bits, kBitsPerByte); + writer->storage_.resize(prev_bytes + next_bytes); + parent_ = writer->current_allotment_; + writer->current_allotment_ = this; +} + +BitWriter::Allotment::~Allotment() { + if (!called_) { + // Not calling is a bug - unused storage will not be reclaimed. + JXL_ABORT("Did not call Allotment::ReclaimUnused"); + } +} + +void BitWriter::Allotment::FinishedHistogram(BitWriter* JXL_RESTRICT writer) { + if (writer == nullptr) return; + JXL_ASSERT(!called_); // Call before ReclaimUnused + JXL_ASSERT(histogram_bits_ == 0); // Do not call twice + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + histogram_bits_ = writer->BitsWritten() - prev_bits_written_; +} + +void BitWriter::Allotment::PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits) { + JXL_ASSERT(!called_); // Do not call twice + called_ = true; + if (writer == nullptr) return; + + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + *used_bits = writer->BitsWritten() - prev_bits_written_; + JXL_ASSERT(*used_bits <= max_bits_); + *unused_bits = max_bits_ - *used_bits; + + // Reclaim unused bytes whole bytes from writer's allotment. + const size_t unused_bytes = *unused_bits / kBitsPerByte; // truncate + JXL_ASSERT(writer->storage_.size() >= unused_bytes); + writer->storage_.resize(writer->storage_.size() - unused_bytes); + writer->current_allotment_ = parent_; + // Ensure we don't also charge the parent for these bits. + auto parent = parent_; + while (parent != nullptr) { + parent->prev_bits_written_ += *used_bits; + parent = parent->parent_; + } +} + +void BitWriter::AppendByteAligned(const Span& span) { + if (!span.size()) return; + storage_.resize(storage_.size() + span.size() + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += span.size() * kBitsPerByte; +} + +void BitWriter::AppendByteAligned(const BitWriter& other) { + JXL_ASSERT(other.BitsWritten() % kBitsPerByte == 0); + JXL_ASSERT(other.BitsWritten() / kBitsPerByte != 0); + + AppendByteAligned(other.GetSpan()); +} + +void BitWriter::AppendByteAligned(const std::vector& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const BitWriter& writer : others) { + JXL_ASSERT(writer.BitsWritten() % kBitsPerByte == 0); + other_bytes += writer.BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const BitWriter& writer : others) { + const Span span = writer.GetSpan(); + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +// TODO(lode): avoid code duplication +void BitWriter::AppendByteAligned( + const std::vector>& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const auto& writer : others) { + JXL_ASSERT(writer->BitsWritten() % kBitsPerByte == 0); + other_bytes += writer->BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const auto& writer : others) { + const Span span = writer->GetSpan(); + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +BitWriter& BitWriter::operator+=(const BitWriter& other) { + // Required for correctness, otherwise owned[bits_written_] is out of bounds. + if (other.bits_written_ == 0) return *this; + const size_t other_bytes = DivCeil(other.bits_written_, kBitsPerByte); + const size_t prev_bytes = storage_.size(); + storage_.resize(prev_bytes + other_bytes + 1); // extra zero padding + + if (bits_written_ % kBitsPerByte == 0) { + // Only copy fully-initialized bytes. + const size_t full_bytes = other.bits_written_ / kBitsPerByte; // truncated + memcpy(&storage_[bits_written_ / kBitsPerByte], other.storage_.data(), + full_bytes); + storage_[bits_written_ / kBitsPerByte + full_bytes] = 0; // for next Write + bits_written_ += full_bytes * kBitsPerByte; + + const size_t leftovers = other.bits_written_ % kBitsPerByte; + if (leftovers != 0) { + BitReader reader(Span(other.storage_.data() + full_bytes, + other_bytes - full_bytes)); + Write(leftovers, reader.ReadBits(leftovers)); + JXL_CHECK(reader.Close()); + } + return *this; + } + + constexpr size_t N = kMaxBitsPerCall < BitReader::kMaxBitsPerCall + ? kMaxBitsPerCall + : BitReader::kMaxBitsPerCall; + + // Do not use GetSpan because other may not be byte-aligned. + BitReader reader(other.storage_); + size_t i = 0; + for (; i + N <= other.bits_written_; i += N) { + Write(N, reader.ReadFixedBits()); + } + const size_t leftovers = other.bits_written_ - i; + if (leftovers != 0) { + Write(leftovers, reader.ReadBits(leftovers)); + } + JXL_CHECK(reader.Close()); + return *this; +} + +// Example: let's assume that 3 bits (Rs below) have been written already: +// BYTE+0 BYTE+1 BYTE+2 +// 0000 0RRR ???? ???? ???? ???? +// +// Now, we could write up to 5 bits by just shifting them left by 3 bits and +// OR'ing to BYTE-0. +// +// For n > 5 bits, we write the lowest 5 bits as above, then write the next +// lowest bits into BYTE+1 starting from its lower bits and so on. +void BitWriter::Write(size_t n_bits, uint64_t bits) { + JXL_DASSERT((bits >> n_bits) == 0); + JXL_DASSERT(n_bits <= kMaxBitsPerCall); + uint8_t* p = &storage_[bits_written_ / kBitsPerByte]; + const size_t bits_in_first_byte = bits_written_ % kBitsPerByte; + bits <<= bits_in_first_byte; +#if JXL_BYTE_ORDER_LITTLE + uint64_t v = *p; + // Last (partial) or next byte to write must be zero-initialized! + // PaddedBytes initializes the first, and Write/Append maintain this. + JXL_DASSERT(v >> bits_in_first_byte == 0); + v |= bits; + memcpy(p, &v, sizeof(v)); // Write bytes: possibly more than n_bits/8 +#else + *p++ |= static_cast(bits & 0xFF); + for (size_t bits_left_to_write = n_bits + bits_in_first_byte; + bits_left_to_write >= 9; bits_left_to_write -= 8) { + bits >>= 8; + *p++ = static_cast(bits & 0xFF); + } + *p = 0; +#endif + bits_written_ += n_bits; +} + +BitWriter& BitWriter::operator+=(const PaddedBytes& other) { + const size_t other_bytes = other.size(); + // Required for correctness, otherwise owned[bits_written_] is out of bounds. + if (other_bytes == 0) return *this; + const size_t other_bits = other_bytes * kBitsPerByte; + + storage_.resize(storage_.size() + other_bytes + 1); + if (bits_written_ % kBitsPerByte == 0) { + memcpy(&storage_[bits_written_ / kBitsPerByte], other.data(), other_bytes); + storage_[bits_written_ / kBitsPerByte + other_bytes] = 0; // for next Write + bits_written_ += other_bits; + return *this; + } + constexpr size_t N = kMaxBitsPerCall < BitReader::kMaxBitsPerCall + ? kMaxBitsPerCall + : BitReader::kMaxBitsPerCall; + + BitReader reader(other); + size_t i = 0; + for (; i + N <= other_bits; i += N) { + Write(N, reader.ReadFixedBits()); + } + const size_t leftovers = other_bits - i; + Write(leftovers, reader.ReadBits(leftovers)); + JXL_CHECK(reader.Close()); + return *this; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h new file mode 100644 index 000000000000..97e803fc967f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h @@ -0,0 +1,153 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_BIT_WRITER_H_ +#define LIB_JXL_ENC_BIT_WRITER_H_ + +// BitWriter class: unbuffered writes using unaligned 64-bit stores. + +#include +#include + +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" + +namespace jxl { + +struct BitWriter { + // Upper bound on `n_bits` in each call to Write. We shift a 64-bit word by + // 7 bits (max already valid bits in the last byte) and at least 1 bit is + // needed to zero-initialize the bit-stream ahead (i.e. if 7 bits are valid + // and we write 57 bits, then the next write will access a byte that was not + // yet zero-initialized). + static constexpr size_t kMaxBitsPerCall = 56; + + BitWriter() : bits_written_(0) {} + + // Disallow copying - may lead to bugs. + BitWriter(const BitWriter&) = delete; + BitWriter& operator=(const BitWriter&) = delete; + BitWriter(BitWriter&&) = default; + BitWriter& operator=(BitWriter&&) = default; + + explicit BitWriter(PaddedBytes&& donor) + : bits_written_(donor.size() * kBitsPerByte), + storage_(std::move(donor)) {} + + size_t BitsWritten() const { return bits_written_; } + + Span GetSpan() const { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + return Span(storage_.data(), bits_written_ / kBitsPerByte); + } + + // Example usage: bytes = std::move(writer).TakeBytes(); Useful for the + // top-level encoder which returns PaddedBytes, not a BitWriter. + // *this must be an rvalue reference and is invalid afterwards. + PaddedBytes&& TakeBytes() && { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + storage_.resize(bits_written_ / kBitsPerByte); + return std::move(storage_); + } + + // Must be byte-aligned before calling. + void AppendByteAligned(const Span& span); + // NOTE: no allotment needed, the other BitWriters have already been charged. + void AppendByteAligned(const BitWriter& other); + void AppendByteAligned(const std::vector>& others); + void AppendByteAligned(const std::vector& others); + + class Allotment { + public: + // Expands a BitWriter's storage. Must happen before calling Write or + // ZeroPadToByte. Must call ReclaimUnused after writing to reclaim the + // unused storage so that BitWriter memory use remains tightly bounded. + Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits); + ~Allotment(); + + size_t MaxBits() const { return max_bits_; } + + // Call after writing a histogram, but before ReclaimUnused. + void FinishedHistogram(BitWriter* JXL_RESTRICT writer); + + size_t HistogramBits() const { + JXL_ASSERT(called_); + return histogram_bits_; + } + + // Do not call directly - use ::ReclaimAndCharge instead, which ensures + // the bits are charged to a layer. + void PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits); + + private: + size_t prev_bits_written_; + const size_t max_bits_; + size_t histogram_bits_ = 0; + bool called_ = false; + Allotment* parent_; + }; + + // WARNING: think twice before using this. Concatenating two BitWriters that + // pad to bytes is NOT the same as one contiguous BitWriter. + BitWriter& operator+=(const BitWriter& other); + + // TODO(janwas): remove once all callers use BitWriter + BitWriter& operator+=(const PaddedBytes& other); + + // Writes bits into bytes in increasing addresses, and within a byte + // least-significant-bit first. + // + // The function can write up to 56 bits in one go. + void Write(size_t n_bits, uint64_t bits); + + // This should only rarely be used - e.g. when the current location will be + // referenced via byte offset (TOCs point to groups), or byte-aligned reading + // is required for speed. WARNING: this interacts badly with operator+=, + // see above. + void ZeroPadToByte() { + const size_t remainder_bits = + RoundUpBitsToByteMultiple(bits_written_) - bits_written_; + if (remainder_bits == 0) return; + Write(remainder_bits, 0); + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + } + + // TODO(janwas): remove? only called from ANS + void RewindStorage(const size_t pos0) { + JXL_ASSERT(pos0 <= bits_written_); + bits_written_ = pos0; + static const uint8_t kRewindMasks[8] = {0x0, 0x1, 0x3, 0x7, + 0xf, 0x1f, 0x3f, 0x7f}; + storage_[pos0 >> 3] &= kRewindMasks[pos0 & 7]; + } + + private: + size_t bits_written_; + PaddedBytes storage_; + Allotment* current_allotment_ = nullptr; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_BIT_WRITER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc new file mode 100644 index 000000000000..91b7a4b589be --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_butteraugli_comparator.h" + +#include +#include + +#include "lib/jxl/color_management.h" + +namespace jxl { + +JxlButteraugliComparator::JxlButteraugliComparator( + const ButteraugliParams& params) + : params_(params) {} + +Status JxlButteraugliComparator::SetReferenceImage(const ImageBundle& ref) { + const ImageBundle* ref_linear_srgb; + ImageMetadata metadata = *ref.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(ref, ColorEncoding::LinearSRGB(ref.IsGray()), + /*pool=*/nullptr, &store, &ref_linear_srgb)) { + return false; + } + + comparator_.reset( + new ButteraugliComparator(ref_linear_srgb->color(), params_)); + xsize_ = ref.xsize(); + ysize_ = ref.ysize(); + return true; +} + +Status JxlButteraugliComparator::CompareWith(const ImageBundle& actual, + ImageF* diffmap, float* score) { + if (!comparator_) { + return JXL_FAILURE("Must set reference image first"); + } + if (xsize_ != actual.xsize() || ysize_ != actual.ysize()) { + return JXL_FAILURE("Images must have same size"); + } + + const ImageBundle* actual_linear_srgb; + ImageMetadata metadata = *actual.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(actual, ColorEncoding::LinearSRGB(actual.IsGray()), + /*pool=*/nullptr, &store, &actual_linear_srgb)) { + return false; + } + + ImageF temp_diffmap(xsize_, ysize_); + comparator_->Diffmap(actual_linear_srgb->color(), temp_diffmap); + + if (score != nullptr) { + *score = ButteraugliScoreFromDiffmap(temp_diffmap, ¶ms_); + } + if (diffmap != nullptr) { + diffmap->Swap(temp_diffmap); + } + + return true; +} + +float JxlButteraugliComparator::GoodQualityScore() const { + return ButteraugliFuzzyInverse(1.5); +} + +float JxlButteraugliComparator::BadQualityScore() const { + return ButteraugliFuzzyInverse(0.5); +} + +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, ImageF* distmap, + ThreadPool* pool) { + JxlButteraugliComparator comparator(params); + return ComputeScore(rgb0, rgb1, &comparator, distmap, pool); +} + +float ButteraugliDistance(const CodecInOut& rgb0, const CodecInOut& rgb1, + const ButteraugliParams& params, ImageF* distmap, + ThreadPool* pool) { + JxlButteraugliComparator comparator(params); + JXL_ASSERT(rgb0.frames.size() == rgb1.frames.size()); + float max_dist = 0.0f; + for (size_t i = 0; i < rgb0.frames.size(); ++i) { + max_dist = std::max(max_dist, ComputeScore(rgb0.frames[i], rgb1.frames[i], + &comparator, distmap, pool)); + } + return max_dist; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h new file mode 100644 index 000000000000..b9093ac79e47 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ +#define LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ + +#include + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_comparator.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class JxlButteraugliComparator : public Comparator { + public: + explicit JxlButteraugliComparator(const ButteraugliParams& params); + + Status SetReferenceImage(const ImageBundle& ref) override; + + Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) override; + + float GoodQualityScore() const override; + float BadQualityScore() const override; + + private: + ButteraugliParams params_; + std::unique_ptr comparator_; + size_t xsize_ = 0; + size_t ysize_ = 0; +}; + +// Returns the butteraugli distance between rgb0 and rgb1. +// If distmap is not null, it must be the same size as rgb0 and rgb1. +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, + ImageF* distmap = nullptr, + ThreadPool* pool = nullptr); + +float ButteraugliDistance(const CodecInOut& rgb0, const CodecInOut& rgb1, + const ButteraugliParams& params, + ImageF* distmap = nullptr, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.cc b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.cc new file mode 100644 index 000000000000..f1430d4d89dc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.cc @@ -0,0 +1,221 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_butteraugli_pnorm.h" + +#include +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_butteraugli_pnorm.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; + +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p) { + PROFILER_FUNC; + // In approximate-border mode, skip pixels on the border likely to be affected + // by FastGauss' zero-valued-boundary behavior. The border is less than half + // the largest-diameter kernel (37x37 pixels), and 0 if the image is tiny. + // NOTE: chosen such that it is vector-aligned. + size_t border = (params.approximate_border) ? 8 : 0; + if (distmap.xsize() <= 2 * border || distmap.ysize() <= 2 * border) { + border = 0; + } + + const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize()); + if (std::abs(p - 3.0) < 1E-6) { + double sum1[3] = {0.0}; + +// Prefer double if possible, but otherwise use float rather than scalar. +#if HWY_CAP_FLOAT64 + using T = double; + const Rebind df; +#else + using T = float; +#endif + const HWY_FULL(T) d; + constexpr size_t N = MaxLanes(HWY_FULL(T)()); + // Manually aligned storage to avoid asan crash on clang-7 due to + // unaligned spill. + HWY_ALIGN T sum_totals0[N] = {0}; + HWY_ALIGN T sum_totals1[N] = {0}; + HWY_ALIGN T sum_totals2[N] = {0}; + + for (size_t y = border; y < distmap.ysize() - border; ++y) { + const float* JXL_RESTRICT row = distmap.ConstRow(y); + + auto sums0 = Zero(d); + auto sums1 = Zero(d); + auto sums2 = Zero(d); + + size_t x = border; + for (; x + Lanes(d) <= distmap.xsize() - border; x += Lanes(d)) { +#if HWY_CAP_FLOAT64 + const auto d1 = PromoteTo(d, Load(df, row + x)); +#else + const auto d1 = Load(d, row + x); +#endif + const auto d2 = d1 * d1 * d1; + sums0 += d2; + const auto d3 = d2 * d2; + sums1 += d3; + const auto d4 = d3 * d3; + sums2 += d4; + } + + Store(sums0 + Load(d, sum_totals0), d, sum_totals0); + Store(sums1 + Load(d, sum_totals1), d, sum_totals1); + Store(sums2 + Load(d, sum_totals2), d, sum_totals2); + + for (; x < distmap.xsize() - border; ++x) { + const double d1 = row[x]; + double d2 = d1 * d1 * d1; + sum1[0] += d2; + d2 *= d2; + sum1[1] += d2; + d2 *= d2; + sum1[2] += d2; + } + } + double v = 0; + v += pow( + onePerPixels * (sum1[0] + GetLane(SumOfLanes(Load(d, sum_totals0)))), + 1.0 / (p * 1.0)); + v += pow( + onePerPixels * (sum1[1] + GetLane(SumOfLanes(Load(d, sum_totals1)))), + 1.0 / (p * 2.0)); + v += pow( + onePerPixels * (sum1[2] + GetLane(SumOfLanes(Load(d, sum_totals2)))), + 1.0 / (p * 4.0)); + v /= 3.0; + return v; + } else { + static std::atomic once{0}; + if (once.fetch_add(1, std::memory_order_relaxed) == 0) { + JXL_WARNING("WARNING: using slow ComputeDistanceP"); + } + double sum1[3] = {0.0}; + for (size_t y = border; y < distmap.ysize() - border; ++y) { + const float* JXL_RESTRICT row = distmap.ConstRow(y); + for (size_t x = border; x < distmap.xsize() - border; ++x) { + double d2 = std::pow(row[x], p); + sum1[0] += d2; + d2 *= d2; + sum1[1] += d2; + d2 *= d2; + sum1[2] += d2; + } + } + double v = 0; + for (int i = 0; i < 3; ++i) { + v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i))); + } + v /= 3.0; + return v; + } +} + +// TODO(lode): take alpha into account when needed +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2) { + PROFILER_FUNC; + // Convert to sRGB - closer to perception than linear. + const Image3F* srgb1 = &ib1.color(); + Image3F copy1; + if (!ib1.IsSRGB()) { + JXL_CHECK(ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), ©1)); + srgb1 = ©1; + } + const Image3F* srgb2 = &ib2.color(); + Image3F copy2; + if (!ib2.IsSRGB()) { + JXL_CHECK(ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), ©2)); + srgb2 = ©2; + } + + JXL_CHECK(SameSize(*srgb1, *srgb2)); + + // TODO(veluca): SIMD. + float yuvmatrix[3][3] = {{0.299, 0.587, 0.114}, + {-0.14713, -0.28886, 0.436}, + {0.615, -0.51499, -0.10001}}; + double sum_of_squares[3] = {}; + for (size_t y = 0; y < srgb1->ysize(); ++y) { + const float* JXL_RESTRICT row1[3]; + const float* JXL_RESTRICT row2[3]; + for (size_t j = 0; j < 3; j++) { + row1[j] = srgb1->ConstPlaneRow(j, y); + row2[j] = srgb2->ConstPlaneRow(j, y); + } + for (size_t x = 0; x < srgb1->xsize(); ++x) { + float cdiff[3] = {}; + // YUV conversion is linear, so we can run it on the difference. + for (size_t j = 0; j < 3; j++) { + cdiff[j] = row1[j][x] - row2[j][x]; + } + float yuvdiff[3] = {}; + for (size_t j = 0; j < 3; j++) { + for (size_t k = 0; k < 3; k++) { + yuvdiff[j] += yuvmatrix[j][k] * cdiff[k]; + } + } + for (size_t j = 0; j < 3; j++) { + sum_of_squares[j] += yuvdiff[j] * yuvdiff[j]; + } + } + } + // Weighted PSNR as in JPEG-XL: chroma counts 1/8. + const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8}; + // Avoid squaring the weight - 1/64 is too extreme. + double norm = 0; + for (size_t i = 0; i < 3; i++) { + norm += std::sqrt(sum_of_squares[i]) * weights[i]; + } + // This function returns distance *squared*. + return norm * norm; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ComputeDistanceP); +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p) { + return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p); +} + +HWY_EXPORT(ComputeDistance2); +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2) { + return HWY_DYNAMIC_DISPATCH(ComputeDistance2)(ib1, ib2); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.h b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.h new file mode 100644 index 000000000000..345ba975584e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_pnorm.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ +#define LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ + +#include + +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Computes p-norm given the butteraugli distmap. +double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params, + double p); + +double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2); + +} // namespace jxl + +#endif // LIB_JXL_ENC_BUTTERAUGLI_PNORM_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_cache.cc b/third_party/jpeg-xl/lib/jxl/enc_cache.cc new file mode 100644 index 000000000000..42c986601b06 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cache.cc @@ -0,0 +1,204 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_cache.h" + +#include +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +void InitializePassesEncoder(const Image3F& opsin, ThreadPool* pool, + PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out) { + PROFILER_FUNC; + + PassesSharedState& JXL_RESTRICT shared = enc_state->shared; + + enc_state->histogram_idx.resize(shared.frame_dim.num_groups); + + enc_state->x_qm_multiplier = + std::pow(1.25f, shared.frame_header.x_qm_scale - 2.0f); + enc_state->b_qm_multiplier = + std::pow(1.25f, shared.frame_header.b_qm_scale - 2.0f); + + if (enc_state->coeffs.size() < shared.frame_header.passes.num_passes) { + enc_state->coeffs.reserve(shared.frame_header.passes.num_passes); + for (size_t i = enc_state->coeffs.size(); + i < shared.frame_header.passes.num_passes; i++) { + // Allocate enough coefficients for each group on every row. + enc_state->coeffs.emplace_back(make_unique>( + kGroupDim * kGroupDim, shared.frame_dim.num_groups)); + } + } + while (enc_state->coeffs.size() > shared.frame_header.passes.num_passes) { + enc_state->coeffs.pop_back(); + } + + Image3F dc(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks); + RunOnPool( + pool, 0, shared.frame_dim.num_groups, ThreadPool::SkipInit(), + [&](size_t group_idx, size_t _) { + ComputeCoefficients(group_idx, enc_state, opsin, &dc); + }, + "Compute coeffs"); + + if (shared.frame_header.flags & FrameHeader::kUseDcFrame) { + CompressParams cparams = enc_state->cparams; + // Guess a distance that produces good initial results. + cparams.butteraugli_distance = + std::max(kMinButteraugliDistance, + enc_state->cparams.butteraugli_distance * 0.1f); + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.patches = Override::kOff; + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.max_error_mode = true; + for (size_t c = 0; c < 3; c++) { + cparams.max_error[c] = shared.quantizer.MulDC()[c]; + } + JXL_ASSERT(cparams.progressive_dc > 0); + cparams.progressive_dc--; + // The DC frame will have alpha=0. Don't erase its contents. + cparams.keep_invisible = true; + // No EPF or Gaborish in DC frames. + cparams.epf = 0; + cparams.gaborish = Override::kOff; + // Use kVarDCT in max_error_mode for intermediate progressive DC, + // and kModular for the smallest DC (first in the bitstream) + if (cparams.progressive_dc == 0) { + cparams.modular_mode = true; + cparams.quality_pair.first = cparams.quality_pair.second = + 99.f - enc_state->cparams.butteraugli_distance * 0.2f; + } + ImageBundle ib(&shared.metadata->m); + // This is a lie - dc is in XYB + // (but EncodeFrame will skip RGB->XYB conversion anyway) + ib.SetFromImage( + std::move(dc), + ColorEncoding::LinearSRGB(shared.metadata->m.color_encoding.IsGray())); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add dummy extra channels to the patch image: dc_level frames do not yet + // support extra channels, but the codec expects that the amount of extra + // channels in frames matches that in the metadata of the codestream. + std::vector extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + const auto& eci = ib.metadata()->extra_channel_info[i]; + extra_channels.emplace_back(eci.Size(ib.xsize()), eci.Size(ib.ysize())); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): dc_level must copy and use the real extra channels + // instead. + ZeroFillImage(&extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + std::unique_ptr state = + jxl::make_unique(); + + auto special_frame = std::unique_ptr(new BitWriter()); + FrameInfo dc_frame_info; + dc_frame_info.frame_type = FrameType::kDCFrame; + dc_frame_info.dc_level = shared.frame_header.dc_level + 1; + dc_frame_info.ib_needs_color_transform = false; + dc_frame_info.save_before_color_transform = true; // Implicitly true + // TODO(lode): the EncodeFrame / DecodeFrame pair here is likely broken in + // case of dc_level >= 3, since EncodeFrame may output multiple frames + // to the bitwriter, while DecodeFrame reads only one. + JXL_CHECK(EncodeFrame(cparams, dc_frame_info, shared.metadata, ib, + state.get(), pool, special_frame.get(), nullptr)); + const Span encoded = special_frame->GetSpan(); + enc_state->special_frames.emplace_back(std::move(special_frame)); + + BitReader br(encoded); + ImageBundle decoded(&shared.metadata->m); + std::unique_ptr dec_state = + jxl::make_unique(); + JXL_CHECK(dec_state->output_encoding_info.Set(shared.metadata->m)); + JXL_CHECK(DecodeFrame({}, dec_state.get(), pool, &br, &decoded, + *shared.metadata, /*constraints=*/nullptr)); + // TODO(lode): shared.frame_header.dc_level should be equal to + // dec_state.shared->frame_header.dc_level - 1 here, since above we set + // dc_frame_info.dc_level = shared.frame_header.dc_level + 1, and + // dc_frame_info.dc_level is used by EncodeFrame. However, if EncodeFrame + // outputs multiple frames, this assumption could be wrong. + shared.dc_storage = + CopyImage(dec_state->shared->dc_frames[shared.frame_header.dc_level]); + ZeroFillImage(&shared.quant_dc); + shared.dc = &shared.dc_storage; + JXL_CHECK(br.Close()); + } else { + auto compute_dc_coeffs = [&](int group_index, int /* thread */) { + modular_frame_encoder->AddVarDCTDC( + dc, group_index, + enc_state->cparams.butteraugli_distance >= 2.0f && + enc_state->cparams.speed_tier != SpeedTier::kFalcon, + enc_state); + }; + RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, ThreadPool::SkipInit(), + compute_dc_coeffs, "Compute DC coeffs"); + // TODO(veluca): this is only useful in tests and if inspection is enabled. + if (!(shared.frame_header.flags & FrameHeader::kSkipAdaptiveDCSmoothing)) { + AdaptiveDCSmoothing(shared.quantizer.MulDC(), &shared.dc_storage, pool); + } + } + auto compute_ac_meta = [&](int group_index, int /* thread */) { + modular_frame_encoder->AddACMetadata(group_index, /*jpeg_transcode=*/false, + enc_state); + }; + RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, ThreadPool::SkipInit(), + compute_ac_meta, "Compute AC Metadata"); + + if (aux_out != nullptr) { + aux_out->InspectImage3F("compressed_image:InitializeFrameEncCache:dc_dec", + shared.dc_storage); + } +} + +void EncCache::InitOnce() { + PROFILER_FUNC; + + if (num_nzeroes.xsize() == 0) { + num_nzeroes = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_cache.h b/third_party/jpeg-xl/lib/jxl/enc_cache.h new file mode 100644 index 000000000000..ba4710c48bde --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cache.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_CACHE_H_ +#define LIB_JXL_ENC_CACHE_H_ + +#include +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_heuristics.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/progressive_split.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// Contains encoder state. +struct PassesEncoderState { + PassesSharedState shared; + + ImageF initial_quant_field; // Invalid in Falcon mode. + ImageF initial_quant_masking; // Invalid in Falcon mode. + + // Per-pass DCT coefficients for the image. One row per group. + std::vector> coeffs; + + // Raw data for special (reference+DC) frames. + std::vector> special_frames; + + // For splitting into passes. + ProgressiveSplitter progressive_splitter; + + CompressParams cparams; + + struct PassData { + std::vector> ac_tokens; + std::vector context_map; + EntropyEncodingData codes; + }; + + std::vector passes; + std::vector histogram_idx; + + // Coefficient orders that are non-default. + std::vector used_orders; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_qm_multiplier = 1.0f; + float b_qm_multiplier = 1.0f; + + // Heuristics to be used by the encoder. + std::unique_ptr heuristics = + make_unique(); +}; + +// Initialize per-frame information. +class ModularFrameEncoder; +void InitializePassesEncoder(const Image3F& opsin, ThreadPool* pool, + PassesEncoderState* passes_enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out); + +// Working area for ComputeCoefficients (per-group!) +struct EncCache { + // Allocates memory when first called, shrinks images to current group size. + void InitOnce(); + + // TokenizeCoefficients + Image3I num_nzeroes; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_CACHE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc new file mode 100644 index 000000000000..4db68afd6098 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc @@ -0,0 +1,384 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_chroma_from_luma.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc" +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/quantizer.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +static HWY_FULL(float) df; + +struct CFLFunction { + static constexpr float kCoeff = 1.f / 3; + static constexpr float kThres = 100.0f; + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + CFLFunction(const float* values_m, const float* values_s, size_t num, + float base, float distance_mul) + : values_m(values_m), + values_s(values_s), + num(num), + base(base), + distance_mul(distance_mul) {} + + // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) + + // distance_mul * x^2 * num. + float Compute(float x, float eps, float* fpeps, float* fmeps) const { + float first_derivative = 2 * distance_mul * num * x; + float first_derivative_peps = 2 * distance_mul * num * (x + eps); + float first_derivative_meps = 2 * distance_mul * num * (x - eps); + + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto thres = Set(df, kThres); + const auto coeffx2 = Set(df, kCoeff * 2.0f); + const auto one = Set(df, 1.0f); + const auto zero = Set(df, 0.0f); + const auto base_v = Set(df, base); + const auto x_v = Set(df, x); + const auto xpe_v = Set(df, x + eps); + const auto xme_v = Set(df, x - eps); + auto fd_v = Zero(df); + auto fdpe_v = Zero(df); + auto fdme_v = Zero(df); + JXL_ASSERT(num % Lanes(df) == 0); + + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = inv_color_factor * Load(df, values_m + i); + const auto b = base_v * Load(df, values_m + i) - Load(df, values_s + i); + const auto v = a * x_v + b; + const auto vpe = a * xpe_v + b; + const auto vme = a * xme_v + b; + const auto av = Abs(v); + const auto avpe = Abs(vpe); + const auto avme = Abs(vme); + auto d = coeffx2 * (av + one) * a; + auto dpe = coeffx2 * (avpe + one) * a; + auto dme = coeffx2 * (avme + one) * a; + d = IfThenElse(v < zero, zero - d, d); + dpe = IfThenElse(vpe < zero, zero - dpe, dpe); + dme = IfThenElse(vme < zero, zero - dme, dme); + fd_v += IfThenElse(av >= thres, zero, d); + fdpe_v += IfThenElse(av >= thres, zero, dpe); + fdme_v += IfThenElse(av >= thres, zero, dme); + } + + *fpeps = first_derivative_peps + GetLane(SumOfLanes(fdpe_v)); + *fmeps = first_derivative_meps + GetLane(SumOfLanes(fdme_v)); + return first_derivative + GetLane(SumOfLanes(fd_v)); + } + + const float* JXL_RESTRICT values_m; + const float* JXL_RESTRICT values_s; + size_t num; + float base; + float distance_mul; +}; + +int32_t FindBestMultiplier(const float* values_m, const float* values_s, + size_t num, float base, float distance_mul, + bool fast) { + if (num == 0) { + return 0; + } + float x; + if (fast) { + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + auto ca = Zero(df); + auto cb = Zero(df); + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto base_v = Set(df, base); + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = inv_color_factor * Load(df, values_m + i); + const auto b = base_v * Load(df, values_m + i) - Load(df, values_s + i); + ca = MulAdd(a, a, ca); + cb = MulAdd(a, b, cb); + } + // + distance_mul * x^2 * num + x = -GetLane(SumOfLanes(cb)) / + (GetLane(SumOfLanes(ca)) + num * distance_mul * 0.5f); + } else { + constexpr float eps = 1; + constexpr float kClamp = 20.0f; + CFLFunction fn(values_m, values_s, num, base, distance_mul); + x = 0; + // Up to 20 Newton iterations, with approximate derivatives. + // Derivatives are approximate due to the high amount of noise in the exact + // derivatives. + for (size_t i = 0; i < 20; i++) { + float dfpeps, dfmeps; + float df = fn.Compute(x, eps, &dfpeps, &dfmeps); + float ddf = (dfpeps - dfmeps) / (2 * eps); + float step = df / ddf; + x -= std::min(kClamp, std::max(-kClamp, step)); + if (std::abs(step) < 3e-3) break; + } + } + return std::max(-128.0f, std::min(127.0f, roundf(x))); +} + +void InitDCStorage(size_t num_blocks, ImageF* dc_values) { + // First row: Y channel + // Second row: X channel + // Third row: Y channel + // Fourth row: B channel + *dc_values = ImageF(RoundUpTo(num_blocks, Lanes(df)), 4); + + JXL_ASSERT(dc_values->xsize() != 0); + // Zero-fill the last lanes + for (size_t y = 0; y < 4; y++) { + for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize(); + x++) { + dc_values->Row(y)[x] = 0; + } + } +} + +void ComputeDC(const ImageF& dc_values, bool fast, int* dc_x, int* dc_b) { + constexpr float kDistanceMultiplierDC = 1e-5f; + const float* JXL_RESTRICT dc_values_yx = dc_values.Row(0); + const float* JXL_RESTRICT dc_values_x = dc_values.Row(1); + const float* JXL_RESTRICT dc_values_yb = dc_values.Row(2); + const float* JXL_RESTRICT dc_values_b = dc_values.Row(3); + *dc_x = FindBestMultiplier(dc_values_yx, dc_values_x, dc_values.xsize(), 0.0f, + kDistanceMultiplierDC, fast); + *dc_b = FindBestMultiplier(dc_values_yb, dc_values_b, dc_values.xsize(), + kYToBRatio, kDistanceMultiplierDC, fast); +} + +void ComputeTile(const Image3F& opsin, const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, const Quantizer* quantizer, + const Rect& r, bool fast, bool use_dct8, ImageSB* map_x, + ImageSB* map_b, ImageF* dc_values, float* mem) { + static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks, + "Invalid color tile dim"); + size_t xsize_blocks = opsin.xsize() / kBlockDim; + constexpr float kDistanceMultiplierAC = 1e-3f; + + const size_t y0 = r.y0(); + const size_t x0 = r.x0(); + const size_t x1 = r.x0() + r.xsize(); + const size_t y1 = r.y0() + r.ysize(); + + int ty = y0 / kColorTileDimInBlocks; + int tx = x0 / kColorTileDimInBlocks; + + int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty); + int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty); + + float* JXL_RESTRICT dc_values_yx = dc_values->Row(0); + float* JXL_RESTRICT dc_values_x = dc_values->Row(1); + float* JXL_RESTRICT dc_values_yb = dc_values->Row(2); + float* JXL_RESTRICT dc_values_b = dc_values->Row(3); + + // All are aligned. + float* HWY_RESTRICT block_y = mem; + float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim; + JXL_DASSERT(scratch_space + 2 * AcStrategy::kMaxCoeffArea == + block_y + CfLHeuristics::kItemsPerThread); + + // Small (~256 bytes each) + HWY_ALIGN_MAX float + dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + size_t num_ac = 0; + + for (size_t y = y0; y < y1; ++y) { + const float* JXL_RESTRICT row_y = opsin.ConstPlaneRow(1, y * kBlockDim); + const float* JXL_RESTRICT row_x = opsin.ConstPlaneRow(0, y * kBlockDim); + const float* JXL_RESTRICT row_b = opsin.ConstPlaneRow(2, y * kBlockDim); + size_t stride = opsin.PixelsPerRow(); + + for (size_t x = x0; x < x1; x++) { + AcStrategy acs = use_dct8 + ? AcStrategy::FromRawStrategy(AcStrategy::Type::DCT) + : ac_strategy->ConstRow(y)[x]; + if (!acs.IsFirstBlock()) continue; + size_t xs = acs.covered_blocks_x(); + TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride, + block_y, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs); + TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride, + block_x, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs); + TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride, + block_b, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs); + const float* const JXL_RESTRICT qm_x = + dequant.InvMatrix(acs.Strategy(), 0); + const float* const JXL_RESTRICT qm_b = + dequant.InvMatrix(acs.Strategy(), 2); + // Why does a constant seem to work better than + // raw_quant_field->Row(y)[x] ? + float q = use_dct8 ? 1 : quantizer->Scale() * 400.0f; + float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0); + float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2); + + // Copy DCs in dc_values. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < xs; ix++) { + dc_values_yx[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_x; + dc_values_x[(iy + y) * xsize_blocks + ix + x] = + dc_x[iy * xs + ix] * q_dc_x; + dc_values_yb[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_b; + dc_values_b[(iy + y) * xsize_blocks + ix + x] = + dc_b[iy * xs + ix] * q_dc_b; + } + } + + // Do not use this block for computing AC CfL. + if (acs.covered_blocks_x() + x0 > x1 || + acs.covered_blocks_y() + y0 > y1) { + continue; + } + + // Copy AC coefficients in the local block. The order in which + // coefficients get stored does not matter. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + // Zero out LFs. This introduces terms in the optimization loop that + // don't affect the result, as they are all 0, but allow for simpler + // SIMDfication. + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + block_y[cx * kBlockDim * iy + ix] = 0; + block_x[cx * kBlockDim * iy + ix] = 0; + block_b[cx * kBlockDim * iy + ix] = 0; + } + } + const auto qv = Set(df, q); + for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) { + const auto b_y = Load(df, block_y + i); + const auto b_x = Load(df, block_x + i); + const auto b_b = Load(df, block_b + i); + const auto qqm_x = qv * Load(df, qm_x + i); + const auto qqm_b = qv * Load(df, qm_b + i); + Store(b_y * qqm_x, df, coeffs_yx + num_ac); + Store(b_x * qqm_x, df, coeffs_x + num_ac); + Store(b_y * qqm_b, df, coeffs_yb + num_ac); + Store(b_b * qqm_b, df, coeffs_b + num_ac); + num_ac += Lanes(df); + } + } + } + JXL_CHECK(num_ac % Lanes(df) == 0); + row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f, + kDistanceMultiplierAC, fast); + row_out_b[tx] = FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, kYToBRatio, + kDistanceMultiplierAC, fast); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InitDCStorage); +HWY_EXPORT(ComputeDC); +HWY_EXPORT(ComputeTile); + +void CfLHeuristics::Init(const Image3F& opsin) { + size_t xsize_blocks = opsin.xsize() / kBlockDim; + size_t ysize_blocks = opsin.ysize() / kBlockDim; + HWY_DYNAMIC_DISPATCH(InitDCStorage) + (xsize_blocks * ysize_blocks, &dc_values); +} + +void CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const Quantizer* quantizer, bool fast, + size_t thread, ColorCorrelationMap* cmap) { + bool use_dct8 = ac_strategy == nullptr; + HWY_DYNAMIC_DISPATCH(ComputeTile) + (opsin, dequant, ac_strategy, quantizer, r, fast, use_dct8, &cmap->ytox_map, + &cmap->ytob_map, &dc_values, mem.get() + thread * kItemsPerThread); +} + +void CfLHeuristics::ComputeDC(bool fast, ColorCorrelationMap* cmap) { + int32_t ytob_dc = 0; + int32_t ytox_dc = 0; + HWY_DYNAMIC_DISPATCH(ComputeDC)(dc_values, fast, &ytox_dc, &ytob_dc); + cmap->SetYToBDC(ytob_dc); + cmap->SetYToXDC(ytox_dc); +} + +void ColorCorrelationMapEncodeDC(ColorCorrelationMap* map, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + float color_factor = map->GetColorFactor(); + float base_correlation_x = map->GetBaseCorrelationX(); + float base_correlation_b = map->GetBaseCorrelationB(); + int32_t ytox_dc = map->GetYToXDC(); + int32_t ytob_dc = map->GetYToBDC(); + + BitWriter::Allotment allotment(writer, 1 + 2 * kBitsPerByte + 12 + 32); + if (ytox_dc == 0 && ytob_dc == 0 && color_factor == kDefaultColorFactor && + base_correlation_x == 0.0f && base_correlation_b == kYToBRatio) { + writer->Write(1, 1); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return; + } + writer->Write(1, 0); + JXL_CHECK(U32Coder::Write(kColorFactorDist, color_factor, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_x, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_b, writer)); + writer->Write(kBitsPerByte, ytox_dc - std::numeric_limits::min()); + writer->Write(kBitsPerByte, ytob_dc - std::numeric_limits::min()); + ReclaimAndCharge(writer, &allotment, layer, aux_out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h new file mode 100644 index 000000000000..ad17f484a5d4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h @@ -0,0 +1,76 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +void ColorCorrelationMapEncodeDC(ColorCorrelationMap* map, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +struct CfLHeuristics { + void Init(const Image3F& opsin); + + void PrepareForThreads(size_t num_threads) { + mem = hwy::AllocateAligned(num_threads * kItemsPerThread); + } + + void ComputeTile(const Rect& r, const Image3F& opsin, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const Quantizer* quantizer, bool fast, size_t thread, + ColorCorrelationMap* cmap); + + void ComputeDC(bool fast, ColorCorrelationMap* cmap); + + ImageF dc_values; + hwy::AlignedFreeUniquePtr mem; + + // Working set is too large for stack; allocate dynamically. + constexpr static size_t kItemsPerThread = + AcStrategy::kMaxCoeffArea * 3 // Blocks + + kColorTileDim * kColorTileDim * 4 // AC coeff storage + + AcStrategy::kMaxCoeffArea * 2; // Scratch space +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_cluster.cc b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc new file mode 100644 index 000000000000..eca6014305ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc @@ -0,0 +1,297 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_cluster.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/fast_math-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +template +V Entropy(V count, V inv_total, V total) { + const HWY_CAPPED(float, Histogram::kRounding) d; + const auto zero = Set(d, 0.0f); + return IfThenZeroElse(count == total, + zero - count * FastLog2f(d, inv_total * count)); +} + +void HistogramEntropy(const Histogram& a) { + a.entropy_ = 0.0f; + if (a.total_count_ == 0) return; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / a.total_count_); + auto entropy_lanes = Zero(df); + auto total = Set(df, a.total_count_); + + for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) { + const auto counts = LoadU(di, &a.data_[i]); + entropy_lanes += Entropy(ConvertTo(df, counts), inv_tot, total); + } + a.entropy_ += GetLane(SumOfLanes(entropy_lanes)); +} + +float HistogramDistance(const Histogram& a, const Histogram& b) { + if (a.total_count_ == 0 || b.total_count_ == 0) return 0; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_)); + auto distance_lanes = Zero(df); + auto total = Set(df, a.total_count_ + b.total_count_); + + for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size()); + i += Lanes(di)) { + const auto a_counts = + a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di); + const auto b_counts = + b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di); + const auto counts = ConvertTo(df, a_counts + b_counts); + distance_lanes += Entropy(counts, inv_tot, total); + } + const float total_distance = GetLane(SumOfLanes(distance_lanes)); + return total_distance - a.entropy_ - b.entropy_; +} + +// First step of a k-means clustering with a fancy distance metric. +void FastClusterHistograms(const std::vector& in, + const size_t num_contexts, size_t max_histograms, + float min_distance, std::vector* out, + std::vector* histogram_symbols) { + PROFILER_FUNC; + size_t largest_idx = 0; + for (size_t i = 0; i < num_contexts; i++) { + HistogramEntropy(in[i]); + if (in[i].total_count_ > in[largest_idx].total_count_) { + largest_idx = i; + } + } + out->clear(); + out->reserve(max_histograms); + std::vector dists(num_contexts, std::numeric_limits::max()); + histogram_symbols->clear(); + histogram_symbols->resize(num_contexts, max_histograms); + + while (out->size() < max_histograms && out->size() < num_contexts) { + (*histogram_symbols)[largest_idx] = out->size(); + out->push_back(in[largest_idx]); + largest_idx = 0; + for (size_t i = 0; i < num_contexts; i++) { + dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); + // Avoid repeating histograms + if ((*histogram_symbols)[i] != max_histograms) continue; + if (dists[i] > dists[largest_idx]) largest_idx = i; + } + if (dists[largest_idx] < min_distance) break; + } + + for (size_t i = 0; i < num_contexts; i++) { + if ((*histogram_symbols)[i] != max_histograms) continue; + size_t best = 0; + float best_dist = HistogramDistance(in[i], (*out)[best]); + for (size_t j = 1; j < out->size(); j++) { + float dist = HistogramDistance(in[i], (*out)[j]); + if (dist < best_dist) { + best = j; + best_dist = dist; + } + } + (*out)[best].AddHistogram(in[i]); + HistogramEntropy((*out)[best]); + (*histogram_symbols)[i] = best; + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(FastClusterHistograms); // Local function +HWY_EXPORT(HistogramEntropy); // Local function + +float Histogram::ShannonEntropy() const { + HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); + return entropy_; +} + +namespace { +// ----------------------------------------------------------------------------- +// Histogram refinement + +// Reorder histograms in *out so that the new symbols in *symbols come in +// increasing order. +void HistogramReindex(std::vector* out, + std::vector* symbols) { + std::vector tmp(*out); + std::map new_index; + int next_index = 0; + for (uint32_t symbol : *symbols) { + if (new_index.find(symbol) == new_index.end()) { + new_index[symbol] = next_index; + (*out)[next_index] = tmp[symbol]; + ++next_index; + } + } + out->resize(next_index); + for (uint32_t& symbol : *symbols) { + symbol = new_index[symbol]; + } +} + +} // namespace + +// Clusters similar histograms in 'in' together, the selected histograms are +// placed in 'out', and for each index in 'in', *histogram_symbols will +// indicate which of the 'out' histograms is the best approximation. +void ClusterHistograms(const HistogramParams params, + const std::vector& in, + const size_t num_contexts, size_t max_histograms, + std::vector* out, + std::vector* histogram_symbols) { + constexpr float kMinDistanceForDistinctFast = 64.0f; + constexpr float kMinDistanceForDistinctBest = 16.0f; + max_histograms = std::min(max_histograms, params.max_histograms); + if (params.clustering == HistogramParams::ClusteringType::kFastest) { + HWY_DYNAMIC_DISPATCH(FastClusterHistograms) + (in, num_contexts, 4, kMinDistanceForDistinctFast, out, histogram_symbols); + } else if (params.clustering == HistogramParams::ClusteringType::kFast) { + HWY_DYNAMIC_DISPATCH(FastClusterHistograms) + (in, num_contexts, max_histograms, kMinDistanceForDistinctFast, out, + histogram_symbols); + } else { + PROFILER_FUNC; + HWY_DYNAMIC_DISPATCH(FastClusterHistograms) + (in, num_contexts, max_histograms, kMinDistanceForDistinctBest, out, + histogram_symbols); + for (size_t i = 0; i < out->size(); i++) { + (*out)[i].entropy_ = + ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size()); + } + uint32_t next_version = 2; + std::vector version(out->size(), 1); + std::vector renumbering(out->size()); + std::iota(renumbering.begin(), renumbering.end(), 0); + + // Try to pair up clusters if doing so reduces the total cost. + + struct HistogramPair { + // validity of a pair: p.version == max(version[i], version[j]) + float cost; + uint32_t first; + uint32_t second; + uint32_t version; + // We use > because priority queues sort in *decreasing* order, but we + // want lower cost elements to appear first. + bool operator<(const HistogramPair& other) const { + return std::make_tuple(cost, first, second, version) > + std::make_tuple(other.cost, other.first, other.second, + other.version); + } + }; + + // Create list of all pairs by increasing merging cost. + std::priority_queue pairs_to_merge; + for (uint32_t i = 0; i < out->size(); i++) { + for (uint32_t j = i + 1; j < out->size(); j++) { + Histogram histo; + histo.AddHistogram((*out)[i]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[i].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, i, j, std::max(version[i], version[j])}); + } + } + + // Merge the best pair to merge, add new pairs that get formed as a + // consequence. + while (!pairs_to_merge.empty()) { + uint32_t first = pairs_to_merge.top().first; + uint32_t second = pairs_to_merge.top().second; + uint32_t ver = pairs_to_merge.top().version; + pairs_to_merge.pop(); + if (ver != std::max(version[first], version[second]) || + version[first] == 0 || version[second] == 0) { + continue; + } + (*out)[first].AddHistogram((*out)[second]); + (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(), + (*out)[first].data_.size()); + for (size_t i = 0; i < renumbering.size(); i++) { + if (renumbering[i] == second) { + renumbering[i] = first; + } + } + version[second] = 0; + version[first] = next_version++; + for (uint32_t j = 0; j < out->size(); j++) { + if (j == first) continue; + if (version[j] == 0) continue; + Histogram histo; + histo.AddHistogram((*out)[first]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[first].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, std::min(first, j), std::max(first, j), + std::max(version[first], version[j])}); + } + } + std::vector reverse_renumbering(out->size(), -1); + size_t num_alive = 0; + for (size_t i = 0; i < out->size(); i++) { + if (version[i] == 0) continue; + (*out)[num_alive++] = (*out)[i]; + reverse_renumbering[i] = num_alive - 1; + } + out->resize(num_alive); + for (size_t i = 0; i < histogram_symbols->size(); i++) { + (*histogram_symbols)[i] = + reverse_renumbering[renumbering[(*histogram_symbols)[i]]]; + } + } + + // Convert the context map to a canonical form. + HistogramReindex(out, histogram_symbols); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_cluster.h b/third_party/jpeg-xl/lib/jxl/enc_cluster.h new file mode 100644 index 000000000000..06937d9938be --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cluster.h @@ -0,0 +1,70 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functions for clustering similar histograms together. + +#ifndef LIB_JXL_ENC_CLUSTER_H_ +#define LIB_JXL_ENC_CLUSTER_H_ + +#include +#include +#include + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/enc_ans.h" + +namespace jxl { + +struct Histogram { + Histogram() { total_count_ = 0; } + void Clear() { + data_.clear(); + total_count_ = 0; + } + void Add(size_t symbol) { + if (data_.size() <= symbol) { + data_.resize(DivCeil(symbol + 1, kRounding) * kRounding); + } + ++data_[symbol]; + ++total_count_; + } + void AddHistogram(const Histogram& other) { + if (other.data_.size() > data_.size()) { + data_.resize(other.data_.size()); + } + for (size_t i = 0; i < other.data_.size(); ++i) { + data_[i] += other.data_[i]; + } + total_count_ += other.total_count_; + } + float PopulationCost() const { + return ANSPopulationCost(data_.data(), data_.size()); + } + float ShannonEntropy() const; + + std::vector data_; + size_t total_count_; + mutable float entropy_; // WARNING: not kept up-to-date. + static constexpr size_t kRounding = 8; +}; + +void ClusterHistograms(HistogramParams params, const std::vector& in, + size_t num_contexts, size_t max_histograms, + std::vector* out, + std::vector* histogram_symbols); +} // namespace jxl + +#endif // LIB_JXL_ENC_CLUSTER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc new file mode 100644 index 000000000000..9399ba253a56 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc @@ -0,0 +1,283 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +uint32_t ComputeUsedOrders(const SpeedTier speed, + const AcStrategyImage& ac_strategy, + const Rect& rect) { + // Use default orders for small images. + if (ac_strategy.xsize() < 5 && ac_strategy.ysize() < 5) return 0; + + // Only uses DCT8 = 0, so bitfield = 1. + if (speed == SpeedTier::kFalcon) return 1; + + uint32_t ret = 0; + size_t xsize_blocks = rect.xsize(); + size_t ysize_blocks = rect.ysize(); + // TODO(veluca): precompute when doing DCT. + for (size_t by = 0; by < ysize_blocks; ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + int ord = kStrategyOrder[acs_row[bx].RawStrategy()]; + // Do not customize coefficient orders for blocks bigger than 32x32. + if (ord > 6) { + continue; + } + ret |= 1u << ord; + } + } + return ret; +} + +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, uint32_t& used_orders, + coeff_order_t* JXL_RESTRICT order) { + std::vector num_zeros(kCoeffOrderMaxSize); + // If compressing at high speed and only using 8x8 DCTs, only consider a + // subset of blocks. + double block_fraction = 1.0f; + // TODO(veluca): figure out why sampling blocks if non-8x8s are used makes + // encoding significantly less dense. + if (speed >= SpeedTier::kSquirrel && used_orders == 1) { + block_fraction = 0.5f; + } + // No need to compute number of zero coefficients if all orders are the + // default. + if (used_orders != 0) { + uint64_t threshold = + (std::numeric_limits::max() >> 32) * block_fraction; + uint64_t s[2] = {0x94D049BB133111EBull, 0xBF58476D1CE4E5B9ull}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + // Count number of zero coefficients, separately for each DCT band. + // TODO(veluca): precompute when doing DCT. + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * kGroupDimInBlocks, gy * kGroupDimInBlocks, + kGroupDimInBlocks, kGroupDimInBlocks, + frame_dim.xsize_blocks, frame_dim.ysize_blocks); + ConstACPtr rows[3]; + ACType type = acs.Type(); + for (size_t c = 0; c < 3; c++) { + rows[c] = acs.PlaneRow(c, group_index, 0); + } + size_t ac_offset = 0; + + // TODO(veluca): SIMDfy. + for (size_t by = 0; by < rect.ysize(); ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < rect.xsize(); ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + if (!use_sample()) continue; + size_t size = kDCTBlockSize << acs.log2_covered_blocks(); + for (size_t c = 0; c < 3; ++c) { + const size_t order_offset = + CoeffOrderOffset(kStrategyOrder[acs.RawStrategy()], c); + if (type == ACType::k16) { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr16[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } else { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr32[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } + // Ensure LLFs are first in the order. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + num_zeros[order_offset + iy * kBlockDim * cx + ix] = -1; + } + } + } + ac_offset += size; + } + } + } + } + struct PosAndCount { + uint32_t pos; + uint32_t count; + }; + auto mem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + + uint16_t computed = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + size_t sz = kDCTBlockSize * acs.covered_blocks_x() * acs.covered_blocks_y(); + // Ensure natural coefficient order is not permuted if the order is + // not transmitted. + if ((1 << ord) & ~used_orders) { + for (size_t c = 0; c < 3; c++) { + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + SetDefaultOrder(AcStrategy::FromRawStrategy(o), &order[offset]); + } + continue; + } + const coeff_order_t* natural_coeff_order = acs.NaturalCoeffOrder(); + + bool is_nondefault = false; + for (uint8_t c = 0; c < 3; c++) { + // Apply zig-zag order. + PosAndCount* pos_and_val = mem.get(); + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + float inv_sqrt_sz = 1.0f / std::sqrt(sz); + for (size_t i = 0; i < sz; ++i) { + size_t pos = natural_coeff_order[i]; + pos_and_val[i].pos = pos; + // We don't care for the exact number -> quantize number of zeros, + // to get less permuted order. + pos_and_val[i].count = num_zeros[offset + pos] * inv_sqrt_sz + 0.1f; + } + + // Stable-sort -> elements with same number of zeros will preserve their + // order. + auto comparator = [](const PosAndCount& a, const PosAndCount& b) -> bool { + return a.count < b.count; + }; + std::stable_sort(pos_and_val, pos_and_val + sz, comparator); + + // Grab indices. + for (size_t i = 0; i < sz; ++i) { + order[offset + i] = pos_and_val[i].pos; + is_nondefault |= natural_coeff_order[i] != pos_and_val[i].pos; + } + } + if (!is_nondefault) { + used_orders &= ~(1 << ord); + } + } +} + +namespace { + +void TokenizePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, std::vector* tokens) { + std::vector lehmer(size); + std::vector temp(size + 1); + ComputeLehmerCode(order, temp.data(), size, lehmer.data()); + size_t end = size; + while (end > skip && lehmer[end - 1] == 0) { + --end; + } + tokens->emplace_back(CoeffOrderContext(size), end - skip); + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + tokens->emplace_back(CoeffOrderContext(last), lehmer[i]); + last = lehmer[i]; + } +} + +} // namespace + +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out) { + std::vector> tokens(1); + TokenizePermutation(order, skip, size, &tokens[0]); + std::vector context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +namespace { +void EncodeCoeffOrder(const coeff_order_t* JXL_RESTRICT order, AcStrategy acs, + std::vector* tokens, coeff_order_t* order_zigzag) { + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + const coeff_order_t* natural_coeff_order_lut = acs.NaturalCoeffOrderLut(); + for (size_t i = 0; i < size; ++i) { + order_zigzag[i] = natural_coeff_order_lut[order[i]]; + } + TokenizePermutation(order_zigzag, llf, size, tokens); +} +} // namespace + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out) { + auto mem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); + uint16_t computed = 0; + std::vector> tokens(1); + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + if ((used_orders & (1 << ord)) == 0) continue; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + for (size_t c = 0; c < 3; c++) { + EncodeCoeffOrder(&order[CoeffOrderOffset(ord, c)], acs, &tokens[0], + mem.get()); + } + } + // Do not write anything if no order is used. + if (used_orders != 0) { + std::vector context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h new file mode 100644 index 000000000000..67692b4f9ead --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_COEFF_ORDER_H_ +#define LIB_JXL_ENC_COEFF_ORDER_H_ + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" + +namespace jxl { + +// Orders that are actually used in part of image. `rect` is in block units. +uint32_t ComputeUsedOrders(SpeedTier speed, const AcStrategyImage& ac_strategy, + const Rect& rect); + +// Modify zig-zag order, so that DCT bands with more zeros go later. +// Order of DCT bands with same number of zeros is untouched, so +// permutation will be cheaper to encode. +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, uint32_t& used_orders, + coeff_order_t* JXL_RESTRICT order); + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out); + +// Encoding/decoding of a single permutation. `size`: number of elements in the +// permutation. `skip`: number of elements to skip from the *beginning* of the +// permutation. +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COEFF_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_color_management.cc b/third_party/jpeg-xl/lib/jxl/enc_color_management.cc new file mode 100644 index 000000000000..73f89ecd7a26 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_color_management.cc @@ -0,0 +1,907 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defined by build system; this avoids IDE warnings. Must come before +// color_management.h (affects header definitions). +#ifndef JPEGXL_ENABLE_SKCMS +#define JPEGXL_ENABLE_SKCMS 0 +#endif + +#include "lib/jxl/enc_color_management.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_color_management.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/transfer_functions-inl.h" +#if JPEGXL_ENABLE_SKCMS +#include "skcms.h" +#else // JPEGXL_ENABLE_SKCMS +#include "lcms2.h" +#include "lcms2_plugin.h" +#endif // JPEGXL_ENABLE_SKCMS + +#define JXL_CMS_VERBOSE 0 + +// Define these only once. We can't use HWY_ONCE here because it is defined as +// 1 only on the last pass. +#ifndef LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ +#define LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ + +namespace jxl { +#if JPEGXL_ENABLE_SKCMS +struct ColorSpaceTransform::SkcmsICC { + // Parsed skcms_ICCProfiles retain pointers to the original data. + PaddedBytes icc_src_, icc_dst_; + skcms_ICCProfile profile_src_, profile_dst_; +}; +#endif // JPEGXL_ENABLE_SKCMS +} // namespace jxl + +#endif // LIB_JXL_ENC_COLOR_MANAGEMENT_CC_ + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +#if JXL_CMS_VERBOSE >= 2 +const size_t kX = 0; // pixel index, multiplied by 3 for RGB +#endif + +// xform_src = UndoGammaCompression(buf_src). +void BeforeTransform(ColorSpaceTransform* t, const float* buf_src, + float* xform_src) { + switch (t->preprocess_) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + + case ExtraTF::kPQ: { + // By default, PQ content has an intensity target of 10000, stored + // exactly. + HWY_FULL(float) df; + const auto multiplier = Set(df, t->intensity_target_ == 10000.f + ? 1.0f + : 10000.f / t->intensity_target_); + for (size_t i = 0; i < t->buf_src_.xsize(); i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = multiplier * TF_PQ().DisplayFromEncoded(df, val); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoPQ %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } + + case ExtraTF::kHLG: + for (size_t i = 0; i < t->buf_src_.xsize(); ++i) { + xform_src[i] = static_cast( + TF_HLG().DisplayFromEncoded(static_cast(buf_src[i]))); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoHLG %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < t->buf_src_.xsize(); i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = TF_SRGB().DisplayFromEncoded(val); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoSRGB %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } +} + +// Applies gamma compression in-place. +void AfterTransform(ColorSpaceTransform* t, float* JXL_RESTRICT buf_dst) { + switch (t->postprocess_) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + case ExtraTF::kPQ: { + HWY_FULL(float) df; + const auto multiplier = Set(df, t->intensity_target_ == 10000.f + ? 1.0f + : t->intensity_target_ * 1e-4f); + for (size_t i = 0; i < t->buf_dst_.xsize(); i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = TF_PQ().EncodedFromDisplay(df, multiplier * val); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after PQ enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } + case ExtraTF::kHLG: + for (size_t i = 0; i < t->buf_dst_.xsize(); ++i) { + buf_dst[i] = static_cast( + TF_HLG().EncodedFromDisplay(static_cast(buf_dst[i]))); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after HLG enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < t->buf_dst_.xsize(); i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = + TF_SRGB().EncodedFromDisplay(HWY_FULL(float)(), val); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after SRGB enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } +} + +void DoColorSpaceTransform(ColorSpaceTransform* t, const size_t thread, + const float* buf_src, float* buf_dst) { + // No lock needed. + + float* xform_src = const_cast(buf_src); // Read-only. + if (t->preprocess_ != ExtraTF::kNone) { + xform_src = t->buf_src_.Row(thread); // Writable buffer. + BeforeTransform(t, buf_src, xform_src); + } + +#if JXL_CMS_VERBOSE >= 2 + // Save inputs for printing before in-place transforms overwrite them. + const float in0 = xform_src[3 * kX + 0]; + const float in1 = xform_src[3 * kX + 1]; + const float in2 = xform_src[3 * kX + 2]; +#endif + + if (t->skip_lcms_) { + if (buf_dst != xform_src) { + memcpy(buf_dst, xform_src, t->buf_dst_.xsize() * sizeof(*buf_dst)); + } // else: in-place, no need to copy + } else { +#if JPEGXL_ENABLE_SKCMS + JXL_CHECK(skcms_Transform( + xform_src, skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &t->skcms_icc_->profile_src_, buf_dst, skcms_PixelFormat_RGB_fff, + skcms_AlphaFormat_Opaque, &t->skcms_icc_->profile_dst_, t->xsize_)); +#else // JPEGXL_ENABLE_SKCMS + JXL_DASSERT(thread < t->transforms_.size()); + cmsHTRANSFORM xform = t->transforms_[thread]; + cmsDoTransform(xform, xform_src, buf_dst, + static_cast(t->xsize_)); +#endif // JPEGXL_ENABLE_SKCMS + } +#if JXL_CMS_VERBOSE >= 2 + printf("xform skip%d: %.4f %.4f %.4f (%p) -> (%p) %.4f %.4f %.4f\n", + t->skip_lcms_, in0, in1, in2, xform_src, buf_dst, buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + + if (t->postprocess_ != ExtraTF::kNone) { + AfterTransform(t, buf_dst); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(DoColorSpaceTransform); +void DoColorSpaceTransform(ColorSpaceTransform* t, size_t thread, + const float* buf_src, float* buf_dst) { + return HWY_DYNAMIC_DISPATCH(DoColorSpaceTransform)(t, thread, buf_src, + buf_dst); +} + +namespace { + +// Define to 1 on OS X as a workaround for older LCMS lacking MD5. +#define JXL_CMS_OLD_VERSION 0 + +// cms functions (even *THR) are not thread-safe, except cmsDoTransform. +// To ensure all functions are covered without frequent lock-taking nor risk of +// recursive lock, we lock in the top-level APIs. +static std::mutex& LcmsMutex() { + static std::mutex m; + return m; +} + +#if JPEGXL_ENABLE_SKCMS + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const float XYZ[3]) { + const float factor = 1.f / (XYZ[0] + XYZ[1] + XYZ[2]); + CIExy xy; + xy.x = XYZ[0] * factor; + xy.y = XYZ[1] * factor; + return xy; +} + +#else // JPEGXL_ENABLE_SKCMS +// (LCMS interface requires xyY but we omit the Y for white points/primaries.) + +JXL_MUST_USE_RESULT CIExy CIExyFromxyY(const cmsCIExyY& xyY) { + CIExy xy; + xy.x = xyY.x; + xy.y = xyY.y; + return xy; +} + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const cmsCIEXYZ& XYZ) { + cmsCIExyY xyY; + cmsXYZ2xyY(/*Dest=*/&xyY, /*Source=*/&XYZ); + return CIExyFromxyY(xyY); +} + +JXL_MUST_USE_RESULT cmsCIEXYZ D50_XYZ() { + // Quantized D50 as stored in ICC profiles. + return {0.96420288, 1.0, 0.82490540}; +} + +JXL_MUST_USE_RESULT cmsCIExyY xyYFromCIExy(const CIExy& xy) { + const cmsCIExyY xyY = {xy.x, xy.y, 1.0}; + return xyY; +} + +// RAII + +struct ProfileDeleter { + void operator()(void* p) { cmsCloseProfile(p); } +}; +using Profile = std::unique_ptr; + +struct TransformDeleter { + void operator()(void* p) { cmsDeleteTransform(p); } +}; +using Transform = std::unique_ptr; + +struct CurveDeleter { + void operator()(cmsToneCurve* p) { cmsFreeToneCurve(p); } +}; +using Curve = std::unique_ptr; + +Status CreateProfileXYZ(const cmsContext context, + Profile* JXL_RESTRICT profile) { + profile->reset(cmsCreateXYZProfileTHR(context)); + if (profile->get() == nullptr) return JXL_FAILURE("Failed to create XYZ"); + return true; +} + +#endif // !JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS +// IMPORTANT: icc must outlive profile. +Status DecodeProfile(const PaddedBytes& icc, skcms_ICCProfile* const profile) { + if (!skcms_Parse(icc.data(), icc.size(), profile)) { + return JXL_FAILURE("Failed to parse ICC profile with %zu bytes", + icc.size()); + } + return true; +} +#else // JPEGXL_ENABLE_SKCMS +Status DecodeProfile(const cmsContext context, const PaddedBytes& icc, + Profile* profile) { + profile->reset(cmsOpenProfileFromMemTHR(context, icc.data(), icc.size())); + if (profile->get() == nullptr) { + return JXL_FAILURE("Failed to decode profile"); + } + + // WARNING: due to the LCMS MD5 issue mentioned above, many existing + // profiles have incorrect MD5, so do not even bother checking them nor + // generating warning clutter. + + return true; +} +#endif // JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS + +ColorSpace ColorSpaceFromProfile(const skcms_ICCProfile& profile) { + switch (profile.data_color_space) { + case skcms_Signature_RGB: + return ColorSpace::kRGB; + case skcms_Signature_Gray: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// "profile1" is pre-decoded to save time in DetectTransferFunction. +Status ProfileEquivalentToICC(const skcms_ICCProfile& profile1, + const PaddedBytes& icc) { + skcms_ICCProfile profile2; + JXL_RETURN_IF_ERROR(skcms_Parse(icc.data(), icc.size(), &profile2)); + return skcms_ApproximatelyEqualProfiles(&profile1, &profile2); +} + +// vector_out := matmul(matrix, vector_in) +void MatrixProduct(const skcms_Matrix3x3& matrix, const float vector_in[3], + float vector_out[3]) { + for (int i = 0; i < 3; ++i) { + vector_out[i] = 0; + for (int j = 0; j < 3; ++j) { + vector_out[i] += matrix.vals[i][j] * vector_in[j]; + } + } +} + +// Returns white point that was specified when creating the profile. +JXL_MUST_USE_RESULT Status UnadaptedWhitePoint(const skcms_ICCProfile& profile, + CIExy* out) { + float media_white_point_XYZ[3]; + if (!skcms_GetWTPT(&profile, media_white_point_XYZ)) { + return JXL_FAILURE("ICC profile does not contain WhitePoint tag"); + } + skcms_Matrix3x3 CHAD; + if (!skcms_GetCHAD(&profile, &CHAD)) { + // If there is no chromatic adaptation matrix, it means that the white point + // is already unadapted. + *out = CIExyFromXYZ(media_white_point_XYZ); + return true; + } + // Otherwise, it has been adapted to the PCS white point using said matrix, + // and the adaptation needs to be undone. + skcms_Matrix3x3 inverse_CHAD; + if (!skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)) { + return JXL_FAILURE("Non-invertible ChromaticAdaptation matrix"); + } + float unadapted_white_point_XYZ[3]; + MatrixProduct(inverse_CHAD, media_white_point_XYZ, unadapted_white_point_XYZ); + *out = CIExyFromXYZ(unadapted_white_point_XYZ); + return true; +} + +Status IdentifyPrimaries(const skcms_ICCProfile& profile, + const CIExy& wp_unadapted, ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + + skcms_Matrix3x3 CHAD, inverse_CHAD; + if (skcms_GetCHAD(&profile, &CHAD)) { + JXL_RETURN_IF_ERROR(skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)); + } else { + static constexpr skcms_Matrix3x3 kLMSFromXYZ = { + {{0.8951, 0.2664, -0.1614}, + {-0.7502, 1.7135, 0.0367}, + {0.0389, -0.0685, 1.0296}}}; + static constexpr skcms_Matrix3x3 kXYZFromLMS = { + {{0.9869929, -0.1470543, 0.1599627}, + {0.4323053, 0.5183603, 0.0492912}, + {-0.0085287, 0.0400428, 0.9684867}}}; + static constexpr float kWpD50XYZ[3] = {0.96420288, 1.0, 0.82490540}; + float wp_unadapted_XYZ[3]; + JXL_RETURN_IF_ERROR(CIEXYZFromWhiteCIExy(wp_unadapted, wp_unadapted_XYZ)); + float wp_D50_LMS[3], wp_unadapted_LMS[3]; + MatrixProduct(kLMSFromXYZ, kWpD50XYZ, wp_D50_LMS); + MatrixProduct(kLMSFromXYZ, wp_unadapted_XYZ, wp_unadapted_LMS); + inverse_CHAD = {{{wp_unadapted_LMS[0] / wp_D50_LMS[0], 0, 0}, + {0, wp_unadapted_LMS[1] / wp_D50_LMS[1], 0}, + {0, 0, wp_unadapted_LMS[2] / wp_D50_LMS[2]}}}; + inverse_CHAD = skcms_Matrix3x3_concat(&kXYZFromLMS, &inverse_CHAD); + inverse_CHAD = skcms_Matrix3x3_concat(&inverse_CHAD, &kLMSFromXYZ); + } + + float XYZ[3]; + PrimariesCIExy primaries; + CIExy* const chromaticities[] = {&primaries.r, &primaries.g, &primaries.b}; + for (int i = 0; i < 3; ++i) { + float RGB[3] = {}; + RGB[i] = 1; + skcms_Transform(RGB, skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &profile, XYZ, skcms_PixelFormat_RGB_fff, + skcms_AlphaFormat_Opaque, skcms_XYZD50_profile(), 1); + float unadapted_XYZ[3]; + MatrixProduct(inverse_CHAD, XYZ, unadapted_XYZ); + *chromaticities[i] = CIExyFromXYZ(unadapted_XYZ); + } + return c->SetPrimaries(primaries); +} + +void DetectTransferFunction(const skcms_ICCProfile& profile, + ColorEncoding* JXL_RESTRICT c) { + if (c->tf.SetImplicit()) return; + + for (TransferFunction tf : Values()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + + c->tf.SetTransferFunction(tf); + + skcms_ICCProfile profile_test; + PaddedBytes bytes; + if (MaybeCreateProfile(*c, &bytes) && DecodeProfile(bytes, &profile_test) && + skcms_ApproximatelyEqualProfiles(&profile, &profile_test)) { + return; + } + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +#else // JPEGXL_ENABLE_SKCMS + +uint32_t Type32(const ColorEncoding& c) { + if (c.IsGray()) return TYPE_GRAY_FLT; + return TYPE_RGB_FLT; +} + +uint32_t Type64(const ColorEncoding& c) { + if (c.IsGray()) return TYPE_GRAY_DBL; + return TYPE_RGB_DBL; +} + +ColorSpace ColorSpaceFromProfile(const Profile& profile) { + switch (cmsGetColorSpace(profile.get())) { + case cmsSigRgbData: + return ColorSpace::kRGB; + case cmsSigGrayData: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// "profile1" is pre-decoded to save time in DetectTransferFunction. +Status ProfileEquivalentToICC(const cmsContext context, const Profile& profile1, + const PaddedBytes& icc, const ColorEncoding& c) { + const uint32_t type_src = Type64(c); + + Profile profile2; + JXL_RETURN_IF_ERROR(DecodeProfile(context, icc, &profile2)); + + Profile profile_xyz; + JXL_RETURN_IF_ERROR(CreateProfileXYZ(context, &profile_xyz)); + + const uint32_t intent = INTENT_RELATIVE_COLORIMETRIC; + const uint32_t flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_BLACKPOINTCOMPENSATION | + cmsFLAGS_HIGHRESPRECALC; + Transform xform1(cmsCreateTransformTHR(context, profile1.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + Transform xform2(cmsCreateTransformTHR(context, profile2.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + if (xform1 == nullptr || xform2 == nullptr) { + return JXL_FAILURE("Failed to create transform"); + } + + double in[3]; + double out1[3]; + double out2[3]; + + // Uniformly spaced samples from very dark to almost fully bright. + const double init = 1E-3; + const double step = 0.2; + + if (c.IsGray()) { + // Finer sampling and replicate each component. + for (in[0] = init; in[0] < 1.0; in[0] += step / 8) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + if (!ApproxEq(out1[0], out2[0], 2E-4)) { + return false; + } + } + } else { + for (in[0] = init; in[0] < 1.0; in[0] += step) { + for (in[1] = init; in[1] < 1.0; in[1] += step) { + for (in[2] = init; in[2] < 1.0; in[2] += step) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + for (size_t i = 0; i < 3; ++i) { + if (!ApproxEq(out1[i], out2[i], 2E-4)) { + return false; + } + } + } + } + } + } + + return true; +} + +// Returns white point that was specified when creating the profile. +// NOTE: we can't just use cmsSigMediaWhitePointTag because its interpretation +// differs between ICC versions. +JXL_MUST_USE_RESULT cmsCIEXYZ UnadaptedWhitePoint(const cmsContext context, + const Profile& profile, + const ColorEncoding& c) { + cmsCIEXYZ XYZ = {1.0, 1.0, 1.0}; + + Profile profile_xyz; + if (!CreateProfileXYZ(context, &profile_xyz)) return XYZ; + // Array arguments are one per profile. + cmsHPROFILE profiles[2] = {profile.get(), profile_xyz.get()}; + // Leave white point unchanged - that is what we're trying to extract. + cmsUInt32Number intents[2] = {INTENT_ABSOLUTE_COLORIMETRIC, + INTENT_ABSOLUTE_COLORIMETRIC}; + cmsBool black_compensation[2] = {0, 0}; + cmsFloat64Number adaption[2] = {0.0, 0.0}; + // Only transforming a single pixel, so skip expensive optimizations. + cmsUInt32Number flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_HIGHRESPRECALC; + Transform xform(cmsCreateExtendedTransform( + context, 2, profiles, black_compensation, intents, adaption, nullptr, 0, + Type64(c), TYPE_XYZ_DBL, flags)); + if (!xform) return XYZ; // TODO(lode): return error + + // xy are relative, so magnitude does not matter if we ignore output Y. + const cmsFloat64Number in[3] = {1.0, 1.0, 1.0}; + cmsDoTransform(xform.get(), in, &XYZ.X, 1); + return XYZ; +} + +Status IdentifyPrimaries(const Profile& profile, const cmsCIEXYZ& wp_unadapted, + ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + if (ColorSpaceFromProfile(profile) == ColorSpace::kUnknown) return true; + + // These were adapted to the profile illuminant before storing in the profile. + const cmsCIEXYZ* adapted_r = static_cast( + cmsReadTag(profile.get(), cmsSigRedColorantTag)); + const cmsCIEXYZ* adapted_g = static_cast( + cmsReadTag(profile.get(), cmsSigGreenColorantTag)); + const cmsCIEXYZ* adapted_b = static_cast( + cmsReadTag(profile.get(), cmsSigBlueColorantTag)); + if (adapted_r == nullptr || adapted_g == nullptr || adapted_b == nullptr) { + return JXL_FAILURE("Failed to retrieve colorants"); + } + + // TODO(janwas): no longer assume Bradford and D50. + // Undo the chromatic adaptation. + const cmsCIEXYZ d50 = D50_XYZ(); + + cmsCIEXYZ r, g, b; + cmsAdaptToIlluminant(&r, &d50, &wp_unadapted, adapted_r); + cmsAdaptToIlluminant(&g, &d50, &wp_unadapted, adapted_g); + cmsAdaptToIlluminant(&b, &d50, &wp_unadapted, adapted_b); + + const PrimariesCIExy rgb = {CIExyFromXYZ(r), CIExyFromXYZ(g), + CIExyFromXYZ(b)}; + return c->SetPrimaries(rgb); +} + +void DetectTransferFunction(const cmsContext context, const Profile& profile, + ColorEncoding* JXL_RESTRICT c) { + if (c->tf.SetImplicit()) return; + + for (TransferFunction tf : Values()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + + c->tf.SetTransferFunction(tf); + + PaddedBytes icc_test; + if (MaybeCreateProfile(*c, &icc_test) && + ProfileEquivalentToICC(context, profile, icc_test, *c)) { + return; + } + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +void ErrorHandler(cmsContext context, cmsUInt32Number code, const char* text) { + JXL_WARNING("LCMS error %u: %s", code, text); +} + +// Returns a context for the current thread, creating it if necessary. +cmsContext GetContext() { + static thread_local void* context_; + if (context_ == nullptr) { + context_ = cmsCreateContext(nullptr, nullptr); + JXL_ASSERT(context_ != nullptr); + + cmsSetLogErrorHandlerTHR(static_cast(context_), &ErrorHandler); + } + return static_cast(context_); +} + +#endif // JPEGXL_ENABLE_SKCMS + +} // namespace + +// All functions that call lcms directly (except ColorSpaceTransform::Run) must +// lock LcmsMutex(). + +Status ColorEncoding::SetFieldsFromICC() { + // In case parsing fails, mark the ColorEncoding as invalid. + SetColorSpace(ColorSpace::kUnknown); + tf.SetTransferFunction(TransferFunction::kUnknown); + + if (icc_.empty()) return JXL_FAILURE("Empty ICC profile"); + +#if JPEGXL_ENABLE_SKCMS + if (icc_.size() < 128) { + return JXL_FAILURE("ICC file too small"); + } + + skcms_ICCProfile profile; + JXL_RETURN_IF_ERROR(skcms_Parse(icc_.data(), icc_.size(), &profile)); + + // skcms does not return the rendering intent, so get it from the file. It + // is encoded as big-endian 32-bit integer in bytes 60..63. + uint32_t rendering_intent32 = icc_[67]; + if (rendering_intent32 > 3 || icc_[64] != 0 || icc_[65] != 0 || + icc_[66] != 0) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + + SetColorSpace(ColorSpaceFromProfile(profile)); + + CIExy wp_unadapted; + JXL_RETURN_IF_ERROR(UnadaptedWhitePoint(profile, &wp_unadapted)); + JXL_RETURN_IF_ERROR(SetWhitePoint(wp_unadapted)); + + // Relies on color_space. + JXL_RETURN_IF_ERROR(IdentifyPrimaries(profile, wp_unadapted, this)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(profile, this); + // ICC and RenderingIntent have the same values (0..3). + rendering_intent = static_cast(rendering_intent32); +#else // JPEGXL_ENABLE_SKCMS + + std::lock_guard guard(LcmsMutex()); + const cmsContext context = GetContext(); + + Profile profile; + JXL_RETURN_IF_ERROR(DecodeProfile(context, icc_, &profile)); + + const cmsUInt32Number rendering_intent32 = + cmsGetHeaderRenderingIntent(profile.get()); + if (rendering_intent32 > 3) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + + SetColorSpace(ColorSpaceFromProfile(profile)); + + const cmsCIEXYZ wp_unadapted = UnadaptedWhitePoint(context, profile, *this); + JXL_RETURN_IF_ERROR(SetWhitePoint(CIExyFromXYZ(wp_unadapted))); + + // Relies on color_space. + JXL_RETURN_IF_ERROR(IdentifyPrimaries(profile, wp_unadapted, this)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(context, profile, this); + + // ICC and RenderingIntent have the same values (0..3). + rendering_intent = static_cast(rendering_intent32); +#endif // JPEGXL_ENABLE_SKCMS + + return true; +} + +void ColorEncoding::DecideIfWantICC() { + PaddedBytes icc_new; + bool equivalent; +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile profile; + if (!DecodeProfile(ICC(), &profile)) return; + if (!MaybeCreateProfile(*this, &icc_new)) return; + equivalent = ProfileEquivalentToICC(profile, icc_new); +#else // JPEGXL_ENABLE_SKCMS + const cmsContext context = GetContext(); + Profile profile; + if (!DecodeProfile(context, ICC(), &profile)) return; + if (!MaybeCreateProfile(*this, &icc_new)) return; + equivalent = ProfileEquivalentToICC(context, profile, icc_new, *this); +#endif // JPEGXL_ENABLE_SKCMS + + // Successfully created a profile => reconstruction should be equivalent. + JXL_ASSERT(equivalent); + want_icc_ = false; +} + +ColorSpaceTransform::~ColorSpaceTransform() { +#if !JPEGXL_ENABLE_SKCMS + std::lock_guard guard(LcmsMutex()); + for (void* p : transforms_) { + TransformDeleter()(p); + } +#endif +} + +ColorSpaceTransform::ColorSpaceTransform() +#if JPEGXL_ENABLE_SKCMS + : skcms_icc_(new SkcmsICC()) +#endif // JPEGXL_ENABLE_SKCMS +{ +} + +Status ColorSpaceTransform::Init(const ColorEncoding& c_src, + const ColorEncoding& c_dst, + float intensity_target, size_t xsize, + const size_t num_threads) { + std::lock_guard guard(LcmsMutex()); +#if JXL_CMS_VERBOSE + printf("%s -> %s\n", Description(c_src).c_str(), Description(c_dst).c_str()); +#endif + +#if JPEGXL_ENABLE_SKCMS + skcms_icc_->icc_src_ = c_src.ICC(); + skcms_icc_->icc_dst_ = c_dst.ICC(); + JXL_RETURN_IF_ERROR( + DecodeProfile(skcms_icc_->icc_src_, &skcms_icc_->profile_src_)); + JXL_RETURN_IF_ERROR( + DecodeProfile(skcms_icc_->icc_dst_, &skcms_icc_->profile_dst_)); +#else // JPEGXL_ENABLE_SKCMS + const cmsContext context = GetContext(); + Profile profile_src, profile_dst; + JXL_RETURN_IF_ERROR(DecodeProfile(context, c_src.ICC(), &profile_src)); + JXL_RETURN_IF_ERROR(DecodeProfile(context, c_dst.ICC(), &profile_dst)); +#endif // JPEGXL_ENABLE_SKCMS + + skip_lcms_ = false; + if (c_src.SameColorEncoding(c_dst)) { + skip_lcms_ = true; +#if JXL_CMS_VERBOSE + printf("Skip CMS\n"); +#endif + } + + // Special-case for BT.2100 HLG/PQ and SRGB <=> linear: + const bool src_linear = c_src.tf.IsLinear(); + const bool dst_linear = c_dst.tf.IsLinear(); + if (((c_src.tf.IsPQ() || c_src.tf.IsHLG()) && dst_linear) || + ((c_dst.tf.IsPQ() || c_dst.tf.IsHLG()) && src_linear) || + ((c_src.tf.IsPQ() != c_dst.tf.IsPQ()) && intensity_target_ != 10000) || + (c_src.tf.IsSRGB() && dst_linear) || (c_dst.tf.IsSRGB() && src_linear)) { + // Construct new profiles as if the data were already/still linear. + ColorEncoding c_linear_src = c_src; + ColorEncoding c_linear_dst = c_dst; + c_linear_src.tf.SetTransferFunction(TransferFunction::kLinear); + c_linear_dst.tf.SetTransferFunction(TransferFunction::kLinear); + PaddedBytes icc_src, icc_dst; +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile new_src, new_dst; +#else // JPEGXL_ENABLE_SKCMS + Profile new_src, new_dst; +#endif // JPEGXL_ENABLE_SKCMS + // Only enable ExtraTF if profile creation succeeded. + if (MaybeCreateProfile(c_linear_src, &icc_src) && + MaybeCreateProfile(c_linear_dst, &icc_dst) && +#if JPEGXL_ENABLE_SKCMS + DecodeProfile(icc_src, &new_src) && DecodeProfile(icc_dst, &new_dst)) { +#else // JPEGXL_ENABLE_SKCMS + DecodeProfile(context, icc_src, &new_src) && + DecodeProfile(context, icc_dst, &new_dst)) { +#endif // JPEGXL_ENABLE_SKCMS + if (c_src.SameColorSpace(c_dst)) { + skip_lcms_ = true; + } +#if JXL_CMS_VERBOSE + printf("Special linear <-> HLG/PQ/sRGB; skip=%d\n", skip_lcms_); +#endif +#if JPEGXL_ENABLE_SKCMS + skcms_icc_->icc_src_ = PaddedBytes(); + skcms_icc_->profile_src_ = new_src; + skcms_icc_->icc_dst_ = PaddedBytes(); + skcms_icc_->profile_dst_ = new_dst; +#else // JPEGXL_ENABLE_SKCMS + profile_src.swap(new_src); + profile_dst.swap(new_dst); +#endif // JPEGXL_ENABLE_SKCMS + if (!c_src.tf.IsLinear()) { + preprocess_ = c_src.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_src.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + } + if (!c_dst.tf.IsLinear()) { + postprocess_ = c_dst.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_dst.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + } + } else { + JXL_WARNING("Failed to create extra linear profiles"); + } + } + +#if JPEGXL_ENABLE_SKCMS + if (!skcms_MakeUsableAsDestination(&skcms_icc_->profile_dst_)) { + return JXL_FAILURE( + "Failed to make %s usable as a color transform destination", + Description(c_dst).c_str()); + } +#endif // JPEGXL_ENABLE_SKCMS + + // Not including alpha channel (copied separately). + const size_t channels_src = c_src.Channels(); + const size_t channels_dst = c_dst.Channels(); + JXL_CHECK(channels_src == channels_dst); +#if JXL_CMS_VERBOSE + printf("Channels: %zu; Threads: %zu\n", channels_src, num_threads); +#endif + +#if !JPEGXL_ENABLE_SKCMS + // Type includes color space (XYZ vs RGB), so can be different. + const uint32_t type_src = Type32(c_src); + const uint32_t type_dst = Type32(c_dst); + transforms_.clear(); + for (size_t i = 0; i < num_threads; ++i) { + const uint32_t intent = static_cast(c_dst.rendering_intent); + const uint32_t flags = + cmsFLAGS_BLACKPOINTCOMPENSATION | cmsFLAGS_HIGHRESPRECALC; + // NOTE: we're using the current thread's context and assuming all state + // modified by cmsDoTransform resides in the transform, not the context. + transforms_.emplace_back(cmsCreateTransformTHR(context, profile_src.get(), + type_src, profile_dst.get(), + type_dst, intent, flags)); + if (transforms_.back() == nullptr) { + return JXL_FAILURE("Failed to create transform"); + } + } +#endif // !JPEGXL_ENABLE_SKCMS + + // Ideally LCMS would convert directly from External to Image3. However, + // cmsDoTransformLineStride only accepts 32-bit BytesPerPlaneIn, whereas our + // planes can be more than 4 GiB apart. Hence, transform inputs/outputs must + // be interleaved. Calling cmsDoTransform for each pixel is expensive + // (indirect call). We therefore transform rows, which requires per-thread + // buffers. To avoid separate allocations, we use the rows of an image. + // Because LCMS apparently also cannot handle <= 16 bit inputs and 32-bit + // outputs (or vice versa), we use floating point input/output. +#if JPEGXL_ENABLE_SKCMS + // SkiaCMS doesn't support grayscale float buffers, so we create space for RGB + // float buffers anyway. + buf_src_ = ImageF(xsize * 3, num_threads); + buf_dst_ = ImageF(xsize * 3, num_threads); +#else + buf_src_ = ImageF(xsize * channels_src, num_threads); + buf_dst_ = ImageF(xsize * channels_dst, num_threads); +#endif + intensity_target_ = intensity_target; + xsize_ = xsize; + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_color_management.h b/third_party/jpeg-xl/lib/jxl/enc_color_management.h new file mode 100644 index 000000000000..8c367f0440c7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_color_management.h @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_COLOR_MANAGEMENT_H_ +#define LIB_JXL_ENC_COLOR_MANAGEMENT_H_ + +// ICC profiles and color space conversions. + +#include +#include + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Run is thread-safe. +class ColorSpaceTransform { + public: + ColorSpaceTransform(); + ~ColorSpaceTransform(); + + // Cannot copy (transforms_ holds pointers). + ColorSpaceTransform(const ColorSpaceTransform&) = delete; + ColorSpaceTransform& operator=(const ColorSpaceTransform&) = delete; + + // "Constructor"; allocates for up to `num_threads`, or returns false. + // `intensity_target` is used for conversion to and from PQ, which is absolute + // (1 always represents 10000 cd/m²) and thus needs scaling in linear space if + // 1 is to represent another luminance level instead. + Status Init(const ColorEncoding& c_src, const ColorEncoding& c_dst, + float intensity_target, size_t xsize, size_t num_threads); + + float* BufSrc(const size_t thread) { return buf_src_.Row(thread); } + + float* BufDst(const size_t thread) { return buf_dst_.Row(thread); } + +#if JPEGXL_ENABLE_SKCMS + struct SkcmsICC; + std::unique_ptr skcms_icc_; +#else + // One per thread - cannot share because of caching. + std::vector transforms_; +#endif + + ImageF buf_src_; + ImageF buf_dst_; + float intensity_target_; + size_t xsize_; + bool skip_lcms_ = false; + ExtraTF preprocess_ = ExtraTF::kNone; + ExtraTF postprocess_ = ExtraTF::kNone; +}; + +// buf_X can either be from BufX() or caller-allocated, interleaved storage. +// `thread` must be less than the `num_threads` passed to Init. +// `t` is non-const because buf_* may be modified. +void DoColorSpaceTransform(ColorSpaceTransform* t, size_t thread, + const float* buf_src, float* buf_dst); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COLOR_MANAGEMENT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_comparator.cc b/third_party/jpeg-xl/lib/jxl/enc_comparator.cc new file mode 100644 index 000000000000..7fd944321369 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_comparator.cc @@ -0,0 +1,149 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_comparator.h" + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_gamma_correct.h" + +namespace jxl { +namespace { + +// color is linear, but blending happens in gamma-compressed space using +// (gamma-compressed) grayscale background color, alpha image represents +// weights of the sRGB colors in the [0 .. (1 << bit_depth) - 1] interval, +// output image is in linear space. +void AlphaBlend(const Image3F& in, const size_t c, float background_linear, + const ImageF& alpha, Image3F* out) { + const float background = LinearToSrgb8Direct(background_linear); + + for (size_t y = 0; y < out->ysize(); ++y) { + const float* JXL_RESTRICT row_a = alpha.ConstRow(y); + const float* JXL_RESTRICT row_i = in.ConstPlaneRow(c, y); + float* JXL_RESTRICT row_o = out->PlaneRow(c, y); + for (size_t x = 0; x < out->xsize(); ++x) { + const float a = row_a[x]; + if (a <= 0.f) { + row_o[x] = background_linear; + } else if (a >= 1.f) { + row_o[x] = row_i[x]; + } else { + const float w_fg = a; + const float w_bg = 1.0f - w_fg; + const float fg = w_fg * LinearToSrgb8Direct(row_i[x]); + const float bg = w_bg * background; + row_o[x] = Srgb8ToLinearDirect(fg + bg); + } + } + } +} + +const Image3F* AlphaBlend(const ImageBundle& ib, const Image3F& linear, + float background_linear, Image3F* copy) { + // No alpha => all opaque. + if (!ib.HasAlpha()) return &linear; + + *copy = Image3F(linear.xsize(), linear.ysize()); + for (size_t c = 0; c < 3; ++c) { + AlphaBlend(linear, c, background_linear, ib.alpha(), copy); + } + return copy; +} + +void AlphaBlend(float background_linear, ImageBundle* io_linear_srgb) { + // No alpha => all opaque. + if (!io_linear_srgb->HasAlpha()) return; + + for (size_t c = 0; c < 3; ++c) { + AlphaBlend(*io_linear_srgb->color(), c, background_linear, + *io_linear_srgb->alpha(), io_linear_srgb->color()); + } +} + +float ComputeScoreImpl(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, ImageF* distmap) { + JXL_CHECK(comparator->SetReferenceImage(rgb0)); + float score; + JXL_CHECK(comparator->CompareWith(rgb1, distmap, &score)); + return score; +} + +} // namespace + +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, ImageF* diffmap, ThreadPool* pool) { + PROFILER_FUNC; + // Convert to linear sRGB (unless already in that space) + ImageMetadata metadata0 = *rgb0.metadata(); + ImageBundle store0(&metadata0); + const ImageBundle* linear_srgb0; + JXL_CHECK(TransformIfNeeded(rgb0, ColorEncoding::LinearSRGB(rgb0.IsGray()), + pool, &store0, &linear_srgb0)); + ImageMetadata metadata1 = *rgb1.metadata(); + ImageBundle store1(&metadata1); + const ImageBundle* linear_srgb1; + JXL_CHECK(TransformIfNeeded(rgb1, ColorEncoding::LinearSRGB(rgb1.IsGray()), + pool, &store1, &linear_srgb1)); + + // No alpha: skip blending, only need a single call to Butteraugli. + if (!rgb0.HasAlpha() && !rgb1.HasAlpha()) { + return ComputeScoreImpl(*linear_srgb0, *linear_srgb1, comparator, diffmap); + } + + // Blend on black and white backgrounds + + const float black = 0.0f; + ImageBundle blended_black0 = linear_srgb0->Copy(); + ImageBundle blended_black1 = linear_srgb1->Copy(); + AlphaBlend(black, &blended_black0); + AlphaBlend(black, &blended_black1); + + const float white = 1.0f; + ImageBundle blended_white0 = linear_srgb0->Copy(); + ImageBundle blended_white1 = linear_srgb1->Copy(); + + AlphaBlend(white, &blended_white0); + AlphaBlend(white, &blended_white1); + + ImageF diffmap_black, diffmap_white; + const float dist_black = ComputeScoreImpl(blended_black0, blended_black1, + comparator, &diffmap_black); + const float dist_white = ComputeScoreImpl(blended_white0, blended_white1, + comparator, &diffmap_white); + + // diffmap and return values are the max of diffmap_black/white. + if (diffmap != nullptr) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + *diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const float* JXL_RESTRICT row_black = diffmap_black.ConstRow(y); + const float* JXL_RESTRICT row_white = diffmap_white.ConstRow(y); + float* JXL_RESTRICT row_out = diffmap->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = std::max(row_black[x], row_white[x]); + } + } + } + return std::max(dist_black, dist_white); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_comparator.h b/third_party/jpeg-xl/lib/jxl/enc_comparator.h new file mode 100644 index 000000000000..91437f65a7a5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_comparator.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_COMPARATOR_H_ +#define LIB_JXL_ENC_COMPARATOR_H_ + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class Comparator { + public: + virtual ~Comparator() = default; + + // Sets the reference image, the first to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + virtual Status SetReferenceImage(const ImageBundle& ref) = 0; + + // Sets the actual image (with loss), the second to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + // In diffmap it outputs the local score per pixel, while in score it outputs + // a single score. Any one may be set to nullptr to not compute it. + virtual Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) = 0; + + // Quality thresholds for diffmap and score values. + // The good score must represent a value where the images are considered to + // be perceptually indistinguishable (but not identical) + // The bad value must be larger than good to indicate "lower means better" + // and smaller than good to indicate "higher means better" + virtual float GoodQualityScore() const = 0; + virtual float BadQualityScore() const = 0; +}; + +// Computes the score given images in any RGB color model, optionally with +// alpha channel. +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, ImageF* diffmap = nullptr, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COMPARATOR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_context_map.cc b/third_party/jpeg-xl/lib/jxl/enc_context_map.cc new file mode 100644 index 000000000000..532cd1303380 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_context_map.cc @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Library to encode the context map. + +#include "lib/jxl/enc_context_map.h" + +#include + +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" + +namespace jxl { + +namespace { + +size_t IndexOf(const std::vector& v, uint8_t value) { + size_t i = 0; + for (; i < v.size(); ++i) { + if (v[i] == value) return i; + } + return i; +} + +void MoveToFront(std::vector* v, size_t index) { + uint8_t value = (*v)[index]; + for (size_t i = index; i != 0; --i) { + (*v)[i] = (*v)[i - 1]; + } + (*v)[0] = value; +} + +std::vector MoveToFrontTransform(const std::vector& v) { + if (v.empty()) return v; + uint8_t max_value = *std::max_element(v.begin(), v.end()); + std::vector mtf(max_value + 1); + for (size_t i = 0; i <= max_value; ++i) mtf[i] = i; + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + size_t index = IndexOf(mtf, v[i]); + JXL_ASSERT(index < mtf.size()); + result[i] = static_cast(index); + MoveToFront(&mtf, index); + } + return result; +} + +} // namespace + +void EncodeContextMap(const std::vector& context_map, + size_t num_histograms, BitWriter* writer) { + if (num_histograms == 1) { + // Simple code + writer->Write(1, 1); + // 0 bits per entry. + writer->Write(2, 0); + return; + } + + std::vector transformed_symbols = MoveToFrontTransform(context_map); + std::vector> tokens(1), mtf_tokens(1); + EntropyEncodingData codes; + std::vector dummy_context_map; + for (size_t i = 0; i < context_map.size(); i++) { + tokens[0].emplace_back(0, context_map[i]); + } + for (size_t i = 0; i < transformed_symbols.size(); i++) { + mtf_tokens[0].emplace_back(0, transformed_symbols[i]); + } + HistogramParams params; + params.uint_method = HistogramParams::HybridUintMethod::kContextMap; + size_t ans_cost = BuildAndEncodeHistograms( + params, 1, tokens, &codes, &dummy_context_map, nullptr, 0, nullptr); + size_t mtf_cost = BuildAndEncodeHistograms( + params, 1, mtf_tokens, &codes, &dummy_context_map, nullptr, 0, nullptr); + bool use_mtf = mtf_cost < ans_cost; + // Rebuild token list. + tokens[0].clear(); + for (size_t i = 0; i < transformed_symbols.size(); i++) { + tokens[0].emplace_back(0, + use_mtf ? transformed_symbols[i] : context_map[i]); + } + size_t entry_bits = CeilLog2Nonzero(num_histograms); + size_t simple_cost = entry_bits * context_map.size(); + if (entry_bits < 4 && simple_cost < ans_cost && simple_cost < mtf_cost) { + writer->Write(1, 1); + writer->Write(2, entry_bits); + for (size_t i = 0; i < context_map.size(); i++) { + writer->Write(entry_bits, context_map[i]); + } + } else { + writer->Write(1, 0); + writer->Write(1, use_mtf); // Use/don't use MTF. + BuildAndEncodeHistograms(params, 1, tokens, &codes, &dummy_context_map, + writer, 0, nullptr); + WriteTokens(tokens[0], codes, dummy_context_map, writer); + } +} + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out) { + auto& dct = block_ctx_map.dc_thresholds; + auto& qft = block_ctx_map.qf_thresholds; + auto& ctx_map = block_ctx_map.ctx_map; + BitWriter::Allotment allotment( + writer, + (dct[0].size() + dct[1].size() + dct[2].size() + qft.size()) * 34 + 1 + + 4 + 4 + ctx_map.size() * 10 + 1024); + if (dct[0].empty() && dct[1].empty() && dct[2].empty() && qft.empty() && + ctx_map.size() == 21 && + std::equal(ctx_map.begin(), ctx_map.end(), BlockCtxMap::kDefaultCtxMap)) { + writer->Write(1, 1); // default + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); + return; + } + writer->Write(1, 0); + for (int j : {0, 1, 2}) { + writer->Write(4, dct[j].size()); + for (int i : dct[j]) { + JXL_CHECK(U32Coder::Write(kDCThresholdDist, PackSigned(i), writer)); + } + } + writer->Write(4, qft.size()); + for (uint32_t i : qft) { + JXL_CHECK(U32Coder::Write(kQFThresholdDist, i - 1, writer)); + } + EncodeContextMap(ctx_map, block_ctx_map.num_ctxs, writer); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_context_map.h b/third_party/jpeg-xl/lib/jxl/enc_context_map.h new file mode 100644 index 000000000000..155e8c204bd2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_context_map.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_CONTEXT_MAP_H_ +#define LIB_JXL_ENC_CONTEXT_MAP_H_ + +#include +#include + +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Max limit is 255 because encoding assumes numbers < 255 +// More clusters can help compression, but makes encode/decode somewhat slower +static const size_t kClustersLimit = 128; + +// Encodes the given context map to the bit stream. The number of different +// histogram ids is given by num_histograms. +void EncodeContextMap(const std::vector& context_map, + size_t num_histograms, BitWriter* writer); + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out); +} // namespace jxl + +#endif // LIB_JXL_ENC_CONTEXT_MAP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc new file mode 100644 index 000000000000..07da72d0654a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc @@ -0,0 +1,676 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_detect_dots.h" + +#include + +#include +#include +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_detect_dots.cc" +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/optimize.h" + +// Set JXL_DEBUG_DOT_DETECT to 1 to enable debugging. +#ifndef JXL_DEBUG_DOT_DETECT +#define JXL_DEBUG_DOT_DETECT 0 +#endif + +#if JXL_DEBUG_DOT_DETECT +#include "lib/jxl/aux_out.h" +#endif + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +ImageF SumOfSquareDifferences(const Image3F& forig, const Image3F& smooth, + ThreadPool* pool) { + const HWY_FULL(float) d; + const auto color_coef0 = Set(d, 0.0f); + const auto color_coef1 = Set(d, 10.0f); + const auto color_coef2 = Set(d, 0.0f); + + ImageF sum_of_squares(forig.xsize(), forig.ysize()); + RunOnPool( + pool, 0, forig.ysize(), ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT orig_row0 = forig.Plane(0).ConstRow(y); + const float* JXL_RESTRICT orig_row1 = forig.Plane(1).ConstRow(y); + const float* JXL_RESTRICT orig_row2 = forig.Plane(2).ConstRow(y); + const float* JXL_RESTRICT smooth_row0 = smooth.Plane(0).ConstRow(y); + const float* JXL_RESTRICT smooth_row1 = smooth.Plane(1).ConstRow(y); + const float* JXL_RESTRICT smooth_row2 = smooth.Plane(2).ConstRow(y); + float* JXL_RESTRICT sos_row = sum_of_squares.Row(y); + + for (size_t x = 0; x < forig.xsize(); x += Lanes(d)) { + auto v0 = Load(d, orig_row0 + x) - Load(d, smooth_row0 + x); + auto v1 = Load(d, orig_row1 + x) - Load(d, smooth_row1 + x); + auto v2 = Load(d, orig_row2 + x) - Load(d, smooth_row2 + x); + v0 *= v0; + v1 *= v1; + v2 *= v2; + v0 *= color_coef0; // FMA doesn't help here. + v1 *= color_coef1; + v2 *= color_coef2; + const auto sos = v0 + v1 + v2; // weighted sum of square diffs + Store(sos, d, sos_row + x); + } + }, + "ComputeEnergyImage"); + return sum_of_squares; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(SumOfSquareDifferences); // Local function + +const int kEllipseWindowSize = 5; + +namespace { +struct GaussianEllipse { + double x; // position in x + double y; // position in y + double sigma_x; // scale in x + double sigma_y; // scale in y + double angle; // ellipse rotation in radians + std::array intensity; // intensity in each channel + + // The following variables do not need to be encoded + double l2_loss; // error after the Gaussian was fit + double l1_loss; + double ridge_loss; // the l2_loss plus regularization term + double custom_loss; // experimental custom loss + std::array bgColor; // best background color + size_t neg_pixels; // number of negative pixels when subtracting dot + std::array neg_value; // debt due to channel truncation +}; +double DotGaussianModel(double dx, double dy, double ct, double st, + double sigma_x, double sigma_y, double intensity) { + double rx = ct * dx + st * dy; + double ry = -st * dx + ct * dy; + double md = (rx * rx / sigma_x) + (ry * ry / sigma_y); + double value = intensity * exp(-0.5 * md); + return value; +} + +constexpr bool kOptimizeBackground = true; + +// Gaussian that smooths noise but preserves dots +const WeightsSeparable5& WeightsSeparable5Gaussian0_65() { + constexpr float w0 = 0.558311f; + constexpr float w1 = 0.210395f; + constexpr float w2 = 0.010449f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +// (Iterated) Gaussian that removes dots. +const WeightsSeparable5& WeightsSeparable5Gaussian3() { + constexpr float w0 = 0.222338f; + constexpr float w1 = 0.210431f; + constexpr float w2 = 0.1784f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +ImageF ComputeEnergyImage(const Image3F& orig, Image3F* smooth, + ThreadPool* pool) { + PROFILER_FUNC; + + // Prepare guidance images for dot selection. + Image3F forig(orig.xsize(), orig.ysize()); + Image3F tmp(orig.xsize(), orig.ysize()); + *smooth = Image3F(orig.xsize(), orig.ysize()); + + const auto& weights1 = WeightsSeparable5Gaussian0_65(); + const auto& weights3 = WeightsSeparable5Gaussian3(); + + Separable5_3(orig, Rect(orig), weights1, pool, &forig); + + Separable5_3(orig, Rect(orig), weights3, pool, &tmp); + Separable5_3(tmp, Rect(tmp), weights3, pool, smooth); + +#if JXL_DEBUG_DOT_DETECT + AuxOut aux; + aux.debug_prefix = "/tmp/sebastian/"; + aux.DumpImage("filtered", forig); + aux.DumpImage("sm", *smooth); +#endif + + return HWY_DYNAMIC_DISPATCH(SumOfSquareDifferences)(forig, *smooth, pool); +} + +struct Pixel { + int x; + int y; +}; + +Pixel operator+(const Pixel& a, const Pixel& b) { + return Pixel{a.x + b.x, a.y + b.y}; +} + +// Maximum area in pixels of a ellipse +const size_t kMaxCCSize = 1000; + +// Extracts a connected component from a Binary image where seed is part +// of the component +bool ExtractComponent(ImageF* img, std::vector* pixels, + const Pixel& seed, double threshold) { + PROFILER_FUNC; + static const std::vector neighbors{{1, -1}, {1, 0}, {1, 1}, {0, -1}, + {0, 1}, {-1, -1}, {-1, 1}, {1, 0}}; + std::vector q{seed}; + while (!q.empty()) { + Pixel current = q.back(); + q.pop_back(); + pixels->push_back(current); + if (pixels->size() > kMaxCCSize) return false; + for (const Pixel& delta : neighbors) { + Pixel child = current + delta; + if (child.x >= 0 && static_cast(child.x) < img->xsize() && + child.y >= 0 && static_cast(child.y) < img->ysize()) { + float* value = &img->Row(child.y)[child.x]; + if (*value > threshold) { + *value = 0.0; + q.push_back(child); + } + } + } + } + return true; +} + +inline bool PointInRect(const Rect& r, const Pixel& p) { + return (static_cast(p.x) >= r.x0() && + static_cast(p.x) < (r.x0() + r.xsize()) && + static_cast(p.y) >= r.y0() && + static_cast(p.y) < (r.y0() + r.ysize())); +} + +struct ConnectedComponent { + ConnectedComponent(const Rect& bounds, const std::vector&& pixels) + : bounds(bounds), pixels(pixels) {} + Rect bounds; + std::vector pixels; + float maxEnergy; + float meanEnergy; + float varEnergy; + float meanBg; + float varBg; + float score; + Pixel mode; + + void CompStats(const ImageF& energy, int extra) { + PROFILER_FUNC; + maxEnergy = 0.0; + meanEnergy = 0.0; + varEnergy = 0.0; + meanBg = 0.0; + varBg = 0.0; + int nIn = 0; + int nOut = 0; + mode.x = 0; + mode.y = 0; + for (int sy = -extra; sy < (static_cast(bounds.ysize()) + extra); + sy++) { + int y = sy + static_cast(bounds.y0()); + if (y < 0 || static_cast(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT erow = energy.ConstRow(y); + for (int sx = -extra; sx < (static_cast(bounds.xsize()) + extra); + sx++) { + int x = sx + static_cast(bounds.x0()); + if (x < 0 || static_cast(x) >= energy.xsize()) continue; + if (erow[x] > maxEnergy) { + maxEnergy = erow[x]; + mode.x = x; + mode.y = y; + } + if (PointInRect(bounds, Pixel{x, y})) { + meanEnergy += erow[x]; + varEnergy += erow[x] * erow[x]; + nIn++; + } else { + meanBg += erow[x]; + varBg += erow[x] * erow[x]; + nOut++; + } + } + } + meanEnergy = meanEnergy / nIn; + meanBg = meanBg / nOut; + varEnergy = (varEnergy / nIn) - meanEnergy * meanEnergy; + varBg = (varBg / nOut) - meanBg * meanBg; + score = (meanEnergy - meanBg) / std::sqrt(varBg); + } +}; + +Rect BoundingRectangle(const std::vector& pixels) { + PROFILER_FUNC; + JXL_ASSERT(!pixels.empty()); + int low_x, high_x, low_y, high_y; + low_x = high_x = pixels[0].x; + low_y = high_y = pixels[0].y; + for (const Pixel& p : pixels) { + low_x = std::min(low_x, p.x); + high_x = std::max(high_x, p.x); + low_y = std::min(low_y, p.y); + high_y = std::max(high_y, p.y); + } + return Rect(low_x, low_y, high_x - low_x + 1, high_y - low_y + 1); +} + +std::vector FindCC(const ImageF& energy, double t_low, + double t_high, uint32_t maxWindow, + double minScore) { + PROFILER_FUNC; + const int kExtraRect = 4; + ImageF img = CopyImage(energy); + std::vector ans; + for (size_t y = 0; y < img.ysize(); y++) { + float* JXL_RESTRICT row = img.Row(y); + for (size_t x = 0; x < img.xsize(); x++) { + if (row[x] > t_high) { + std::vector pixels; + row[x] = 0.0; + bool success = ExtractComponent( + &img, &pixels, Pixel{static_cast(x), static_cast(y)}, + t_low); + if (!success) continue; +#if JXL_DEBUG_DOT_DETECT + for (size_t i = 0; i < pixels.size(); i++) { + fprintf(stderr, "(%d,%d) ", pixels[i].x, pixels[i].y); + } + fprintf(stderr, "\n"); +#endif // JXL_DEBUG_DOT_DETECT + Rect bounds = BoundingRectangle(pixels); + if (bounds.xsize() < maxWindow && bounds.ysize() < maxWindow) { + ConnectedComponent cc{bounds, std::move(pixels)}; + cc.CompStats(energy, kExtraRect); + if (cc.score < minScore) continue; + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "cc mode: (%d,%d), max: %f, bgMean: %f bgVar: " + "%f bound:(%zu,%zu,%zu,%zu)\n", + cc.mode.x, cc.mode.y, cc.maxEnergy, cc.meanEnergy, + cc.varEnergy, cc.bounds.x0(), cc.bounds.y0(), + cc.bounds.xsize(), cc.bounds.ysize()); + ans.push_back(cc); + } + } + } + } + return ans; +} + +// TODO (sggonzalez): Adapt this function for the different color spaces or +// remove it if the color space with the best performance does not need it +void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc, + const Image3F& img, const Image3F& background) { + PROFILER_FUNC; + const int rectBounds = 2; + const double kIntensityR = 0.0; // 0.015; + const double kSigmaR = 0.0; // 0.01; + const double kZeroEpsilon = 0.1; // Tolerance to consider a value negative + double ct = cos(ellipse->angle), st = sin(ellipse->angle); + const std::array channelGains{1.0, 1.0, 1.0}; + int N = 0; + ellipse->l1_loss = 0.0; + ellipse->l2_loss = 0.0; + ellipse->neg_pixels = 0; + ellipse->neg_value.fill(0.0); + double distMeanModeSq = (cc.mode.x - ellipse->x) * (cc.mode.x - ellipse->x) + + (cc.mode.y - ellipse->y) * (cc.mode.y - ellipse->y); + ellipse->custom_loss = 0.0; + for (int c = 0; c < 3; c++) { + for (int sy = -rectBounds; + sy < (static_cast(cc.bounds.ysize()) + rectBounds); sy++) { + int y = sy + cc.bounds.y0(); + if (y < 0 || static_cast(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + // bgrow is only used if kOptimizeBackground is false. + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int sx = -rectBounds; + sx < (static_cast(cc.bounds.xsize()) + rectBounds); sx++) { + int x = sx + cc.bounds.x0(); + if (x < 0 || static_cast(x) >= img.xsize()) continue; + double target = row[x]; + double dotDelta = DotGaussianModel( + x - ellipse->x, y - ellipse->y, ct, st, ellipse->sigma_x, + ellipse->sigma_y, ellipse->intensity[c]); + if (dotDelta > target + kZeroEpsilon) { + ellipse->neg_pixels++; + ellipse->neg_value[c] += dotDelta - target; + } + double bkg = kOptimizeBackground ? ellipse->bgColor[c] : bgrow[x]; + double pred = bkg + dotDelta; + double diff = target - pred; + double l2 = channelGains[c] * diff * diff; + double l1 = channelGains[c] * std::fabs(diff); + ellipse->l2_loss += l2; + ellipse->l1_loss += l1; + double w = DotGaussianModel(x - cc.mode.x, y - cc.mode.y, 1.0, 0.0, + 1.0 + ellipse->sigma_x, + 1.0 + ellipse->sigma_y, 1.0); + ellipse->custom_loss += w * l2; + N++; + } + } + } + ellipse->l2_loss /= N; + ellipse->custom_loss /= N; + ellipse->custom_loss += 20.0 * distMeanModeSq + ellipse->neg_value[1]; + ellipse->l1_loss /= N; + double ridgeTerm = kSigmaR * ellipse->sigma_x + kSigmaR * ellipse->sigma_y; + for (int c = 0; c < 3; c++) { + ridgeTerm += kIntensityR * ellipse->intensity[c] * ellipse->intensity[c]; + } + ellipse->ridge_loss = ellipse->l2_loss + ridgeTerm; +} + +GaussianEllipse FitGaussianFast(const ConnectedComponent& cc, + const ImageF& energy, const Image3F& img, + const Image3F& background) { + PROFILER_FUNC; + constexpr bool leastSqIntensity = true; + constexpr double kEpsilon = 1e-6; + GaussianEllipse ans; + constexpr int kRectBounds = (kEllipseWindowSize >> 1); + + // Compute the 1st and 2nd moments of the CC + double sum = 0.0; + int N = 0; + std::array m1{0.0, 0.0, 0.0}; + std::array m2{0.0, 0.0, 0.0}; + std::array color{0.0, 0.0, 0.0}; + std::array bgColor{0.0, 0.0, 0.0}; + + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, "%zu %zu %zu %zu\n", cc.bounds.x0(), + cc.bounds.y0(), cc.bounds.xsize(), cc.bounds.ysize()); + for (int c = 0; c < 3; c++) { + color[c] = img.ConstPlaneRow(c, cc.mode.y)[cc.mode.x] - + background.ConstPlaneRow(c, cc.mode.y)[cc.mode.x]; + } + double sign = (color[1] > 0) ? 1 : -1; + for (int sy = -kRectBounds; sy <= kRectBounds; sy++) { + int y = sy + cc.mode.y; + if (y < 0 || static_cast(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(1, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(1, y); + for (int sx = -kRectBounds; sx <= kRectBounds; sx++) { + int x = sx + cc.mode.x; + if (x < 0 || static_cast(x) >= energy.xsize()) continue; + double w = std::max(kEpsilon, sign * (row[x] - bgrow[x])); + sum += w; + + m1[0] += w * x; + m1[1] += w * y; + m2[0] += w * x * x; + m2[1] += w * x * y; + m2[2] += w * y * y; + for (int c = 0; c < 3; c++) { + bgColor[c] += background.ConstPlaneRow(c, y)[x]; + } + N++; + } + } + JXL_CHECK(N > 0); + + for (int i = 0; i < 3; i++) { + m1[i] /= sum; + m2[i] /= sum; + bgColor[i] /= N; + } + + // Some magic constants + constexpr double kSigmaMult = 1.0; + constexpr std::array kScaleMult{1.1, 1.1, 1.1}; + + // Now set the parameters of the Gaussian + ans.x = m1[0]; + ans.y = m1[1]; + for (int j = 0; j < 3; j++) { + ans.intensity[j] = kScaleMult[j] * color[j]; + } + + ImageD Sigma(2, 2), D(1, 2), U(2, 2); + Sigma.Row(0)[0] = m2[0] - m1[0] * m1[0]; + Sigma.Row(1)[1] = m2[2] - m1[1] * m1[1]; + Sigma.Row(0)[1] = Sigma.Row(1)[0] = m2[1] - m1[0] * m1[1]; + ConvertToDiagonal(Sigma, &D, &U); + const double* JXL_RESTRICT d = D.ConstRow(0); + const double* JXL_RESTRICT u = U.ConstRow(1); + int p1 = 0, p2 = 1; + if (d[0] < d[1]) std::swap(p1, p2); + ans.sigma_x = kSigmaMult * d[p1]; + ans.sigma_y = kSigmaMult * d[p2]; + ans.angle = std::atan2(u[p1], u[p2]); + ans.l2_loss = 0.0; + ans.bgColor = bgColor; + if (leastSqIntensity) { + GaussianEllipse* ellipse = &ans; + double ct = cos(ans.angle), st = sin(ans.angle); + // Estimate intensity with least squares (fixed background) + for (int c = 0; c < 3; c++) { + double gg = 0.0; + double gd = 0.0; + int yc = static_cast(cc.mode.y); + int xc = static_cast(cc.mode.x); + for (int y = yc - kRectBounds; y <= yc + kRectBounds; y++) { + if (y < 0 || static_cast(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int x = xc - kRectBounds; x <= xc + kRectBounds; x++) { + if (x < 0 || static_cast(x) >= img.xsize()) continue; + double target = row[x] - bgrow[x]; + double gaussian = + DotGaussianModel(x - ellipse->x, y - ellipse->y, ct, st, + ellipse->sigma_x, ellipse->sigma_y, 1.0); + gg += gaussian * gaussian; + gd += gaussian * target; + } + } + ans.intensity[c] = gd / (gg + 1e-6); // Regularized least squares + } + } + ComputeDotLosses(&ans, cc, img, background); + return ans; +} + +// Probably slower, but optimizes the right thing +GaussianEllipse FitGaussianOpt(const ConnectedComponent& cc, + const ImageF& energy, const Image3F& img, + const Image3F& background) { + const std::array colorFactor{1.0, 1.0, 1.0}; + GaussianEllipse ans = FitGaussianFast(cc, energy, img, background); + auto l2_loss = [&img, &cc, &background, + &colorFactor](const std::vector& params) -> double { + GaussianEllipse ellipse; + ellipse.x = img.xsize() * params[0]; + ellipse.y = img.ysize() * params[1]; + ellipse.sigma_x = exp(params[2]); // strictly positive + ellipse.sigma_y = exp(params[3]); // strictly positive + ellipse.angle = params[4]; + for (int c = 0; c < 3; c++) { + ellipse.intensity[c] = colorFactor[c] * params[5 + c]; + ellipse.bgColor[c] = colorFactor[c] * params[8 + c]; + } + ComputeDotLosses(&ellipse, cc, img, background); + return ellipse.l2_loss; + }; + std::vector init{ans.x / img.xsize(), + ans.y / img.ysize(), + log(ans.sigma_x), + log(ans.sigma_y), + ans.angle, + ans.intensity[0] / colorFactor[0], + ans.intensity[1] / colorFactor[1], + ans.intensity[2] / colorFactor[2], + ans.bgColor[0] / colorFactor[0], + ans.bgColor[1] / colorFactor[1], + ans.bgColor[2] / colorFactor[2]}; + auto p = optimize::RunSimplex(11, 0.01, 77, init, l2_loss); + + ans.l2_loss = p[0]; + ans.x = img.xsize() * p[1]; + ans.y = img.ysize() * p[2]; + ans.sigma_x = exp(p[3]); + ans.sigma_y = exp(p[4]); + ans.angle = p[5]; + for (int c = 0; c < 3; c++) { + ans.intensity[c] = colorFactor[c] * p[6 + c]; + ans.bgColor[c] = colorFactor[c] * p[9 + c]; + } + return ans; +} + +GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy, + const Image3F& img, const Image3F& background) { + auto ellipse = FitGaussianFast(cc, energy, img, background); + if (ellipse.sigma_x < ellipse.sigma_y) { + std::swap(ellipse.sigma_x, ellipse.sigma_y); + ellipse.angle += kPi / 2.0; + } + ellipse.angle -= kPi * std::floor(ellipse.angle / kPi); + if (fabs(ellipse.angle - kPi) < 1e-6 || fabs(ellipse.angle) < 1e-6) { + ellipse.angle = 0.0; + } + JXL_CHECK(ellipse.angle >= 0 && ellipse.angle <= kPi && + ellipse.sigma_x >= ellipse.sigma_y); + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "Ellipse mu=(%lf,%lf) sigma=(%lf,%lf) angle=%lf " + "intensity=(%lf,%lf,%lf) bg=(%lf,%lf,%lf) l2_loss=%lf " + "custom_loss=%lf, neg_pix=%zu, neg_v=(%lf,%lf,%lf)\n", + ellipse.x, ellipse.y, ellipse.sigma_x, ellipse.sigma_y, + ellipse.angle, ellipse.intensity[0], ellipse.intensity[1], + ellipse.intensity[2], ellipse.bgColor[0], ellipse.bgColor[1], + ellipse.bgColor[2], ellipse.l2_loss, ellipse.custom_loss, + ellipse.neg_pixels, ellipse.neg_value[0], ellipse.neg_value[1], + ellipse.neg_value[2]); + return ellipse; +} + +} // namespace + +std::vector DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool) { + PROFILER_FUNC; + std::vector dots; + Image3F smooth(opsin.xsize(), opsin.ysize()); + ImageF energy = ComputeEnergyImage(opsin, &smooth, pool); +#if JXL_DEBUG_DOT_DETECT + AuxOut aux; + aux.debug_prefix = "/tmp/sebastian/"; + aux.DumpXybImage("smooth", smooth); + aux.DumpPlaneNormalized("energy", energy); +#endif // JXL_DEBUG_DOT_DETECT + std::vector components = FindCC( + energy, params.t_low, params.t_high, params.maxWinSize, params.minScore); + size_t numCC = + std::min(params.maxCC, (components.size() * params.percCC) / 100); + if (components.size() > numCC) { + std::sort( + components.begin(), components.end(), + [](const ConnectedComponent& a, const ConnectedComponent& b) -> bool { + return a.score > b.score; + }); + components.erase(components.begin() + numCC, components.end()); + } + for (const auto& cc : components) { + GaussianEllipse ellipse = FitGaussian(cc, energy, opsin, smooth); + if (ellipse.x < 0.0 || + std::ceil(ellipse.x) >= static_cast(opsin.xsize()) || + ellipse.y < 0.0 || + std::ceil(ellipse.y) >= static_cast(opsin.ysize())) { + continue; + } + if (ellipse.neg_pixels > params.maxNegPixels) continue; + double intensity = 0.21 * ellipse.intensity[0] + + 0.72 * ellipse.intensity[1] + + 0.07 * ellipse.intensity[2]; + double intensitySq = intensity * intensity; + // for (int c = 0; c < 3; c++) { + // intensitySq += ellipse.intensity[c] * ellipse.intensity[c]; + //} + double sqDistMeanMode = (ellipse.x - cc.mode.x) * (ellipse.x - cc.mode.x) + + (ellipse.y - cc.mode.y) * (ellipse.y - cc.mode.y); + if (ellipse.l2_loss < params.maxL2Loss && + ellipse.custom_loss < params.maxCustomLoss && + intensitySq > (params.minIntensity * params.minIntensity) && + sqDistMeanMode < params.maxDistMeanMode * params.maxDistMeanMode) { + size_t x0 = cc.bounds.x0(); + size_t y0 = cc.bounds.y0(); + dots.emplace_back(); + dots.back().second.emplace_back(x0, y0); + QuantizedPatch& patch = dots.back().first; + patch.xsize = cc.bounds.xsize(); + patch.ysize = cc.bounds.ysize(); + for (size_t y = 0; y < patch.ysize; y++) { + for (size_t x = 0; x < patch.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + patch.fpixels[c][y * patch.xsize + x] = + opsin.ConstPlaneRow(c, y0 + y)[x0 + x] - + smooth.ConstPlaneRow(c, y0 + y)[x0 + x]; + } + } + } + } + } +#if JXL_DEBUG_DOT_DETECT + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, "Candidates: %zu, Dots: %zu\n", + components.size(), dots.size()); + ApplyGaussianEllipses(&smooth, dots, 1.0); + aux.DumpXybImage("draw", smooth); + ApplyGaussianEllipses(&smooth, dots, -1.0); + + auto qdots = QuantizeGaussianEllipses(dots, qParams); + auto deq = DequantizeGaussianEllipses(qdots, qParams); + ApplyGaussianEllipses(&smooth, deq, 1.0); + aux.DumpXybImage("qdraw", smooth); + ApplyGaussianEllipses(&smooth, deq, -1.0); +#endif // JXL_DEBUG_DOT_DETECT + return dots; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h new file mode 100644 index 000000000000..37efdf93dead --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// We attempt to remove dots, or speckle from images using Gaussian blur. +#ifndef LIB_JXL_ENC_DETECT_DOTS_H_ +#define LIB_JXL_ENC_DETECT_DOTS_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct GaussianDetectParams { + double t_high = 0; // at least one pixel must have larger energy than t_high + double t_low = 0; // all pixels must have a larger energy than tLow + uint32_t maxWinSize = 0; // discard dots larger than this containing window + double maxL2Loss = 0; + double maxCustomLoss = 0; + double minIntensity = 0; // If the intensity is too low, discard it + double maxDistMeanMode = 0; // The mean and the mode must be close + size_t maxNegPixels = 0; // Maximum number of negative pixel + size_t minScore = 0; + size_t maxCC = 50; // Maximum number of CC to keep + size_t percCC = 15; // Percentage in [0,100] of CC to keep +}; + +// Ellipse Quantization Params +struct EllipseQuantParams { + size_t xsize; // Image size in x + size_t ysize; // Image size in y + size_t qPosition; // Position quantization delta + // Quantization for the Gaussian sigma parameters + double minSigma; + double maxSigma; + size_t qSigma; // number of quantization levels + // Quantization for the rotation angle (between -pi and pi) + size_t qAngle; + // Quantization for the intensity + std::array minIntensity; + std::array maxIntensity; + std::array qIntensity; // number of quantization levels + // Extra parameters for the encoding + bool subtractQuantized; // Should we subtract quantized or detected dots? + float ytox; + float ytob; + + void QuantPositionSize(size_t* xsize, size_t* ysize) const; +}; + +// Detects dots in XYB image. +std::vector DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DETECT_DOTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc new file mode 100644 index 000000000000..189f84cbb388 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_dot_dictionary.h" + +#include +#include + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_detect_dots.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Private implementation of Dictionary Encode/Decode +namespace { + +/* Quantization constants for Ellipse dots */ +const size_t kEllipsePosQ = 2; // Quantization level for the position +const double kEllipseMinSigma = 0.1; // Minimum sigma value +const double kEllipseMaxSigma = 3.1; // Maximum Sigma value +const size_t kEllipseSigmaQ = 16; // Number of quantization levels for sigma +const size_t kEllipseAngleQ = 8; // Quantization level for the angle +// TODO: fix these values. +const std::array kEllipseMinIntensity{-0.05, 0.0, -0.5}; +const std::array kEllipseMaxIntensity{0.05, 1.0, 0.4}; +const std::array kEllipseIntensityQ{10, 36, 10}; +} // namespace + +std::vector FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool) { + if (ApplyOverride(cparams.dots, + cparams.butteraugli_distance >= kMinButteraugliForDots)) { + GaussianDetectParams ellipse_params; + ellipse_params.t_high = 0.04; + ellipse_params.t_low = 0.02; + ellipse_params.maxWinSize = 5; + ellipse_params.maxL2Loss = 0.005; + ellipse_params.maxCustomLoss = 300; + ellipse_params.minIntensity = 0.12; + ellipse_params.maxDistMeanMode = 1.0; + ellipse_params.maxNegPixels = 0; + ellipse_params.minScore = 12.0; + ellipse_params.maxCC = 100; + ellipse_params.percCC = 100; + EllipseQuantParams qParams{ + opsin.xsize(), opsin.ysize(), kEllipsePosQ, + kEllipseMinSigma, kEllipseMaxSigma, kEllipseSigmaQ, + kEllipseAngleQ, kEllipseMinIntensity, kEllipseMaxIntensity, + kEllipseIntensityQ, kEllipsePosQ <= 5, cmap.YtoXRatio(0), + cmap.YtoBRatio(0)}; + + return DetectGaussianEllipses(opsin, ellipse_params, qParams, pool); + } + return {}; +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h new file mode 100644 index 000000000000..06171fdb23da --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h @@ -0,0 +1,44 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_DOT_DICTIONARY_H_ +#define LIB_JXL_ENC_DOT_DICTIONARY_H_ + +// Dots are stored in a dictionary to avoid storing similar dots multiple +// times. + +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" + +namespace jxl { + +std::vector FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DOT_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc new file mode 100644 index 000000000000..20168c6ed307 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc @@ -0,0 +1,277 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_entropy_coder.h" + +#include +#include + +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_entropy_coder.cc" +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Returns number of non-zero coefficients (but skip LLF). +// We cannot rely on block[] being all-zero bits, so first truncate to integer. +// Also writes the per-8x8 block nzeros starting at nzeros_pos. +int32_t NumNonZeroExceptLLF(const size_t cx, const size_t cy, + const AcStrategy acs, const size_t covered_blocks, + const size_t log2_covered_blocks, + const int32_t* JXL_RESTRICT block, + const size_t nzeros_stride, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // Mask sufficient for one row of coefficients. + HWY_ALIGN const int32_t + llf_mask_lanes[AcStrategy::kMaxCoeffBlocks * (1 + kBlockDim)] = { + -1, -1, -1, -1}; + // First cx=1,2,4 elements are FF..FF, others 0. + const int32_t* llf_mask_pos = + llf_mask_lanes + AcStrategy::kMaxCoeffBlocks - cx; + + // Rows with LLF: mask out the LLF + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto llf_mask = LoadU(di, llf_mask_pos + x); + + // LLF counts as zero so we don't include it in nzeros. + const auto coef = + AndNot(llf_mask, Load(di, &block[y * cx * kBlockDim + x])); + + neg_sum_zero += VecFromMask(coef == zero); + } + } + } + + // Remaining rows: no mask + for (size_t y = cy; y < cy * kBlockDim; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * cx * kBlockDim + x]); + neg_sum_zero += VecFromMask(coef == zero); + } + } + + // We want area - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(cx * cy * kDCTBlockSize) + GetLane(SumOfLanes(neg_sum_zero)); + + const int32_t shifted_nzeros = static_cast( + (nzeros + covered_blocks - 1) >> log2_covered_blocks); + // Need non-canonicalized dimensions! + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + nzeros_pos[x + y * nzeros_stride] = shifted_nzeros; + } + } + + return nzeros; +} + +// Specialization for 8x8, where only top-left is LLF/DC. +// About 1% overall speedup vs. NumNonZeroExceptLLF. +int32_t NumNonZero8x8ExceptDC(const int32_t* JXL_RESTRICT block, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // First row has DC, so mask + const size_t y = 0; + HWY_ALIGN const int32_t dc_mask_lanes[kBlockDim] = {-1}; + + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto dc_mask = Load(di, dc_mask_lanes + x); + + // DC counts as zero so we don't include it in nzeros. + const auto coef = AndNot(dc_mask, Load(di, &block[y * kBlockDim + x])); + + neg_sum_zero += VecFromMask(coef == zero); + } + } + + // Remaining rows: no mask + for (size_t y = 1; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * kBlockDim + x]); + neg_sum_zero += VecFromMask(coef == zero); + } + } + + // We want 64 - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(kDCTBlockSize) + GetLane(SumOfLanes(neg_sum_zero)); + + *nzeros_pos = nzeros; + + return nzeros; +} + +// The number of nonzeros of each block is predicted from the top and the left +// blocks, with opportune scaling to take into account the number of blocks of +// each strategy. The predicted number of nonzeros divided by two is used as a +// context; if this number is above 63, a specific context is used. If the +// number of nonzeros of a strategy is above 63, it is written directly using a +// fixed number of bits (that depends on the size of the strategy). +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + const size_t xsize_blocks = rect.xsize(); + const size_t ysize_blocks = rect.ysize(); + + // TODO(user): update the estimate: usually less coefficients are used. + output->reserve(output->size() + + 3 * xsize_blocks * ysize_blocks * kDCTBlockSize); + + size_t offset[3] = {}; + const size_t nzeros_stride = tmp_num_nzeroes->PixelsPerRow(); + for (size_t by = 0; by < ysize_blocks; ++by) { + size_t sby[3] = {by >> cs.VShift(0), by >> cs.VShift(1), + by >> cs.VShift(2)}; + int32_t* JXL_RESTRICT row_nzeros[3] = { + tmp_num_nzeroes->PlaneRow(0, sby[0]), + tmp_num_nzeroes->PlaneRow(1, sby[1]), + tmp_num_nzeroes->PlaneRow(2, sby[2]), + }; + const int32_t* JXL_RESTRICT row_nzeros_top[3] = { + sby[0] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(0, sby[0] - 1), + sby[1] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(1, sby[1] - 1), + sby[2] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(2, sby[2] - 1), + }; + const uint8_t* JXL_RESTRICT row_qdc = + qdc.ConstRow(rect.y0() + by) + rect.x0(); + const int32_t* JXL_RESTRICT row_qf = rect.ConstRow(qf, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + size_t sbx[3] = {bx >> cs.HShift(0), bx >> cs.HShift(1), + bx >> cs.HShift(2)}; + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + const size_t covered_blocks = cx * cy; // = #LLF coefficients + const size_t log2_covered_blocks = + Num0BitsBelowLS1Bit_Nonzero(covered_blocks); + const size_t size = covered_blocks * kDCTBlockSize; + + CoefficientLayout(&cy, &cx); // swap cx/cy to canonical order + + for (int c : {1, 0, 2}) { + if (sbx[c] << cs.HShift(c) != bx) continue; + if (sby[c] << cs.VShift(c) != by) continue; + const int32_t* JXL_RESTRICT block = ac_rows[c] + offset[c]; + + int32_t nzeros = + (covered_blocks == 1) + ? NumNonZero8x8ExceptDC(block, row_nzeros[c] + sbx[c]) + : NumNonZeroExceptLLF(cx, cy, acs, covered_blocks, + log2_covered_blocks, block, nzeros_stride, + row_nzeros[c] + sbx[c]); + + int ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &orders[CoeffOrderOffset(ord, c)]; + + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top[c], row_nzeros[c], sbx[c], 32); + size_t block_ctx = + block_ctx_map.Context(row_qdc[bx], row_qf[sbx[c]], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx); + + output->emplace_back(nzero_ctx, nzeros); + const size_t histo_offset = + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + // Skip LLF. + size_t prev = (nzeros > static_cast(size / 16) ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + int32_t coeff = block[order[k]]; + size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + uint32_t u_coeff = PackSigned(coeff); + output->emplace_back(ctx, u_coeff); + prev = coeff != 0; + nzeros -= prev; + } + JXL_DASSERT(nzeros == 0); + offset[c] += size; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(TokenizeCoefficients); +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + return HWY_DYNAMIC_DISPATCH(TokenizeCoefficients)( + orders, rect, ac_rows, ac_strategy, cs, tmp_num_nzeroes, output, qdc, qf, + block_ctx_map); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h new file mode 100644 index 000000000000..cf546e022d01 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_ENTROPY_CODER_H_ +#define LIB_JXL_ENC_ENTROPY_CODER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_context.h" // BlockCtxMap +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/frame_header.h" // YCbCrChromaSubsampling +#include "lib/jxl/image.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +// Generate DCT NxN quantized AC values tokens. +// Only the subset "rect" [in units of blocks] within all images. +// See also DecodeACVarBlock. +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ENTROPY_CODER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image.cc new file mode 100644 index 000000000000..48706652be87 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image.cc @@ -0,0 +1,292 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_external_image.h" + +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" + +namespace jxl { +namespace { + +// Loads a float in big endian +float LoadBEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadBE32(p); + memcpy(&value, &u, 4); + return value; +} + +// Loads a float in little endian +float LoadLEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadLE32(p); + memcpy(&value, &u, 4); + return value; +} + +typedef uint32_t(LoadFuncType)(const uint8_t* p); +template +void JXL_INLINE LoadFloatRow(float* JXL_RESTRICT row_out, const uint8_t* in, + float mul, size_t xsize, size_t bytes_per_pixel) { + size_t i = 0; + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = mul * LoadFunc(in + i); + i += bytes_per_pixel; + } +} + +uint32_t JXL_INLINE Load8(const uint8_t* p) { return *p; } + +} // namespace + +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + bool has_alpha, bool alpha_is_premultiplied, + size_t bits_per_sample, JxlEndianness endianness, + bool flipped_y, ThreadPool* pool, ImageBundle* ib) { + if (bits_per_sample < 1 || bits_per_sample > 32) { + return JXL_FAILURE("Invalid bits_per_sample value."); + } + // TODO(deymo): Implement 1-bit per sample as 8 samples per byte. In + // any other case we use DivCeil(bits_per_sample, 8) bytes per pixel per + // channel. + if (bits_per_sample == 1) { + return JXL_FAILURE("packed 1-bit per sample is not yet supported"); + } + + const size_t color_channels = c_current.Channels(); + const size_t channels = color_channels + has_alpha; + + // bytes_per_channel and bytes_per_pixel are only valid for + // bits_per_sample > 1. + const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t bytes_per_pixel = channels * bytes_per_channel; + + const size_t row_size = xsize * bytes_per_pixel; + if (ysize && bytes.size() / ysize < row_size) { + return JXL_FAILURE("Buffer size is too small"); + } + + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + const uint8_t* const in = bytes.data(); + + Image3F color(xsize, ysize); + ImageF alpha; + if (has_alpha) { + alpha = ImageF(xsize, ysize); + } + + // Matches the old behavior of PackedImage. + // TODO(sboukortt): make this a parameter. + const bool float_in = bits_per_sample == 32; + + const auto get_y = [flipped_y, ysize](const size_t y) { + return flipped_y ? ysize - 1 - y : y; + }; + + if (float_in) { + if (bits_per_sample != 32) { + return JXL_FAILURE("non-32-bit float not supported"); + } + for (size_t c = 0; c < color_channels; ++c) { + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t y = get_y(task); + size_t i = + row_size * task + (c * bits_per_sample / jxl::kBitsPerByte); + float* JXL_RESTRICT row_out = color.PlaneRow(c, y); + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat(in + i); + i += bytes_per_pixel; + } + } + }, + "ConvertRGBFloat"); + } + } else { + // Multiplier to convert from the integer range to floating point 0-1 range. + float mul = 1. / ((1ull << bits_per_sample) - 1); + for (size_t c = 0; c < color_channels; ++c) { + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t y = get_y(task); + size_t i = row_size * task + c * bytes_per_channel; + float* JXL_RESTRICT row_out = color.PlaneRow(c, y); + // TODO(deymo): add bits_per_sample == 1 case here. Also maybe + // implement masking if bits_per_sample is not a multiple of 8. + if (bits_per_sample <= 8) { + LoadFloatRow(row_out, in + i, mul, xsize, bytes_per_pixel); + } else if (bits_per_sample <= 16) { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } else if (bits_per_sample <= 24) { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } else { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } + }, + "ConvertRGBUint"); + } + } + + if (color_channels == 1) { + CopyImageTo(color.Plane(0), &color.Plane(1)); + CopyImageTo(color.Plane(0), &color.Plane(2)); + } + + ib->SetFromImage(std::move(color), c_current); + + if (has_alpha) { + if (float_in) { + if (bits_per_sample != 32) { + return JXL_FAILURE("non-32-bit float not supported"); + } + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t y = get_y(task); + size_t i = row_size * task + + (color_channels * bits_per_sample / jxl::kBitsPerByte); + float* JXL_RESTRICT row_out = alpha.Row(y); + if (little_endian) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadLEFloat(in + i); + i += bytes_per_pixel; + } + } else { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = LoadBEFloat(in + i); + i += bytes_per_pixel; + } + } + }, + "ConvertAlphaFloat"); + } else { + float mul = 1. / ((1ull << bits_per_sample) - 1); + RunOnPool( + pool, 0, static_cast(ysize), ThreadPool::SkipInit(), + [&](const int task, int /*thread*/) { + const size_t y = get_y(task); + size_t i = row_size * task + color_channels * bytes_per_channel; + float* JXL_RESTRICT row_out = alpha.Row(y); + // TODO(deymo): add bits_per_sample == 1 case here. Also maybe + // implement masking if bits_per_sample is not a multiple of 8. + if (bits_per_sample <= 8) { + LoadFloatRow(row_out, in + i, mul, xsize, bytes_per_pixel); + } else if (bits_per_sample <= 16) { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } else if (bits_per_sample <= 24) { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } else { + if (little_endian) { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } else { + LoadFloatRow(row_out, in + i, mul, xsize, + bytes_per_pixel); + } + } + }, + "ConvertAlphaUint"); + } + + ib->SetAlpha(std::move(alpha), alpha_is_premultiplied); + } + + return true; +} + +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib) { + size_t bitdepth; + + // TODO(zond): Make this accept more than float and uint8/16. + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + bitdepth = 32; + } else if (pixel_format.data_type == JXL_TYPE_UINT8) { + bitdepth = 8; + } else if (pixel_format.data_type == JXL_TYPE_UINT16) { + bitdepth = 16; + } else { + return JXL_FAILURE("unsupported bitdepth"); + } + + JXL_RETURN_IF_ERROR(ConvertFromExternal( + jxl::Span(static_cast(const_cast(buffer)), + size), + xsize, ysize, c_current, + /*has_alpha=*/pixel_format.num_channels == 2 || + pixel_format.num_channels == 4, + /*alpha_is_premultiplied=*/false, bitdepth, pixel_format.endianness, + /*flipped_y=*/false, pool, ib)); + ib->VerifyMetadata(); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image.h b/third_party/jpeg-xl/lib/jxl/enc_external_image.h new file mode 100644 index 000000000000..6ea321412cee --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image.h @@ -0,0 +1,59 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_ENC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include +#include + +#include "jxl/types.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Return the size in bytes of a given xsize, channels and bits_per_sample +// interleaved image. +constexpr size_t RowSize(size_t xsize, size_t channels, + size_t bits_per_sample) { + return bits_per_sample == 1 + ? DivCeil(xsize, kBitsPerByte) + : xsize * channels * DivCeil(bits_per_sample, kBitsPerByte); +} + +// Convert an interleaved pixel buffer to the internal ImageBundle +// representation. This is the opposite of ConvertToExternal(). +Status ConvertFromExternal(Span bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + bool has_alpha, bool alpha_is_premultiplied, + size_t bits_per_sample, JxlEndianness endianness, + bool flipped_y, ThreadPool* pool, ImageBundle* ib); + +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib); + +} // namespace jxl + +#endif // LIB_JXL_ENC_EXTERNAL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc new file mode 100644 index 000000000000..72c8c77588d7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Encoder case, deinterleaves a buffer. +void BM_EncExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + std::vector interleaved(xsize * ysize * 4); + + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertFromExternal( + Span(interleaved.data(), interleaved.size()), xsize, + ysize, + /*c_current=*/ColorEncoding::SRGB(), + /*has_alpha=*/true, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/8, JXL_NATIVE_ENDIAN, + /*flipped_y=*/false, + /*pool=*/nullptr, &ib)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_EncExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc new file mode 100644 index 000000000000..13c0debb1cc3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_external_image.h" + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +#if !defined(JXL_CRASH_ON_ERROR) +TEST(ExternalImageTest, InvalidSize) { + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + const uint8_t buf[10 * 100 * 8] = {}; + EXPECT_FALSE(ConvertFromExternal( + Span(buf, 10), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), /*has_alpha=*/true, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*flipped_y=*/false, nullptr, &ib)); + EXPECT_FALSE(ConvertFromExternal( + Span(buf, sizeof(buf) - 1), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), /*has_alpha=*/true, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*flipped_y=*/false, nullptr, &ib)); + EXPECT_TRUE( + ConvertFromExternal(Span(buf, sizeof(buf)), /*xsize=*/10, + /*ysize=*/100, /*c_current=*/ColorEncoding::SRGB(), + /*has_alpha=*/true, /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*flipped_y=*/false, nullptr, &ib)); +} +#endif + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_fast_heuristics.cc b/third_party/jpeg-xl/lib/jxl/enc_fast_heuristics.cc new file mode 100644 index 000000000000..517947fb3e7b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fast_heuristics.cc @@ -0,0 +1,370 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "lib/jxl/convolve.h" +#include "lib/jxl/enc_ac_strategy.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ar_control_field.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_heuristics.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/gauss_blur.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_fast_heuristics.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { +using DF4 = HWY_CAPPED(float, 4); +DF4 df4; +HWY_FULL(float) df; + +Status Heuristics(PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* linear, Image3F* opsin, ThreadPool* pool, + AuxOut* aux_out) { + PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented"); + CompressParams& cparams = enc_state->cparams; + PassesSharedState& shared = enc_state->shared; + const FrameDimensions& frame_dim = enc_state->shared.frame_dim; + JXL_CHECK(cparams.butteraugli_distance > 0); + + // TODO(veluca): make this tiled. + if (shared.frame_header.loop_filter.gab) { + GaborishInverse(opsin, 0.9908511000000001f, pool); + } + // Compute image of high frequencies by removing a blurred version. + // TODO(veluca): certainly can be made faster, and use less memory... + constexpr size_t pad = 16; + Image3F padded = PadImageMirror(*opsin, pad, pad); + // Make the image (X, Y, B-Y) + // TODO(veluca): SubtractFrom is not parallel *and* not SIMD-fied. + SubtractFrom(padded.Plane(1), &padded.Plane(2)); + // Ensure that OOB access for CfL does nothing. Not necessary if doing things + // properly... + Image3F hf(padded.xsize() + 64, padded.ysize()); + ZeroFillImage(&hf); + hf.ShrinkTo(padded.xsize(), padded.ysize()); + ImageF temp(padded.xsize(), padded.ysize()); + // TODO(veluca): consider some faster blurring method. + auto g = CreateRecursiveGaussian(11.415258091746161); + for (size_t c = 0; c < 3; c++) { + FastGaussian(g, padded.Plane(c), pool, &temp, &hf.Plane(c)); + SubtractFrom(padded.Plane(c), &hf.Plane(c)); + } + // TODO(veluca): DC CfL? + size_t xcolortiles = DivCeil(frame_dim.xsize_blocks, kColorTileDimInBlocks); + size_t ycolortiles = DivCeil(frame_dim.ysize_blocks, kColorTileDimInBlocks); + RunOnPool( + pool, 0, xcolortiles * ycolortiles, ThreadPool::SkipInit(), + [&](size_t tile_id, size_t _) { + size_t tx = tile_id % xcolortiles; + size_t ty = tile_id / xcolortiles; + size_t x0 = tx * kColorTileDim; + size_t x1 = std::min(x0 + kColorTileDim, hf.xsize()); + size_t y0 = ty * kColorTileDim; + size_t y1 = std::min(y0 + kColorTileDim, hf.ysize()); + for (size_t c : {0, 2}) { + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + auto ca = Zero(df); + auto cb = Zero(df); + const auto inv_color_factor = Set(df, kInvColorFactor); + for (size_t y = y0; y < y1; y++) { + const float* row_m = hf.PlaneRow(1, y); + const float* row_s = hf.PlaneRow(c, y); + for (size_t x = x0; x < x1; x += Lanes(df)) { + // color residual = ax + b + const auto a = inv_color_factor * Load(df, row_m + x); + const auto b = Zero(df) - Load(df, row_s + x); + ca = MulAdd(a, a, ca); + cb = MulAdd(a, b, cb); + } + } + float best = + -GetLane(SumOfLanes(cb)) / (GetLane(SumOfLanes(ca)) + 1e-9f); + int8_t& res = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map) + .Row(ty)[tx]; + res = std::max(-128.0f, std::min(127.0f, roundf(best))); + } + }, + "CfL"); + Image3F pooled(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4); + Image3F summed(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4); + RunOnPool( + pool, 0, frame_dim.ysize_padded / 4, ThreadPool::SkipInit(), + [&](size_t y, size_t _) { + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT row_out = pooled.PlaneRow(c, y); + float* JXL_RESTRICT row_out_avg = summed.PlaneRow(c, y); + const float* JXL_RESTRICT row_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + row_in[iy] = hf.PlaneRow(c, 4 * y + pad + iy); + } + for (size_t x = 0; x < frame_dim.xsize_padded / 4; x++) { + auto max = Zero(df4); + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + const auto nn = Abs(Load(df4, row_in[iy] + x * 4 + ix + pad)); + sum += nn; + max = IfThenElse(max > nn, max, nn); + } + } + row_out_avg[x] = GetLane(SumOfLanes(sum)); + row_out[x] = GetLane(MaxOfLanes(max)); + } + } + }, + "MaxPool"); + // TODO(veluca): better handling of the border + // TODO(veluca): consider some faster blurring method. + // TODO(veluca): parallelize. + // Remove noise from the resulting image. + auto g2 = CreateRecursiveGaussian(2.0849544429861884); + constexpr size_t pad2 = 16; + Image3F summed_pad = PadImageMirror(summed, pad2, pad2); + ImageF tmp_out(summed_pad.xsize(), summed_pad.ysize()); + ImageF tmp2(summed_pad.xsize(), summed_pad.ysize()); + Image3F pooled_pad = PadImageMirror(pooled, pad2, pad2); + for (size_t c = 0; c < 3; c++) { + FastGaussian(g2, summed_pad.Plane(c), pool, &tmp2, &tmp_out); + const auto unblurred_multiplier = Set(df, 0.5f); + for (size_t y = 0; y < summed.ysize(); y++) { + float* row = summed.PlaneRow(c, y); + const float* row_blur = tmp_out.Row(y + pad2); + for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) { + const auto b = Load(df, row_blur + x + pad2); + const auto o = Load(df, row + x) * unblurred_multiplier; + const auto m = IfThenElse(b > o, b, o); + Store(m, df, row + x); + } + } + } + for (size_t c = 0; c < 3; c++) { + FastGaussian(g2, pooled_pad.Plane(c), pool, &tmp2, &tmp_out); + const auto unblurred_multiplier = Set(df, 0.5f); + for (size_t y = 0; y < pooled.ysize(); y++) { + float* row = pooled.PlaneRow(c, y); + const float* row_blur = tmp_out.Row(y + pad2); + for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) { + const auto b = Load(df, row_blur + x + pad2); + const auto o = Load(df, row + x) * unblurred_multiplier; + const auto m = IfThenElse(b > o, b, o); + Store(m, df, row + x); + } + } + } + const static float kChannelMul[3] = { + 7.9644294909680253f, + 0.5700000183257159f, + 0.20267448837597055f, + }; + ImageF pooledhf44(pooled.xsize(), pooled.ysize()); + for (size_t y = 0; y < pooled.ysize(); y++) { + const float* row_in_x = pooled.ConstPlaneRow(0, y); + const float* row_in_y = pooled.ConstPlaneRow(1, y); + const float* row_in_b = pooled.ConstPlaneRow(2, y); + float* row_out = pooledhf44.Row(y); + for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) { + auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x); + v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v); + v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v); + Store(v, df, row_out + x); + } + } + ImageF summedhf44(summed.xsize(), summed.ysize()); + for (size_t y = 0; y < summed.ysize(); y++) { + const float* row_in_x = summed.ConstPlaneRow(0, y); + const float* row_in_y = summed.ConstPlaneRow(1, y); + const float* row_in_b = summed.ConstPlaneRow(2, y); + float* row_out = summedhf44.Row(y); + for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) { + auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x); + v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v); + v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v); + Store(v, df, row_out + x); + } + } + aux_out->DumpPlaneNormalized("pooledhf44", pooledhf44); + aux_out->DumpPlaneNormalized("summedhf44", summedhf44); + + static const float kDcQuantMul = 0.88170190420916206; + static const float kAcQuantMul = 2.5165738934721524; + + float dc_quant = kDcQuantMul * InitialQuantDC(cparams.butteraugli_distance); + float ac_quant_base = kAcQuantMul / cparams.butteraugli_distance; + ImageF quant_field(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + + static_assert(kColorTileDim == 64, "Fix the code below"); + auto mmacs = [&](size_t bx, size_t by, AcStrategy acs, float& min, + float& max) { + min = 1e10; + max = 0; + for (size_t y = 2 * by; y < 2 * (by + acs.covered_blocks_y()); y++) { + const float* row = summedhf44.Row(y); + for (size_t x = 2 * bx; x < 2 * (bx + acs.covered_blocks_x()); x++) { + min = std::min(min, row[x]); + max = std::max(max, row[x]); + } + } + }; + // Multipliers for allowed range of summedhf44. + std::pair candidates[] = { + // The order is such that, in case of ties, 8x8 is favoured over 4x4 which + // is favoured over 2x2. Similarly, we prefer square transforms over + // same-area rectangular ones. + {AcStrategy::Type::DCT2X2, 1.5f}, + {AcStrategy::Type::DCT4X4, 1.4f}, + {AcStrategy::Type::DCT4X8, 1.2f}, + {AcStrategy::Type::DCT8X4, 1.2f}, + {AcStrategy::Type::AFV0, + 1.15f}, // doesn't really work with these heuristics + {AcStrategy::Type::AFV1, 1.15f}, + {AcStrategy::Type::AFV2, 1.15f}, + {AcStrategy::Type::AFV3, 1.15f}, + {AcStrategy::Type::DCT, 1.0f}, + {AcStrategy::Type::DCT16X8, 0.8f}, + {AcStrategy::Type::DCT8X16, 0.8f}, + {AcStrategy::Type::DCT16X16, 0.2f}, + {AcStrategy::Type::DCT16X32, 0.2f}, + {AcStrategy::Type::DCT32X16, 0.2f}, + {AcStrategy::Type::DCT32X32, 0.2f}, + {AcStrategy::Type::DCT32X64, 0.1f}, + {AcStrategy::Type::DCT64X32, 0.1f}, + {AcStrategy::Type::DCT64X64, 0.04f}, + +#if 0 + {AcStrategy::Type::DCT2X2, 1e+10}, {AcStrategy::Type::DCT4X4, 2.0f}, + {AcStrategy::Type::DCT, 1.0f}, {AcStrategy::Type::DCT16X8, 1.0f}, + {AcStrategy::Type::DCT8X16, 1.0f}, {AcStrategy::Type::DCT32X8, 1.0f}, + {AcStrategy::Type::DCT8X32, 1.0f}, {AcStrategy::Type::DCT32X16, 1.0f}, + {AcStrategy::Type::DCT16X32, 1.0f}, {AcStrategy::Type::DCT64X32, 1.0f}, + {AcStrategy::Type::DCT32X64, 1.0f}, {AcStrategy::Type::DCT16X16, 1.0f}, + {AcStrategy::Type::DCT32X32, 1.0f}, {AcStrategy::Type::DCT64X64, 1.0f}, +#endif + // TODO(veluca): figure out if we want 4x8 and/or AVF. + }; + float max_range = 1e-8f + 0.5f * std::pow(cparams.butteraugli_distance, 0.5f); + // Change quant field and sharpness amounts based on (pooled|summed)hf44, and + // compute block sizes. + // TODO(veluca): maybe this could be done per group: it would allow choosing + // floating blocks better. + RunOnPool( + pool, 0, xcolortiles * ycolortiles, ThreadPool::SkipInit(), + [&](size_t tile_id, size_t _) { + size_t tx = tile_id % xcolortiles; + size_t ty = tile_id / xcolortiles; + size_t x0 = tx * kColorTileDim / kBlockDim; + size_t x1 = std::min(x0 + kColorTileDimInBlocks, quant_field.xsize()); + size_t y0 = ty * kColorTileDim / kBlockDim; + size_t y1 = std::min(y0 + kColorTileDimInBlocks, quant_field.ysize()); + size_t qf_stride = quant_field.PixelsPerRow(); + size_t epf_stride = shared.epf_sharpness.PixelsPerRow(); + bool chosen_mask[64] = {}; + for (size_t y = y0; y < y1; y++) { + uint8_t* epf_row = shared.epf_sharpness.Row(y); + float* qf_row = quant_field.Row(y); + for (size_t x = x0; x < x1; x++) { + if (chosen_mask[(y - y0) * 8 + (x - x0)]) continue; + // Default to DCT8 just in case something funny happens in the loop + // below. + AcStrategy::Type best = AcStrategy::DCT; + size_t best_covered = 1; + float qf = ac_quant_base; + for (size_t i = 0; i < sizeof(candidates) / sizeof(*candidates); + i++) { + AcStrategy acs = AcStrategy::FromRawStrategy(candidates[i].first); + if (y + acs.covered_blocks_y() > y1) continue; + if (x + acs.covered_blocks_x() > x1) continue; + bool fits = true; + for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) { + for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) { + if (chosen_mask[(iy - y0) * 8 + (ix - x0)]) { + fits = false; + break; + } + } + } + if (!fits) continue; + float min, max; + mmacs(x, y, acs, min, max); + if (max - min > max_range * candidates[i].second) continue; + size_t cb = acs.covered_blocks_x() * acs.covered_blocks_y(); + if (cb >= best_covered) { + best_covered = cb; + best = candidates[i].first; + // TODO(veluca): make this better. + qf = ac_quant_base / + (3.9312946339134007f + 2.6011435675118082f * min); + } + } + shared.ac_strategy.Set(x, y, best); + AcStrategy acs = AcStrategy::FromRawStrategy(best); + for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) { + for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) { + chosen_mask[(iy - y0) * 8 + (ix - x0)] = 1; + qf_row[ix + (iy - y) * qf_stride] = qf; + } + } + // TODO + for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) { + for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) { + epf_row[ix + (iy - y) * epf_stride] = 4; + } + } + } + } + }, + "QF+ACS+EPF"); + aux_out->DumpPlaneNormalized("qf", quant_field); + aux_out->DumpPlaneNormalized("epf", shared.epf_sharpness); + DumpAcStrategy(shared.ac_strategy, frame_dim.xsize_padded, + frame_dim.ysize_padded, "acs", aux_out); + + shared.quantizer.SetQuantField(dc_quant, quant_field, + &shared.raw_quant_field); + + return true; +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(Heuristics); +Status FastEncoderHeuristics::LossyFrameHeuristics( + PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* linear, Image3F* opsin, ThreadPool* pool, + AuxOut* aux_out) { + return HWY_DYNAMIC_DISPATCH(Heuristics)(enc_state, modular_frame_encoder, + linear, opsin, pool, aux_out); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/enc_file.cc b/third_party/jpeg-xl/lib/jxl/enc_file.cc new file mode 100644 index 000000000000..a7396462c5f2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_file.cc @@ -0,0 +1,283 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_file.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +namespace { + +// DC + 'Very Low Frequency' +PassDefinition progressive_passes_dc_vlf[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}}; + +PassDefinition progressive_passes_dc_lf[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}}; + +PassDefinition progressive_passes_dc_lf_salient_ac[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/true, + /*suitable_for_downsampling_of_at_least=*/0}}; + +PassDefinition progressive_passes_dc_lf_salient_ac_other_ac[] = { + {/*num_coefficients=*/2, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/true, + /*suitable_for_downsampling_of_at_least=*/0}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/0}}; + +PassDefinition progressive_passes_dc_quant_ac_full_ac[] = { + {/*num_coefficients=*/8, /*shift=*/2, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/8, /*shift=*/1, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/0}, +}; + +constexpr uint16_t kExifOrientationTag = 274; + +// Parses the Exif data just enough to extract any render-impacting info. +// If the Exif data is invalid or could not be parsed, then it is treated +// as a no-op. +// TODO (jon): tag 1 can be used to represent Adobe RGB 1998 if it has value +// "R03" +// TODO (jon): set intrinsic dimensions according to +// https://discourse.wicg.io/t/proposal-exif-image-resolution-auto-and-from-image/4326/24 +void InterpretExif(const PaddedBytes& exif, CodecMetadata* metadata) { + if (exif.size() < 12) return; // not enough bytes for a valid exif blob + const uint8_t* t = exif.data(); + bool bigendian = false; + if (LoadLE32(t) == 0x2A004D4D) { + bigendian = true; + } else if (LoadLE32(t) != 0x002A4949) { + return; // not a valid tiff header + } + t += 4; + uint32_t offset = (bigendian ? LoadBE32(t) : LoadLE32(t)); + if (exif.size() < 12 + offset + 2 || offset < 8) return; + t += offset - 4; + uint16_t nb_tags = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + while (nb_tags > 0) { + if (t + 12 >= exif.data() + exif.size()) return; + uint16_t tag = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + uint16_t type = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + uint32_t count = (bigendian ? LoadBE32(t) : LoadLE32(t)); + t += 4; + uint16_t value = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 4; + if (tag == kExifOrientationTag) { + if (type == 3 && count == 1) { + if (value >= 1 && value <= 8) { + metadata->m.orientation = value; + } + } + } + nb_tags--; + } +} + +Status PrepareCodecMetadataFromIO(const CompressParams& cparams, + const CodecInOut* io, + CodecMetadata* metadata) { + *metadata = io->metadata; + + JXL_RETURN_IF_ERROR(metadata->size.Set(io->xsize(), io->ysize())); + + // Keep ICC profile in lossless modes because a reconstructed profile may be + // slightly different (quantization). + // Also keep ICC in JPEG reconstruction mode as we need byte-exact profiles. + const bool lossless_modular = + cparams.modular_mode && cparams.quality_pair.first == 100.0f; + if (!lossless_modular && !io->Main().IsJPEG()) { + metadata->m.color_encoding.DecideIfWantICC(); + } + + metadata->m.xyb_encoded = + cparams.color_transform == ColorTransform::kXYB ? true : false; + + InterpretExif(io->blobs.exif, metadata); + + return true; +} + +} // namespace + +Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, + const CodecMetadata* metadata, ThreadPool* pool, + BitWriter* JXL_RESTRICT writer) { + BitWriter preview_writer; + // TODO(janwas): also support generating preview by downsampling + if (ib.HasColor()) { + AuxOut aux_out; + PassesEncoderState passes_enc_state; + // TODO(lode): check if we want all extra channels and matching xyb_encoded + // for the preview, such that using the main ImageMetadata object for + // encoding this frame is warrented. + FrameInfo frame_info; + frame_info.is_preview = true; + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, frame_info, metadata, ib, + &passes_enc_state, pool, &preview_writer, + &aux_out)); + preview_writer.ZeroPadToByte(); + } + + if (preview_writer.BitsWritten() != 0) { + writer->ZeroPadToByte(); + writer->AppendByteAligned(preview_writer); + } + + return true; +} + +Status WriteHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out) { + // Marker/signature + BitWriter::Allotment allotment(writer, 16); + writer->Write(8, 0xFF); + writer->Write(8, kCodestreamMarker); + ReclaimAndCharge(writer, &allotment, kLayerHeader, aux_out); + + JXL_RETURN_IF_ERROR( + WriteSizeHeader(metadata->size, writer, kLayerHeader, aux_out)); + + JXL_RETURN_IF_ERROR( + WriteImageMetadata(metadata->m, writer, kLayerHeader, aux_out)); + + metadata->transform_data.nonserialized_xyb_encoded = metadata->m.xyb_encoded; + JXL_RETURN_IF_ERROR( + Bundle::Write(metadata->transform_data, writer, kLayerHeader, aux_out)); + + return true; +} + +Status EncodeFile(const CompressParams& cparams, const CodecInOut* io, + PassesEncoderState* passes_enc_state, PaddedBytes* compressed, + AuxOut* aux_out, ThreadPool* pool) { + io->CheckMetadata(); + BitWriter writer; + + std::unique_ptr metadata = jxl::make_unique(); + JXL_RETURN_IF_ERROR(PrepareCodecMetadataFromIO(cparams, io, metadata.get())); + JXL_RETURN_IF_ERROR(WriteHeaders(metadata.get(), &writer, aux_out)); + + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata->m.color_encoding.WantICC()) { + JXL_RETURN_IF_ERROR(WriteICC(metadata->m.color_encoding.ICC(), &writer, + kLayerHeader, aux_out)); + } + + if (metadata->m.have_preview) { + JXL_RETURN_IF_ERROR(EncodePreview(cparams, io->preview_frame, + metadata.get(), pool, &writer)); + } + + // Each frame should start on byte boundaries. + writer.ZeroPadToByte(); + + if (cparams.progressive_mode || cparams.qprogressive_mode) { + if (cparams.saliency_map != nullptr) { + passes_enc_state->progressive_splitter.SetSaliencyMap( + cparams.saliency_map); + } + passes_enc_state->progressive_splitter.SetSaliencyThreshold( + cparams.saliency_threshold); + if (cparams.qprogressive_mode) { + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_quant_ac_full_ac}); + } else { + switch (cparams.saliency_num_progressive_steps) { + case 1: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_vlf}); + break; + case 2: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf}); + break; + case 3: + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac}); + break; + case 4: + if (cparams.saliency_threshold == 0.0f) { + // No need for a 4th pass if saliency-threshold regards everything + // as salient. + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac}); + } else { + passes_enc_state->progressive_splitter.SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_lf_salient_ac_other_ac}); + } + break; + default: + return JXL_FAILURE("Invalid saliency_num_progressive_steps."); + } + } + } + + for (size_t i = 0; i < io->frames.size(); i++) { + FrameInfo info; + info.is_last = i == io->frames.size() - 1; + if (io->frames[i].use_for_next_frame) { + info.save_as_reference = 1; + } + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, info, metadata.get(), + io->frames[i], passes_enc_state, pool, + &writer, aux_out)); + } + + // Clean up passes_enc_state in case it gets reused. + for (size_t i = 0; i < 4; i++) { + passes_enc_state->shared.dc_frames[i] = Image3F(); + passes_enc_state->shared.reference_frames[i].storage = ImageBundle(); + } + + *compressed = std::move(writer).TakeBytes(); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_file.h b/third_party/jpeg-xl/lib/jxl/enc_file.h new file mode 100644 index 000000000000..90f9007209b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_file.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_FILE_H_ +#define LIB_JXL_ENC_FILE_H_ + +// Facade for JXL encoding. + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" + +namespace jxl { + +// Write preview from `io`. +Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, + const CodecMetadata* metadata, ThreadPool* pool, + BitWriter* JXL_RESTRICT writer); + +// Write headers from the CodecMetadata. Also may modify nonserialized_... +// fields of the metadata. +Status WriteHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out); + +// Compresses pixels from `io` (given in any ColorEncoding). +// `io->metadata.m.original` must be set. +Status EncodeFile(const CompressParams& params, const CodecInOut* io, + PassesEncoderState* passes_enc_state, PaddedBytes* compressed, + AuxOut* aux_out = nullptr, ThreadPool* pool = nullptr); + +// Backwards-compatible interface. Don't use in new code. +// TODO(deymo): Remove this function once we migrate users to C encoder API. +struct FrameEncCache {}; +JXL_INLINE Status EncodeFile(const CompressParams& params, const CodecInOut* io, + FrameEncCache* /* unused */, + PaddedBytes* compressed, AuxOut* aux_out = nullptr, + ThreadPool* pool = nullptr) { + PassesEncoderState passes_enc_state; + return EncodeFile(params, io, &passes_enc_state, compressed, aux_out, pool); +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_FILE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_frame.cc b/third_party/jpeg-xl/lib/jxl/enc_frame.cc new file mode 100644 index 000000000000..08b063905868 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_frame.cc @@ -0,0 +1,1366 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_frame.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_entropy_coder.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { +namespace { + +void ClusterGroups(PassesEncoderState* enc_state) { + if (enc_state->shared.frame_header.passes.num_passes > 1) { + // TODO(veluca): implement this for progressive modes. + return; + } + // This only considers pass 0 for now. + std::vector context_map; + EntropyEncodingData codes; + auto& ac = enc_state->passes[0].ac_tokens; + size_t limit = std::ceil(std::sqrt(ac.size())); + if (limit == 1) return; + size_t num_contexts = enc_state->shared.block_ctx_map.NumACContexts(); + std::vector costs(ac.size()); + HistogramParams params; + params.uint_method = HistogramParams::HybridUintMethod::kNone; + params.lz77_method = HistogramParams::LZ77Method::kNone; + params.ans_histogram_strategy = + HistogramParams::ANSHistogramStrategy::kApproximate; + size_t max = 0; + float total_cost = 0; + auto token_cost = [&](std::vector>& tokens, size_t num_ctx, + bool estimate = true) { + // TODO(veluca): not estimating is very expensive. + BitWriter writer; + size_t c = BuildAndEncodeHistograms( + params, num_ctx, tokens, &codes, &context_map, + estimate ? nullptr : &writer, 0, /*aux_out=*/0); + if (estimate) return c; + for (size_t i = 0; i < tokens.size(); i++) { + WriteTokens(tokens[i], codes, context_map, &writer, 0, nullptr); + } + return writer.BitsWritten(); + }; + for (size_t i = 0; i < ac.size(); i++) { + std::vector> tokens{ac[i]}; + costs[i] = + token_cost(tokens, enc_state->shared.block_ctx_map.NumACContexts()); + if (costs[i] > costs[max]) { + max = i; + } + total_cost += costs[i]; + } + auto dist = [&](int i, int j) { + std::vector> tokens{ac[i], ac[j]}; + return token_cost(tokens, num_contexts) - costs[i] - costs[j]; + }; + std::vector out{max}; + std::vector old_map(ac.size()); + std::vector dists(ac.size()); + size_t farthest = 0; + for (size_t i = 0; i < ac.size(); i++) { + if (i == max) continue; + dists[i] = dist(max, i); + if (dists[i] > dists[farthest]) { + farthest = i; + } + } + + while (dists[farthest] > 0 && out.size() < limit) { + out.push_back(farthest); + dists[farthest] = 0; + enc_state->histogram_idx[farthest] = out.size() - 1; + for (size_t i = 0; i < ac.size(); i++) { + float d = dist(out.back(), i); + if (d < dists[i]) { + dists[i] = d; + old_map[i] = enc_state->histogram_idx[i]; + enc_state->histogram_idx[i] = out.size() - 1; + } + if (dists[i] > dists[farthest]) { + farthest = i; + } + } + } + + std::vector remap(out.size()); + std::iota(remap.begin(), remap.end(), 0); + for (size_t i = 0; i < enc_state->histogram_idx.size(); i++) { + enc_state->histogram_idx[i] = remap[enc_state->histogram_idx[i]]; + } + auto remap_cost = [&](std::vector remap) { + std::vector re_remap(remap.size(), remap.size()); + size_t r = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (re_remap[remap[i]] == remap.size()) { + re_remap[remap[i]] = r++; + } + remap[i] = re_remap[remap[i]]; + } + auto tokens = ac; + size_t max_hist = 0; + for (size_t i = 0; i < tokens.size(); i++) { + for (size_t j = 0; j < tokens[i].size(); j++) { + size_t hist = remap[enc_state->histogram_idx[i]]; + tokens[i][j].context += hist * num_contexts; + max_hist = std::max(hist + 1, max_hist); + } + } + return token_cost(tokens, max_hist * num_contexts, /*estimate=*/false); + }; + + for (size_t src = 0; src < out.size(); src++) { + float cost = remap_cost(remap); + size_t best = src; + for (size_t j = src + 1; j < out.size(); j++) { + if (remap[src] == remap[j]) continue; + auto remap_c = remap; + std::replace(remap_c.begin(), remap_c.end(), remap[src], remap[j]); + float c = remap_cost(remap_c); + if (c < cost) { + best = j; + cost = c; + } + } + if (src != best) { + std::replace(remap.begin(), remap.end(), remap[src], remap[best]); + } + } + std::vector re_remap(remap.size(), remap.size()); + size_t r = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (re_remap[remap[i]] == remap.size()) { + re_remap[remap[i]] = r++; + } + remap[i] = re_remap[remap[i]]; + } + + enc_state->shared.num_histograms = + *std::max_element(remap.begin(), remap.end()) + 1; + for (size_t i = 0; i < enc_state->histogram_idx.size(); i++) { + enc_state->histogram_idx[i] = remap[enc_state->histogram_idx[i]]; + } + for (size_t i = 0; i < ac.size(); i++) { + for (size_t j = 0; j < ac[i].size(); j++) { + ac[i][j].context += enc_state->histogram_idx[i] * num_contexts; + } + } +} + +uint64_t FrameFlagsFromParams(const CompressParams& cparams) { + uint64_t flags = 0; + + const float dist = cparams.butteraugli_distance; + + // We don't add noise at low butteraugli distances because the original + // noise is stored within the compressed image and adding noise makes things + // worse. + if (ApplyOverride(cparams.noise, dist >= kMinButteraugliForNoise)) { + flags |= FrameHeader::kNoise; + } + + if (cparams.progressive_dc > 0 && cparams.modular_mode == false) { + flags |= FrameHeader::kUseDcFrame; + } + + return flags; +} + +Status LoopFilterFromParams(const CompressParams& cparams, + FrameHeader* JXL_RESTRICT frame_header) { + LoopFilter* loop_filter = &frame_header->loop_filter; + + // Gaborish defaults to enabled in Hare or slower. + loop_filter->gab = ApplyOverride( + cparams.gaborish, cparams.speed_tier <= SpeedTier::kHare && + frame_header->encoding == FrameEncoding::kVarDCT && + cparams.decoding_speed_tier < 4); + + if (cparams.epf != -1) { + loop_filter->epf_iters = cparams.epf; + } else { + if (frame_header->encoding == FrameEncoding::kModular) { + loop_filter->epf_iters = 0; + } else { + constexpr float kThresholds[3] = {0.7, 1.5, 4.0}; + loop_filter->epf_iters = 0; + if (cparams.decoding_speed_tier < 3) { + for (size_t i = cparams.decoding_speed_tier == 2 ? 1 : 0; i < 3; i++) { + if (cparams.butteraugli_distance >= kThresholds[i]) { + loop_filter->epf_iters++; + } + } + } + } + } + // Strength of EPF in modular mode. + if (frame_header->encoding == FrameEncoding::kModular && + cparams.quality_pair.first < 100) { + // TODO(veluca): this formula is nonsense. + loop_filter->epf_sigma_for_modular = + 20.0f * (1.0f - cparams.quality_pair.first / 100); + } + if (frame_header->encoding == FrameEncoding::kModular && + cparams.lossy_palette) { + loop_filter->epf_sigma_for_modular = 1.0f; + } + + return true; +} + +Status MakeFrameHeader(const CompressParams& cparams, + const ProgressiveSplitter& progressive_splitter, + const FrameInfo& frame_info, const ImageBundle& ib, + FrameHeader* JXL_RESTRICT frame_header) { + frame_header->nonserialized_is_preview = frame_info.is_preview; + frame_header->is_last = frame_info.is_last; + frame_header->save_before_color_transform = + frame_info.save_before_color_transform; + frame_header->frame_type = frame_info.frame_type; + frame_header->name = ib.name; + + progressive_splitter.InitPasses(&frame_header->passes); + + if (cparams.modular_mode) { + frame_header->encoding = FrameEncoding::kModular; + frame_header->group_size_shift = cparams.modular_group_size_shift; + } + + if (ib.IsJPEG()) { + // we are transcoding a JPEG, so we don't get to choose + frame_header->encoding = FrameEncoding::kVarDCT; + frame_header->color_transform = ib.color_transform; + frame_header->chroma_subsampling = ib.chroma_subsampling; + } else { + frame_header->color_transform = cparams.color_transform; + if (cparams.chroma_subsampling.MaxHShift() != 0 || + cparams.chroma_subsampling.MaxVShift() != 0) { + // TODO(veluca): properly pad the input image to support this. + return JXL_FAILURE( + "Chroma subsampling is not supported when not recompressing JPEGs"); + } + frame_header->chroma_subsampling = cparams.chroma_subsampling; + } + + frame_header->flags = FrameFlagsFromParams(cparams); + // Noise is not supported in the Modular encoder for now. + if (frame_header->encoding != FrameEncoding::kVarDCT) { + frame_header->UpdateFlag(false, FrameHeader::Flags::kNoise); + } + + JXL_RETURN_IF_ERROR(LoopFilterFromParams(cparams, frame_header)); + + frame_header->dc_level = frame_info.dc_level; + if (frame_header->dc_level > 2) { + // With 3 or more progressive_dc frames, the implementation does not yet + // work, see enc_cache.cc. + return JXL_FAILURE("progressive_dc > 2 is not yet supported"); + } + if (cparams.progressive_dc > 0 && cparams.resampling != 1) { + return JXL_FAILURE("Resampling not supported with DC frames"); + } + if (cparams.resampling != 1 && cparams.resampling != 2 && + cparams.resampling != 4 && cparams.resampling != 8) { + return JXL_FAILURE("Invalid resampling factor"); + } + frame_header->upsampling = cparams.resampling; + frame_header->save_as_reference = frame_info.save_as_reference; + + // Set blend_channel to the first alpha channel, the only implemented mode for + // now and what the encoder currently uses. These values are only encoded in + // case a blend mode involving alpha is used and there are more than one extra + // channels. + const std::vector& extra_channels = + frame_header->nonserialized_metadata->m.extra_channel_info; + // Resized frames. + if (frame_info.frame_type != FrameType::kDCFrame) { + frame_header->frame_origin = ib.origin; + frame_header->frame_size.xsize = ib.xsize(); + frame_header->frame_size.ysize = ib.ysize(); + if (ib.origin.x0 != 0 || ib.origin.y0 != 0 || + ib.xsize() != frame_header->default_xsize() || + ib.ysize() != frame_header->default_ysize()) { + frame_header->custom_size_or_origin = true; + } + } + + // Set blending-related information. TODO(veluca): only supports kReplace or + // kBlend for now. + if (ib.blend || frame_header->custom_size_or_origin) { + size_t index = 0; + if (extra_channels.size() > 1) { + for (size_t i = 0; i < extra_channels.size(); i++) { + if (extra_channels[i].type == ExtraChannel::kAlpha) { + index = i; + break; + } + } + } + frame_header->blending_info.alpha_channel = index; + frame_header->blending_info.mode = + ib.blend ? BlendMode::kBlend : BlendMode::kReplace; + // previous frames are saved with ID 1. + frame_header->blending_info.source = 1; + for (size_t i = 0; i < extra_channels.size(); i++) { + frame_header->extra_channel_blending_info[i].alpha_channel = index; + BlendMode default_blend = BlendMode::kBlend; + if (extra_channels[i].type != ExtraChannel::kBlack && i != index) { + // K needs to be blended, spot colors and other stuff gets added + default_blend = BlendMode::kAdd; + } + frame_header->extra_channel_blending_info[i].mode = + ib.blend ? default_blend : BlendMode::kReplace; + frame_header->extra_channel_blending_info[i].source = 1; + } + } + + frame_header->animation_frame.duration = ib.duration; + + // TODO(veluca): timecode. + + return true; +} + +// Invisible (alpha = 0) pixels tend to be a mess in optimized PNGs. +// Since they have no visual impact whatsoever, we can replace them with +// something that compresses better and reduces artifacts near the edges. This +// does some kind of smooth stuff that seems to work. +// Replace invisible pixels with a weighted average of the pixel to the left, +// the pixel to the topright, and non-invisible neighbours. +// Produces downward-blurry smears, with in the upwards direction only a 1px +// edge duplication but not more. It would probably be better to smear in all +// directions. That requires an alpha-weighed convolution with a large enough +// kernel though, which might be overkill... +void SimplifyInvisible(Image3F* image, const ImageF& alpha) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + float* JXL_RESTRICT row = image->PlaneRow(c, y); + const float* JXL_RESTRICT prow = + (y > 0 ? image->PlaneRow(c, y - 1) : nullptr); + const float* JXL_RESTRICT nrow = + (y + 1 < image->ysize() ? image->PlaneRow(c, y + 1) : nullptr); + const float* JXL_RESTRICT a = alpha.Row(y); + const float* JXL_RESTRICT pa = (y > 0 ? alpha.Row(y - 1) : nullptr); + const float* JXL_RESTRICT na = + (y + 1 < image->ysize() ? alpha.Row(y + 1) : nullptr); + for (size_t x = 0; x < image->xsize(); ++x) { + if (a[x] == 0) { + float d = 0.f; + row[x] = 0; + if (x > 0) { + row[x] += row[x - 1]; + d++; + if (a[x - 1] > 0.f) { + row[x] += row[x - 1]; + d++; + } + } + if (x + 1 < image->xsize()) { + if (y > 0) { + row[x] += prow[x + 1]; + d++; + } + if (a[x + 1] > 0.f) { + row[x] += 2.f * row[x + 1]; + d += 2.f; + } + if (y > 0 && pa[x + 1] > 0.f) { + row[x] += 2.f * prow[x + 1]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x + 1] > 0.f) { + row[x] += 2.f * nrow[x + 1]; + d += 2.f; + } + } + if (y > 0 && pa[x] > 0.f) { + row[x] += 2.f * prow[x]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x] > 0.f) { + row[x] += 2.f * nrow[x]; + d += 2.f; + } + if (d > 1.f) row[x] /= d; + } + } + } + } +} + +} // namespace + +class LossyFrameEncoder { + public: + LossyFrameEncoder(const CompressParams& cparams, + const FrameHeader& frame_header, + PassesEncoderState* JXL_RESTRICT enc_state, + ThreadPool* pool, AuxOut* aux_out) + : enc_state_(enc_state), pool_(pool), aux_out_(aux_out) { + JXL_CHECK(InitializePassesSharedState(frame_header, &enc_state_->shared, + /*encoder=*/true)); + enc_state_->cparams = cparams; + enc_state_->passes.clear(); + } + + Status ComputeEncodingData(const ImageBundle* linear, + Image3F* JXL_RESTRICT opsin, ThreadPool* pool, + ModularFrameEncoder* modular_frame_encoder, + BitWriter* JXL_RESTRICT writer, + FrameHeader* frame_header) { + PROFILER_ZONE("ComputeEncodingData uninstrumented"); + JXL_ASSERT((opsin->xsize() % kBlockDim) == 0 && + (opsin->ysize() % kBlockDim) == 0); + PassesSharedState& shared = enc_state_->shared; + + if (!enc_state_->cparams.max_error_mode) { + float x_qm_scale_steps[3] = {0.65f, 1.25f, 9.0f}; + shared.frame_header.x_qm_scale = 1; + for (float x_qm_scale_step : x_qm_scale_steps) { + if (enc_state_->cparams.butteraugli_distance > x_qm_scale_step) { + shared.frame_header.x_qm_scale++; + } + } + } + + JXL_RETURN_IF_ERROR(enc_state_->heuristics->LossyFrameHeuristics( + enc_state_, modular_frame_encoder, linear, opsin, pool_, aux_out_)); + + InitializePassesEncoder(*opsin, pool_, enc_state_, modular_frame_encoder, + aux_out_); + + enc_state_->passes.resize(enc_state_->progressive_splitter.GetNumPasses()); + for (PassesEncoderState::PassData& pass : enc_state_->passes) { + pass.ac_tokens.resize(shared.frame_dim.num_groups); + } + + ComputeAllCoeffOrders(shared.frame_dim); + shared.num_histograms = 1; + + const auto tokenize_group_init = [&](const size_t num_threads) { + group_caches_.resize(num_threads); + return true; + }; + const auto tokenize_group = [&](const int group_index, const int thread) { + // Tokenize coefficients. + const Rect rect = shared.BlockGroupRect(group_index); + for (size_t idx_pass = 0; idx_pass < enc_state_->passes.size(); + idx_pass++) { + JXL_ASSERT(enc_state_->coeffs[idx_pass]->Type() == ACType::k32); + const int32_t* JXL_RESTRICT ac_rows[3] = { + enc_state_->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32, + }; + // Ensure group cache is initialized. + group_caches_[thread].InitOnce(); + TokenizeCoefficients( + &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect, + ac_rows, shared.ac_strategy, frame_header->chroma_subsampling, + &group_caches_[thread].num_nzeroes, + &enc_state_->passes[idx_pass].ac_tokens[group_index], + enc_state_->shared.quant_dc, enc_state_->shared.raw_quant_field, + enc_state_->shared.block_ctx_map); + } + }; + RunOnPool(pool_, 0, shared.frame_dim.num_groups, tokenize_group_init, + tokenize_group, "TokenizeGroup"); + + *frame_header = shared.frame_header; + return true; + } + + Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data, + ModularFrameEncoder* modular_frame_encoder, + FrameHeader* frame_header) { + PROFILER_ZONE("ComputeJPEGTranscodingData uninstrumented"); + PassesSharedState& shared = enc_state_->shared; + + frame_header->x_qm_scale = 2; + frame_header->b_qm_scale = 2; + + FrameDimensions frame_dim = frame_header->ToFrameDimensions(); + + const size_t xsize = frame_dim.xsize_padded; + const size_t ysize = frame_dim.ysize_padded; + const size_t xsize_blocks = frame_dim.xsize_blocks; + const size_t ysize_blocks = frame_dim.ysize_blocks; + + // no-op chroma from luma + shared.cmap = ColorCorrelationMap(xsize, ysize, false); + shared.ac_strategy.FillDCT8(); + FillImage(uint8_t(0), &shared.epf_sharpness); + + enc_state_->coeffs.clear(); + enc_state_->coeffs.emplace_back(make_unique>( + kGroupDim * kGroupDim, frame_dim.num_groups)); + + // convert JPEG quantization table to a Quantizer object + float dcquantization[3]; + std::vector qe(DequantMatrices::kNum, + QuantEncoding::Library(0)); + + auto jpeg_c_map = JpegOrder(frame_header->color_transform, + jpeg_data.components.size() == 1); + + std::vector qt(192); + for (size_t c = 0; c < 3; c++) { + size_t jpeg_c = jpeg_c_map[c]; + const int* quant = + jpeg_data.quant[jpeg_data.components[jpeg_c].quant_idx].values.data(); + + dcquantization[c] = 255 * 8.0f / quant[0]; + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + // JPEG XL transposes the DCT, JPEG doesn't. + qt[c * 64 + 8 * x + y] = quant[8 * y + x]; + } + } + } + DequantMatricesSetCustomDC(&shared.matrices, dcquantization); + float dcquantization_r[3] = {1.0f / dcquantization[0], + 1.0f / dcquantization[1], + 1.0f / dcquantization[2]}; + + qe[AcStrategy::Type::DCT] = QuantEncoding::RAW(qt); + DequantMatricesSetCustom(&shared.matrices, qe, modular_frame_encoder); + + // Ensure that InvGlobalScale() is 1. + shared.quantizer = Quantizer(&shared.matrices, 1, kGlobalScaleDenom); + // Recompute MulDC() and InvMulDC(). + shared.quantizer.RecomputeFromGlobalScale(); + + // Per-block dequant scaling should be 1. + FillImage(static_cast(shared.quantizer.InvGlobalScale()), + &shared.raw_quant_field); + + std::vector scaled_qtable(192); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 64; i++) { + scaled_qtable[64 * c + i] = + (1 << kCFLFixedPointPrecision) * qt[64 + i] / qt[64 * c + i]; + } + } + + auto jpeg_row = [&](size_t c, size_t y) { + return jpeg_data.components[jpeg_c_map[c]].coeffs.data() + + jpeg_data.components[jpeg_c_map[c]].width_in_blocks * + kDCTBlockSize * y; + }; + + Image3F dc = Image3F(xsize_blocks, ysize_blocks); + bool DCzero = + (shared.frame_header.color_transform == ColorTransform::kYCbCr); + // Compute chroma-from-luma for AC (doesn't seem to be useful for DC) + if (frame_header->chroma_subsampling.Is444() && + enc_state_->cparams.force_cfl_jpeg_recompression && + jpeg_data.components.size() == 3) { + for (size_t c : {0, 2}) { + ImageSB* map = (c == 0 ? &shared.cmap.ytox_map : &shared.cmap.ytob_map); + const float kScale = kDefaultColorFactor; + const int kOffset = 127; + const float kBase = + c == 0 ? shared.cmap.YtoXRatio(0) : shared.cmap.YtoBRatio(0); + const float kZeroThresh = + kScale * kZeroBiasDefault[c] * + 0.9999f; // just epsilon less for better rounding + + auto process_row = [&](int task, int thread) { + size_t ty = task; + int8_t* JXL_RESTRICT row_out = map->Row(ty); + for (size_t tx = 0; tx < map->xsize(); ++tx) { + const size_t y0 = ty * kColorTileDimInBlocks; + const size_t x0 = tx * kColorTileDimInBlocks; + const size_t y1 = std::min(frame_dim.ysize_blocks, + (ty + 1) * kColorTileDimInBlocks); + const size_t x1 = std::min(frame_dim.xsize_blocks, + (tx + 1) * kColorTileDimInBlocks); + int32_t d_num_zeros[257] = {0}; + // TODO(veluca): this needs SIMD + fixed point adaptation, and/or + // conversion to the new CfL algorithm. + for (size_t y = y0; y < y1; ++y) { + const int16_t* JXL_RESTRICT row_m = jpeg_row(1, y); + const int16_t* JXL_RESTRICT row_s = jpeg_row(c, y); + for (size_t x = x0; x < x1; ++x) { + for (size_t coeffpos = 1; coeffpos < kDCTBlockSize; + coeffpos++) { + const float scaled_m = + row_m[x * kDCTBlockSize + coeffpos] * + scaled_qtable[64 * c + coeffpos] * + (1.0f / (1 << kCFLFixedPointPrecision)); + const float scaled_s = + kScale * row_s[x * kDCTBlockSize + coeffpos] + + (kOffset - kBase * kScale) * scaled_m; + if (std::abs(scaled_m) > 1e-8f) { + float from, to; + if (scaled_m > 0) { + from = (scaled_s - kZeroThresh) / scaled_m; + to = (scaled_s + kZeroThresh) / scaled_m; + } else { + from = (scaled_s + kZeroThresh) / scaled_m; + to = (scaled_s - kZeroThresh) / scaled_m; + } + if (from < 0.0f) { + from = 0.0f; + } + if (to > 255.0f) { + to = 255.0f; + } + // Instead of clamping the both values + // we just check that range is sane. + if (from <= to) { + d_num_zeros[static_cast(std::ceil(from))]++; + d_num_zeros[static_cast(std::floor(to + 1))]--; + } + } + } + } + } + int best = 0; + int32_t best_sum = 0; + FindIndexOfSumMaximum(d_num_zeros, 256, &best, &best_sum); + int32_t offset_sum = 0; + for (int i = 0; i < 256; ++i) { + if (i <= kOffset) { + offset_sum += d_num_zeros[i]; + } + } + row_out[tx] = 0; + if (best_sum > offset_sum + 1) { + row_out[tx] = best - kOffset; + } + } + }; + + RunOnPool(pool_, 0, map->ysize(), ThreadPool::SkipInit(), process_row, + "FindCorrelation"); + } + } + if (!frame_header->chroma_subsampling.Is444()) { + ZeroFillImage(&dc); + enc_state_->coeffs[0]->ZeroFill(); + } + // JPEG DC is from -1024 to 1023. + std::vector dc_counts[3] = {}; + dc_counts[0].resize(2048); + dc_counts[1].resize(2048); + dc_counts[2].resize(2048); + size_t total_dc[3] = {}; + for (size_t c : {1, 0, 2}) { + if (jpeg_data.components.size() == 1 && c != 1) { + enc_state_->coeffs[0]->ZeroFillPlane(c); + ZeroFillImage(&dc.Plane(c)); + // Ensure no division by 0. + dc_counts[c][1024] = 1; + total_dc[c] = 1; + continue; + } + size_t hshift = frame_header->chroma_subsampling.HShift(c); + size_t vshift = frame_header->chroma_subsampling.VShift(c); + ImageSB& map = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map); + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + size_t offset = 0; + int32_t* JXL_RESTRICT ac = + enc_state_->coeffs[0]->PlaneRow(c, group_index, 0).ptr32; + for (size_t by = gy * kGroupDimInBlocks; + by < ysize_blocks && by < (gy + 1) * kGroupDimInBlocks; ++by) { + if ((by >> vshift) << vshift != by) continue; + const int16_t* JXL_RESTRICT inputjpeg = jpeg_row(c, by >> vshift); + const int16_t* JXL_RESTRICT inputjpegY = jpeg_row(1, by); + float* JXL_RESTRICT fdc = dc.PlaneRow(c, by >> vshift); + const int8_t* JXL_RESTRICT cm = + map.ConstRow(by / kColorTileDimInBlocks); + for (size_t bx = gx * kGroupDimInBlocks; + bx < xsize_blocks && bx < (gx + 1) * kGroupDimInBlocks; ++bx) { + if ((bx >> hshift) << hshift != bx) continue; + size_t base = (bx >> hshift) * kDCTBlockSize; + int idc; + if (DCzero) { + idc = inputjpeg[base]; + } else { + idc = inputjpeg[base] + 1024 / qt[c * 64]; + } + dc_counts[c][std::min(static_cast(idc + 1024), + uint32_t(2047))]++; + total_dc[c]++; + fdc[bx >> hshift] = idc * dcquantization_r[c]; + if (c == 1 || !enc_state_->cparams.force_cfl_jpeg_recompression || + !frame_header->chroma_subsampling.Is444()) { + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + ac[offset + y * 8 + x] = inputjpeg[base + x * 8 + y]; + } + } + } else { + const int32_t scale = + shared.cmap.RatioJPEG(cm[bx / kColorTileDimInBlocks]); + + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + int Y = inputjpegY[kDCTBlockSize * bx + x * 8 + y]; + int QChroma = inputjpeg[kDCTBlockSize * bx + x * 8 + y]; + // Fixed-point multiply of CfL scale with quant table ratio + // first, and Y value second. + int coeff_scale = (scale * scaled_qtable[64 * c + y * 8 + x] + + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int cfl_factor = (Y * coeff_scale + + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int QCR = QChroma - cfl_factor; + ac[offset + y * 8 + x] = QCR; + } + } + } + offset += 64; + } + } + } + } + + auto& dct = enc_state_->shared.block_ctx_map.dc_thresholds; + auto& num_dc_ctxs = enc_state_->shared.block_ctx_map.num_dc_ctxs; + enc_state_->shared.block_ctx_map.num_dc_ctxs = 1; + for (size_t i = 0; i < 3; i++) { + dct[i].clear(); + int num_thresholds = (CeilLog2Nonzero(total_dc[i]) - 10) / 2; + // up to 3 buckets per channel: + // dark/medium/bright, yellow/unsat/blue, green/unsat/red + num_thresholds = std::min(std::max(num_thresholds, 0), 2); + size_t cumsum = 0; + size_t cut = total_dc[i] / (num_thresholds + 1); + for (int j = 0; j < 2048; j++) { + cumsum += dc_counts[i][j]; + if (cumsum > cut) { + dct[i].push_back(j - 1025); + cut = total_dc[i] * (dct[i].size() + 1) / (num_thresholds + 1); + } + } + num_dc_ctxs *= dct[i].size() + 1; + } + + auto& ctx_map = enc_state_->shared.block_ctx_map.ctx_map; + ctx_map.clear(); + ctx_map.resize(3 * kNumOrders * num_dc_ctxs, 0); + + int lbuckets = (dct[1].size() + 1); + for (size_t i = 0; i < num_dc_ctxs; i++) { + // up to 9 contexts for luma + ctx_map[i] = i / lbuckets; + // up to 3 contexts for chroma + ctx_map[kNumOrders * num_dc_ctxs + i] = + num_dc_ctxs / lbuckets + (i % lbuckets); + ctx_map[2 * kNumOrders * num_dc_ctxs + i] = + num_dc_ctxs / lbuckets + (i % lbuckets); + } + enc_state_->shared.block_ctx_map.num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + + enc_state_->histogram_idx.resize(shared.frame_dim.num_groups); + + // disable DC frame for now + shared.frame_header.UpdateFlag(false, FrameHeader::kUseDcFrame); + auto compute_dc_coeffs = [&](int group_index, int /* thread */) { + modular_frame_encoder->AddVarDCTDC(dc, group_index, /*nl_dc=*/false, + enc_state_); + modular_frame_encoder->AddACMetadata(group_index, /*jpeg_transcode=*/true, + enc_state_); + }; + RunOnPool(pool_, 0, shared.frame_dim.num_dc_groups, ThreadPool::SkipInit(), + compute_dc_coeffs, "Compute DC coeffs"); + + // Must happen before WriteFrameHeader! + shared.frame_header.UpdateFlag(true, FrameHeader::kSkipAdaptiveDCSmoothing); + + enc_state_->passes.resize(enc_state_->progressive_splitter.GetNumPasses()); + for (PassesEncoderState::PassData& pass : enc_state_->passes) { + pass.ac_tokens.resize(shared.frame_dim.num_groups); + } + + JXL_CHECK(enc_state_->passes.size() == + 1); // skipping coeff splitting so need to have only one pass + + ComputeAllCoeffOrders(frame_dim); + shared.num_histograms = 1; + + const auto tokenize_group_init = [&](const size_t num_threads) { + group_caches_.resize(num_threads); + return true; + }; + const auto tokenize_group = [&](const int group_index, const int thread) { + // Tokenize coefficients. + const Rect rect = shared.BlockGroupRect(group_index); + for (size_t idx_pass = 0; idx_pass < enc_state_->passes.size(); + idx_pass++) { + JXL_ASSERT(enc_state_->coeffs[idx_pass]->Type() == ACType::k32); + const int32_t* JXL_RESTRICT ac_rows[3] = { + enc_state_->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32, + enc_state_->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32, + }; + // Ensure group cache is initialized. + group_caches_[thread].InitOnce(); + TokenizeCoefficients( + &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect, + ac_rows, shared.ac_strategy, frame_header->chroma_subsampling, + &group_caches_[thread].num_nzeroes, + &enc_state_->passes[idx_pass].ac_tokens[group_index], + enc_state_->shared.quant_dc, enc_state_->shared.raw_quant_field, + enc_state_->shared.block_ctx_map); + } + }; + RunOnPool(pool_, 0, shared.frame_dim.num_groups, tokenize_group_init, + tokenize_group, "TokenizeGroup"); + *frame_header = shared.frame_header; + return true; + } + + Status EncodeGlobalDCInfo(const FrameHeader& frame_header, + BitWriter* writer) const { + // Encode quantizer DC and global scale. + JXL_RETURN_IF_ERROR( + enc_state_->shared.quantizer.Encode(writer, kLayerQuant, aux_out_)); + EncodeBlockCtxMap(enc_state_->shared.block_ctx_map, writer, aux_out_); + ColorCorrelationMapEncodeDC(&enc_state_->shared.cmap, writer, kLayerDC, + aux_out_); + return true; + } + + Status EncodeGlobalACInfo(BitWriter* writer, + ModularFrameEncoder* modular_frame_encoder) { + JXL_RETURN_IF_ERROR(DequantMatricesEncode(&enc_state_->shared.matrices, + writer, kLayerDequantTables, + aux_out_, modular_frame_encoder)); + if (enc_state_->cparams.speed_tier <= SpeedTier::kTortoise) { + ClusterGroups(enc_state_); + } + size_t num_histo_bits = + CeilLog2Nonzero(enc_state_->shared.frame_dim.num_groups); + if (num_histo_bits != 0) { + BitWriter::Allotment allotment(writer, num_histo_bits); + writer->Write(num_histo_bits, enc_state_->shared.num_histograms - 1); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out_); + } + + for (size_t i = 0; i < enc_state_->progressive_splitter.GetNumPasses(); + i++) { + // Encode coefficient orders. + size_t order_bits = 0; + JXL_RETURN_IF_ERROR(U32Coder::CanEncode( + kOrderEnc, enc_state_->used_orders[i], &order_bits)); + BitWriter::Allotment allotment(writer, order_bits); + JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state_->used_orders[i], writer)); + ReclaimAndCharge(writer, &allotment, kLayerOrder, aux_out_); + EncodeCoeffOrders( + enc_state_->used_orders[i], + &enc_state_->shared + .coeff_orders[i * enc_state_->shared.coeff_order_size], + writer, kLayerOrder, aux_out_); + + // Encode histograms. + HistogramParams hist_params( + enc_state_->cparams.speed_tier, + enc_state_->shared.block_ctx_map.NumACContexts()); + if (enc_state_->cparams.speed_tier > SpeedTier::kTortoise) { + hist_params.lz77_method = HistogramParams::LZ77Method::kNone; + } + if (enc_state_->cparams.decoding_speed_tier >= 1) { + hist_params.max_histograms = 6; + } + BuildAndEncodeHistograms( + hist_params, + enc_state_->shared.num_histograms * + enc_state_->shared.block_ctx_map.NumACContexts(), + enc_state_->passes[i].ac_tokens, &enc_state_->passes[i].codes, + &enc_state_->passes[i].context_map, writer, kLayerAC, aux_out_); + } + + return true; + } + + Status EncodeACGroup(size_t pass, size_t group_index, BitWriter* group_code, + AuxOut* local_aux_out) { + return EncodeGroupTokenizedCoefficients( + group_index, pass, enc_state_->histogram_idx[group_index], *enc_state_, + group_code, local_aux_out); + } + + PassesEncoderState* State() { return enc_state_; } + + private: + void ComputeAllCoeffOrders(const FrameDimensions& frame_dim) { + PROFILER_FUNC; + enc_state_->used_orders.resize( + enc_state_->progressive_splitter.GetNumPasses()); + for (size_t i = 0; i < enc_state_->progressive_splitter.GetNumPasses(); + i++) { + // No coefficient reordering in Falcon mode. + if (enc_state_->cparams.speed_tier != SpeedTier::kFalcon) { + enc_state_->used_orders[i] = ComputeUsedOrders( + enc_state_->cparams.speed_tier, enc_state_->shared.ac_strategy, + Rect(enc_state_->shared.raw_quant_field)); + } + ComputeCoeffOrder( + enc_state_->cparams.speed_tier, *enc_state_->coeffs[i], + enc_state_->shared.ac_strategy, frame_dim, enc_state_->used_orders[i], + &enc_state_->shared + .coeff_orders[i * enc_state_->shared.coeff_order_size]); + } + } + + template + static inline void FindIndexOfSumMaximum(const V* array, const size_t len, + R* idx, V* sum) { + JXL_ASSERT(len > 0); + V maxval = 0; + V val = 0; + R maxidx = 0; + for (size_t i = 0; i < len; ++i) { + val += array[i]; + if (val > maxval) { + maxval = val; + maxidx = i; + } + } + *idx = maxidx; + *sum = maxval; + } + + PassesEncoderState* JXL_RESTRICT enc_state_; + ThreadPool* pool_; + AuxOut* aux_out_; + std::vector group_caches_; +}; + +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, PassesEncoderState* passes_enc_state, + ThreadPool* pool, BitWriter* writer, AuxOut* aux_out) { + ib.VerifyMetadata(); + + CompressParams cparams = cparams_orig; + if (cparams.progressive_dc < 0) { + if (cparams.progressive_dc != -1) { + return JXL_FAILURE("Invalid progressive DC setting value (%d)", + cparams.progressive_dc); + } + cparams.progressive_dc = 0; + // Enable progressive_dc for lower qualities. + if (cparams.butteraugli_distance >= + kMinButteraugliDistanceForProgressiveDc) { + cparams.progressive_dc = 1; + } + } + + if (frame_info.dc_level + cparams.progressive_dc > 4) { + return JXL_FAILURE("Too many levels of progressive DC"); + } + + if (cparams.butteraugli_distance != 0 && + cparams.butteraugli_distance < kMinButteraugliDistance) { + return JXL_FAILURE("Butteraugli distance is too low (%f)", + cparams.butteraugli_distance); + } + if (cparams.butteraugli_distance > 0.9f && cparams.modular_mode == false && + cparams.quality_pair.first == 100) { + // in case the color image is lossy, make the alpha slightly lossy too + cparams.quality_pair.first = + std::max(90.f, 99.f - 0.3f * cparams.butteraugli_distance); + } + + if (ib.IsJPEG()) { + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.modular_mode = false; + } + + const size_t xsize = ib.xsize(); + const size_t ysize = ib.ysize(); + if (xsize == 0 || ysize == 0) return JXL_FAILURE("Empty image"); + + // Assert that this metadata is correctly set up for the compression params, + // this should have been done by enc_file.cc + JXL_ASSERT(metadata->m.xyb_encoded == + (cparams.color_transform == ColorTransform::kXYB)); + std::unique_ptr frame_header = + jxl::make_unique(metadata); + JXL_RETURN_IF_ERROR(MakeFrameHeader(cparams, + passes_enc_state->progressive_splitter, + frame_info, ib, frame_header.get())); + // Check that if the codestream header says xyb_encoded, the color_transform + // matches the requirement. This is checked from the cparams here, even though + // optimally we'd be able to check this against what has actually been written + // in the main codestream header, but since ib is a const object and the data + // written to the main codestream header is (in modified form) in ib, the + // encoder cannot indicate this fact in the ib's metadata. + if (cparams_orig.color_transform == ColorTransform::kXYB) { + if (frame_header->color_transform != ColorTransform::kXYB) { + return JXL_FAILURE( + "The color transform of frames must be xyb if the codestream is xyb " + "encoded"); + } + } else { + if (frame_header->color_transform == ColorTransform::kXYB) { + return JXL_FAILURE( + "The color transform of frames cannot be xyb if the codestream is " + "not xyb encoded"); + } + } + + FrameDimensions frame_dim = frame_header->ToFrameDimensions(); + + const size_t num_groups = frame_dim.num_groups; + + Image3F opsin; + const ColorEncoding& c_linear = ColorEncoding::LinearSRGB(ib.IsGray()); + std::unique_ptr metadata_linear = + jxl::make_unique(); + metadata_linear->xyb_encoded = + (cparams.color_transform == ColorTransform::kXYB); + metadata_linear->color_encoding = c_linear; + ImageBundle linear_storage(metadata_linear.get()); + + std::vector aux_outs; + // LossyFrameEncoder stores a reference to a std::function + // so we need to keep the std::function being referenced + // alive while lossy_frame_encoder is used. We could make resize_aux_outs a + // lambda type by making LossyFrameEncoder a template instead, but this is + // simpler. + const std::function resize_aux_outs = + [&aux_outs, aux_out](size_t num_threads) -> Status { + if (aux_out != nullptr) { + size_t old_size = aux_outs.size(); + for (size_t i = num_threads; i < old_size; i++) { + aux_out->Assimilate(aux_outs[i]); + } + aux_outs.resize(num_threads); + // Each thread needs these INPUTS. Don't copy the entire AuxOut + // because it may contain stats which would be Assimilated multiple + // times below. + for (size_t i = old_size; i < aux_outs.size(); i++) { + aux_outs[i].dump_image = aux_out->dump_image; + aux_outs[i].debug_prefix = aux_out->debug_prefix; + } + } + return true; + }; + + LossyFrameEncoder lossy_frame_encoder(cparams, *frame_header, + passes_enc_state, pool, aux_out); + std::unique_ptr modular_frame_encoder = + jxl::make_unique(*frame_header, cparams); + + if (ib.IsJPEG()) { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeJPEGTranscodingData( + *ib.jpeg_data, modular_frame_encoder.get(), frame_header.get())); + } else if (!lossy_frame_encoder.State()->heuristics->HandlesColorConversion( + cparams, ib) || + frame_header->encoding != FrameEncoding::kVarDCT) { + // Allocating a large enough image avoids a copy when padding. + opsin = + Image3F(RoundUpToBlockDim(ib.xsize()), RoundUpToBlockDim(ib.ysize())); + opsin.ShrinkTo(ib.xsize(), ib.ysize()); + + const bool want_linear = frame_header->encoding == FrameEncoding::kVarDCT && + cparams.speed_tier <= SpeedTier::kKitten; + const ImageBundle* JXL_RESTRICT ib_or_linear = &ib; + + if (frame_header->color_transform == ColorTransform::kXYB && + frame_info.ib_needs_color_transform) { + // linear_storage would only be used by the Butteraugli loop (passing + // linear sRGB avoids a color conversion there). Otherwise, don't + // fill it to reduce memory usage. + ib_or_linear = + ToXYB(ib, pool, &opsin, want_linear ? &linear_storage : nullptr); + } else { // RGB or YCbCr: don't do anything (forward YCbCr is not + // implemented, this is only used when the input is already in + // YCbCr) + // If encoding a special DC or reference frame, don't do anything: + // input is already in XYB. + CopyImageTo(ib.color(), &opsin); + } + if (ib.HasAlpha() && !ib.AlphaIsPremultiplied() && + (frame_header->encoding == FrameEncoding::kVarDCT || + cparams.quality_pair.first < 100) && + !cparams.keep_invisible) { + // if lossy, simplify invisible pixels + SimplifyInvisible(&opsin, ib.alpha()); + if (want_linear) { + SimplifyInvisible(const_cast(&ib_or_linear->color()), + ib.alpha()); + } + } + if (aux_out != nullptr) { + JXL_RETURN_IF_ERROR( + aux_out->InspectImage3F("enc_frame:OpsinDynamicsImage", opsin)); + } + if (frame_header->encoding == FrameEncoding::kVarDCT) { + PadImageToBlockMultipleInPlace(&opsin); + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeEncodingData( + ib_or_linear, &opsin, pool, modular_frame_encoder.get(), writer, + frame_header.get())); + } else if (cparams.resampling != 1) { + // In VarDCT mode, LossyFrameHeuristics takes care of running downsampling + // after noise, if necessary. + DownsampleImage(&opsin, cparams.resampling); + } + } else { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.ComputeEncodingData( + &ib, &opsin, pool, modular_frame_encoder.get(), writer, + frame_header.get())); + } + // needs to happen *AFTER* VarDCT-ComputeEncodingData. + JXL_RETURN_IF_ERROR(modular_frame_encoder->ComputeEncodingData( + *frame_header, ib, &opsin, lossy_frame_encoder.State(), pool, aux_out, + /* do_color=*/frame_header->encoding == FrameEncoding::kModular)); + + writer->AppendByteAligned(lossy_frame_encoder.State()->special_frames); + frame_header->UpdateFlag( + lossy_frame_encoder.State()->shared.image_features.patches.HasAny(), + FrameHeader::kPatches); + frame_header->UpdateFlag( + lossy_frame_encoder.State()->shared.image_features.splines.HasAny(), + FrameHeader::kSplines); + JXL_RETURN_IF_ERROR(WriteFrameHeader(*frame_header, writer, aux_out)); + + const size_t num_passes = + passes_enc_state->progressive_splitter.GetNumPasses(); + + // DC global info + DC groups + AC global info + AC groups * + // num_passes. + const bool has_ac_global = true; + std::vector group_codes(NumTocEntries(frame_dim.num_groups, + frame_dim.num_dc_groups, + num_passes, has_ac_global)); + const size_t global_ac_index = frame_dim.num_dc_groups + 1; + const bool is_small_image = frame_dim.num_groups == 1 && num_passes == 1; + const auto get_output = [&](const size_t index) { + return &group_codes[is_small_image ? 0 : index]; + }; + auto ac_group_code = [&](size_t pass, size_t group) { + return get_output(AcGroupIndex(pass, group, frame_dim.num_groups, + frame_dim.num_dc_groups, has_ac_global)); + }; + + if (frame_header->flags & FrameHeader::kPatches) { + PatchDictionaryEncoder::Encode( + lossy_frame_encoder.State()->shared.image_features.patches, + get_output(0), kLayerDictionary, aux_out); + } + + if (frame_header->flags & FrameHeader::kSplines) { + EncodeSplines(lossy_frame_encoder.State()->shared.image_features.splines, + get_output(0), kLayerSplines, HistogramParams(), aux_out); + } + + if (frame_header->flags & FrameHeader::kNoise) { + EncodeNoise(lossy_frame_encoder.State()->shared.image_features.noise_params, + get_output(0), kLayerNoise, aux_out); + } + + JXL_RETURN_IF_ERROR( + DequantMatricesEncodeDC(&lossy_frame_encoder.State()->shared.matrices, + get_output(0), kLayerDequantTables, aux_out)); + if (frame_header->encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + lossy_frame_encoder.EncodeGlobalDCInfo(*frame_header, get_output(0))); + } + JXL_RETURN_IF_ERROR( + modular_frame_encoder->EncodeGlobalInfo(get_output(0), aux_out)); + JXL_RETURN_IF_ERROR(modular_frame_encoder->EncodeStream( + get_output(0), aux_out, kLayerModularGlobal, ModularStreamId::Global())); + + const auto process_dc_group = [&](const int group_index, const int thread) { + AuxOut* my_aux_out = aux_out ? &aux_outs[thread] : nullptr; + BitWriter* output = get_output(group_index + 1); + if (frame_header->encoding == FrameEncoding::kVarDCT && + !(frame_header->flags & FrameHeader::kUseDcFrame)) { + BitWriter::Allotment allotment(output, 2); + output->Write(2, modular_frame_encoder->extra_dc_precision[group_index]); + ReclaimAndCharge(output, &allotment, kLayerDC, my_aux_out); + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerDC, + ModularStreamId::VarDCTDC(group_index))); + } + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerModularDcGroup, + ModularStreamId::ModularDC(group_index))); + if (frame_header->encoding == FrameEncoding::kVarDCT) { + const Rect& rect = + lossy_frame_encoder.State()->shared.DCGroupRect(group_index); + size_t nb_bits = CeilLog2Nonzero(rect.xsize() * rect.ysize()); + if (nb_bits != 0) { + BitWriter::Allotment allotment(output, nb_bits); + output->Write(nb_bits, + modular_frame_encoder->ac_metadata_size[group_index] - 1); + ReclaimAndCharge(output, &allotment, kLayerControlFields, my_aux_out); + } + JXL_CHECK(modular_frame_encoder->EncodeStream( + output, my_aux_out, kLayerControlFields, + ModularStreamId::ACMetadata(group_index))); + } + }; + RunOnPool(pool, 0, frame_dim.num_dc_groups, resize_aux_outs, process_dc_group, + "EncodeDCGroup"); + + if (frame_header->encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(lossy_frame_encoder.EncodeGlobalACInfo( + get_output(global_ac_index), modular_frame_encoder.get())); + } + + std::atomic num_errors{0}; + const auto process_group = [&](const int group_index, const int thread) { + AuxOut* my_aux_out = aux_out ? &aux_outs[thread] : nullptr; + + for (size_t i = 0; i < num_passes; i++) { + if (frame_header->encoding == FrameEncoding::kVarDCT) { + if (!lossy_frame_encoder.EncodeACGroup( + i, group_index, ac_group_code(i, group_index), my_aux_out)) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + // Write all modular encoded data (color?, alpha, depth, extra channels) + if (!modular_frame_encoder->EncodeStream( + ac_group_code(i, group_index), my_aux_out, kLayerModularAcGroup, + ModularStreamId::ModularAC(group_index, i))) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + }; + RunOnPool(pool, 0, num_groups, resize_aux_outs, process_group, + "EncodeGroupCoefficients"); + + // Resizing aux_outs to 0 also Assimilates the array. + static_cast(resize_aux_outs(0)); + JXL_RETURN_IF_ERROR(num_errors.load(std::memory_order_relaxed) == 0); + + for (BitWriter& bw : group_codes) { + bw.ZeroPadToByte(); // end of group. + } + + std::vector* permutation_ptr = nullptr; + std::vector permutation; + if (cparams.middleout && !(num_passes == 1 && num_groups == 1)) { + permutation_ptr = &permutation; + // Don't permute global DC/AC or DC. + permutation.resize(global_ac_index + 1); + std::iota(permutation.begin(), permutation.end(), 0); + std::vector ac_group_order(num_groups); + std::iota(ac_group_order.begin(), ac_group_order.end(), 0); + int64_t cx = ib.xsize() / 2; + int64_t cy = ib.ysize() / 2; + auto get_distance_from_center = [&](size_t gid) { + Rect r = passes_enc_state->shared.GroupRect(gid); + int64_t gcx = r.x0() + r.xsize() / 2; + int64_t gcy = r.y0() + r.ysize() / 2; + int64_t dx = gcx - cx; + int64_t dy = gcy - cy; + // Concentric squares in counterclockwise order. + return std::make_pair(std::max(std::abs(dx), std::abs(dy)), + std::atan2(dy, dx)); + }; + std::sort(ac_group_order.begin(), ac_group_order.end(), + [&](coeff_order_t a, coeff_order_t b) { + return get_distance_from_center(a) < + get_distance_from_center(b); + }); + std::vector inv_ac_group_order(ac_group_order.size(), 0); + for (size_t i = 0; i < ac_group_order.size(); i++) { + inv_ac_group_order[ac_group_order[i]] = i; + } + for (size_t i = 0; i < num_passes; i++) { + size_t pass_start = permutation.size(); + for (coeff_order_t v : inv_ac_group_order) { + permutation.push_back(pass_start + v); + } + } + std::vector new_group_codes(group_codes.size()); + for (size_t i = 0; i < permutation.size(); i++) { + new_group_codes[permutation[i]] = std::move(group_codes[i]); + } + group_codes = std::move(new_group_codes); + } + + JXL_RETURN_IF_ERROR( + WriteGroupOffsets(group_codes, permutation_ptr, writer, aux_out)); + writer->AppendByteAligned(group_codes); + writer->ZeroPadToByte(); // end of frame. + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_frame.h b/third_party/jpeg-xl/lib/jxl/enc_frame.h new file mode 100644 index 000000000000..851a2c0a1102 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_frame.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_FRAME_H_ +#define LIB_JXL_ENC_FRAME_H_ + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Information needed for encoding a frame that is not contained elsewhere and +// does not belong to `cparams`. +struct FrameInfo { + // TODO(veluca): consider adding more parameters, such as custom patches. + bool save_before_color_transform = false; + // Whether or not the input image bundle is already in the codestream + // colorspace (as deduced by cparams). + // TODO(veluca): this is a hack - ImageBundle doesn't have a simple way to say + // "this is already in XYB". + bool ib_needs_color_transform = true; + FrameType frame_type = FrameType::kRegularFrame; + size_t dc_level = 0; + // Only used for kRegularFrame. + bool is_last = true; + bool is_preview = false; + // Information for storing this frame for future use (only for non-DC frames). + size_t save_as_reference = 0; +}; + +// Encodes a single frame (including its header) into a byte stream. Groups may +// be processed in parallel by `pool`. metadata is the ImageMetadata encoded in +// the codestream, and must be used for the FrameHeaders, do not use +// ib.metadata. +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, PassesEncoderState* passes_enc_state, + ThreadPool* pool, BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_FRAME_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h b/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h new file mode 100644 index 000000000000..3458be9ce6f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_GAMMA_CORRECT_H_ +#define LIB_JXL_ENC_GAMMA_CORRECT_H_ + +// Deprecated: sRGB transfer function. Use color_management.h instead. + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/transfer_functions-inl.h" + +namespace jxl { + +// Values are in [0, 1]. +static JXL_INLINE double Srgb8ToLinearDirect(double srgb) { + if (srgb <= 0.0) return 0.0; + if (srgb <= 0.04045) return srgb / 12.92; + if (srgb >= 1.0) return 1.0; + return std::pow((srgb + 0.055) / 1.055, 2.4); +} + +// Values are in [0, 1]. +static JXL_INLINE double LinearToSrgb8Direct(double linear) { + if (linear <= 0.0) return 0.0; + if (linear >= 1.0) return 1.0; + if (linear <= 0.0031308) return linear * 12.92; + return std::pow(linear, 1.0 / 2.4) * 1.055 - 0.055; +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_GAMMA_CORRECT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_group.cc b/third_party/jpeg-xl/lib/jxl/enc_group.cc new file mode 100644 index 000000000000..7f2e56235907 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_group.cc @@ -0,0 +1,351 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_group.h" + +#include + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_group.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeBlockAC(const Quantizer& quantizer, const bool error_diffusion, + size_t c, int32_t quant, float qm_multiplier, + size_t quant_kind, size_t xsize, size_t ysize, + const float* JXL_RESTRICT block_in, + int32_t* JXL_RESTRICT block_out) { + PROFILER_FUNC; + const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c); + const float qac = quantizer.Scale() * quant; + // Not SIMD-fied for now. + float thres[4] = {0.5f, 0.6f, 0.6f, 0.65f}; + if (c != 1) { + for (int i = 1; i < 4; ++i) { + thres[i] = 0.75f; + } + } + + if (!error_diffusion) { + HWY_CAPPED(float, kBlockDim) df; + HWY_CAPPED(int32_t, kBlockDim) di; + HWY_CAPPED(uint32_t, kBlockDim) du; + const auto quant = Set(df, qac * qm_multiplier); + + for (size_t y = 0; y < ysize * kBlockDim; y++) { + size_t yfix = static_cast(y >= ysize * kBlockDim / 2) * 2; + const size_t off = y * kBlockDim * xsize; + for (size_t x = 0; x < xsize * kBlockDim; x += Lanes(df)) { + auto thr = Zero(df); + if (xsize == 1) { + HWY_ALIGN uint32_t kMask[kBlockDim] = {0, 0, 0, 0, + ~0u, ~0u, ~0u, ~0u}; + const auto mask = MaskFromVec(BitCast(df, Load(du, kMask + x))); + thr = + IfThenElse(mask, Set(df, thres[yfix + 1]), Set(df, thres[yfix])); + } else { + // Same for all lanes in the vector. + thr = Set( + df, + thres[yfix + static_cast(x >= xsize * kBlockDim / 2)]); + } + + const auto q = Load(df, qm + off + x) * quant; + const auto in = Load(df, block_in + off + x); + const auto val = q * in; + const auto nzero_mask = Abs(val) >= thr; + const auto v = ConvertTo(di, IfThenElseZero(nzero_mask, Round(val))); + Store(v, di, block_out + off + x); + } + } + return; + } + +retry: + int hfNonZeros[4] = {}; + float hfError[4] = {}; + float hfMaxError[4] = {}; + size_t hfMaxErrorIx[4] = {}; + for (size_t y = 0; y < ysize * kBlockDim; y++) { + for (size_t x = 0; x < xsize * kBlockDim; x++) { + const size_t pos = y * kBlockDim * xsize + x; + if (x < xsize && y < ysize) { + // Ensure block is initialized + block_out[pos] = 0; + continue; + } + const size_t hfix = (static_cast(y >= ysize * kBlockDim / 2) * 2 + + static_cast(x >= xsize * kBlockDim / 2)); + const float val = block_in[pos] * (qm[pos] * qac * qm_multiplier); + float v = (std::abs(val) < thres[hfix]) ? 0 : rintf(val); + const float error = std::abs(val) - std::abs(v); + hfError[hfix] += error; + if (hfMaxError[hfix] < error) { + hfMaxError[hfix] = error; + hfMaxErrorIx[hfix] = pos; + } + if (v != 0.0f) { + hfNonZeros[hfix] += std::abs(v); + } + block_out[pos] = static_cast(rintf(v)); + } + } + if (c != 1) return; + // TODO(veluca): include AFV? + const size_t kPartialBlockKinds = + (1 << AcStrategy::Type::IDENTITY) | (1 << AcStrategy::Type::DCT2X2) | + (1 << AcStrategy::Type::DCT4X4) | (1 << AcStrategy::Type::DCT4X8) | + (1 << AcStrategy::Type::DCT8X4); + if ((1 << quant_kind) & kPartialBlockKinds) return; + float hfErrorLimit = 0.1f * (xsize * ysize) * kDCTBlockSize * 0.25f; + bool goretry = false; + for (int i = 1; i < 4; ++i) { + if (hfError[i] >= hfErrorLimit && + hfNonZeros[i] <= (xsize + ysize) * 0.25f) { + if (thres[i] >= 0.4f) { + thres[i] -= 0.01f; + goretry = true; + } + } + } + if (goretry) goto retry; + for (int i = 1; i < 4; ++i) { + if (hfError[i] >= hfErrorLimit && hfNonZeros[i] == 0) { + const size_t pos = hfMaxErrorIx[i]; + if (hfMaxError[i] >= 0.4f) { + block_out[pos] = block_in[pos] > 0.0f ? 1.0f : -1.0f; + } + } + } +} + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeRoundtripYBlockAC(const Quantizer& quantizer, + const bool error_diffusion, int32_t quant, + size_t quant_kind, size_t xsize, size_t ysize, + const float* JXL_RESTRICT biases, + float* JXL_RESTRICT inout, + int32_t* JXL_RESTRICT quantized) { + QuantizeBlockAC(quantizer, error_diffusion, 1, quant, 1.0f, quant_kind, xsize, + ysize, inout, quantized); + + PROFILER_ZONE("enc quant adjust bias"); + const float* JXL_RESTRICT dequant_matrix = + quantizer.DequantMatrix(quant_kind, 1); + + HWY_CAPPED(float, kDCTBlockSize) df; + HWY_CAPPED(int32_t, kDCTBlockSize) di; + const auto inv_qac = Set(df, quantizer.inv_quant_ac(quant)); + for (size_t k = 0; k < kDCTBlockSize * xsize * ysize; k += Lanes(df)) { + const auto quant = Load(di, quantized + k); + const auto adj_quant = AdjustQuantBias(di, 1, quant, biases); + const auto dequantm = Load(df, dequant_matrix + k); + Store(adj_quant * dequantm * inv_qac, df, inout + k); + } +} + +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc) { + PROFILER_FUNC; + const Rect block_group_rect = enc_state->shared.BlockGroupRect(group_idx); + const Rect group_rect = enc_state->shared.GroupRect(group_idx); + const Rect cmap_rect( + block_group_rect.x0() / kColorTileDimInBlocks, + block_group_rect.y0() / kColorTileDimInBlocks, + DivCeil(block_group_rect.xsize(), kColorTileDimInBlocks), + DivCeil(block_group_rect.ysize(), kColorTileDimInBlocks)); + + const size_t xsize_blocks = block_group_rect.xsize(); + const size_t ysize_blocks = block_group_rect.ysize(); + + const size_t dc_stride = static_cast(dc->PixelsPerRow()); + const size_t opsin_stride = static_cast(opsin.PixelsPerRow()); + + const ImageI& full_quant_field = enc_state->shared.raw_quant_field; + const CompressParams& cparams = enc_state->cparams; + + // TODO(veluca): consider strategies to reduce this memory. + auto mem = hwy::AllocateAligned(3 * AcStrategy::kMaxCoeffArea); + auto fmem = hwy::AllocateAligned(5 * AcStrategy::kMaxCoeffArea); + float* JXL_RESTRICT scratch_space = + fmem.get() + 3 * AcStrategy::kMaxCoeffArea; + { + // Only use error diffusion in Squirrel mode or slower. + const bool error_diffusion = cparams.speed_tier <= SpeedTier::kSquirrel; + constexpr HWY_CAPPED(float, kDCTBlockSize) d; + + int32_t* JXL_RESTRICT coeffs[kMaxNumPasses][3] = {}; + size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); + JXL_DASSERT(num_passes > 0); + for (size_t i = 0; i < num_passes; i++) { + // TODO(veluca): 16-bit quantized coeffs are not implemented yet. + JXL_ASSERT(enc_state->coeffs[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + coeffs[i][c] = enc_state->coeffs[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + + HWY_ALIGN float* coeffs_in = fmem.get(); + HWY_ALIGN int32_t* quantized = mem.get(); + + size_t offset = 0; + + for (size_t by = 0; by < ysize_blocks; ++by) { + const int32_t* JXL_RESTRICT row_quant_ac = + block_group_rect.ConstRow(full_quant_field, by); + size_t ty = by / kColorTileDimInBlocks; + const int8_t* JXL_RESTRICT row_cmap[3] = { + cmap_rect.ConstRow(enc_state->shared.cmap.ytox_map, ty), + nullptr, + cmap_rect.ConstRow(enc_state->shared.cmap.ytob_map, ty), + }; + const float* JXL_RESTRICT opsin_rows[3] = { + group_rect.ConstPlaneRow(opsin, 0, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 1, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 2, by * kBlockDim), + }; + float* JXL_RESTRICT dc_rows[3] = { + block_group_rect.PlaneRow(dc, 0, by), + block_group_rect.PlaneRow(dc, 1, by), + block_group_rect.PlaneRow(dc, 2, by), + }; + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(block_group_rect, by); + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + const auto x_factor = + Set(d, enc_state->shared.cmap.YtoXRatio(row_cmap[0][tx])); + const auto b_factor = + Set(d, enc_state->shared.cmap.YtoBRatio(row_cmap[2][tx])); + for (size_t bx = tx * kColorTileDimInBlocks; + bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks; ++bx) { + const AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + + size_t xblocks = acs.covered_blocks_x(); + size_t yblocks = acs.covered_blocks_y(); + + CoefficientLayout(&yblocks, &xblocks); + + size_t size = kDCTBlockSize * xblocks * yblocks; + + // DCT Y channel, roundtrip-quantize it and set DC. + const int32_t quant_ac = row_quant_ac[bx]; + TransformFromPixels(acs.Strategy(), opsin_rows[1] + bx * kBlockDim, + opsin_stride, coeffs_in + size, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + size, + dc_rows[1] + bx, dc_stride); + QuantizeRoundtripYBlockAC( + enc_state->shared.quantizer, error_diffusion, quant_ac, + acs.RawStrategy(), xblocks, yblocks, kDefaultQuantBias, + coeffs_in + size, quantized + size); + + // DCT X and B channels + for (size_t c : {0, 2}) { + TransformFromPixels(acs.Strategy(), opsin_rows[c] + bx * kBlockDim, + opsin_stride, coeffs_in + c * size, + scratch_space); + } + + // Unapply color correlation + for (size_t k = 0; k < size; k += Lanes(d)) { + const auto in_x = Load(d, coeffs_in + k); + const auto in_y = Load(d, coeffs_in + size + k); + const auto in_b = Load(d, coeffs_in + 2 * size + k); + const auto out_x = in_x - x_factor * in_y; + const auto out_b = in_b - b_factor * in_y; + Store(out_x, d, coeffs_in + k); + Store(out_b, d, coeffs_in + 2 * size + k); + } + + // Quantize X and B channels and set DC. + for (size_t c : {0, 2}) { + QuantizeBlockAC(enc_state->shared.quantizer, error_diffusion, c, + quant_ac, + c == 0 ? enc_state->x_qm_multiplier + : enc_state->b_qm_multiplier, + acs.RawStrategy(), xblocks, yblocks, + coeffs_in + c * size, quantized + c * size); + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + c * size, + dc_rows[c] + bx, dc_stride); + } + enc_state->progressive_splitter.SplitACCoefficients( + quantized, size, acs, bx, by, offset, coeffs); + offset += size; + } + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ComputeCoefficients); +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc) { + return HWY_DYNAMIC_DISPATCH(ComputeCoefficients)(group_idx, enc_state, opsin, + dc); +} + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out) { + // Select which histogram to use among those of the current pass. + const size_t num_histograms = enc_state.shared.num_histograms; + // num_histograms is 0 only for lossless. + JXL_ASSERT(num_histograms == 0 || histogram_idx < num_histograms); + size_t histo_selector_bits = CeilLog2Nonzero(num_histograms); + + if (histo_selector_bits != 0) { + BitWriter::Allotment allotment(writer, histo_selector_bits); + writer->Write(histo_selector_bits, histogram_idx); + ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out); + } + WriteTokens(enc_state.passes[pass_idx].ac_tokens[group_idx], + enc_state.passes[pass_idx].codes, + enc_state.passes[pass_idx].context_map, writer, kLayerACTokens, + aux_out); + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_group.h b/third_party/jpeg-xl/lib/jxl/enc_group.h new file mode 100644 index 000000000000..aa905a027a0e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_group.h @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_GROUP_H_ +#define LIB_JXL_ENC_GROUP_H_ + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" + +namespace jxl { + +// Fills DC +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, Image3F* dc); + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_GROUP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc new file mode 100644 index 000000000000..4396d01b44e2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc @@ -0,0 +1,437 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_heuristics.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/enc_ac_strategy.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ar_control_field.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/gaborish.h" + +namespace jxl { +namespace { +void FindBestBlockEntropyModel(PassesEncoderState& enc_state) { + if (enc_state.cparams.decoding_speed_tier >= 1) { + static constexpr uint8_t kSimpleCtxMap[] = { + // Cluster all blocks together + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + }; + static_assert( + 3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap, + "Update simple context map"); + + auto bcm = enc_state.shared.block_ctx_map; + bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap)); + bcm.num_ctxs = 2; + bcm.num_dc_ctxs = 1; + return; + } + if (enc_state.cparams.speed_tier == SpeedTier::kFalcon) { + return; + } + const ImageI& rqf = enc_state.shared.raw_quant_field; + // No need to change context modeling for small images. + size_t tot = rqf.xsize() * rqf.ysize(); + size_t size_for_ctx_model = + (1 << 10) * enc_state.cparams.butteraugli_distance; + if (tot < size_for_ctx_model) return; + + struct OccCounters { + // count the occurrences of each qf value and each strategy type. + OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) { + for (size_t y = 0; y < rqf.ysize(); y++) { + const int32_t* qf_row = rqf.Row(y); + AcStrategyRow acs_row = ac_strategy.ConstRow(y); + for (size_t x = 0; x < rqf.xsize(); x++) { + int ord = kStrategyOrder[acs_row[x].RawStrategy()]; + int qf = qf_row[x] - 1; + qf_counts[qf]++; + qf_ord_counts[ord][qf]++; + ord_counts[ord]++; + } + } + } + + size_t qf_counts[256] = {}; + size_t qf_ord_counts[kNumOrders][256] = {}; + size_t ord_counts[kNumOrders] = {}; + }; + // The OccCounters struct is too big to allocate on the stack. + std::unique_ptr counters( + new OccCounters(rqf, enc_state.shared.ac_strategy)); + + // Splitting the context model according to the quantization field seems to + // mostly benefit only large images. + size_t size_for_qf_split = (1 << 13) * enc_state.cparams.butteraugli_distance; + size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2; + std::vector& qft = enc_state.shared.block_ctx_map.qf_thresholds; + qft.clear(); + // Divide the quant field in up to num_qf_segments segments. + size_t cumsum = 0; + size_t next = 1; + size_t last_cut = 256; + size_t cut = tot * next / num_qf_segments; + for (uint32_t j = 0; j < 256; j++) { + cumsum += counters->qf_counts[j]; + if (cumsum > cut) { + if (j != 0) { + qft.push_back(j); + } + last_cut = j; + while (cumsum > cut) { + next++; + cut = tot * next / num_qf_segments; + } + } else if (next > qft.size() + 1) { + if (j - 1 == last_cut && j != 0) { + qft.push_back(j); + } + } + } + + // Count the occurrences of each segment. + std::vector counts(kNumOrders * (qft.size() + 1)); + size_t qft_pos = 0; + for (size_t j = 0; j < 256; j++) { + if (qft_pos < qft.size() && j == qft[qft_pos]) { + qft_pos++; + } + for (size_t i = 0; i < kNumOrders; i++) { + counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j]; + } + } + + // Repeatedly merge the lowest-count pair. + std::vector remap((qft.size() + 1) * kNumOrders); + std::iota(remap.begin(), remap.end(), 0); + std::vector clusters(remap); + // This is O(n^2 log n), but n <= 14. + while (clusters.size() > 5) { + std::sort(clusters.begin(), clusters.end(), + [&](int a, int b) { return counts[a] > counts[b]; }); + counts[clusters[clusters.size() - 2]] += counts[clusters.back()]; + counts[clusters.back()] = 0; + remap[clusters.back()] = clusters[clusters.size() - 2]; + clusters.pop_back(); + } + for (size_t i = 0; i < remap.size(); i++) { + while (remap[remap[i]] != remap[i]) { + remap[i] = remap[remap[i]]; + } + } + // Relabel starting from 0. + std::vector remap_remap(remap.size(), remap.size()); + size_t num = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (remap_remap[remap[i]] == remap.size()) { + remap_remap[remap[i]] = num++; + } + remap[i] = remap_remap[remap[i]]; + } + // Write the block context map. + auto& ctx_map = enc_state.shared.block_ctx_map.ctx_map; + ctx_map = remap; + ctx_map.resize(remap.size() * 3); + for (size_t i = remap.size(); i < remap.size() * 3; i++) { + ctx_map[i] = remap[i % remap.size()] + num; + } + enc_state.shared.block_ctx_map.num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; +} + +// Returns the target size based on whether bitrate or direct targetsize is +// given. +size_t TargetSize(const CompressParams& cparams, + const FrameDimensions& frame_dim) { + if (cparams.target_size > 0) { + return cparams.target_size; + } + if (cparams.target_bitrate > 0.0) { + return 0.5 + cparams.target_bitrate * frame_dim.xsize * frame_dim.ysize / + kBitsPerByte; + } + return 0; +} +} // namespace + +void FindBestDequantMatrices(const CompressParams& cparams, + const Image3F& opsin, + ModularFrameEncoder* modular_frame_encoder, + DequantMatrices* dequant_matrices) { + // TODO(veluca): quant matrices for no-gaborish. + // TODO(veluca): heuristics for in-bitstream quant tables. + *dequant_matrices = DequantMatrices(); + if (cparams.max_error_mode) { + // Set numerators of all quantization matrices to constant values. + float weights[3][1] = {{1.0f / cparams.max_error[0]}, + {1.0f / cparams.max_error[1]}, + {1.0f / cparams.max_error[2]}}; + DctQuantWeightParams dct_params(weights); + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatricesSetCustom(dequant_matrices, encodings, + modular_frame_encoder); + float dc_weights[3] = {1.0f / cparams.max_error[0], + 1.0f / cparams.max_error[1], + 1.0f / cparams.max_error[2]}; + DequantMatricesSetCustomDC(dequant_matrices, dc_weights); + } +} + +bool DefaultEncoderHeuristics::HandlesColorConversion( + const CompressParams& cparams, const ImageBundle& ib) { + return cparams.noise != Override::kOn && cparams.patches != Override::kOn && + cparams.speed_tier >= SpeedTier::kWombat && cparams.resampling == 1 && + cparams.color_transform == ColorTransform::kXYB && + !cparams.modular_mode && !ib.HasAlpha(); +} + +Status DefaultEncoderHeuristics::LossyFrameHeuristics( + PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, Image3F* opsin, ThreadPool* pool, + AuxOut* aux_out) { + PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented"); + + CompressParams& cparams = enc_state->cparams; + PassesSharedState& shared = enc_state->shared; + + // Compute parameters for noise synthesis. + if (shared.frame_header.flags & FrameHeader::kNoise) { + PROFILER_ZONE("enc GetNoiseParam"); + // Don't start at zero amplitude since adding noise is expensive -- it + // significantly slows down decoding, and this is unlikely to + // completely go away even with advanced optimizations. After the + // kNoiseModelingRampUpDistanceRange we have reached the full level, + // i.e. noise is no longer represented by the compressed image, so we + // can add full noise by the noise modeling itself. + static const float kNoiseModelingRampUpDistanceRange = 0.6; + static const float kNoiseLevelAtStartOfRampUp = 0.25; + static const float kNoiseRampupStart = 1.0; + // TODO(user) test and properly select quality_coef with smooth + // filter + float quality_coef = 1.0f; + const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) / + kNoiseModelingRampUpDistanceRange; + if (rampup < 1.0f) { + quality_coef = kNoiseLevelAtStartOfRampUp + + (1.0f - kNoiseLevelAtStartOfRampUp) * rampup; + } + if (rampup < 0.0f) { + quality_coef = kNoiseRampupStart; + } + if (!GetNoiseParameter(*opsin, &shared.image_features.noise_params, + quality_coef)) { + shared.frame_header.flags &= ~FrameHeader::kNoise; + } + } + if (cparams.resampling != 1) { + // In VarDCT mode, LossyFrameHeuristics takes care of running downsampling + // after noise, if necessary. + DownsampleImage(opsin, cparams.resampling); + PadImageToBlockMultipleInPlace(opsin); + } + + const FrameDimensions& frame_dim = enc_state->shared.frame_dim; + size_t target_size = TargetSize(cparams, frame_dim); + size_t opsin_target_size = target_size; + if (cparams.target_size > 0 || cparams.target_bitrate > 0.0) { + cparams.target_size = opsin_target_size; + } else if (cparams.butteraugli_distance < 0) { + return JXL_FAILURE("Expected non-negative distance"); + } + + // Find and subtract splines. + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + shared.image_features.splines = FindSplines(*opsin); + JXL_RETURN_IF_ERROR( + shared.image_features.splines.SubtractFrom(opsin, shared.cmap)); + } + + // Find and subtract patches/dots. + if (ApplyOverride(cparams.patches, + cparams.speed_tier <= SpeedTier::kSquirrel)) { + FindBestPatchDictionary(*opsin, enc_state, pool, aux_out); + PatchDictionaryEncoder::SubtractFrom(shared.image_features.patches, opsin); + } + + static const float kAcQuant = 0.79f; + const float quant_dc = InitialQuantDC(cparams.butteraugli_distance); + Quantizer& quantizer = enc_state->shared.quantizer; + // We don't know the quant field yet, but for computing the global scale + // assuming that it will be the same as for Falcon mode is good enough. + quantizer.ComputeGlobalScaleAndQuant( + quant_dc, kAcQuant / cparams.butteraugli_distance, 0); + + // TODO(veluca): we can now run all the code from here to FindBestQuantizer + // (excluded) one rect at a time. Do that. + + // Dependency graph: + // + // input: either XYB or input image + // + // input image -> XYB [optional] + // XYB -> initial quant field + // XYB -> Gaborished XYB + // Gaborished XYB -> CfL1 + // initial quant field, Gaborished XYB, CfL1 -> ACS + // initial quant field, ACS, Gaborished XYB -> EPF control field + // initial quant field -> adjusted initial quant field + // adjusted initial quant field, ACS -> raw quant field + // raw quant field, ACS, Gaborished XYB -> CfL2 + // + // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field. + + ArControlFieldHeuristics ar_heuristics; + AcStrategyHeuristics acs_heuristics; + CfLHeuristics cfl_heuristics; + + if (!opsin->xsize()) { + JXL_ASSERT(HandlesColorConversion(cparams, *original_pixels)); + *opsin = Image3F(RoundUpToBlockDim(original_pixels->xsize()), + RoundUpToBlockDim(original_pixels->ysize())); + opsin->ShrinkTo(original_pixels->xsize(), original_pixels->ysize()); + ToXYB(*original_pixels, pool, opsin, /*linear=*/nullptr); + PadImageToBlockMultipleInPlace(opsin); + } + + // Compute an initial estimate of the quantization field. + // Call InitialQuantField only in Hare mode or slower. Otherwise, rely + // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon + // mode. + if (cparams.speed_tier > SpeedTier::kHare || cparams.uniform_quant > 0) { + enc_state->initial_quant_field = + ImageF(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks); + if (cparams.speed_tier == SpeedTier::kFalcon || cparams.uniform_quant > 0) { + float q = cparams.uniform_quant > 0 + ? cparams.uniform_quant + : kAcQuant / cparams.butteraugli_distance; + FillImage(q, &enc_state->initial_quant_field); + } + } else { + // Call this here, as it relies on pre-gaborish values. + float butteraugli_distance_for_iqf = cparams.butteraugli_distance; + if (!shared.frame_header.loop_filter.gab) { + butteraugli_distance_for_iqf *= 0.73f; + } + enc_state->initial_quant_field = InitialQuantField( + butteraugli_distance_for_iqf, *opsin, shared.frame_dim, pool, 1.0f, + &enc_state->initial_quant_masking); + } + + // TODO(veluca): do something about animations. + + // Apply inverse-gaborish. + if (shared.frame_header.loop_filter.gab) { + GaborishInverse(opsin, 0.9908511000000001f, pool); + } + + cfl_heuristics.Init(*opsin); + acs_heuristics.Init(*opsin, enc_state); + + auto process_tile = [&](size_t tid, size_t thread) { + size_t n_enc_tiles = + DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, + enc_state->shared.frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, + enc_state->shared.frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + + // For speeds up to Wombat, we only compute the color correlation map + // once we know the transform type and the quantization map. + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + cfl_heuristics.ComputeTile(r, *opsin, enc_state->shared.matrices, + /*ac_strategy=*/nullptr, + /*quantizer=*/nullptr, /*fast=*/false, thread, + &enc_state->shared.cmap); + } + + // Choose block sizes. + acs_heuristics.ProcessRect(r); + + // Choose amount of post-processing smoothing. + // TODO(veluca): should this go *after* AdjustQuantField? + ar_heuristics.RunRect(r, *opsin, enc_state, thread); + + // Always set the initial quant field, so we can compute the CfL map with + // more accuracy. The initial quant field might change in slower modes, but + // adjusting the quant field with butteraugli when all the other encoding + // parameters are fixed is likely a more reliable choice anyway. + AdjustQuantField(enc_state->shared.ac_strategy, r, + &enc_state->initial_quant_field); + quantizer.SetQuantFieldRect(enc_state->initial_quant_field, r, + &enc_state->shared.raw_quant_field); + + // Compute a non-default CfL map if we are at Hare speed, or slower. + if (cparams.speed_tier <= SpeedTier::kHare) { + cfl_heuristics.ComputeTile( + r, *opsin, enc_state->shared.matrices, &enc_state->shared.ac_strategy, + &enc_state->shared.quantizer, + /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, + &enc_state->shared.cmap); + } + }; + RunOnPool( + pool, 0, + DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(enc_state->shared.frame_dim.ysize_blocks, + kEncTileDimInBlocks), + [&](const size_t num_threads) { + ar_heuristics.PrepareForThreads(num_threads); + cfl_heuristics.PrepareForThreads(num_threads); + return true; + }, + process_tile, "Enc Heuristics"); + + acs_heuristics.Finalize(aux_out); + if (cparams.speed_tier <= SpeedTier::kHare) { + cfl_heuristics.ComputeDC(/*fast=*/cparams.speed_tier >= SpeedTier::kWombat, + &enc_state->shared.cmap); + } + + FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, + &enc_state->shared.matrices); + + // Refine quantization levels. + FindBestQuantizer(original_pixels, *opsin, enc_state, pool, aux_out); + + // Choose a context model that depends on the amount of quantization for AC. + if (cparams.speed_tier != SpeedTier::kFalcon) { + FindBestBlockEntropyModel(*enc_state); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_heuristics.h b/third_party/jpeg-xl/lib/jxl/enc_heuristics.h new file mode 100644 index 000000000000..d65fb9b5187b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_heuristics.h @@ -0,0 +1,96 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_HEURISTICS_H_ +#define LIB_JXL_ENC_HEURISTICS_H_ + +// Hook for custom encoder heuristics (VarDCT only for now). + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/enc_ma.h" + +namespace jxl { + +struct PassesEncoderState; +class ImageBundle; +class ModularFrameEncoder; + +class EncoderHeuristics { + public: + virtual ~EncoderHeuristics() = default; + // Initializes encoder structures in `enc_state` using the original image data + // in `original_pixels`, and the XYB image data in `opsin`. Also modifies the + // `opsin` image by applying Gaborish, and doing other modifications if + // necessary. `pool` is used for running the computations on multiple threads. + // `aux_out` collects statistics and can be used to print debug images. + virtual Status LossyFrameHeuristics( + PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, Image3F* opsin, ThreadPool* pool, + AuxOut* aux_out) = 0; + + // Custom fixed tree for lossless mode. Must set `tree` to a valid tree if + // the function returns true. + virtual bool CustomFixedTreeLossless(const FrameDimensions& frame_dim, + Tree* tree) { + return false; + } + + // If this method returns `true`, the `opsin` parameter to + // LossyFrameHeuristics will not be initialized, and should be initialized + // during the call. Moreover, `original_pixels` may not be in a linear + // colorspace (but will be the same as the `ib` value passed to this + // function). + virtual bool HandlesColorConversion(const CompressParams& cparams, + const ImageBundle& ib) { + return false; + } +}; + +class DefaultEncoderHeuristics : public EncoderHeuristics { + public: + Status LossyFrameHeuristics(PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* original_pixels, + Image3F* opsin, ThreadPool* pool, + AuxOut* aux_out) override; + bool HandlesColorConversion(const CompressParams& cparams, + const ImageBundle& ib) override; +}; + +class FastEncoderHeuristics : public EncoderHeuristics { + public: + Status LossyFrameHeuristics(PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const ImageBundle* linear, Image3F* opsin, + ThreadPool* pool, AuxOut* aux_out) override; +}; + +// Exposed here since it may be used by other EncoderHeuristics implementations +// outside this project. +void FindBestDequantMatrices(const CompressParams& cparams, + const Image3F& opsin, + ModularFrameEncoder* modular_frame_encoder, + DequantMatrices* dequant_matrices); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HEURISTICS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman.cc b/third_party/jpeg-xl/lib/jxl/enc_huffman.cc new file mode 100644 index 000000000000..2910445ee50b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman.cc @@ -0,0 +1,223 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_huffman.h" + +#include +#include + +#include "lib/jxl/huffman_tree.h" + +namespace jxl { + +namespace { + +constexpr int kCodeLengthCodes = 18; + +void StoreHuffmanTreeOfHuffmanTreeToBitMask(const int num_codes, + const uint8_t* code_length_bitdepth, + BitWriter* writer) { + static const uint8_t kStorageOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + // The bit lengths of the Huffman code over the code length alphabet + // are compressed with the following static Huffman code: + // Symbol Code + // ------ ---- + // 0 00 + // 1 1110 + // 2 110 + // 3 01 + // 4 10 + // 5 1111 + static const uint8_t kHuffmanBitLengthHuffmanCodeSymbols[6] = {0, 7, 3, + 2, 1, 15}; + static const uint8_t kHuffmanBitLengthHuffmanCodeBitLengths[6] = {2, 4, 3, + 2, 2, 4}; + + // Throw away trailing zeros: + size_t codes_to_store = kCodeLengthCodes; + if (num_codes > 1) { + for (; codes_to_store > 0; --codes_to_store) { + if (code_length_bitdepth[kStorageOrder[codes_to_store - 1]] != 0) { + break; + } + } + } + size_t skip_some = 0; // skips none. + if (code_length_bitdepth[kStorageOrder[0]] == 0 && + code_length_bitdepth[kStorageOrder[1]] == 0) { + skip_some = 2; // skips two. + if (code_length_bitdepth[kStorageOrder[2]] == 0) { + skip_some = 3; // skips three. + } + } + writer->Write(2, skip_some); + for (size_t i = skip_some; i < codes_to_store; ++i) { + size_t l = code_length_bitdepth[kStorageOrder[i]]; + writer->Write(kHuffmanBitLengthHuffmanCodeBitLengths[l], + kHuffmanBitLengthHuffmanCodeSymbols[l]); + } +} + +void StoreHuffmanTreeToBitMask(const size_t huffman_tree_size, + const uint8_t* huffman_tree, + const uint8_t* huffman_tree_extra_bits, + const uint8_t* code_length_bitdepth, + const uint16_t* code_length_bitdepth_symbols, + BitWriter* writer) { + for (size_t i = 0; i < huffman_tree_size; ++i) { + size_t ix = huffman_tree[i]; + writer->Write(code_length_bitdepth[ix], code_length_bitdepth_symbols[ix]); + // Extra bits + switch (ix) { + case 16: + writer->Write(2, huffman_tree_extra_bits[i]); + break; + case 17: + writer->Write(3, huffman_tree_extra_bits[i]); + break; + } + } +} + +void StoreSimpleHuffmanTree(const uint8_t* depths, size_t symbols[4], + size_t num_symbols, size_t max_bits, + BitWriter* writer) { + // value of 1 indicates a simple Huffman code + writer->Write(2, 1); + writer->Write(2, num_symbols - 1); // NSYM - 1 + + // Sort + for (size_t i = 0; i < num_symbols; i++) { + for (size_t j = i + 1; j < num_symbols; j++) { + if (depths[symbols[j]] < depths[symbols[i]]) { + std::swap(symbols[j], symbols[i]); + } + } + } + + if (num_symbols == 2) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + } else if (num_symbols == 3) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + } else { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + writer->Write(max_bits, symbols[3]); + // tree-select + writer->Write(1, depths[symbols[0]] == 1 ? 1 : 0); + } +} + +// num = alphabet size +// depths = symbol depths +void StoreHuffmanTree(const uint8_t* depths, size_t num, BitWriter* writer) { + // Write the Huffman tree into the compact representation. + std::unique_ptr arena(new uint8_t[2 * num]); + uint8_t* huffman_tree = arena.get(); + uint8_t* huffman_tree_extra_bits = arena.get() + num; + size_t huffman_tree_size = 0; + WriteHuffmanTree(depths, num, &huffman_tree_size, huffman_tree, + huffman_tree_extra_bits); + + // Calculate the statistics of the Huffman tree in the compact representation. + uint32_t huffman_tree_histogram[kCodeLengthCodes] = {0}; + for (size_t i = 0; i < huffman_tree_size; ++i) { + ++huffman_tree_histogram[huffman_tree[i]]; + } + + int num_codes = 0; + int code = 0; + for (int i = 0; i < kCodeLengthCodes; ++i) { + if (huffman_tree_histogram[i]) { + if (num_codes == 0) { + code = i; + num_codes = 1; + } else if (num_codes == 1) { + num_codes = 2; + break; + } + } + } + + // Calculate another Huffman tree to use for compressing both the + // earlier Huffman tree with. + uint8_t code_length_bitdepth[kCodeLengthCodes] = {0}; + uint16_t code_length_bitdepth_symbols[kCodeLengthCodes] = {0}; + CreateHuffmanTree(&huffman_tree_histogram[0], kCodeLengthCodes, 5, + &code_length_bitdepth[0]); + ConvertBitDepthsToSymbols(code_length_bitdepth, kCodeLengthCodes, + &code_length_bitdepth_symbols[0]); + + // Now, we have all the data, let's start storing it + StoreHuffmanTreeOfHuffmanTreeToBitMask(num_codes, code_length_bitdepth, + writer); + + if (num_codes == 1) { + code_length_bitdepth[code] = 0; + } + + // Store the real huffman tree now. + StoreHuffmanTreeToBitMask(huffman_tree_size, huffman_tree, + huffman_tree_extra_bits, &code_length_bitdepth[0], + code_length_bitdepth_symbols, writer); +} + +} // namespace + +void BuildAndStoreHuffmanTree(const uint32_t* histogram, const size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer) { + size_t count = 0; + size_t s4[4] = {0}; + for (size_t i = 0; i < length; i++) { + if (histogram[i]) { + if (count < 4) { + s4[count] = i; + } else if (count > 4) { + break; + } + count++; + } + } + + size_t max_bits_counter = length - 1; + size_t max_bits = 0; + while (max_bits_counter) { + max_bits_counter >>= 1; + ++max_bits; + } + + if (count <= 1) { + // Output symbol bits and depths are initialized with 0, nothing to do. + writer->Write(4, 1); + writer->Write(max_bits, s4[0]); + return; + } + + CreateHuffmanTree(histogram, length, 15, depth); + ConvertBitDepthsToSymbols(depth, length, bits); + + if (count <= 4) { + StoreSimpleHuffmanTree(depth, s4, count, max_bits, writer); + } else { + StoreHuffmanTree(depth, length, writer); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman.h b/third_party/jpeg-xl/lib/jxl/enc_huffman.h new file mode 100644 index 000000000000..6ffda5a89687 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_HUFFMAN_H_ +#define LIB_JXL_ENC_HUFFMAN_H_ + +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Builds a Huffman tree for the given histogram, and encodes it into writer +// in a format that can be read by HuffmanDecodingData::ReadFromBitstream. +// An allotment for `writer` must already have been created by the caller. +void BuildAndStoreHuffmanTree(const uint32_t* histogram, size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HUFFMAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc new file mode 100644 index 000000000000..80b5241594ae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc @@ -0,0 +1,438 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_icc_codec.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" + +namespace jxl { +namespace { + +bool EncodeVarInt(uint64_t value, size_t output_size, size_t* output_pos, + uint8_t* output) { + // While more than 7 bits of data are left, + // store 7 bits and set the next byte flag + while (value > 127) { + if (*output_pos > output_size) return false; + // |128: Set the next byte flag + output[(*output_pos)++] = ((uint8_t)(value & 127)) | 128; + // Remove the seven bits we just wrote + value >>= 7; + } + if (*output_pos > output_size) return false; + output[(*output_pos)++] = ((uint8_t)value) & 127; + return true; +} + +void EncodeVarInt(uint64_t value, PaddedBytes* data) { + size_t pos = data->size(); + data->resize(data->size() + 9); + JXL_CHECK(EncodeVarInt(value, data->size(), &pos, data->data())); + data->resize(pos); +} + +// Unshuffles or de-interleaves bytes, for example with width 2, turns +// "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into +// first all the high order bytes, then all the low order bytes. +// Transposes a matrix of width columns and ceil(size / width) rows. There are +// size elements, size may be < width * height, if so the +// last elements of the bottom row are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the bottom of the rightmost column. The input is the input matrix +// in scanline order, the output is the result matrix in scanline order, with +// missing elements skipped over (this may occur at multiple positions). +void Unshuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of input + PaddedBytes result(size); + // i = input index, j output index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[j] = data[i]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// This is performed by the encoder, the encoder must be able to encode any +// random byte stream (not just byte streams that are a valid ICC profile), so +// an error returned by this function is an implementation error. +Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, + const uint8_t* data, size_t size, size_t* pos, + PaddedBytes* result) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size)); + // Required by the specification, see decoder. stride * 4 must be < *pos. + if (!*pos || ((*pos - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + if (*pos < stride * 4) return JXL_FAILURE("Too large stride"); + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = + LinearPredictICCValue(data, *pos, i, stride, width, order); + result->push_back(data[*pos + i] - predicted); + } + *pos += num; + if (width > 1) Unshuffle(result->data() + start, num, width); + return true; +} +} // namespace + +// Outputs a transformed form of the given icc profile. The result itself is +// not particularly smaller than the input data in bytes, but it will be in a +// form that is easier to compress (more zeroes, ...) and will compress better +// with brotli. +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { + PaddedBytes commands; + PaddedBytes data; + + EncodeVarInt(size, result); + + // Header + PaddedBytes header = ICCInitialHeaderPrediction(); + EncodeUint32(0, size, &header); + for (size_t i = 0; i < kICCHeaderSize && i < size; i++) { + ICCPredictHeader(icc, size, header.data(), i); + data.push_back(icc[i] - header[i]); + } + if (size <= kICCHeaderSize) { + EncodeVarInt(0, result); // 0 commands + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + return true; + } + + std::vector tags; + std::vector tagstarts; + std::vector tagsizes; + std::map tagmap; + + // Tag list + size_t pos = kICCHeaderSize; + if (pos + 4 <= size) { + uint64_t numtags = DecodeUint32(icc, size, pos); + pos += 4; + EncodeVarInt(numtags + 1, &commands); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint32_t prevtagsize = 0; + for (size_t i = 0; i < numtags; i++) { + if (pos + 12 > size) break; + + Tag tag = DecodeKeyword(icc, size, pos + 0); + uint32_t tagstart = DecodeUint32(icc, size, pos + 4); + uint32_t tagsize = DecodeUint32(icc, size, pos + 8); + pos += 12; + + tags.push_back(tag); + tagstarts.push_back(tagstart); + tagsizes.push_back(tagsize); + tagmap[tagstart] = tags.size() - 1; + + uint8_t tagcode = kCommandTagUnknown; + for (size_t j = 0; j < kNumTagStrings; j++) { + if (tag == *kTagStrings[j]) { + tagcode = j + kCommandTagStringFirst; + break; + } + } + + if (tag == kRtrcTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag; + if (ok) { + for (size_t i = 0; i < 8; i++) { + if (icc[pos - 8 + i] != icc[pos + 4 + i]) ok = false; + if (icc[pos - 8 + i] != icc[pos + 16 + i]) ok = false; + } + } + if (ok) { + tagcode = kCommandTagTRC; + pos += 24; + i += 2; + } + } + + if (tag == kRxyzTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag; + uint32_t offsetr = tagstart; + uint32_t offsetg = DecodeUint32(icc, size, pos + 4); + uint32_t offsetb = DecodeUint32(icc, size, pos + 16); + uint32_t sizer = tagsize; + uint32_t sizeg = DecodeUint32(icc, size, pos + 8); + uint32_t sizeb = DecodeUint32(icc, size, pos + 20); + ok &= sizer == 20; + ok &= sizeg == 20; + ok &= sizeb == 20; + ok &= (offsetg == offsetr + 20); + ok &= (offsetb == offsetr + 40); + if (ok) { + tagcode = kCommandTagXYZ; + pos += 24; + i += 2; + } + } + + uint8_t command = tagcode; + uint64_t predicted_tagstart = prevtagstart + prevtagsize; + if (predicted_tagstart != tagstart) command |= kFlagBitOffset; + size_t predicted_tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + predicted_tagsize = 20; + } + if (predicted_tagsize != tagsize) command |= kFlagBitSize; + commands.push_back(command); + if (tagcode == 1) { + AppendKeyword(tag, &data); + } + if (command & kFlagBitOffset) EncodeVarInt(tagstart, &commands); + if (command & kFlagBitSize) EncodeVarInt(tagsize, &commands); + + prevtagstart = tagstart; + prevtagsize = tagsize; + } + } + // Indicate end of tag list or varint indicating there's none + commands.push_back(0); + + // Main content + // The main content in a valid ICC profile contains tagged elements, with the + // tag types (4 letter names) given by the tag list above, and the tag list + // pointing to the start and indicating the size of each tagged element. It is + // allowed for tagged elements to overlap, e.g. the curve for R, G and B could + // all point to the same one. + Tag tag; + size_t tagstart = 0, tagsize = 0, clutstart = 0; + + size_t last0 = pos; + // This loop appends commands to the output, processing some sub-section of a + // current tagged element each time. We need to keep track of the tagtype of + // the current element, and update it when we encounter the boundary of a + // next one. + // It is not required that the input data is a valid ICC profile, if the + // encoder does not recognize the data it will still be able to output bytes + // but will not predict as well. + while (pos <= size) { + size_t last1 = pos; + PaddedBytes commands_add; + PaddedBytes data_add; + + // This means the loop brought the position beyond the tag end. + if (pos > tagstart + tagsize) { + tag = {0, 0, 0, 0}; // nonsensical value + } + + if (commands_add.empty() && data_add.empty() && tagmap.count(pos) && + pos + 4 <= size) { + size_t index = tagmap[pos]; + tag = DecodeKeyword(icc, size, pos); + tagstart = tagstarts[index]; + tagsize = tagsizes[index]; + + if (tag == kMlucTag && pos + tagsize <= size && tagsize > 8 && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + commands_add.push_back(kCommandTypeStartFirst + 3); + pos += 8; + commands_add.push_back(kCommandShuffle2); + EncodeVarInt(num, &commands_add); + size_t start = data_add.size(); + for (size_t i = 0; i < num; i++) { + data_add.push_back(icc[pos]); + pos++; + } + Unshuffle(data_add.data() + start, num, 2); + } + + if (tag == kCurvTag && pos + tagsize <= size && tagsize > 8 && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) { + commands_add.push_back(kCommandTypeStartFirst + 5); + pos += 8; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (tag == kMab_Tag || tag == kMba_Tag) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) && + DecodeUint32(icc, size, pos + 4) == 0) { + uint32_t num = DecodeUint32(icc, size, pos + 8) * 2; + if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) { + pos += 12; + last1 = pos; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + + if (pos == tagstart + 24 && pos + 4 < size) { + // Note that this value can be remembered for next iterations of the + // loop, so the "pos == clutstart" if below can trigger during a later + // iteration. + clutstart = tagstart + DecodeUint32(icc, size, pos); + } + + if (pos == clutstart && clutstart + 16 < size) { + size_t numi = icc[tagstart + 8]; + size_t numo = icc[tagstart + 9]; + size_t width = icc[clutstart + 16]; + size_t stride = width * numo; + size_t num = width * numo; + for (size_t i = 0; i < numi && clutstart + i < size; i++) { + num *= icc[clutstart + i]; + } + if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) && + pos + num <= size && pos > stride * 4) { + commands_add.push_back(kCommandPredict); + int order = 1; + uint8_t flags = + (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag && + pos == tagstart + 8 && pos + tagsize - 8 <= size && pos > 16 && + tagsize > 8) { + size_t width = 4, order = 0, stride = width; + size_t num = tagsize - 8; + uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(kCommandPredict); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + + if (commands_add.empty() && data_add.empty() && pos + 20 <= size) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) { + commands_add.push_back(kCommandXYZ); + pos += 8; + for (size_t j = 0; j < 12; j++) data_add.push_back(icc[pos++]); + } + } + + if (commands_add.empty() && data_add.empty() && pos + 8 <= size) { + if (DecodeUint32(icc, size, pos + 4) == 0) { + Tag subTag = DecodeKeyword(icc, size, pos); + for (size_t i = 0; i < kNumTypeStrings; i++) { + if (subTag == *kTypeStrings[i]) { + commands_add.push_back(kCommandTypeStartFirst + i); + pos += 8; + break; + } + } + } + } + + if (!(commands_add.empty() && data_add.empty()) || pos == size) { + if (last0 < last1) { + commands.push_back(kCommandInsert); + EncodeVarInt(last1 - last0, &commands); + while (last0 < last1) { + data.push_back(icc[last0++]); + } + } + for (size_t i = 0; i < commands_add.size(); i++) { + commands.push_back(commands_add[i]); + } + for (size_t i = 0; i < data_add.size(); i++) { + data.push_back(data_add[i]); + } + last0 = pos; + } + if (commands_add.empty() && data_add.empty()) { + pos++; + } + } + + EncodeVarInt(commands.size(), result); + for (size_t i = 0; i < commands.size(); i++) { + result->push_back(commands[i]); + } + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + + return true; +} + +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out) { + if (icc.empty()) return JXL_FAILURE("ICC must be non-empty"); + PaddedBytes enc; + JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc)); + std::vector> tokens(1); + BitWriter::Allotment allotment(writer, 128); + JXL_RETURN_IF_ERROR(U64Coder::Write(enc.size(), writer)); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + + for (size_t i = 0; i < enc.size(); i++) { + tokens[0].emplace_back( + ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0), + enc[i]); + } + HistogramParams params; + params.lz77_method = HistogramParams::LZ77Method::kOptimal; + EntropyEncodingData code; + std::vector context_map; + params.force_huffman = true; + BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map, + writer, layer, aux_out); + WriteTokens(tokens[0], code, context_map, writer, layer, aux_out); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h new file mode 100644 index 000000000000..f449953697dd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_ICC_CODEC_H_ +#define LIB_JXL_ENC_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Should still be called if `icc.empty()` - if so, writes only 1 bit. +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ICC_CODEC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc new file mode 100644 index 000000000000..5455edb902c0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc @@ -0,0 +1,179 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_image_bundle.h" + +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_color_management.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +namespace { + +// Copies ib:rect, converts, and copies into out. +template +Status CopyToT(const ImageMetadata* metadata, const ImageBundle* ib, + const Rect& rect, const ColorEncoding& c_desired, + ThreadPool* pool, Image3* out) { + PROFILER_FUNC; + static_assert( + std::is_same::value || std::numeric_limits::min() == 0, + "CopyToT implemented only for float and unsigned types"); + ColorSpaceTransform c_transform; + // Changing IsGray is probably a bug. + JXL_CHECK(ib->IsGray() == c_desired.IsGray()); +#if JPEGXL_ENABLE_SKCMS + bool is_gray = false; +#else + bool is_gray = ib->IsGray(); +#endif + if (out->xsize() < rect.xsize() || out->ysize() < rect.ysize()) { + *out = Image3(rect.xsize(), rect.ysize()); + } else { + out->ShrinkTo(rect.xsize(), rect.ysize()); + } + RunOnPool( + pool, 0, rect.ysize(), + [&](size_t num_threads) { + return c_transform.Init(ib->c_current(), c_desired, + metadata->IntensityTarget(), rect.xsize(), + num_threads); + }, + [&](const int y, const int thread) { + float* mutable_src_buf = c_transform.BufSrc(thread); + const float* src_buf = mutable_src_buf; + // Interleave input. + if (is_gray) { + src_buf = rect.ConstPlaneRow(ib->color(), 0, y); + } else { + const float* JXL_RESTRICT row_in0 = + rect.ConstPlaneRow(ib->color(), 0, y); + const float* JXL_RESTRICT row_in1 = + rect.ConstPlaneRow(ib->color(), 1, y); + const float* JXL_RESTRICT row_in2 = + rect.ConstPlaneRow(ib->color(), 2, y); + for (size_t x = 0; x < rect.xsize(); x++) { + mutable_src_buf[3 * x + 0] = row_in0[x]; + mutable_src_buf[3 * x + 1] = row_in1[x]; + mutable_src_buf[3 * x + 2] = row_in2[x]; + } + } + float* JXL_RESTRICT dst_buf = c_transform.BufDst(thread); + DoColorSpaceTransform(&c_transform, thread, src_buf, dst_buf); + T* JXL_RESTRICT row_out0 = out->PlaneRow(0, y); + T* JXL_RESTRICT row_out1 = out->PlaneRow(1, y); + T* JXL_RESTRICT row_out2 = out->PlaneRow(2, y); + // De-interleave output and convert type. + if (std::is_same::value) { // deinterleave to float. + if (is_gray) { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[x]; + row_out1[x] = dst_buf[x]; + row_out2[x] = dst_buf[x]; + } + } else { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[3 * x + 0]; + row_out1[x] = dst_buf[3 * x + 1]; + row_out2[x] = dst_buf[3 * x + 2]; + } + } + } else { + // Convert to T, doing clamping. + float max = std::numeric_limits::max(); + auto cvt = [max](float in) { + float v = std::max(0.0f, std::min(max, in * max)); + return static_cast(v < 0 ? v - 0.5f : v + 0.5f); + }; + if (is_gray) { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = cvt(dst_buf[x]); + row_out1[x] = cvt(dst_buf[x]); + row_out2[x] = cvt(dst_buf[x]); + } + } else { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = cvt(dst_buf[3 * x + 0]); + row_out1[x] = cvt(dst_buf[3 * x + 1]); + row_out2[x] = cvt(dst_buf[3 * x + 2]); + } + } + } + }, + "Colorspace transform"); + return true; +} + +} // namespace + +Status ImageBundle::TransformTo(const ColorEncoding& c_desired, + ThreadPool* pool) { + PROFILER_FUNC; + JXL_RETURN_IF_ERROR(CopyTo(Rect(color_), c_desired, &color_, pool)); + c_current_ = c_desired; + return true; +} + +Status ImageBundle::CopyTo(const Rect& rect, const ColorEncoding& c_desired, + Image3B* out, ThreadPool* pool) const { + return CopyToT(metadata_, this, rect, c_desired, pool, out); +} +Status ImageBundle::CopyTo(const Rect& rect, const ColorEncoding& c_desired, + Image3F* out, ThreadPool* pool) const { + return CopyToT(metadata_, this, rect, c_desired, pool, out); +} + +Status ImageBundle::CopyToSRGB(const Rect& rect, Image3B* out, + ThreadPool* pool) const { + return CopyTo(rect, ColorEncoding::SRGB(IsGray()), out, pool); +} + +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + ThreadPool* pool, ImageBundle* store, + const ImageBundle** out) { + if (in.c_current().SameColorEncoding(c_desired)) { + *out = ∈ + return true; + } + // TODO(janwas): avoid copying via createExternal+copyBackToIO + // instead of copy+createExternal+copyBackToIO + store->SetFromImage(CopyImage(in.color()), in.c_current()); + + // Must at least copy the alpha channel for use by external_image. + if (in.HasExtraChannels()) { + std::vector extra_channels; + for (const ImageF& extra_channel : in.extra_channels()) { + extra_channels.emplace_back(CopyImage(extra_channel)); + } + store->SetExtraChannels(std::move(extra_channels)); + } + + if (!store->TransformTo(c_desired, pool)) { + return false; + } + *out = store; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h new file mode 100644 index 000000000000..5cad34443f5e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_IMAGE_BUNDLE_H_ +#define LIB_JXL_ENC_IMAGE_BUNDLE_H_ + +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Does color transformation from in.c_current() to c_desired if the color +// encodings are different, or nothing if they are already the same. +// If color transformation is done, stores the transformed values into store and +// sets the out pointer to store, else leaves store untouched and sets the out +// pointer to &in. +// Returns false if color transform fails. +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + ThreadPool* pool, ImageBundle* store, + const ImageBundle** out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_IMAGE_BUNDLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_modular.cc b/third_party/jpeg-xl/lib/jxl/enc_modular.cc new file mode 100644 index 000000000000..c029d97f7b6d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_modular.cc @@ -0,0 +1,1631 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_modular.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/gaborish.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Squeeze default quantization factors +// these quantization factors are for -Q 50 (other qualities simply scale the +// factors; things are rounded down and obviously cannot get below 1) +static const float squeeze_quality_factor = + 0.35; // for easy tweaking of the quality range (decrease this number for + // higher quality) +static const float squeeze_luma_factor = + 1.1; // for easy tweaking of the balance between luma (or anything + // non-chroma) and chroma (decrease this number for higher quality + // luma) +static const float squeeze_quality_factor_xyb = 2.4f; +static const float squeeze_xyb_qtable[3][16] = { + {163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, 0.64, 0.32, 0.16, + 0.08, 0.04, 0.02, 0.01, 0.005}, // Y + {1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, + 0.5}, // X + {2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, + 0.5}, // B-Y +}; + +static const float squeeze_luma_qtable[16] = { + 163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, + 0.64, 0.32, 0.16, 0.08, 0.04, 0.02, 0.01, 0.005}; +// for 8-bit input, the range of YCoCg chroma is -255..255 so basically this +// does 4:2:0 subsampling (two most fine grained layers get quantized away) +static const float squeeze_chroma_qtable[16] = { + 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, 0.5}; + +// `cutoffs` must be sorted. +Tree MakeFixedTree(int property, const std::vector& cutoffs, + Predictor pred, size_t num_pixels) { + size_t log_px = CeilLog2Nonzero(num_pixels); + size_t min_gap = 0; + // Reduce fixed tree height when encoding small images. + if (log_px < 14) { + min_gap = 8 * (14 - log_px); + } + Tree tree; + struct NodeInfo { + size_t begin, end, pos; + }; + std::queue q; + // Leaf IDs will be set by roundtrip decoding the tree. + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{0, cutoffs.size(), 0}); + while (!q.empty()) { + NodeInfo info = q.front(); + q.pop(); + if (info.begin + min_gap >= info.end) continue; + uint32_t split = (info.begin + info.end) / 2; + tree[info.pos] = + PropertyDecisionNode::Split(property, cutoffs[split], tree.size()); + q.push(NodeInfo{split + 1, info.end, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{info.begin, split, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + } + return tree; +} + +Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels) { + if (tree_kind == ModularOptions::TreeKind::kJpegTranscodeACMeta) { + // All the data is 0, so no need for a fancy tree. + return {PropertyDecisionNode::Leaf(Predictor::Zero)}; + } + if (tree_kind == ModularOptions::TreeKind::kFalconACMeta) { + // All the data is 0 except the quant field. TODO(veluca): make that 0 too. + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + if (tree_kind == ModularOptions::TreeKind::kACMeta) { + // Small image. + if (total_pixels < 1024) { + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + Tree tree; + // 0: c > 1 + tree.push_back(PropertyDecisionNode::Split(0, 1, 1)); + // 1: c > 2 + tree.push_back(PropertyDecisionNode::Split(0, 2, 3)); + // 2: c > 0 + tree.push_back(PropertyDecisionNode::Split(0, 0, 5)); + // 3: EPF control field (all 0 or 4), top > 0 + tree.push_back(PropertyDecisionNode::Split(6, 0, 21)); + // 4: ACS+QF, y > 0 + tree.push_back(PropertyDecisionNode::Split(2, 0, 7)); + // 5: CfL x + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 6: CfL b + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 7: QF: split according to the left quant value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 9)); + // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large + // rectangular 6-11, 8x8 12+), according to previous ACS value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 15)); + // QF + tree.push_back(PropertyDecisionNode::Split(7, 11, 11)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 13)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + // ACS + tree.push_back(PropertyDecisionNode::Split(7, 11, 17)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 19)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + // EPF, left > 0 + tree.push_back(PropertyDecisionNode::Split(7, 0, 23)); + tree.push_back(PropertyDecisionNode::Split(7, 0, 25)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + return tree; + } + if (tree_kind == ModularOptions::TreeKind::kWPFixedDC) { + std::vector cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kNumNonrefProperties - weighted::kNumProperties, + cutoffs, Predictor::Weighted, total_pixels); + } + if (tree_kind == ModularOptions::TreeKind::kGradientFixedDC) { + std::vector cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kGradientProp, cutoffs, Predictor::Gradient, + total_pixels); + } + JXL_ABORT("Unreachable"); + return {}; +} + +// Merges the trees in `trees` using nodes that decide on stream_id, as defined +// by `tree_splits`. +void MergeTrees(const std::vector& trees, + const std::vector& tree_splits, size_t begin, + size_t end, Tree* tree) { + JXL_ASSERT(trees.size() + 1 == tree_splits.size()); + JXL_ASSERT(end > begin); + JXL_ASSERT(end <= trees.size()); + if (end == begin + 1) { + // Insert the tree, adding the opportune offset to all child nodes. + // This will make the leaf IDs wrong, but subsequent roundtripping will fix + // them. + size_t sz = tree->size(); + tree->insert(tree->end(), trees[begin].begin(), trees[begin].end()); + for (size_t i = sz; i < tree->size(); i++) { + (*tree)[i].lchild += sz; + (*tree)[i].rchild += sz; + } + return; + } + size_t mid = (begin + end) / 2; + size_t splitval = tree_splits[mid] - 1; + size_t cur = tree->size(); + tree->emplace_back(1 /*stream_id*/, splitval, 0, 0, Predictor::Zero, 0, 1); + (*tree)[cur].lchild = tree->size(); + MergeTrees(trees, tree_splits, mid, end, tree); + (*tree)[cur].rchild = tree->size(); + MergeTrees(trees, tree_splits, begin, mid, tree); +} + +void QuantizeChannel(Channel& ch, const int q) { + if (q == 1) return; + for (size_t y = 0; y < ch.plane.ysize(); y++) { + pixel_type* row = ch.plane.Row(y); + for (size_t x = 0; x < ch.plane.xsize(); x++) { + if (row[x] < 0) { + row[x] = -((-row[x] + q / 2) / q) * q; + } else { + row[x] = ((row[x] + q / 2) / q) * q; + } + } + } +} + +// convert binary32 float that corresponds to custom [bits]-bit float (with +// [exp_bits] exponent bits) to a [bits]-bit integer representation that should +// fit in pixel_type +Status float_to_int(const float* const row_in, pixel_type* const row_out, + size_t xsize, unsigned int bits, unsigned int exp_bits, + bool fp, float factor) { + JXL_ASSERT(sizeof(pixel_type) * 8 >= bits); + if (!fp) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + 0.5f; + } + return true; + } + if (bits == 32 && fp) { + JXL_ASSERT(exp_bits == 8); + memcpy((void*)row_out, (const void*)row_in, 4 * xsize); + return true; + } + + int exp_bias = (1 << (exp_bits - 1)) - 1; + int max_exp = (1 << exp_bits) - 1; + uint32_t sign = (1u << (bits - 1)); + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> 31); + f &= 0x7fffffff; + if (f == 0) { + row_out[x] = (signbit ? sign : 0); + continue; + } + int exp = (f >> 23) - 127; + if (exp == 128) return JXL_FAILURE("Inf/NaN not allowed"); + int mantissa = (f & 0x007fffff); + // broke up the binary32 into its parts, now reassemble into + // arbitrary float + exp += exp_bias; + if (exp < 0) { // will become a subnormal number + // add implicit leading 1 to mantissa + mantissa |= 0x00800000; + if (exp < -mant_bits) { + return JXL_FAILURE( + "Invalid float number: %g cannot be represented with %i " + "exp_bits and %i mant_bits (exp %i)", + row_in[x], exp_bits, mant_bits, exp); + } + mantissa >>= 1 - exp; + exp = 0; + } + // exp should be representable in exp_bits, otherwise input was + // invalid + if (exp > max_exp) return JXL_FAILURE("Invalid float exponent"); + if (mantissa & ((1 << mant_shift) - 1)) { + return JXL_FAILURE("%g is losing precision (mant: %x)", row_in[x], + mantissa); + } + mantissa >>= mant_shift; + f = (signbit ? sign : 0); + f |= (exp << mant_bits); + f |= mantissa; + row_out[x] = (pixel_type)f; + } + return true; +} +} // namespace + +ModularFrameEncoder::ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig) + : frame_dim(frame_header.ToFrameDimensions()), cparams(cparams_orig) { + size_t num_streams = + ModularStreamId::Num(frame_dim, frame_header.passes.num_passes); + stream_images.resize(num_streams); + if (cquality > 100) cquality = quality; + + // use a sensible default if nothing explicit is specified: + // Squeeze for lossy, no squeeze for lossless + if (cparams.responsive < 0) { + if (quality == 100) { + cparams.responsive = 0; + } else { + cparams.responsive = 1; + } + } + + if (cparams.speed_tier > SpeedTier::kWombat) { + cparams.options.splitting_heuristics_node_threshold = 192; + } else { + cparams.options.splitting_heuristics_node_threshold = 96; + } + { + // Set properties. + std::vector prop_order; + if (cparams.responsive) { + // Properties in order of their likelyhood of being useful for Squeeze + // residuals. + prop_order = {0, 1, 4, 5, 6, 7, 8, 15, 9, 10, 11, 12, 13, 14, 2, 3}; + } else { + // Same, but for the non-Squeeze case. + prop_order = {0, 1, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8}; + } + switch (cparams.speed_tier) { + case SpeedTier::kSquirrel: + cparams.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 8); + cparams.options.max_property_values = 32; + break; + case SpeedTier::kKitten: + cparams.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 10); + cparams.options.max_property_values = 64; + break; + case SpeedTier::kTortoise: + cparams.options.splitting_heuristics_properties = prop_order; + cparams.options.max_property_values = 256; + break; + default: + cparams.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 6); + cparams.options.max_property_values = 16; + break; + } + if (cparams.speed_tier > SpeedTier::kTortoise) { + // Gradient in previous channels. + for (int i = 0; i < cparams.options.max_properties; i++) { + cparams.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i * 4 + 3); + } + } else { + // All the extra properties in Tortoise mode. + for (int i = 0; i < cparams.options.max_properties * 4; i++) { + cparams.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i); + } + } + } + + if (cparams.options.predictor == static_cast(-1)) { + // no explicit predictor(s) given, set a good default + if ((cparams.speed_tier <= SpeedTier::kTortoise || + cparams.modular_mode == false) && + quality == 100 && cparams.near_lossless == false && + cparams.responsive == false) { + // TODO(veluca): allow all predictors that don't break residual + // multipliers in lossy mode. + cparams.options.predictor = Predictor::Variable; + } else if (cparams.near_lossless) { + // weighted predictor for near_lossless + cparams.options.predictor = Predictor::Weighted; + } else if (cparams.responsive) { + // zero predictor for Squeeze residues + cparams.options.predictor = Predictor::Zero; + } else if (quality < 100) { + // If not responsive and lossy. TODO(veluca): use near_lossless instead? + cparams.options.predictor = Predictor::Gradient; + } else if (cparams.speed_tier < SpeedTier::kFalcon) { + // try median and weighted predictor for anything else + cparams.options.predictor = Predictor::Best; + } else { + // just weighted predictor in fastest mode + cparams.options.predictor = Predictor::Weighted; + } + } + tree_splits.push_back(0); + if (cparams.modular_mode == false) { + cparams.options.fast_decode_multiplier = 1.0f; + tree_splits.push_back(ModularStreamId::VarDCTDC(0).ID(frame_dim)); + tree_splits.push_back(ModularStreamId::ModularDC(0).ID(frame_dim)); + tree_splits.push_back(ModularStreamId::ACMetadata(0).ID(frame_dim)); + tree_splits.push_back(ModularStreamId::QuantTable(0).ID(frame_dim)); + tree_splits.push_back(ModularStreamId::ModularAC(0, 0).ID(frame_dim)); + ac_metadata_size.resize(frame_dim.num_dc_groups); + extra_dc_precision.resize(frame_dim.num_dc_groups); + } + tree_splits.push_back(num_streams); + cparams.options.max_chan_size = frame_dim.group_dim; + + // TODO(veluca): figure out how to use different predictor sets per channel. + stream_options.resize(num_streams, cparams.options); +} + +Status ModularFrameEncoder::ComputeEncodingData( + const FrameHeader& frame_header, const ImageBundle& ib, + Image3F* JXL_RESTRICT color, PassesEncoderState* JXL_RESTRICT enc_state, + ThreadPool* pool, AuxOut* aux_out, bool do_color) { + const FrameDimensions& frame_dim = enc_state->shared.frame_dim; + + if (do_color && frame_header.loop_filter.gab) { + GaborishInverse(color, 0.9908511000000001f, pool); + } + + if (do_color && cparams.speed_tier < SpeedTier::kCheetah) { + FindBestPatchDictionary(*color, enc_state, nullptr, nullptr, + cparams.color_transform == ColorTransform::kXYB); + PatchDictionaryEncoder::SubtractFrom( + enc_state->shared.image_features.patches, color); + } + + // Convert ImageBundle to modular Image object + const size_t xsize = std::min(color->xsize(), ib.xsize()); + const size_t ysize = std::min(color->ysize(), ib.ysize()); + + int nb_chans = 3; + if (ib.IsGray() && cparams.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + if (!do_color) nb_chans = 0; + + if (ib.HasExtraChannels()) { + nb_chans += ib.extra_channels().size(); + } + + bool fp = ib.metadata()->bit_depth.floating_point_sample; + + // bits_per_sample is just metadata for XYB images. + if (ib.metadata()->bit_depth.bits_per_sample >= 32 && do_color && + cparams.color_transform != ColorTransform::kXYB) { + if (ib.metadata()->bit_depth.bits_per_sample == 32 && fp == false) { + return JXL_FAILURE("uint32_t not supported in enc_modular"); + } else if (ib.metadata()->bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + + int maxval = (fp ? 1 + : (1u << static_cast( + ib.metadata()->bit_depth.bits_per_sample)) - + 1); + + Image& gi = stream_images[0]; + gi = Image(xsize, ysize, maxval, nb_chans); + int c = 0; + if (cparams.color_transform == ColorTransform::kXYB && + cparams.modular_mode == true) { + static const float enc_factors[3] = {32768.0f, 2048.0f, 2048.0f}; + DequantMatricesSetCustomDC(&enc_state->shared.matrices, enc_factors); + } + if (do_color) { + for (; c < 3; c++) { + if (ib.IsGray() && cparams.color_transform == ColorTransform::kNone && + c != (cparams.color_transform == ColorTransform::kXYB ? 1 : 0)) + continue; + int c_out = c; + // XYB is encoded as YX(B-Y) + if (cparams.color_transform == ColorTransform::kXYB && c < 2) + c_out = 1 - c_out; + float factor = maxval; + if (cparams.color_transform == ColorTransform::kXYB) + factor = enc_state->shared.matrices.InvDCQuant(c); + if (c == 2 && cparams.color_transform == ColorTransform::kXYB) { + JXL_ASSERT(!fp); + for (size_t y = 0; y < ysize; ++y) { + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + pixel_type* const JXL_RESTRICT row_Y = gi.channel[0].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + 0.5f; + row_out[x] -= row_Y[x]; + } + } + } else { + int bits = ib.metadata()->bit_depth.bits_per_sample; + int exp_bits = ib.metadata()->bit_depth.exponent_bits_per_sample; + for (size_t y = 0; y < ysize; ++y) { + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + JXL_RETURN_IF_ERROR( + float_to_int(row_in, row_out, xsize, bits, exp_bits, fp, factor)); + } + } + } + if (ib.IsGray() && cparams.color_transform == ColorTransform::kNone) c = 1; + } + if (ib.HasExtraChannels()) { + for (size_t ec = 0; ec < ib.extra_channels().size(); ec++, c++) { + const ExtraChannelInfo& eci = ib.metadata()->extra_channel_info[ec]; + gi.channel[c].resize(eci.Size(ib.xsize()), eci.Size(ib.ysize())); + gi.channel[c].hshift = gi.channel[c].vshift = eci.dim_shift; + + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + float factor = (fp ? 1 : ((1u << eci.bit_depth.bits_per_sample) - 1)); + for (size_t y = 0; y < ysize; ++y) { + const float* const JXL_RESTRICT row_in = ib.extra_channels()[ec].Row(y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c].Row(y); + JXL_RETURN_IF_ERROR( + float_to_int(row_in, row_out, xsize, bits, exp_bits, fp, factor)); + } + } + } + JXL_ASSERT(c == nb_chans); + + // Set options and apply transformations + + if (quality < 100 || cparams.near_lossless) { + if (cparams.palette_colors != 0) { + JXL_DEBUG_V(3, "Lossy encode, not doing palette transforms"); + } + if (cparams.color_transform == ColorTransform::kXYB) { + cparams.channel_colors_pre_transform_percent = 0; + } + cparams.channel_colors_percent = 0; + cparams.palette_colors = 0; + } + + // if few colors, do all-channel palette before trying channel palette + // Logic is as follows: + // - if you can make a palette with few colors (arbitrary threshold: 200), + // then you can also make channel palettes, but they will just be extra + // signaling cost for almost no benefit + // - if the palette needs more colors, then channel palette might help to + // reduce palette signaling cost + if (cparams.palette_colors != 0 && cparams.speed_tier < SpeedTier::kFalcon) { + // all-channel palette (e.g. RGBA) + if (gi.nb_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.nb_channels; + maybe_palette.nb_colors = + std::min(std::min(200, (int)(xsize * ysize / 8)), + std::abs(cparams.palette_colors) / 16); + maybe_palette.ordered_palette = cparams.palette_colors >= 0; + maybe_palette.lossy_palette = false; + gi.do_transform(maybe_palette, weighted::Header()); + } + } + + // Global channel palette + if (cparams.channel_colors_pre_transform_percent > 0 && + !cparams.lossy_palette) { + // single channel palette (like FLIF's ChannelCompact) + for (size_t i = 0; i < gi.nb_channels; i++) { + int min, max; + gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max); + int64_t colors = max - min + 1; + JXL_DEBUG_V(10, "Channel %zu: range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 6% the size of the + // image itself) + maybe_palette_1.nb_colors = std::min( + (int)(xsize * ysize / 16), + (int)(cparams.channel_colors_pre_transform_percent / 100. * colors)); + if (gi.do_transform(maybe_palette_1, weighted::Header())) { + // effective bit depth is lower, adjust quantization accordingly + gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max); + if (max < maxval) maxval = max; + } + } + } + + // Global palette + if ((cparams.palette_colors != 0 || cparams.lossy_palette) && + cparams.speed_tier < SpeedTier::kFalcon) { + // all-channel palette (e.g. RGBA) + if (gi.nb_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.nb_channels; + maybe_palette.nb_colors = + std::min((int)(xsize * ysize / 8), std::abs(cparams.palette_colors)); + maybe_palette.ordered_palette = cparams.palette_colors >= 0; + maybe_palette.lossy_palette = + (cparams.lossy_palette && gi.nb_channels == 3); + if (maybe_palette.lossy_palette) { + maybe_palette.predictor = Predictor::Average4; + } + // TODO(veluca): use a custom weighted header if using the weighted + // predictor. + gi.do_transform(maybe_palette, weighted::Header()); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.nb_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.nb_channels - 1; + maybe_palette_3.nb_colors = + std::min((int)(xsize * ysize / 8), std::abs(cparams.palette_colors)); + maybe_palette_3.ordered_palette = cparams.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = Predictor::Average4; + } + gi.do_transform(maybe_palette_3, weighted::Header()); + } + } + + if (cparams.color_transform == ColorTransform::kNone && do_color && !fp) { + if (cparams.colorspace == 1 || + (cparams.colorspace < 0 && (quality < 100 || cparams.near_lossless || + cparams.speed_tier > SpeedTier::kHare))) { + Transform ycocg{TransformId::kRCT}; + ycocg.rct_type = 6; + ycocg.begin_c = gi.nb_meta_channels; + gi.do_transform(ycocg, weighted::Header()); + } else if (cparams.colorspace >= 2) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + sg.rct_type = cparams.colorspace - 2; + gi.do_transform(sg, weighted::Header()); + } + } + + if (cparams.responsive && gi.nb_channels != 0) { + gi.do_transform(Transform(TransformId::kSqueeze), + weighted::Header()); // use default squeezing + } + + std::vector quants; + + if (quality < 100 || cquality < 100) { + quants.resize(gi.channel.size(), 1); + JXL_DEBUG_V( + 2, + "Adding quantization constants corresponding to luma quality %.2f " + "and chroma quality %.2f", + quality, cquality); + if (!cparams.responsive) { + JXL_DEBUG_V(1, + "Warning: lossy compression without Squeeze " + "transform is just color quantization."); + quality = (400 + quality) / 5; + cquality = (400 + cquality) / 5; + } + + // convert 'quality' to quantization scaling factor + if (quality > 50) { + quality = 200.0 - quality * 2.0; + } else { + quality = 900.0 - quality * 16.0; + } + if (cquality > 50) { + cquality = 200.0 - cquality * 2.0; + } else { + cquality = 900.0 - cquality * 16.0; + } + if (cparams.color_transform != ColorTransform::kXYB) { + quality *= 0.01f * maxval / 255.f; + cquality *= 0.01f * maxval / 255.f; + } else { + quality *= 0.01f; + cquality *= 0.01f; + } + + if (cparams.options.nb_repeats == 0) { + return JXL_FAILURE("nb_repeats = 0 not supported with modular lossy!"); + } + for (uint32_t i = gi.nb_meta_channels; i < gi.channel.size(); i++) { + Channel& ch = gi.channel[i]; + int shift = ch.hcshift + ch.vcshift; // number of pixel halvings + if (shift > 15) shift = 15; + int q; + // assuming default Squeeze here + int component = ((i - gi.nb_meta_channels) % gi.real_nb_channels); + // last 4 channels are final chroma residuals + if (gi.real_nb_channels > 2 && i >= gi.channel.size() - 4) { + component = 1; + } + + if (cparams.color_transform == ColorTransform::kXYB && component < 3) { + q = (component == 0 ? quality : cquality) * squeeze_quality_factor_xyb * + squeeze_xyb_qtable[component][shift]; + } else { + if (cparams.colorspace != 0 && component > 0 && component < 3) { + q = cquality * squeeze_quality_factor * squeeze_chroma_qtable[shift]; + } else { + q = quality * squeeze_quality_factor * squeeze_luma_factor * + squeeze_luma_qtable[shift]; + } + } + if (q < 1) q = 1; + QuantizeChannel(gi.channel[i], q); + quants[i] = q; + } + } + + // Fill other groups. + struct GroupParams { + Rect rect; + int minShift; + int maxShift; + ModularStreamId id; + }; + std::vector stream_params; + + stream_options[0] = cparams.options; + + // DC + for (size_t group_id = 0; group_id < frame_dim.num_dc_groups; group_id++) { + const size_t gx = group_id % frame_dim.xsize_dc_groups; + const size_t gy = group_id / frame_dim.xsize_dc_groups; + const Rect rect(gx * frame_dim.group_dim << 3, + gy * frame_dim.group_dim << 3, frame_dim.group_dim << 3, + frame_dim.group_dim << 3); + // minShift==3 because kDcGroupDim>>3 == frame_dim.group_dim + // maxShift==1000 is infinity + stream_params.push_back( + GroupParams{rect, 3, 1000, ModularStreamId::ModularDC(group_id)}); + } + // AC global -> nothing. + // AC + for (size_t group_id = 0; group_id < frame_dim.num_groups; group_id++) { + const size_t gx = group_id % frame_dim.xsize_groups; + const size_t gy = group_id / frame_dim.xsize_groups; + const Rect mrect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim); + for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); + i++) { + int maxShift, minShift; + frame_header.passes.GetDownsamplingBracket(i, minShift, maxShift); + stream_params.push_back(GroupParams{ + mrect, minShift, maxShift, ModularStreamId::ModularAC(group_id, i)}); + } + } + gi_channel.resize(stream_images.size()); + + RunOnPool( + pool, 0, stream_params.size(), ThreadPool::SkipInit(), + [&](size_t i, size_t _) { + stream_options[stream_params[i].id.ID(frame_dim)] = cparams.options; + JXL_CHECK(PrepareStreamParams( + stream_params[i].rect, cparams, stream_params[i].minShift, + stream_params[i].maxShift, stream_params[i].id, do_color)); + }, + "ChooseParams"); + { + // Clear out channels that have been copied to groups. + Image& full_image = stream_images[0]; + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + full_image.channel[c].plane = ImageI(); + } + } + + if (!quants.empty()) { + for (uint32_t stream_id = 0; stream_id < stream_images.size(); + stream_id++) { + // skip non-modular stream_ids + if (stream_id > 0 && gi_channel[stream_id].empty()) continue; + Image& image = stream_images[stream_id]; + const ModularOptions& options = stream_options[stream_id]; + for (uint32_t i = image.nb_meta_channels; i < image.channel.size(); i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + continue; + } + if (stream_id > 0 && gi_channel[stream_id].empty()) continue; + size_t ch_id = stream_id == 0 + ? i + : gi_channel[stream_id][i - image.nb_meta_channels]; + uint32_t q = quants[ch_id]; + // Inform the tree splitting heuristics that each channel in each group + // used this quantization factor. This will produce a tree with the + // given multipliers. + if (multiplier_info.empty() || + multiplier_info.back().range[1][0] != stream_id || + multiplier_info.back().multiplier != q) { + StaticPropRange range; + range[0] = {i, i + 1}; + range[1] = {stream_id, stream_id + 1}; + multiplier_info.push_back({range, (uint32_t)q}); + } else { + // Previous channel in the same group had the same quantization + // factor. Don't provide two different ranges, as that creates + // unnecessary nodes. + multiplier_info.back().range[0][1] = i + 1; + } + } + } + // Merge group+channel settings that have the same channels and quantization + // factors, to avoid unnecessary nodes. + std::sort(multiplier_info.begin(), multiplier_info.end(), + [](ModularMultiplierInfo a, ModularMultiplierInfo b) { + return std::make_tuple(a.range, a.multiplier) < + std::make_tuple(b.range, b.multiplier); + }); + size_t new_num = 1; + for (size_t i = 1; i < multiplier_info.size(); i++) { + ModularMultiplierInfo& prev = multiplier_info[new_num - 1]; + ModularMultiplierInfo& cur = multiplier_info[i]; + if (prev.range[0] == cur.range[0] && prev.multiplier == cur.multiplier && + prev.range[1][1] == cur.range[1][0]) { + prev.range[1][1] = cur.range[1][1]; + } else { + multiplier_info[new_num++] = multiplier_info[i]; + } + } + multiplier_info.resize(new_num); + } + + return PrepareEncoding(pool, enc_state->shared.frame_dim, + enc_state->heuristics.get(), aux_out); +} + +Status ModularFrameEncoder::PrepareEncoding(ThreadPool* pool, + const FrameDimensions& frame_dim, + EncoderHeuristics* heuristics, + AuxOut* aux_out) { + if (!tree.empty()) return true; + + // Compute tree. + size_t num_streams = stream_images.size(); + stream_headers.resize(num_streams); + tokens.resize(num_streams); + + if (heuristics->CustomFixedTreeLossless(frame_dim, &tree)) { + // Using a fixed tree. + } else if (cparams.speed_tier != SpeedTier::kFalcon || quality != 100 || + !cparams.modular_mode) { + // Avoid creating a tree with leaves that don't correspond to any pixels. + std::vector useful_splits; + useful_splits.reserve(tree_splits.size()); + for (size_t chunk = 0; chunk < tree_splits.size() - 1; chunk++) { + bool has_pixels = false; + size_t start = tree_splits[chunk]; + size_t stop = tree_splits[chunk + 1]; + for (size_t i = start; i < stop; i++) { + for (const Channel& c : stream_images[i].channel) { + if (c.w && c.h) has_pixels = true; + } + } + if (has_pixels) { + useful_splits.push_back(tree_splits[chunk]); + } + } + // Don't do anything if modular mode does not have any pixels in this image + if (useful_splits.empty()) return true; + useful_splits.push_back(tree_splits.back()); + + std::atomic_flag invalid_force_wp = ATOMIC_FLAG_INIT; + + std::vector trees(useful_splits.size() - 1); + RunOnPool( + pool, 0, useful_splits.size() - 1, ThreadPool::SkipInit(), + [&](size_t chunk, size_t _) { + // TODO(veluca): parallelize more. + size_t total_pixels = 0; + uint32_t start = useful_splits[chunk]; + uint32_t stop = useful_splits[chunk + 1]; + uint32_t max_c = 0; + if (stream_options[start].tree_kind != + ModularOptions::TreeKind::kLearn) { + for (size_t i = start; i < stop; i++) { + for (const Channel& ch : stream_images[i].channel) { + total_pixels += ch.w * ch.h; + } + } + trees[chunk] = + PredefinedTree(stream_options[start].tree_kind, total_pixels); + return; + } + TreeSamples tree_samples; + if (!tree_samples.SetPredictor(stream_options[start].predictor, + stream_options[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + if (!tree_samples.SetProperties( + stream_options[start].splitting_heuristics_properties, + stream_options[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + std::vector pixel_samples; + std::vector diff_samples; + std::vector group_pixel_count; + std::vector channel_pixel_count; + for (size_t i = start; i < stop; i++) { + max_c = std::max(stream_images[i].channel.size(), max_c); + CollectPixelSamples(stream_images[i], stream_options[i], i, + group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples); + } + StaticPropRange range; + range[0] = {0, max_c}; + range[1] = {start, stop}; + auto local_multiplier_info = multiplier_info; + + tree_samples.PreQuantizeProperties( + range, local_multiplier_info, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples, + stream_options[start].max_property_values); + for (size_t i = start; i < stop; i++) { + JXL_CHECK(ModularGenericCompress( + stream_images[i], stream_options[i], /*writer=*/nullptr, + /*aux_out=*/nullptr, 0, i, &tree_samples, &total_pixels)); + } + + // TODO(veluca): parallelize more. + trees[chunk] = + LearnTree(std::move(tree_samples), total_pixels, + stream_options[start], local_multiplier_info, range); + }, + "LearnTrees"); + if (invalid_force_wp.test_and_set(std::memory_order_acq_rel)) { + return JXL_FAILURE("PrepareEncoding: force_no_wp with {Weighted}"); + } + tree.clear(); + MergeTrees(trees, useful_splits, 0, useful_splits.size() - 1, &tree); + } else { + // Fixed tree. + // TODO(veluca): determine cutoffs? + std::vector cutoffs = {-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + size_t total_pixels = 0; + for (const Image& img : stream_images) { + for (const Channel& ch : img.channel) { + total_pixels += ch.w * ch.h; + } + } + tree = MakeFixedTree(kNumNonrefProperties - weighted::kNumProperties, + cutoffs, Predictor::Weighted, total_pixels); + } + // TODO(veluca): do this somewhere else. + if (cparams.near_lossless) { + for (size_t i = 0; i < tree.size(); i++) { + tree[i].predictor_offset = 0; + } + } + tree_tokens.resize(1); + tree_tokens[0].clear(); + Tree decoded_tree; + TokenizeTree(tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree.size() == decoded_tree.size()); + tree = std::move(decoded_tree); + + if (WantDebugOutput(aux_out)) { + PrintTree(tree, aux_out->debug_prefix + "/global_tree"); + } + + image_widths.resize(num_streams); + RunOnPool( + pool, 0, num_streams, ThreadPool::SkipInit(), + [&](size_t stream_id, size_t _) { + AuxOut my_aux_out; + if (aux_out) { + my_aux_out.dump_image = aux_out->dump_image; + my_aux_out.debug_prefix = aux_out->debug_prefix; + } + tokens[stream_id].clear(); + JXL_CHECK(ModularGenericCompress( + stream_images[stream_id], stream_options[stream_id], + /*writer=*/nullptr, &my_aux_out, 0, stream_id, + /*tree_samples=*/nullptr, + /*total_pixels=*/nullptr, + /*tree=*/&tree, /*header=*/&stream_headers[stream_id], + /*tokens=*/&tokens[stream_id], + /*widths=*/&image_widths[stream_id])); + }, + "ComputeTokens"); + return true; +} + +Status ModularFrameEncoder::EncodeGlobalInfo(BitWriter* writer, + AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 1); + // If we are using brotli, or not using modular mode. + if (tree_tokens.empty() || tree_tokens[0].empty()) { + writer->Write(1, 0); + ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out); + return true; + } + writer->Write(1, 1); + ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out); + + // Write tree + HistogramParams params; + if (cparams.speed_tier > SpeedTier::kKitten) { + params.clustering = HistogramParams::ClusteringType::kFast; + params.ans_histogram_strategy = + HistogramParams::ANSHistogramStrategy::kApproximate; + params.lz77_method = HistogramParams::LZ77Method::kNone; + // Near-lossless DC, as well as modular mode, require choosing hybrid uint + // more carefully. + if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || + (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) { + params.uint_method = HistogramParams::HybridUintMethod::kFast; + } else { + params.uint_method = HistogramParams::HybridUintMethod::kNone; + } + } else if (cparams.speed_tier <= SpeedTier::kTortoise) { + params.lz77_method = HistogramParams::LZ77Method::kOptimal; + } else { + params.lz77_method = HistogramParams::LZ77Method::kLZ77; + } + if (cparams.decoding_speed_tier >= 1) { + params.max_histograms = 12; + } + BuildAndEncodeHistograms(params, kNumTreeContexts, tree_tokens, &code, + &context_map, writer, kLayerModularTree, aux_out); + WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree, + aux_out); + params.image_widths = image_widths; + // Write histograms. + BuildAndEncodeHistograms(params, (tree.size() + 1) / 2, tokens, &code, + &context_map, writer, kLayerModularGlobal, aux_out); + return true; +} + +Status ModularFrameEncoder::EncodeStream(BitWriter* writer, AuxOut* aux_out, + size_t layer, + const ModularStreamId& stream) { + size_t stream_id = stream.ID(frame_dim); + if (stream_images[stream_id].real_nb_channels < 1) { + return true; // Image with no channels, header never gets decoded. + } + JXL_RETURN_IF_ERROR( + Bundle::Write(stream_headers[stream_id], writer, layer, aux_out)); + WriteTokens(tokens[stream_id], code, context_map, writer, layer, aux_out); + return true; +} + +namespace { +float EstimateWPCost(const Image& img, size_t i) { + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + int32_t cutoffs[] = {-500, -392, -255, -191, -127, -95, -63, -47, -31, + -23, -15, -11, -7, -4, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, 63, + 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + weighted::Header wp_header; + PredictorMode(i, &wp_header); + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, ch.w, ch.h); + Properties properties(1); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < ch.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + pixel_type guess = wp_state.Predict( + x, y, ch.w, top, left, topright, topleft, toptop, &properties, + offset); + size_t ctx = 0; + for (int c : cutoffs) { + ctx += c >= properties[0]; + } + pixel_type res = r[x] - guess; + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + wp_state.UpdateErrors(r[x], x, y, ch.w); + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +float EstimateCost(const Image& img) { + // TODO(veluca): consider SIMDfication of this code. + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + uint32_t cutoffs[] = {0, 1, 3, 5, 7, 11, 15, 23, 31, + 47, 63, 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + size_t maxdiff = std::max(std::max(left, top), topleft) - + std::min(std::min(left, top), topleft); + size_t ctx = 0; + for (uint32_t c : cutoffs) { + ctx += c > maxdiff; + } + pixel_type res = r[x] - ClampedGradient(top, left, topleft); + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +} // namespace + +Status ModularFrameEncoder::PrepareStreamParams(const Rect& rect, + const CompressParams& cparams, + int minShift, int maxShift, + const ModularStreamId& stream, + bool do_color) { + size_t stream_id = stream.ID(frame_dim); + JXL_ASSERT(stream_id != 0); + Image& full_image = stream_images[0]; + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + int maxval = full_image.maxval; + Image& gi = stream_images[stream_id]; + gi = Image(xsize, ysize, maxval, 0); + // start at the first bigger-than-frame_dim.group_dim non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + gi_channel[stream_id].push_back(c); + Channel gc(r.xsize(), r.ysize()); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + for (size_t y = 0; y < r.ysize(); ++y) { + const pixel_type* const JXL_RESTRICT row_in = r.ConstRow(fc.plane, y); + pixel_type* const JXL_RESTRICT row_out = gc.Row(y); + for (size_t x = 0; x < r.xsize(); ++x) { + row_out[x] = row_in[x]; + } + } + gi.channel.emplace_back(std::move(gc)); + } + gi.nb_channels = gi.channel.size(); + gi.real_nb_channels = gi.nb_channels; + + // Do some per-group transforms + + float quality = cparams.quality_pair.first; + + // Local palette + // TODO(veluca): make this work with quantize-after-prediction in lossy mode. + if (quality == 100 && cparams.palette_colors != 0 && + cparams.speed_tier < SpeedTier::kCheetah) { + // all-channel palette (e.g. RGBA) + if (gi.nb_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.nb_channels; + maybe_palette.nb_colors = std::abs(cparams.palette_colors); + maybe_palette.ordered_palette = cparams.palette_colors >= 0; + gi.do_transform(maybe_palette, weighted::Header()); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.nb_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.nb_channels - 1; + maybe_palette_3.nb_colors = std::abs(cparams.palette_colors); + maybe_palette_3.ordered_palette = cparams.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = Predictor::Weighted; + } + gi.do_transform(maybe_palette_3, weighted::Header()); + } + } + + // Local channel palette + if (cparams.channel_colors_percent > 0 && quality == 100 && + !cparams.lossy_palette && cparams.speed_tier < SpeedTier::kCheetah) { + // single channel palette (like FLIF's ChannelCompact) + for (size_t i = 0; i < gi.nb_channels; i++) { + int min, max; + gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max); + int colors = max - min + 1; + JXL_DEBUG_V(10, "Channel %zu: range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 80% the size of the + // image itself) + maybe_palette_1.nb_colors = + std::min((int)(xsize * ysize * 0.8), + (int)(cparams.channel_colors_percent / 100. * colors)); + gi.do_transform(maybe_palette_1, weighted::Header()); + } + } + if (cparams.near_lossless > 0 && gi.nb_channels != 0) { + Transform nl(TransformId::kNearLossless); + nl.predictor = cparams.options.predictor; + JXL_RETURN_IF_ERROR(nl.predictor != Predictor::Best); + JXL_RETURN_IF_ERROR(nl.predictor != Predictor::Variable); + nl.begin_c = gi.nb_meta_channels; + if (cparams.colorspace == 0) { + nl.num_c = gi.nb_channels; + nl.max_delta_error = cparams.near_lossless; + gi.do_transform(nl, weighted::Header()); + } else { + nl.num_c = 1; + nl.max_delta_error = cparams.near_lossless; + gi.do_transform(nl, weighted::Header()); + nl.begin_c += 1; + nl.num_c = gi.nb_channels - 1; + nl.max_delta_error++; // more loss for chroma + gi.do_transform(nl, weighted::Header()); + } + } + + // lossless and no specific color transform specified: try Nothing, YCoCg, + // and 17 RCTs + if (cparams.color_transform == ColorTransform::kNone && quality == 100 && + cparams.colorspace < 0 && gi.nb_channels > 2 && !cparams.near_lossless && + cparams.responsive == false && do_color && + cparams.speed_tier <= SpeedTier::kHare) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + + size_t nb_rcts_to_try = 0; + switch (cparams.speed_tier) { + case SpeedTier::kFalcon: + nb_rcts_to_try = 0; // Just do global YCoCg + break; + case SpeedTier::kCheetah: + nb_rcts_to_try = 0; // Just do global YCoCg + break; + case SpeedTier::kHare: + nb_rcts_to_try = 4; + break; + case SpeedTier::kWombat: + nb_rcts_to_try = 5; + break; + case SpeedTier::kSquirrel: + nb_rcts_to_try = 7; + break; + case SpeedTier::kKitten: + nb_rcts_to_try = 9; + break; + case SpeedTier::kTortoise: + nb_rcts_to_try = 19; + break; + } + float best_cost = std::numeric_limits::max(); + size_t best_rct = 0; + // These should be 19 actually different transforms; the remaining ones + // are equivalent to one of these (note that the first two are do-nothing + // and YCoCg) modulo channel reordering (which only matters in the case of + // MA-with-prev-channels-properties) and/or sign (e.g. RmG vs GmR) + for (int i : {0 * 7 + 0, 0 * 7 + 6, 0 * 7 + 5, 1 * 7 + 3, 3 * 7 + 5, + 5 * 7 + 5, 1 * 7 + 5, 2 * 7 + 5, 1 * 7 + 1, 0 * 7 + 4, + 1 * 7 + 2, 2 * 7 + 1, 2 * 7 + 2, 2 * 7 + 3, 4 * 7 + 4, + 4 * 7 + 5, 0 * 7 + 2, 0 * 7 + 1, 0 * 7 + 3}) { + if (nb_rcts_to_try == 0) break; + int num_transforms_to_keep = gi.transform.size(); + sg.rct_type = i; + gi.do_transform(sg, weighted::Header()); + float cost = EstimateCost(gi); + if (cost < best_cost) { + best_rct = i; + best_cost = cost; + } + nb_rcts_to_try--; + // Ensure we do not clamp channels to their supposed range, as this + // otherwise breaks in the presence of patches. + gi.undo_transforms(weighted::Header(), num_transforms_to_keep == 0 + ? -1 + : num_transforms_to_keep); + } + // Apply the best RCT to the image for future encoding. + sg.rct_type = best_rct; + gi.do_transform(sg, weighted::Header()); + } else { + // No need to try anything, just use the default options. + } + size_t nb_wp_modes = 0; + switch (cparams.speed_tier) { + case SpeedTier::kFalcon: + nb_wp_modes = 1; + break; + case SpeedTier::kCheetah: + nb_wp_modes = 1; + break; + case SpeedTier::kHare: + nb_wp_modes = 1; + break; + case SpeedTier::kWombat: + nb_wp_modes = 1; + break; + case SpeedTier::kSquirrel: + nb_wp_modes = 1; + break; + case SpeedTier::kKitten: + nb_wp_modes = 2; + break; + case SpeedTier::kTortoise: + nb_wp_modes = 5; + break; + } + if (nb_wp_modes > 1 && + (stream_options[stream_id].predictor == Predictor::Weighted || + stream_options[stream_id].predictor == Predictor::Best || + stream_options[stream_id].predictor == Predictor::Variable)) { + float best_cost = std::numeric_limits::max(); + stream_options[stream_id].wp_mode = 0; + for (size_t i = 0; i < nb_wp_modes; i++) { + float cost = EstimateWPCost(gi, i); + if (cost < best_cost) { + best_cost = cost; + stream_options[stream_id].wp_mode = i; + } + } + } + return true; +} + +int QuantizeWP(const int32_t* qrow, size_t onerow, size_t c, size_t x, size_t y, + size_t w, weighted::State* wp_state, float value, + float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeWP(w, qrow + x, onerow, x, y, Predictor::Weighted, wp_state); + svalue -= pred.guess; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +int QuantizeGradient(const int32_t* qrow, size_t onerow, size_t c, size_t x, + size_t y, size_t w, float value, float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeNoWP(w, qrow + x, onerow, x, y, Predictor::Gradient); + svalue -= pred.guess; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +void ModularFrameEncoder::AddVarDCTDC(const Image3F& dc, size_t group_index, + bool nl_dc, + PassesEncoderState* enc_state) { + const Rect r = enc_state->shared.DCGroupRect(group_index); + extra_dc_precision[group_index] = nl_dc ? 1 : 0; + float mul = 1 << extra_dc_precision[group_index]; + + size_t stream_id = ModularStreamId::VarDCTDC(group_index).ID(frame_dim); + stream_options[stream_id].max_chan_size = 0xFFFFFF; + stream_options[stream_id].predictor = Predictor::Weighted; + stream_options[stream_id].wp_tree_mode = ModularOptions::WPTreeMode::kWPOnly; + if (cparams.speed_tier >= SpeedTier::kSquirrel) { + stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kWPFixedDC; + } + if (cparams.decoding_speed_tier >= 1) { + stream_options[stream_id].tree_kind = + ModularOptions::TreeKind::kGradientFixedDC; + } + + stream_images[stream_id] = Image(r.xsize(), r.ysize(), 255, 3); + if (nl_dc && stream_options[stream_id].tree_kind == + ModularOptions::TreeKind::kGradientFixedDC) { + JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient(quant_row, stride, c, x, y, + r.xsize(), row[x], inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient( + quant_row, stride, c, x, y, r.xsize(), + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + } + } + } + } + } else if (nl_dc) { + JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + weighted::Header header; + weighted::State wp_state(header, r.xsize(), r.ysize()); + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP(quant_row, stride, c, x, y, r.xsize(), + &wp_state, row[x], inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } else { + int32_t* quant_row_y = + stream_images[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP( + quant_row, stride, c, x, y, r.xsize(), &wp_state, + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } + } + } + } else if (enc_state->shared.frame_header.chroma_subsampling.Is444()) { + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = + roundf((row[x] - quant_row_y[x] * (y_factor * cfl_factor)) * + inv_factor); + } + } + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect( + r.x0() >> enc_state->shared.frame_header.chroma_subsampling.HShift(c), + r.y0() >> enc_state->shared.frame_header.chroma_subsampling.VShift(c), + r.xsize() >> + enc_state->shared.frame_header.chroma_subsampling.HShift(c), + r.ysize() >> + enc_state->shared.frame_header.chroma_subsampling.VShift(c)); + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + size_t ys = rect.ysize(); + size_t xs = rect.xsize(); + Channel& ch = stream_images[stream_id].channel[c < 2 ? c ^ 1 : c]; + ch.w = xs; + ch.h = ys; + ch.resize(); + for (size_t y = 0; y < ys; y++) { + int32_t* quant_row = ch.plane.Row(y); + const float* row = rect.ConstPlaneRow(dc, c, y); + for (size_t x = 0; x < xs; x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } + } + } + + DequantDC(r, &enc_state->shared.dc_storage, &enc_state->shared.quant_dc, + stream_images[stream_id], enc_state->shared.quantizer.MulDC(), + 1.0 / mul, enc_state->shared.cmap.DCFactors(), + enc_state->shared.frame_header.chroma_subsampling, + enc_state->shared.block_ctx_map); +} + +void ModularFrameEncoder::AddACMetadata(size_t group_index, bool jpeg_transcode, + PassesEncoderState* enc_state) { + const Rect r = enc_state->shared.DCGroupRect(group_index); + size_t stream_id = ModularStreamId::ACMetadata(group_index).ID(frame_dim); + stream_options[stream_id].max_chan_size = 0xFFFFFF; + stream_options[stream_id].wp_tree_mode = ModularOptions::WPTreeMode::kNoWP; + if (jpeg_transcode) { + stream_options[stream_id].tree_kind = + ModularOptions::TreeKind::kJpegTranscodeACMeta; + } else if (cparams.speed_tier == SpeedTier::kFalcon) { + stream_options[stream_id].tree_kind = + ModularOptions::TreeKind::kFalconACMeta; + } else if (cparams.speed_tier > SpeedTier::kKitten) { + stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kACMeta; + } + // If we are using a non-constant CfL field, and are in a slow enough mode, + // re-enable tree computation for it. + if (cparams.speed_tier < SpeedTier::kSquirrel && + cparams.force_cfl_jpeg_recompression) { + stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kLearn; + } + // YToX, YToB, ACS + QF, EPF + Image& image = stream_images[stream_id]; + image = Image(r.xsize(), r.ysize(), 255, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(r.xsize() * r.ysize(), 2, 0, 0); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytox_map, + Rect(image.channel[0].plane), &image.channel[0].plane); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytob_map, + Rect(image.channel[1].plane), &image.channel[1].plane); + size_t num = 0; + for (size_t y = 0; y < r.ysize(); y++) { + AcStrategyRow row_acs = enc_state->shared.ac_strategy.ConstRow(r, y); + const int* row_qf = r.ConstRow(enc_state->shared.raw_quant_field, y); + const uint8_t* row_epf = r.ConstRow(enc_state->shared.epf_sharpness, y); + int* out_acs = image.channel[2].plane.Row(0); + int* out_qf = image.channel[2].plane.Row(1); + int* row_out_epf = image.channel[3].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + row_out_epf[x] = row_epf[x]; + if (!row_acs[x].IsFirstBlock()) continue; + out_acs[num] = row_acs[x].RawStrategy(); + out_qf[num] = row_qf[x] - 1; + num++; + } + } + image.channel[2].w = num; + image.channel[2].resize(); + ac_metadata_size[group_index] = num; +} + +void ModularFrameEncoder::EncodeQuantTable( + size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder) { + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + JXL_CHECK(F16Coder::Write(encoding.qraw.qtable_den, writer)); + if (modular_frame_encoder) { + JXL_CHECK(modular_frame_encoder->EncodeStream( + writer, nullptr, 0, ModularStreamId::QuantTable(idx))); + return; + } + Image image(size_x, size_y, 255, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } + ModularOptions cfopts; + JXL_CHECK(ModularGenericCompress(image, cfopts, writer)); +} + +void ModularFrameEncoder::AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, + size_t idx) { + size_t stream_id = ModularStreamId::QuantTable(idx).ID(frame_dim); + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + Image& image = stream_images[stream_id]; + image = Image(size_x, size_y, 255, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_modular.h b/third_party/jpeg-xl/lib/jxl/enc_modular.h new file mode 100644 index 000000000000..e447d26fa547 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_modular.h @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_MODULAR_H_ +#define LIB_JXL_ENC_MODULAR_H_ + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +class ModularFrameEncoder { + public: + ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig); + Status ComputeEncodingData(const FrameHeader& frame_header, + const ImageBundle& ib, Image3F* JXL_RESTRICT color, + PassesEncoderState* JXL_RESTRICT enc_state, + ThreadPool* pool, AuxOut* aux_out, bool do_color); + // Encodes global info (tree + histograms) in the `writer`. + Status EncodeGlobalInfo(BitWriter* writer, AuxOut* aux_out); + // Encodes a specific modular image (identified by `stream`) in the `writer`, + // assigning bits to the provided `layer`. + Status EncodeStream(BitWriter* writer, AuxOut* aux_out, size_t layer, + const ModularStreamId& stream); + // Creates a modular image for a given DC group of VarDCT mode. `dc` is the + // input DC image, not quantized; the group is specified by `group_index`, and + // `nl_dc` decides whether to apply a near-lossless processing to the DC or + // not. + void AddVarDCTDC(const Image3F& dc, size_t group_index, bool nl_dc, + PassesEncoderState* enc_state); + // Creates a modular image for the AC metadata of the given group + // (`group_index`). + void AddACMetadata(size_t group_index, bool jpeg_transcode, + PassesEncoderState* enc_state); + // Encodes a RAW quantization table in `writer`. If `modular_frame_encoder` is + // null, the quantization table in `encoding` is used, with dimensions `size_x + // x size_y`. Otherwise, the table with ID `idx` is encoded from the given + // `modular_frame_encoder`. + static void EncodeQuantTable(size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder); + // Stores a quantization table for future usage with `EncodeQuantTable`. + void AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, size_t idx); + + std::vector ac_metadata_size; + std::vector extra_dc_precision; + + private: + Status PrepareEncoding(ThreadPool* pool, const FrameDimensions& frame_dim, + EncoderHeuristics* heuristics, + AuxOut* aux_out = nullptr); + Status PrepareStreamParams(const Rect& rect, const CompressParams& cparams, + int minShift, int maxShift, + const ModularStreamId& stream, bool do_color); + std::vector stream_images; + std::vector stream_options; + + Tree tree; + std::vector> tree_tokens; + std::vector stream_headers; + std::vector> tokens; + EntropyEncodingData code; + std::vector context_map; + FrameDimensions frame_dim; + CompressParams cparams; + float quality = cparams.quality_pair.first; + float cquality = cparams.quality_pair.second; + std::vector tree_splits; + std::vector multiplier_info; + std::vector> gi_channel; + std::vector image_widths; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_MODULAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_noise.cc b/third_party/jpeg-xl/lib/jxl/enc_noise.cc new file mode 100644 index 000000000000..7eccda623808 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_noise.cc @@ -0,0 +1,424 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_noise.h" + +#include +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/robust_statistics.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/optimize.h" + +namespace jxl { +namespace { + +using OptimizeArray = optimize::Array; + +float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x, + const int y, const int block_size) { + const int small_bl_size_x = 3; + const int small_bl_size_y = 4; + const int kNumSAD = + (block_size - small_bl_size_x) * (block_size - small_bl_size_y); + // block_size x block_size reference pixels + int counter = 0; + const int offset = 2; + + std::vector sad(kNumSAD, 0); + for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) { + for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) { + float sad_sum = 0; + // size of the center patch, we compare all the patches inside window with + // the center one + for (int cy = 0; cy < small_bl_size_y; ++cy) { + for (int cx = 0; cx < small_bl_size_x; ++cx) { + float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] + + opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]); + float center = + 0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] + + opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]); + sad_sum += std::abs(center - wnd); + } + } + sad[counter++] = sad_sum; + } + } + const int kSamples = (kNumSAD) / 2; + // As with ROAD (rank order absolute distance), we keep the smallest half of + // the values in SAD (we use here the more robust patch SAD instead of + // absolute single-pixel differences). + std::sort(sad.begin(), sad.end()); + const float total_sad_sum = + std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f); + return total_sad_sum / kSamples; +} + +class NoiseHistogram { + public: + static constexpr int kBins = 256; + + NoiseHistogram() { std::fill(bins, bins + kBins, 0); } + + void Increment(const float x) { bins[Index(x)] += 1; } + int Get(const float x) const { return bins[Index(x)]; } + int Bin(const size_t bin) const { return bins[bin]; } + + void Print() const { + for (unsigned int bin : bins) { + printf("%d\n", bin); + } + } + + int Mode() const { + uint32_t cdf[kBins]; + std::partial_sum(bins, bins + kBins, cdf); + return HalfRangeMode()(cdf, kBins); + } + + double Quantile(double q01) const { + const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1}); + const int64_t target = static_cast(q01 * total); + // Until sum >= target: + int64_t sum = 0; + size_t i = 0; + for (; i < kBins; ++i) { + sum += bins[i]; + // Exact match: assume middle of bin i + if (sum == target) { + return i + 0.5; + } + if (sum > target) break; + } + + // Next non-empty bin (in case histogram is sparsely filled) + size_t next = i + 1; + while (next < kBins && bins[next] == 0) { + ++next; + } + + // Linear interpolation according to how far into next we went + const double excess = target - sum; + const double weight_next = bins[Index(next)] / excess; + return ClampX(next * weight_next + i * (1.0 - weight_next)); + } + + // Inter-quartile range + double IQR() const { return Quantile(0.75) - Quantile(0.25); } + + private: + template + T ClampX(const T x) const { + return std::min(std::max(T(0), x), T(kBins - 1)); + } + size_t Index(const float x) const { return ClampX(static_cast(x)); } + + uint32_t bins[kBins]; +}; + +std::vector GetSADScoresForPatches(const Image3F& opsin, + const size_t block_s, + const size_t num_bin, + NoiseHistogram* sad_histogram) { + std::vector sad_scores( + (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f); + + int block_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s); + sad_scores[block_index++] = sad_sc; + sad_histogram->Increment(sad_sc * num_bin); + } + } + return sad_scores; +} + +float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) { + // Here we assume that the most patches with similar SAD value is a "flat" + // patches. However, some images might contain regular texture part and + // generate second strong peak at the histogram + // TODO(user) handle bimodal and heavy-tailed case + const int mode = histogram.Mode(); + return static_cast(mode) / NoiseHistogram::kBins; +} + +// loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2 +// where asym = 1 if F(x) < nl, kAsym if F(x) > nl. +struct LossFunction { + explicit LossFunction(std::vector nl0) : nl(std::move(nl0)) {} + + double Compute(const OptimizeArray& w, OptimizeArray* df, + bool skip_regularization = false) const { + constexpr double kReg = 0.005; + constexpr double kAsym = 1.1; + double loss_function = 0; + for (size_t i = 0; i < w.size(); i++) { + (*df)[i] = 0; + } + for (auto ind : nl) { + std::pair pos = IndexAndFrac(ind.intensity); + JXL_DASSERT(pos.first >= 0 && static_cast(pos.first) < + NoiseParams::kNumNoisePoints - 1); + double low = w[pos.first]; + double hi = w[pos.first + 1]; + double val = low * (1.0f - pos.second) + hi * pos.second; + double dist = val - ind.noise_level; + if (dist > 0) { + loss_function += kAsym * dist * dist; + (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= kAsym * pos.second * dist; + } else { + loss_function += dist * dist; + (*df)[pos.first] -= (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= pos.second * dist; + } + } + if (skip_regularization) return loss_function; + for (size_t i = 0; i + 1 < w.size(); i++) { + double diff = w[i] - w[i + 1]; + loss_function += kReg * nl.size() * diff * diff; + (*df)[i] -= kReg * diff * nl.size(); + (*df)[i + 1] += kReg * diff * nl.size(); + } + return loss_function; + } + + std::vector nl; +}; + +void OptimizeNoiseParameters(const std::vector& noise_level, + NoiseParams* noise_params) { + constexpr double kMaxError = 1e-3; + static const double kPrecision = 1e-8; + static const int kMaxIter = 40; + + float avg = 0; + for (const NoiseLevel& nl : noise_level) { + avg += nl.noise_level; + } + avg /= noise_level.size(); + + LossFunction loss_function(noise_level); + OptimizeArray parameter_vector; + for (size_t i = 0; i < parameter_vector.size(); i++) { + parameter_vector[i] = avg; + } + + parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod( + loss_function, parameter_vector, kPrecision, kMaxIter); + + OptimizeArray df = parameter_vector; + float loss = loss_function.Compute(parameter_vector, &df, + /*skip_regularization=*/true) / + noise_level.size(); + + // Approximation went too badly: escape with no noise at all. + if (loss > kMaxError) { + noise_params->Clear(); + return; + } + + for (size_t i = 0; i < parameter_vector.size(); i++) { + noise_params->lut[i] = std::max(parameter_vector[i], 0.0); + } +} + +std::vector GetTextureStrength(const Image3F& opsin, + const size_t block_s) { + std::vector texture_strength_index((opsin.ysize() / block_s) * + (opsin.xsize() / block_s)); + size_t block_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + float texture_strength = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl + 1 < block_s; ++x_bl) { + float diff = opsin.PlaneRow(1, y)[x + x_bl + 1] - + opsin.PlaneRow(1, y)[x + x_bl]; + texture_strength += diff * diff; + } + } + for (size_t y_bl = 0; y_bl + 1 < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + float diff = opsin.PlaneRow(1, y + 1)[x + x_bl] - + opsin.PlaneRow(1, y)[x + x_bl]; + texture_strength += diff * diff; + } + } + texture_strength_index[block_index] = texture_strength; + ++block_index; + } + } + return texture_strength_index; +} + +float GetThresholdFlatIndices(const std::vector& texture_strength, + const int n_patches) { + std::vector kth_statistic = texture_strength; + std::stable_sort(kth_statistic.begin(), kth_statistic.end()); + return kth_statistic[n_patches]; +} + +std::vector GetNoiseLevel( + const Image3F& opsin, const std::vector& texture_strength, + const float threshold, const size_t block_s) { + std::vector noise_level_per_intensity; + + const int filt_size = 1; + static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = { + {-0.25f, -1.0f, -0.25f}, + {-1.0f, 5.0f, -1.0f}, + {-0.25f, -1.0f, -0.25f}, + }; + + // The noise model is built based on channel 0.5 * (X+Y) as we notice that it + // is similar to the model 0.5 * (Y-X) + size_t patch_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + if (texture_strength[patch_index] <= threshold) { + // Calculate mean value + float mean_int = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] + + opsin.PlaneRow(0, y + y_bl)[x + x_bl]); + } + } + mean_int /= block_s * block_s; + + // Calculate Noise level + float noise_level = 0; + size_t count = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + float filtered_value = 0; + for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) { + if ((static_cast(y_bl) + y_f) >= 0 && + (y_bl + y_f) < block_s) { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } else { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } + } + noise_level += std::abs(filtered_value); + ++count; + } + } + noise_level /= count; + NoiseLevel nl; + nl.intensity = mean_int; + nl.noise_level = noise_level; + noise_level_per_intensity.push_back(nl); + } + ++patch_index; + } + } + return noise_level_per_intensity; +} + +void EncodeFloatParam(float val, float precision, BitWriter* writer) { + JXL_ASSERT(val >= 0); + const int absval_quant = static_cast(val * precision + 0.5f); + JXL_ASSERT(absval_quant < (1 << 10)); + writer->Write(10, absval_quant); +} + +} // namespace + +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef) { + // The size of a patch in decoder might be different from encoder's patch + // size. + // For encoder: the patch size should be big enough to estimate + // noise level, but, at the same time, it should be not too big + // to be able to estimate intensity value of the patch + const size_t block_s = 8; + const size_t kNumBin = 256; + NoiseHistogram sad_histogram; + std::vector sad_scores = + GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram); + float sad_threshold = GetSADThreshold(sad_histogram, kNumBin); + // If threshold is too large, the image has a strong pattern. This pattern + // fools our model and it will add too much noise. Therefore, we do not add + // noise for such images + if (sad_threshold > 0.15f || sad_threshold <= 0.0f) { + noise_params->Clear(); + return false; + } + std::vector nl = + GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s); + + OptimizeNoiseParameters(nl, noise_params); + for (float& i : noise_params->lut) { + i *= quality_coef * 1.4; + } + return noise_params->HasAny(); +} + +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + JXL_ASSERT(noise_params.HasAny()); + + BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16); + for (float i : noise_params.lut) { + EncodeFloatParam(i, kNoisePrecision, writer); + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_noise.h b/third_party/jpeg-xl/lib/jxl/enc_noise.h new file mode 100644 index 000000000000..9471478894f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_noise.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_NOISE_H_ +#define LIB_JXL_ENC_NOISE_H_ + +// Noise parameter estimation. + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +// Get parameters of the noise for NoiseParams model +// Returns whether a valid noise model (with HasAny()) is set. +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef); + +// Does not write anything if `noise_params` are empty. Otherwise, caller must +// set FrameHeader.flags.kNoise. +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_params.h b/third_party/jpeg-xl/lib/jxl/enc_params.h new file mode 100644 index 000000000000..a7c168e908de --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_params.h @@ -0,0 +1,255 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_PARAMS_H_ +#define LIB_JXL_ENC_PARAMS_H_ + +// Parameters and flags that govern JXL compression. + +#include +#include + +#include + +#include "lib/jxl/base/override.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +enum class SpeedTier { + // Turns on FindBestQuantizationHQ loop. Equivalent to "guetzli" mode. + kTortoise = 1, + // Turns on FindBestQuantization butteraugli loop. + kKitten = 2, + // Turns on dots, patches, and spline detection by default, as well as full + // context clustering. Default. + kSquirrel = 3, + // Turns on error diffusion and full AC strategy heuristics. Equivalent to + // "fast" mode. + kWombat = 4, + // Turns on gaborish by default, non-default cmap, initial quant field. + kHare = 5, + // Turns on simple heuristics for AC strategy, quant field, and clustering; + // also enables coefficient reordering. + kCheetah = 6, + // Turns off most encoder features, for the fastest possible encoding time. + kFalcon = 7, +}; + +inline bool ParseSpeedTier(const std::string& s, SpeedTier* out) { + if (s == "falcon") { + *out = SpeedTier::kFalcon; + return true; + } else if (s == "cheetah") { + *out = SpeedTier::kCheetah; + return true; + } else if (s == "hare") { + *out = SpeedTier::kHare; + return true; + } else if (s == "fast" || s == "wombat") { + *out = SpeedTier::kWombat; + return true; + } else if (s == "squirrel") { + *out = SpeedTier::kSquirrel; + return true; + } else if (s == "kitten") { + *out = SpeedTier::kKitten; + return true; + } else if (s == "guetzli" || s == "tortoise") { + *out = SpeedTier::kTortoise; + return true; + } + size_t st = 10 - static_cast(strtoull(s.c_str(), nullptr, 0)); + if (st <= static_cast(SpeedTier::kFalcon) && + st >= static_cast(SpeedTier::kTortoise)) { + *out = SpeedTier(st); + return true; + } + return false; +} + +inline const char* SpeedTierName(SpeedTier speed_tier) { + switch (speed_tier) { + case SpeedTier::kFalcon: + return "falcon"; + case SpeedTier::kCheetah: + return "cheetah"; + case SpeedTier::kHare: + return "hare"; + case SpeedTier::kWombat: + return "wombat"; + case SpeedTier::kSquirrel: + return "squirrel"; + case SpeedTier::kKitten: + return "kitten"; + case SpeedTier::kTortoise: + return "tortoise"; + } + return "INVALID"; +} + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct CompressParams { + float butteraugli_distance = 1.0f; + size_t target_size = 0; + float target_bitrate = 0.0f; + + // 0.0 means search for the adaptive quantization map that matches the + // butteraugli distance, positive values mean quantize everywhere with that + // value. + float uniform_quant = 0.0f; + float quant_border_bias = 0.0f; + + // Try to achieve a maximum pixel-by-pixel error on each channel. + bool max_error_mode = false; + float max_error[3] = {0.0, 0.0, 0.0}; + + SpeedTier speed_tier = SpeedTier::kSquirrel; + + // 0 = default. + // 1 = slightly worse quality. + // 4 = fastest speed, lowest quality + // TODO(veluca): hook this up to the C API. + size_t decoding_speed_tier = 0; + + int max_butteraugli_iters = 4; + + int max_butteraugli_iters_guetzli_mode = 100; + + ColorTransform color_transform = ColorTransform::kXYB; + YCbCrChromaSubsampling chroma_subsampling; + + // If true, the "modular mode options" members below are used. + bool modular_mode = false; + + // Change group size in modular mode (0=128, 1=256, 2=512, 3=1024). + size_t modular_group_size_shift = 1; + + Override preview = Override::kDefault; + Override noise = Override::kDefault; + Override dots = Override::kDefault; + Override patches = Override::kDefault; + Override gaborish = Override::kDefault; + int epf = -1; + + // TODO(deymo): Remove "gradient" once all clients stop setting this value. + // This flag is already deprecated and is unused in the encoder. + Override gradient = Override::kOff; + + // Progressive mode. + bool progressive_mode = false; + + // Quantized-progressive mode. + bool qprogressive_mode = false; + + // Put center groups first in the bitstream. + bool middleout = false; + + int progressive_dc = -1; + + // Ensure invisible pixels are not set to 0. + bool keep_invisible = false; + + // Progressive-mode saliency. + // + // How many progressive saliency-encoding steps to perform. + // - 1: Encode only DC and lowest-frequency AC. Does not need a saliency-map. + // - 2: Encode only DC+LF, dropping all HF AC data. + // Does not need a saliency-map. + // - 3: Encode DC+LF+{salient HF}, dropping all non-salient HF data. + // - 4: Encode DC+LF+{salient HF}+{other HF}. + // - 5: Encode DC+LF+{quantized HF}+{low HF bits}. + size_t saliency_num_progressive_steps = 3; + // Every saliency-heatmap cell with saliency >= threshold will be considered + // as 'salient'. The default value of 0.0 will consider every AC-block + // as salient, hence not require a saliency-map, and not actually generate + // a 4th progressive step. + float saliency_threshold = 0.0f; + // Saliency-map (owned by caller). + ImageF* saliency_map = nullptr; + + // Input and output file name. Will be used to provide pluggable saliency + // extractor with paths. + const char* file_in = nullptr; + const char* file_out = nullptr; + + // Currently unused as of 2020-01. + bool clear_metadata = false; + + // Prints extra information during/after encoding. + bool verbose = false; + + ButteraugliParams ba_params; + + // Force usage of CfL when doing JPEG recompression. This can have unexpected + // effects on the decoded pixels, while still being JPEG-compliant and + // allowing reconstruction of the original JPEG. + bool force_cfl_jpeg_recompression = true; + + // modular mode options below + ModularOptions options; + int responsive = -1; + // A pair of . + std::pair quality_pair{100.f, 100.f}; + int colorspace = -1; + // Use Global channel palette if #colors < this percentage of range + float channel_colors_pre_transform_percent = 95.f; + // Use Local channel palette if #colors < this percentage of range + float channel_colors_percent = 80.f; + int near_lossless = 0; + int palette_colors = 1 << 10; // up to 10-bit palette is probably worthwhile + bool lossy_palette = false; + + // Returns whether these params are lossless as defined by SetLossless(); + bool IsLossless() const { + return modular_mode && quality_pair.first == 100 && + quality_pair.second == 100 && + color_transform == jxl::ColorTransform::kNone; + } + + // Sets the parameters required to make the codec lossless. + void SetLossless() { + modular_mode = true; + quality_pair.first = 100; + quality_pair.second = 100; + color_transform = jxl::ColorTransform::kNone; + } + + bool use_new_heuristics = false; + + // Down/upsample the image before encoding / after decoding by this factor. + size_t resampling = 1; +}; + +static constexpr float kMinButteraugliForDynamicAR = 0.5f; +static constexpr float kMinButteraugliForDots = 3.0f; +static constexpr float kMinButteraugliToSubtractOriginalPatches = 3.0f; +static constexpr float kMinButteraugliDistanceForProgressiveDc = 4.5f; + +// Always off +static constexpr float kMinButteraugliForNoise = 99.0f; + +// Minimum butteraugli distance the encoder accepts. +static constexpr float kMinButteraugliDistance = 0.01f; + +// Tile size for encoder-side processing. Must be equal to color tile dim in the +// current implementation. +static constexpr size_t kEncTileDim = 64; +static constexpr size_t kEncTileDimInBlocks = kEncTileDim / kBlockDim; + +} // namespace jxl + +#endif // LIB_JXL_ENC_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc new file mode 100644 index 000000000000..83ded9f97e69 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc @@ -0,0 +1,792 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_patch_dictionary.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_dot_dictionary.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +// static +void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + JXL_ASSERT(pdic.HasAny()); + std::vector> tokens(1); + + auto add_num = [&](int context, size_t num) { + tokens[0].emplace_back(context, num); + }; + size_t num_ref_patch = 0; + for (size_t i = 0; i < pdic.positions_.size();) { + size_t i_start = i; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos == pdic.positions_[i_start].ref_pos) { + i++; + } + num_ref_patch++; + } + add_num(kNumRefPatchContext, num_ref_patch); + for (size_t i = 0; i < pdic.positions_.size();) { + size_t i_start = i; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos == pdic.positions_[i_start].ref_pos) { + i++; + } + size_t num = i - i_start; + JXL_ASSERT(num > 0); + add_num(kReferenceFrameContext, pdic.positions_[i_start].ref_pos.ref); + add_num(kPatchReferencePositionContext, + pdic.positions_[i_start].ref_pos.x0); + add_num(kPatchReferencePositionContext, + pdic.positions_[i_start].ref_pos.y0); + add_num(kPatchSizeContext, pdic.positions_[i_start].ref_pos.xsize - 1); + add_num(kPatchSizeContext, pdic.positions_[i_start].ref_pos.ysize - 1); + add_num(kPatchCountContext, num - 1); + for (size_t j = i_start; j < i; j++) { + const PatchPosition& pos = pdic.positions_[j]; + if (j == i_start) { + add_num(kPatchPositionContext, pos.x); + add_num(kPatchPositionContext, pos.y); + } else { + add_num(kPatchOffsetContext, + PackSigned(pos.x - pdic.positions_[j - 1].x)); + add_num(kPatchOffsetContext, + PackSigned(pos.y - pdic.positions_[j - 1].y)); + } + JXL_ASSERT(pdic.shared_->metadata->m.extra_channel_info.size() + 1 == + pos.blending.size()); + for (size_t i = 0; + i < pdic.shared_->metadata->m.extra_channel_info.size() + 1; i++) { + const PatchBlending& info = pos.blending[i]; + add_num(kPatchBlendModeContext, static_cast(info.mode)); + if (UsesAlpha(info.mode) && + pdic.shared_->metadata->m.extra_channel_info.size() > 1) { + add_num(kPatchAlphaChannelContext, info.alpha_channel); + } + if (UsesClamp(info.mode)) { + add_num(kPatchClampContext, info.clamp); + } + } + } + } + + EntropyEncodingData codes; + std::vector context_map; + BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts, + tokens, &codes, &context_map, writer, layer, + aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +// static +void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic, + Image3F* opsin) { + pdic.Apply(opsin, Rect(*opsin), Rect(*opsin)); +} + +namespace { + +struct PatchColorspaceInfo { + float kChannelDequant[3]; + float kChannelWeights[3]; + + explicit PatchColorspaceInfo(bool is_xyb) { + if (is_xyb) { + kChannelDequant[0] = 0.01615; + kChannelDequant[1] = 0.08875; + kChannelDequant[2] = 0.1922; + kChannelWeights[0] = 30.0; + kChannelWeights[1] = 3.0; + kChannelWeights[2] = 1.0; + } else { + kChannelDequant[0] = 20; + kChannelDequant[1] = 22; + kChannelDequant[2] = 20; + kChannelWeights[0] = 0.017; + kChannelWeights[1] = 0.02; + kChannelWeights[2] = 0.017; + } + } + + float ScaleForQuantization(float val, size_t c) { + return val / kChannelDequant[c]; + } + + int Quantize(float val, size_t c) { + return truncf(ScaleForQuantization(val, c)); + } + + bool is_similar_v(const float v1[3], const float v2[3], float threshold) { + float distance = 0; + for (size_t c = 0; c < 3; c++) { + distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c]; + } + return distance <= threshold; + } +}; + +std::vector FindTextLikePatches( + const Image3F& opsin, const PassesEncoderState* JXL_RESTRICT state, + ThreadPool* pool, AuxOut* aux_out, bool is_xyb) { + if (state->cparams.patches == Override::kOff) return {}; + + PatchColorspaceInfo pci(is_xyb); + float kSimilarThreshold = 0.8f; + + auto is_similar_impl = [&pci](std::pair p1, + std::pair p2, + const float* JXL_RESTRICT rows[3], + size_t stride, float threshold) { + float v1[3], v2[3]; + for (size_t c = 0; c < 3; c++) { + v1[c] = rows[c][p1.second * stride + p1.first]; + v2[c] = rows[c][p2.second * stride + p2.first]; + } + return pci.is_similar_v(v1, v2, threshold); + }; + + std::atomic has_screenshot_areas{false}; + const size_t opsin_stride = opsin.PixelsPerRow(); + const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0), + opsin.ConstPlaneRow(1, 0), + opsin.ConstPlaneRow(2, 0)}; + + auto is_same = [&opsin_rows, opsin_stride](std::pair p1, + std::pair p2) { + for (size_t c = 0; c < 3; c++) { + float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first]; + float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first]; + if (std::fabs(v1 - v2) > 1e-4) { + return false; + } + } + return true; + }; + + auto is_similar = [&](std::pair p1, + std::pair p2) { + return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold); + }; + + constexpr int64_t kPatchSide = 4; + constexpr int64_t kExtraSide = 4; + + // Look for kPatchSide size squares, naturally aligned, that all have the same + // pixel values. + ImageB is_screenshot_like(DivCeil(opsin.xsize(), kPatchSide), + DivCeil(opsin.ysize(), kPatchSide)); + ZeroFillImage(&is_screenshot_like); + uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0); + const size_t screenshot_stride = is_screenshot_like.PixelsPerRow(); + const auto process_row = [&](uint64_t y, int _) { + for (uint64_t x = 0; x < opsin.xsize() / kPatchSide; x++) { + bool all_same = true; + for (size_t iy = 0; iy < static_cast(kPatchSide); iy++) { + for (size_t ix = 0; ix < static_cast(kPatchSide); ix++) { + size_t cx = x * kPatchSide + ix; + size_t cy = y * kPatchSide + iy; + if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) { + all_same = false; + break; + } + } + } + if (!all_same) continue; + size_t num = 0; + size_t num_same = 0; + for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) { + for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) { + int64_t cx = x * kPatchSide + ix; + int64_t cy = y * kPatchSide + iy; + if (cx < 0 || static_cast(cx) >= opsin.xsize() || // + cy < 0 || static_cast(cy) >= opsin.ysize()) { + continue; + } + num++; + if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++; + } + } + // Too few equal pixels nearby. + if (num_same * 8 < num * 7) continue; + screenshot_row[y * screenshot_stride + x] = 1; + has_screenshot_areas = true; + } + }; + RunOnPool(pool, 0, opsin.ysize() / kPatchSide, ThreadPool::SkipInit(), + process_row, "IsScreenshotLike"); + + // TODO(veluca): also parallelize the rest of this function. + if (WantDebugOutput(aux_out)) { + aux_out->DumpPlaneNormalized("screenshot_like", is_screenshot_like); + } + + constexpr int kSearchRadius = 1; + + if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) { + return {}; + } + // Search for "similar enough" pixels near the screenshot-like areas. + ImageB is_background(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&is_background); + Image3F background(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&background); + constexpr size_t kDistanceLimit = 50; + float* JXL_RESTRICT background_rows[3] = { + background.PlaneRow(0, 0), + background.PlaneRow(1, 0), + background.PlaneRow(2, 0), + }; + const size_t background_stride = background.PixelsPerRow(); + uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0); + const size_t is_background_stride = is_background.PixelsPerRow(); + std::vector< + std::pair, std::pair>> + queue; + size_t queue_front = 0; + for (size_t y = 0; y < opsin.ysize(); y++) { + for (size_t x = 0; x < opsin.xsize(); x++) { + if (!screenshot_row[screenshot_stride * (y / kPatchSide) + + (x / kPatchSide)]) + continue; + queue.push_back({{x, y}, {x, y}}); + } + } + while (queue.size() != queue_front) { + std::pair cur = queue[queue_front].first; + std::pair src = queue[queue_front].second; + queue_front++; + if (is_background_row[cur.second * is_background_stride + cur.first]) + continue; + is_background_row[cur.second * is_background_stride + cur.first] = 1; + for (size_t c = 0; c < 3; c++) { + background_rows[c][cur.second * background_stride + cur.first] = + opsin_rows[c][src.second * opsin_stride + src.first]; + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = cur.first + dx; + int next_second = cur.second + dy; + if (next_first < 0 || next_second < 0 || + static_cast(next_first) >= opsin.xsize() || + static_cast(next_second) >= opsin.ysize()) { + continue; + } + if (static_cast( + std::abs(next_first - static_cast(src.first)) + + std::abs(next_second - static_cast(src.second))) > + kDistanceLimit) { + continue; + } + std::pair next{next_first, next_second}; + if (is_similar(src, next)) { + if (!screenshot_row[next.second / kPatchSide * screenshot_stride + + next.first / kPatchSide] || + is_same(src, next)) { + if (!is_background_row[next.second * is_background_stride + + next.first]) + queue.emplace_back(next, src); + } + } + } + } + } + queue.clear(); + + ImageF ccs; + std::mt19937 rng; + std::uniform_real_distribution dist(0.5, 1.0); + bool paint_ccs = false; + if (WantDebugOutput(aux_out)) { + aux_out->DumpPlaneNormalized("is_background", is_background); + aux_out->DumpXybImage("background", background); + ccs = ImageF(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&ccs); + paint_ccs = true; + } + + constexpr float kVerySimilarThreshold = 0.03f; + constexpr float kHasSimilarThreshold = 0.03f; + + const float* JXL_RESTRICT const_background_rows[3] = { + background_rows[0], background_rows[1], background_rows[2]}; + auto is_similar_b = [&](std::pair p1, std::pair p2) { + return is_similar_impl(p1, p2, const_background_rows, background_stride, + kVerySimilarThreshold); + }; + + constexpr int kMinPeak = 2; + constexpr int kHasSimilarRadius = 2; + + std::vector info; + + // Find small CC outside the "similar enough" areas, compute bounding boxes, + // and run heuristics to exclude some patches. + ImageB visited(opsin.xsize(), opsin.ysize()); + ZeroFillImage(&visited); + uint8_t* JXL_RESTRICT visited_row = visited.Row(0); + const size_t visited_stride = visited.PixelsPerRow(); + std::vector> cc; + std::vector> stack; + for (size_t y = 0; y < opsin.ysize(); y++) { + for (size_t x = 0; x < opsin.xsize(); x++) { + if (is_background_row[y * is_background_stride + x]) continue; + cc.clear(); + stack.clear(); + stack.emplace_back(x, y); + size_t min_x = x; + size_t max_x = x; + size_t min_y = y; + size_t max_y = y; + std::pair reference; + bool found_border = false; + bool all_similar = true; + while (!stack.empty()) { + std::pair cur = stack.back(); + stack.pop_back(); + if (visited_row[cur.second * visited_stride + cur.first]) continue; + visited_row[cur.second * visited_stride + cur.first] = 1; + if (cur.first < min_x) min_x = cur.first; + if (cur.first > max_x) max_x = cur.first; + if (cur.second < min_y) min_y = cur.second; + if (cur.second > max_y) max_y = cur.second; + if (paint_ccs) { + cc.push_back(cur); + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = static_cast(cur.first) + dx; + int next_second = static_cast(cur.second) + dy; + if (next_first < 0 || next_second < 0 || + static_cast(next_first) >= opsin.xsize() || + static_cast(next_second) >= opsin.ysize()) { + continue; + } + std::pair next{next_first, next_second}; + if (!is_background_row[next.second * is_background_stride + + next.first]) { + stack.push_back(next); + } else { + if (!found_border) { + reference = next; + found_border = true; + } else { + if (!is_similar_b(next, reference)) all_similar = false; + } + } + } + } + } + if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize || + max_y - min_y >= kMaxPatchSize) { + continue; + } + size_t bpos = background_stride * reference.second + reference.first; + float ref[3] = {background_rows[0][bpos], background_rows[1][bpos], + background_rows[2][bpos]}; + bool has_similar = false; + for (size_t iy = std::max( + static_cast(min_y) - kHasSimilarRadius, 0); + iy < std::min(max_y + kHasSimilarRadius + 1, opsin.ysize()); iy++) { + for (size_t ix = std::max( + static_cast(min_x) - kHasSimilarRadius, 0); + ix < std::min(max_x + kHasSimilarRadius + 1, opsin.xsize()); + ix++) { + size_t opos = opsin_stride * iy + ix; + float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos], + opsin_rows[2][opos]}; + if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) { + has_similar = true; + } + } + } + if (!has_similar) continue; + info.emplace_back(); + info.back().second.emplace_back(min_x, min_y); + QuantizedPatch& patch = info.back().first; + patch.xsize = max_x - min_x + 1; + patch.ysize = max_y - min_y + 1; + int max_value = 0; + for (size_t c : {1, 0, 2}) { + for (size_t iy = min_y; iy <= max_y; iy++) { + for (size_t ix = min_x; ix <= max_x; ix++) { + size_t offset = (iy - min_y) * patch.xsize + ix - min_x; + patch.fpixels[c][offset] = + opsin_rows[c][iy * opsin_stride + ix] - ref[c]; + int val = pci.Quantize(patch.fpixels[c][offset], c); + patch.pixels[c][offset] = val; + if (std::abs(val) > max_value) max_value = std::abs(val); + } + } + } + if (max_value < kMinPeak) { + info.pop_back(); + continue; + } + if (paint_ccs) { + float cc_color = dist(rng); + for (std::pair p : cc) { + ccs.Row(p.second)[p.first] = cc_color; + } + } + } + } + + if (paint_ccs) { + JXL_ASSERT(WantDebugOutput(aux_out)); + aux_out->DumpPlaneNormalized("ccs", ccs); + } + if (info.empty()) { + return {}; + } + + // Remove duplicates. + constexpr size_t kMinPatchOccurences = 2; + std::sort(info.begin(), info.end()); + size_t unique = 0; + for (size_t i = 1; i < info.size(); i++) { + if (info[i].first == info[unique].first) { + info[unique].second.insert(info[unique].second.end(), + info[i].second.begin(), info[i].second.end()); + } else { + if (info[unique].second.size() >= kMinPatchOccurences) { + unique++; + } + info[unique] = info[i]; + } + } + if (info[unique].second.size() >= kMinPatchOccurences) { + unique++; + } + info.resize(unique); + + size_t max_patch_size = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (pixels > max_patch_size) max_patch_size = pixels; + } + + // don't use patches if all patches are smaller than this + constexpr size_t kMinMaxPatchSize = 20; + if (max_patch_size < kMinMaxPatchSize) return {}; + + // Ensure that the specified set of patches doesn't produce out-of-bounds + // pixels. + // TODO(veluca): figure out why this is still necessary even with RCTs that + // don't depend on bit depth. + if (state->cparams.modular_mode && state->cparams.quality_pair.first >= 100) { + constexpr size_t kMaxPatchArea = kMaxPatchSize * kMaxPatchSize; + std::vector min_then_max_px(2 * kMaxPatchArea); + for (size_t i = 0; i < info.size(); i++) { + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT min_px = min_then_max_px.data(); + float* JXL_RESTRICT max_px = min_px + kMaxPatchArea; + std::fill(min_px, min_px + kMaxPatchArea, 1); + std::fill(max_px, max_px + kMaxPatchArea, 0); + size_t xsize = info[i].first.xsize; + for (size_t j = 0; j < info[i].second.size(); j++) { + size_t bx = info[i].second[j].first; + size_t by = info[i].second[j].second; + for (size_t iy = 0; iy < info[i].first.ysize; iy++) { + for (size_t ix = 0; ix < xsize; ix++) { + float v = opsin_rows[c][(by + iy) * opsin_stride + bx + ix]; + if (v < min_px[iy * xsize + ix]) min_px[iy * xsize + ix] = v; + if (v > max_px[iy * xsize + ix]) max_px[iy * xsize + ix] = v; + } + } + } + for (size_t iy = 0; iy < info[i].first.ysize; iy++) { + for (size_t ix = 0; ix < xsize; ix++) { + float smallest = min_px[iy * xsize + ix]; + float biggest = max_px[iy * xsize + ix]; + JXL_ASSERT(smallest <= biggest); + float& out = info[i].first.fpixels[c][iy * xsize + ix]; + // Clamp fpixels so that subtracting the patch never creates a + // negative value, or a value above 1. + JXL_ASSERT(biggest - 1 <= smallest); + out = std::max(smallest, out); + out = std::min(biggest - 1.f, out); + } + } + } + } + } + return info; +} + +} // namespace + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + ThreadPool* pool, AuxOut* aux_out, bool is_xyb) { + state->shared.image_features.patches = PatchDictionary(); + state->shared.image_features.patches.SetPassesSharedState(&state->shared); + + std::vector info = + FindTextLikePatches(opsin, state, pool, aux_out, is_xyb); + + // TODO(veluca): this doesn't work if both dots and patches are enabled. + // For now, since dots and patches are not likely to occur in the same kind of + // images, disable dots if some patches were found. + if (info.empty() && + ApplyOverride( + state->cparams.dots, + state->cparams.speed_tier <= SpeedTier::kSquirrel && + state->cparams.butteraugli_distance >= kMinButteraugliForDots)) { + info = FindDotDictionary(state->cparams, opsin, state->shared.cmap, pool); + } + + if (info.empty()) return; + + std::sort( + info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) { + return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize; + }); + + size_t max_x_size = 0; + size_t max_y_size = 0; + size_t total_pixels = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize; + if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize; + total_pixels += pixels; + } + + // Bin-packing & conversion of patches. + constexpr float kBinPackingSlackness = 1.05f; + size_t ref_xsize = std::max(max_x_size, std::sqrt(total_pixels)); + size_t ref_ysize = std::max(max_y_size, std::sqrt(total_pixels)); + std::vector> ref_positions(info.size()); + // TODO(veluca): allow partial overlaps of patches that have the same pixels. + size_t max_y = 0; + do { + max_y = 0; + // Increase packed image size. + ref_xsize = ref_xsize * kBinPackingSlackness + 1; + ref_ysize = ref_ysize * kBinPackingSlackness + 1; + + ImageB occupied(ref_xsize, ref_ysize); + ZeroFillImage(&occupied); + uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0); + size_t occupied_stride = occupied.PixelsPerRow(); + + bool success = true; + // For every patch... + for (size_t patch = 0; patch < info.size(); patch++) { + size_t x0 = 0; + size_t y0 = 0; + size_t xsize = info[patch].first.xsize; + size_t ysize = info[patch].first.ysize; + bool found = false; + // For every possible start position ... + for (; y0 + ysize <= ref_ysize; y0++) { + x0 = 0; + for (; x0 + xsize <= ref_xsize; x0++) { + bool has_occupied_pixel = false; + size_t x = x0; + // Check if it is possible to place the patch in this position in the + // reference frame. + for (size_t y = y0; y < y0 + ysize; y++) { + x = x0; + for (; x < x0 + xsize; x++) { + if (occupied_rows[y * occupied_stride + x]) { + has_occupied_pixel = true; + break; + } + } + } // end of positioning check + if (!has_occupied_pixel) { + found = true; + break; + } + x0 = x; // Jump to next pixel after the occupied one. + } + if (found) break; + } // end of start position checking + + // We didn't find a possible position: repeat from the beginning with a + // larger reference frame size. + if (!found) { + success = false; + break; + } + + // We found a position: mark the corresponding positions in the reference + // image as used. + ref_positions[patch] = {x0, y0}; + for (size_t y = y0; y < y0 + ysize; y++) { + for (size_t x = x0; x < x0 + xsize; x++) { + occupied_rows[y * occupied_stride + x] = true; + } + } + max_y = std::max(max_y, y0 + ysize); + } + + if (success) break; + } while (true); + + JXL_ASSERT(ref_ysize >= max_y); + + ref_ysize = max_y; + + Image3F reference_frame(ref_xsize, ref_ysize); + // TODO(veluca): figure out a better way to fill the image. + ZeroFillImage(&reference_frame); + std::vector positions; + float* JXL_RESTRICT ref_rows[3] = { + reference_frame.PlaneRow(0, 0), + reference_frame.PlaneRow(1, 0), + reference_frame.PlaneRow(2, 0), + }; + size_t ref_stride = reference_frame.PixelsPerRow(); + + for (size_t i = 0; i < info.size(); i++) { + PatchReferencePosition ref_pos; + ref_pos.xsize = info[i].first.xsize; + ref_pos.ysize = info[i].first.ysize; + ref_pos.x0 = ref_positions[i].first; + ref_pos.y0 = ref_positions[i].second; + ref_pos.ref = 0; + for (size_t y = 0; y < ref_pos.ysize; y++) { + for (size_t x = 0; x < ref_pos.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] = + info[i].first.fpixels[c][y * ref_pos.xsize + x]; + } + } + } + // Add color channels, ignore other channels. + std::vector blending_info( + state->shared.metadata->m.extra_channel_info.size() + 1, + PatchBlending{PatchBlendMode::kNone, 0, false}); + blending_info[0].mode = PatchBlendMode::kAdd; + for (const auto& pos : info[i].second) { + positions.emplace_back( + PatchPosition{pos.first, pos.second, blending_info, ref_pos}); + } + } + + CompressParams cparams = state->cparams; + cparams.resampling = 1; + // Recursive application of patches could create very weird issues. + cparams.patches = Override::kOff; + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.modular_mode = true; + cparams.responsive = 0; + cparams.progressive_dc = 0; + cparams.progressive_mode = false; + cparams.qprogressive_mode = false; + // Use gradient predictor and not Predictor::Best. + cparams.options.predictor = Predictor::Gradient; + // TODO(veluca): possibly change heuristics here. + if (!cparams.modular_mode) { + cparams.quality_pair.first = cparams.quality_pair.second = + 80 - cparams.butteraugli_distance * 12; + } else { + cparams.quality_pair.first = (100 + 3 * cparams.quality_pair.first) * 0.25f; + cparams.quality_pair.second = + (100 + 3 * cparams.quality_pair.second) * 0.25f; + } + FrameInfo patch_frame_info; + patch_frame_info.save_as_reference = 0; // always saved. + patch_frame_info.frame_type = FrameType::kReferenceOnly; + patch_frame_info.save_before_color_transform = true; + + ImageBundle ib(&state->shared.metadata->m); + // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is + // no simple way to express that yet. + patch_frame_info.ib_needs_color_transform = false; + patch_frame_info.save_as_reference = 0; + ib.SetFromImage(std::move(reference_frame), + state->shared.metadata->m.color_encoding); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add dummy extra channels to the patch image: patches do not yet support + // extra channels, but the codec expects that the amount of extra channels + // in frames matches that in the metadata of the codestream. + std::vector extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + extra_channels.emplace_back(ib.xsize(), ib.ysize()); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): patches must copy and use the real extra channels instead. + FillImage(1.0f, &extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + + PassesEncoderState roundtrip_state; + auto special_frame = std::unique_ptr(new BitWriter()); + JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib, + &roundtrip_state, pool, special_frame.get(), nullptr)); + const Span encoded = special_frame->GetSpan(); + state->special_frames.emplace_back(std::move(special_frame)); + if (cparams.butteraugli_distance < kMinButteraugliToSubtractOriginalPatches) { + BitReader br(encoded); + ImageBundle decoded(&state->shared.metadata->m); + PassesDecoderState dec_state; + JXL_CHECK(dec_state.output_encoding_info.Set(state->shared.metadata->m)); + JXL_CHECK(DecodeFrame({}, &dec_state, pool, &br, &decoded, + *state->shared.metadata, /*constraints=*/nullptr)); + JXL_CHECK(br.Close()); + state->shared.reference_frames[0] = + std::move(dec_state.shared_storage.reference_frames[0]); + } else { + state->shared.reference_frames[0].storage = std::move(ib); + } + state->shared.reference_frames[0].frame = + &state->shared.reference_frames[0].storage; + // TODO(veluca): this assumes that applying patches is commutative, which is + // not true for all blending modes. This code only produces kAdd patches, so + // this works out. + std::sort(positions.begin(), positions.end()); + PatchDictionaryEncoder::SetPositions(&state->shared.image_features.patches, + std::move(positions)); +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h new file mode 100644 index 000000000000..13ecb6ebdbd3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_PATCH_DICTIONARY_H_ +#define LIB_JXL_ENC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include +#include +#include + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +// Friend class of PatchDictionary. +class PatchDictionaryEncoder { + public: + // Only call if HasAny(). + static void Encode(const PatchDictionary& pdic, BitWriter* writer, + size_t layer, AuxOut* aux_out); + + static void SetPositions(PatchDictionary* pdic, + std::vector positions) { + pdic->positions_ = std::move(positions); + pdic->ComputePatchCache(); + } + + static void SubtractFrom(const PatchDictionary& pdic, Image3F* opsin); +}; + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + ThreadPool* pool, AuxOut* aux_out, + bool is_xyb = true); + +} // namespace jxl + +#endif // LIB_JXL_ENC_PATCH_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc new file mode 100644 index 000000000000..11d01b42001f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc @@ -0,0 +1,212 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_quant_weights.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace { + +Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { + JXL_ASSERT(params.num_distance_bands >= 1); + writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, + params.num_distance_bands - 1); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params.num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); + } + } + return true; +} + +Status EncodeQuant(const QuantEncoding& encoding, size_t idx, size_t size_x, + size_t size_y, BitWriter* writer, + ModularFrameEncoder* modular_frame_encoder) { + writer->Write(kLog2NumQuantModes, encoding.mode); + size_x *= kBlockDim; + size_y *= kBlockDim; + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); + break; + } + case QuantEncoding::kQuantModeID: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.dct2weights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4x8multipliers[c], writer)); + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4multipliers[c][i], writer)); + } + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeRAW: { + ModularFrameEncoder::EncodeQuantTable(size_x, size_y, writer, encoding, + idx, modular_frame_encoder); + break; + } + case QuantEncoding::kQuantModeAFV: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + JXL_RETURN_IF_ERROR( + EncodeDctParams(encoding.dct_params_afv_4x4, writer)); + } + break; + } + } + return true; +} + +} // namespace + +Status DequantMatricesEncode(const DequantMatrices* matrices, BitWriter* writer, + size_t layer, AuxOut* aux_out, + ModularFrameEncoder* modular_frame_encoder) { + bool all_default = true; + const std::vector& encodings = matrices->encodings(); + + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode != QuantEncoding::kQuantModeLibrary || + encodings[i].predefined != 0) { + all_default = false; + } + } + // TODO(janwas): better bound + BitWriter::Allotment allotment(writer, 512 * 1024); + writer->Write(1, all_default); + if (!all_default) { + for (size_t i = 0; i < encodings.size(); i++) { + JXL_RETURN_IF_ERROR(EncodeQuant( + encodings[i], i, DequantMatrices::required_size_x[i], + DequantMatrices::required_size_y[i], writer, modular_frame_encoder)); + } + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +Status DequantMatricesEncodeDC(const DequantMatrices* matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + bool all_default = true; + const float* dc_quant = matrices->DCQuants(); + for (size_t c = 0; c < 3; c++) { + if (dc_quant[c] != kDCQuant[c]) { + all_default = false; + } + } + BitWriter::Allotment allotment(writer, 1 + sizeof(float) * kBitsPerByte * 3); + writer->Write(1, all_default); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); + } + } + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { + matrices->SetDCQuant(dc); + // Roundtrip encode/decode DC to ensure same values as decoder. + BitWriter writer; + JXL_CHECK(DequantMatricesEncodeDC(matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->DecodeDC(&br)); + JXL_CHECK(br.Close()); +} + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector& encodings, + ModularFrameEncoder* encoder) { + JXL_ASSERT(encodings.size() == DequantMatrices::kNum); + matrices->SetEncodings(encodings); + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { + encoder->AddQuantTable(DequantMatrices::required_size_x[i] * kBlockDim, + DequantMatrices::required_size_y[i] * kBlockDim, + encodings[i], i); + } + } + // Roundtrip encode/decode the matrices to ensure same values as decoder. + // Do not pass modular en/decoder, as they only change entropy and not + // values. + BitWriter writer; + JXL_CHECK(DequantMatricesEncode(matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->Decode(&br)); + JXL_CHECK(br.Close()); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h new file mode 100644 index 000000000000..019a2859b4ea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_QUANT_WEIGHTS_H_ +#define LIB_JXL_ENC_QUANT_WEIGHTS_H_ + +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +Status DequantMatricesEncode( + const DequantMatrices* matrices, BitWriter* writer, size_t layer, + AuxOut* aux_out, ModularFrameEncoder* modular_frame_encoder = nullptr); +Status DequantMatricesEncodeDC(const DequantMatrices* matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out); +// For consistency with QuantEncoding, higher values correspond to more +// precision. +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc); + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector& encodings, + ModularFrameEncoder* encoder); + +} // namespace jxl + +#endif // LIB_JXL_ENC_QUANT_WEIGHTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_splines.cc b/third_party/jpeg-xl/lib/jxl/enc_splines.cc new file mode 100644 index 000000000000..0c657c5bb8dc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_splines.cc @@ -0,0 +1,105 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +class QuantizedSplineEncoder { + public: + // Only call if HasAny(). + static void Tokenize(const QuantizedSpline& spline, + std::vector* const tokens) { + tokens->emplace_back(kNumControlPointsContext, + spline.control_points_.size()); + for (const auto& point : spline.control_points_) { + tokens->emplace_back(kControlPointsContext, PackSigned(point.first)); + tokens->emplace_back(kControlPointsContext, PackSigned(point.second)); + } + const auto encode_dct = [tokens](const int dct[32]) { + for (int i = 0; i < 32; ++i) { + tokens->emplace_back(kDCTContext, PackSigned(dct[i])); + } + }; + for (int c = 0; c < 3; ++c) { + encode_dct(spline.color_dct_[c]); + } + encode_dct(spline.sigma_dct_); + } +}; + +namespace { + +void EncodeAllStartingPoints(const std::vector& points, + std::vector* tokens) { + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < points.size(); i++) { + const int64_t x = lroundf(points[i].x); + const int64_t y = lroundf(points[i].y); + if (i == 0) { + tokens->emplace_back(kStartingPositionContext, x); + tokens->emplace_back(kStartingPositionContext, y); + } else { + tokens->emplace_back(kStartingPositionContext, PackSigned(x - last_x)); + tokens->emplace_back(kStartingPositionContext, PackSigned(y - last_y)); + } + last_x = x; + last_y = y; + } +} + +} // namespace + +void EncodeSplines(const Splines& splines, BitWriter* writer, + const size_t layer, const HistogramParams& histogram_params, + AuxOut* aux_out) { + JXL_ASSERT(splines.HasAny()); + + const std::vector& quantized_splines = + splines.QuantizedSplines(); + std::vector> tokens(1); + tokens[0].emplace_back(kNumSplinesContext, quantized_splines.size() - 1); + EncodeAllStartingPoints(splines.StartingPoints(), &tokens[0]); + + tokens[0].emplace_back(kQuantizationAdjustmentContext, + PackSigned(splines.GetQuantizationAdjustment())); + + for (const QuantizedSpline& spline : quantized_splines) { + QuantizedSplineEncoder::Tokenize(spline, &tokens[0]); + } + + EntropyEncodingData codes; + std::vector context_map; + BuildAndEncodeHistograms(histogram_params, kNumSplineContexts, tokens, &codes, + &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out); +} + +Splines FindSplines(const Image3F& opsin) { + // TODO: implement spline detection. + return {}; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_splines.h b/third_party/jpeg-xl/lib/jxl/enc_splines.h new file mode 100644 index 000000000000..ea5e3c5568d8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_splines.h @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_SPLINES_H_ +#define LIB_JXL_ENC_SPLINES_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Only call if splines.HasAny(). +void EncodeSplines(const Splines& splines, BitWriter* writer, + const size_t layer, const HistogramParams& histogram_params, + AuxOut* aux_out); + +Splines FindSplines(const Image3F& opsin); + +} // namespace jxl + +#endif // LIB_JXL_ENC_SPLINES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_toc.cc b/third_party/jpeg-xl/lib/jxl/enc_toc.cc new file mode 100644 index 000000000000..6729dfa83e93 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_toc.cc @@ -0,0 +1,55 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_toc.h" + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/toc.h" + +namespace jxl { +Status WriteGroupOffsets(const std::vector& group_codes, + const std::vector* permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, MaxBits(group_codes.size())); + if (permutation && !group_codes.empty()) { + // Don't write a permutation at all for an empty group_codes. + writer->Write(1, 1); // permutation + JXL_DASSERT(permutation->size() == group_codes.size()); + EncodePermutation(permutation->data(), /*skip=*/0, permutation->size(), + writer, /* layer= */ 0, aux_out); + + } else { + writer->Write(1, 0); // no permutation + } + writer->ZeroPadToByte(); // before TOC entries + + for (size_t i = 0; i < group_codes.size(); i++) { + JXL_ASSERT(group_codes[i].BitsWritten() % kBitsPerByte == 0); + const size_t group_size = group_codes[i].BitsWritten() / kBitsPerByte; + JXL_RETURN_IF_ERROR(U32Coder::Write(kTocDist, group_size, writer)); + } + writer->ZeroPadToByte(); // before first group + ReclaimAndCharge(writer, &allotment, kLayerTOC, aux_out); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_toc.h b/third_party/jpeg-xl/lib/jxl/enc_toc.h new file mode 100644 index 000000000000..f3133eef41b9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_toc.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_TOC_H_ +#define LIB_JXL_ENC_TOC_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Writes the group offsets. If the permutation vector is nullptr, the identity +// permutation will be used. +Status WriteGroupOffsets(const std::vector& group_codes, + const std::vector* permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TOC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h b/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h new file mode 100644 index 000000000000..42b3d5b3f431 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h @@ -0,0 +1,853 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(LIB_JXL_ENC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#undef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_ENC_TRANSFORMS_INL_H_ +#endif + +#include + +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +template +struct DoIDCT { + template + void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) { + ComputeScaledIDCT()(from, to, scratch_space); + } +}; + +template +struct DoIDCT { + template + void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) const { + ComputeTransposedScaledIDCT()(from, to, scratch_space); + } +}; + +// Inverse of ReinterpretingDCT. +template +HWY_INLINE void ReinterpretingIDCT(const float* input, + const size_t input_stride, float* output, + const size_t output_stride) { + HWY_ALIGN float block[ROWS * COLS] = {}; + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + block[y * COLS + x] = input[y * input_stride + x] * + DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + block[y * ROWS + x] = input[y * input_stride + x] * + DCTTotalResampleScale(y) * + DCTTotalResampleScale(x); + } + } + } + + // ROWS, COLS <= 8, so we can put scratch space on the stack. + HWY_ALIGN float scratch_space[ROWS * COLS]; + DoIDCT()(block, DCTTo(output, output_stride), scratch_space); +} + +template +void DCT2TopBlock(const float* block, size_t stride, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * 2 * stride + x * 2]; + float c01 = block[y * 2 * stride + x * 2 + 1]; + float c10 = block[(y * 2 + 1) * stride + x * 2]; + float c11 = block[(y * 2 + 1) * stride + x * 2 + 1]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + r00 *= 0.25f; + r01 *= 0.25f; + r10 *= 0.25f; + r11 *= 0.25f; + temp[y * kBlockDim + x] = r00; + temp[y * kBlockDim + num_2x2 + x] = r01; + temp[(y + num_2x2) * kBlockDim + x] = r10; + temp[(y + num_2x2) * kBlockDim + num_2x2 + x] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * kBlockDim + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + HWY_ALIGN static constexpr float k4x4AFVBasisTranspose[16][16] = { + { + 0.2500000000000000, + 0.8769029297991420f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.4105377591765233f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + 0.2206518106944235f, + 0.0000000000000000, + 0.0000000000000000, + -0.7071067811865474f, + 0.6235485373547691f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.4067007583026075f, + -0.2125574805828875f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999482f, + -0.3046847507248690f, + 0.3017929516615495f, + 0.4082482904638627f, + 0.1747866975480809f, + -0.2110560104933578f, + -0.1426608480880726f, + -0.1381354035075859f, + -0.1743760259965107f, + 0.1135498731499434f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.4444481661973445f, + 0.3085497062849767f, + 0.0000000000000000f, + -0.0643507165794627f, + 0.1585450355184006f, + 0.5112616136591823f, + 0.2579236279634118f, + 0.0000000000000000, + 0.0812611176717539f, + 0.1856718091610980f, + -0.3416446842253372f, + 0.3302282550303788f, + 0.0702790691196284f, + -0.0741750459581035f, + }, + { + 0.2500000000000000, + 0.2206518106944236f, + 0.0000000000000000, + 0.0000000000000000, + 0.7071067811865476f, + 0.6235485373547694f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375378f, + 0.0000000000000000, + 0.4706702258572536f, + 0.0000000000000000, + -0.0643507165794628f, + -0.0403851516082220f, + 0.0000000000000000, + 0.1627234014286620f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.7367497537172237f, + 0.0875511500058708f, + -0.2921026642334881f, + 0.1940289303259434f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1957439937204294f, + -0.1621205195722993f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792424f, + -0.2904801297289980f, + 0.0952002265347504f, + 0.0000000000000000, + -0.3675398009862027f, + 0.4921585901373873f, + 0.2462710772207515f, + -0.0794670660590957f, + 0.3623817333531167f, + -0.4351904965232280f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921017f, + -0.0657870154914280f, + 0.0000000000000000, + -0.4082482904638628f, + -0.3078822139579090f, + -0.3852501370925192f, + -0.0857401903551931f, + -0.4613374887461511f, + 0.0000000000000000, + 0.2191868483885747f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.4067007583026072f, + -0.2125574805828705f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999464f, + 0.3046847507248840f, + 0.3017929516615503f, + -0.4082482904638635f, + -0.1747866975480813f, + 0.2110560104933581f, + -0.1426608480880734f, + -0.1381354035075829f, + -0.1743760259965108f, + 0.1135498731499426f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.1957439937204287f, + -0.1621205195722833f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792444f, + 0.2904801297290076f, + 0.0952002265347505f, + 0.0000000000000000, + 0.3675398009862011f, + -0.4921585901373891f, + 0.2462710772207514f, + -0.0794670660591026f, + 0.3623817333531165f, + -0.4351904965232251f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.0000000000000000, + -0.4706702258572528f, + 0.0000000000000000, + -0.0643507165794627f, + 0.1107416575309343f, + 0.0000000000000000, + -0.1627234014286617f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.1488339922711357f, + 0.4972464710953509f, + 0.2921026642334879f, + 0.5550443808910661f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1137907446044809f, + -0.1464291867126764f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0829816309488205f, + -0.2388977352334460f, + -0.3531238544981630f, + -0.4082482904638630f, + 0.4826689115059883f, + 0.1741941265991622f, + -0.0476868035022925f, + 0.1253805944856366f, + -0.4326608024727445f, + -0.2546827712406646f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.4444481661973438f, + 0.3085497062849487f, + 0.0000000000000000, + -0.0643507165794628f, + 0.1585450355183970f, + -0.5112616136592012f, + 0.2579236279634129f, + 0.0000000000000000, + -0.0812611176717504f, + -0.1856718091610990f, + -0.3416446842253373f, + 0.3302282550303805f, + 0.0702790691196282f, + -0.0741750459581023f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921022f, + 0.0657870154914254f, + 0.0000000000000000, + 0.4082482904638634f, + 0.3078822139579031f, + 0.3852501370925211f, + -0.0857401903551927f, + -0.4613374887461554f, + 0.0000000000000000, + 0.2191868483885728f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.1137907446044814f, + -0.1464291867126654f, + 0.0000000000000000, + -0.0643507165794627f, + 0.0829816309488214f, + 0.2388977352334547f, + -0.3531238544981624f, + 0.4082482904638630f, + -0.4826689115059858f, + -0.1741941265991621f, + -0.0476868035022928f, + 0.1253805944856431f, + -0.4326608024727457f, + -0.2546827712406641f, + }, + { + 0.2500000000000000, + -0.1014005039375374f, + 0.0000000000000000, + 0.4251149611657548f, + 0.0000000000000000, + -0.0643507165794626f, + -0.4517556589999480f, + 0.0000000000000000, + -0.6035859033230976f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.1426608480880724f, + -0.1381354035075845f, + 0.3487520519930227f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto scalar = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto px = Set(d, pixels[j]); + auto basis = Load(d, k4x4AFVBasisTranspose[j] + i); + scalar = MulAdd(px, basis, scalar); + } + Store(scalar, d, coeffs + i); + } +} + +// Coefficient layout: +// - (even, even) positions hold AFV coefficients +// - (odd, even) positions hold DCT4x4 coefficients +// - (any, odd) positions hold DCT4x8 coefficients +template +void AFVTransformFromPixels(const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients) { + HWY_ALIGN float scratch_space[4 * 8 * 2]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + HWY_ALIGN float block[4 * 8]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)] = + pixels[(iy + 4 * afv_y) * pixels_stride + ix + 4 * afv_x]; + } + } + // AFV coefficients in (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + AFVDCT4x4(block, coeff); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[iy * 2 * 8 + ix * 2] = coeff[iy * 4 + ix]; + } + } + // 4x4 DCT of the block with same y and different x. + ComputeTransposedScaledDCT<4>()( + DCTFrom(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + block, scratch_space); + // ... in (odd, even) positions. + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[iy * 2 * 8 + ix * 2 + 1] = block[iy * 4 + ix]; + } + } + // 4x8 DCT of the other half of the block. + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(1 + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + float block00 = coefficients[0] * 0.25f; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + coefficients[0] = (block00 + block01 + 2 * block10) * 0.25f; + coefficients[1] = (block00 - block01) * 0.5f; + coefficients[8] = (block00 + block01 - 2 * block10) * 0.25f; +} + +HWY_MAYBE_UNUSED void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + PROFILER_ZONE("DCT Identity"); + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block_dc += pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix]; + } + } + block_dc *= 1.0f / 16; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + coefficients[(y + iy * 2) * 8 + x + ix * 2] = + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] - + pixels[(y * 4 + 1) * pixels_stride + x * 4 + 1]; + } + } + coefficients[(y + 2) * 8 + x + 2] = coefficients[y * 8 + x]; + coefficients[y * 8 + x] = block_dc; + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT8X4: { + PROFILER_ZONE("DCT 8x4"); + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<8, 4>()(DCTFrom(pixels + x * 4, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + // Store transposed. + coefficients[(x + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X8: { + PROFILER_ZONE("DCT 4x8"); + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + y * 4 * pixels_stride, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(y + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X4: { + PROFILER_ZONE("DCT 4"); + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + ComputeTransposedScaledDCT<4>()( + DCTFrom(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[(y + iy * 2) * 8 + x + ix * 2] = block[iy * 4 + ix]; + } + } + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT2X2: { + PROFILER_ZONE("DCT 2"); + DCT2TopBlock<8>(pixels, pixels_stride, coefficients); + DCT2TopBlock<4>(coefficients, kBlockDim, coefficients); + DCT2TopBlock<2>(coefficients, kBlockDim, coefficients); + break; + } + case Type::DCT16X16: { + PROFILER_ZONE("DCT 16"); + ComputeTransposedScaledDCT<16>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::DCT16X8: { + PROFILER_ZONE("DCT 16x8"); + ComputeScaledDCT<16, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X16: { + PROFILER_ZONE("DCT 8x16"); + ComputeScaledDCT<8, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X8: { + PROFILER_ZONE("DCT 32x8"); + ComputeScaledDCT<32, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X32: { + PROFILER_ZONE("DCT 8x32"); + ComputeScaledDCT<8, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X16: { + PROFILER_ZONE("DCT 32x16"); + ComputeScaledDCT<32, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT16X32: { + PROFILER_ZONE("DCT 16x32"); + ComputeScaledDCT<16, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X32: { + PROFILER_ZONE("DCT 32"); + ComputeTransposedScaledDCT<32>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::DCT: { + PROFILER_ZONE("DCT 8"); + ComputeTransposedScaledDCT<8>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::AFV0: { + PROFILER_ZONE("AFV0"); + AFVTransformFromPixels<0>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV1: { + PROFILER_ZONE("AFV1"); + AFVTransformFromPixels<1>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV2: { + PROFILER_ZONE("AFV2"); + AFVTransformFromPixels<2>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV3: { + PROFILER_ZONE("AFV3"); + AFVTransformFromPixels<3>(pixels, pixels_stride, coefficients); + break; + } + case Type::DCT64X64: { + PROFILER_ZONE("DCT 64x64"); + ComputeTransposedScaledDCT<64>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::DCT64X32: { + PROFILER_ZONE("DCT 64x32"); + ComputeScaledDCT<64, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X64: { + PROFILER_ZONE("DCT 32x64"); + ComputeScaledDCT<32, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X128: { + PROFILER_ZONE("DCT 128x128"); + ComputeTransposedScaledDCT<128>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::DCT128X64: { + PROFILER_ZONE("DCT 128x64"); + ComputeScaledDCT<128, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT64X128: { + PROFILER_ZONE("DCT 64x128"); + ComputeScaledDCT<64, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT256X256: { + PROFILER_ZONE("DCT 256x256"); + ComputeTransposedScaledDCT<256>()(DCTFrom(pixels, pixels_stride), + coefficients, scratch_space); + break; + } + case Type::DCT256X128: { + PROFILER_ZONE("DCT 256x128"); + ComputeScaledDCT<256, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X256: { + PROFILER_ZONE("DCT 128x256"); + ComputeScaledDCT<128, 256>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void DCFromLowestFrequencies(const AcStrategy::Type strategy, + const float* block, float* dc, + size_t dc_stride) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X16: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X16: { + ReinterpretingIDCT( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X8: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X16: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X32: { + ReinterpretingIDCT( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X32: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X64: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X64: { + ReinterpretingIDCT( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X64: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/8, /*ROWS=*/16, /*COLS=*/8>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/16, /*ROWS=*/8, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + dc[0] = block[0]; + break; + case Type::kNumValidStrategies: + JXL_ABORT("Invalid strategy"); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_ENC_TRANSFORMS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms.cc b/third_party/jpeg-xl/lib/jxl/enc_transforms.cc new file mode 100644 index 000000000000..6b7c302a4e18 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms.cc @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_transforms.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_transforms.cc" +#include +#include + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformFromPixels); +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformFromPixels)( + strategy, pixels, pixels_stride, coefficients, scratch_space); +} + +HWY_EXPORT(DCFromLowestFrequencies); +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride) { + return HWY_DYNAMIC_DISPATCH(DCFromLowestFrequencies)(strategy, block, dc, + dc_stride); +} + +HWY_EXPORT(AFVDCT4x4); +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + return HWY_DYNAMIC_DISPATCH(AFVDCT4x4)(pixels, coeffs); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms.h b/third_party/jpeg-xl/lib/jxl/enc_transforms.h new file mode 100644 index 000000000000..1cf0bec828e4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_TRANSFORMS_H_ +#define LIB_JXL_ENC_TRANSFORMS_H_ + +// Facade for (non-inlined) integral transforms. + +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride); + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TRANSFORMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_xyb.cc b/third_party/jpeg-xl/lib/jxl/enc_xyb.cc new file mode 100644 index 000000000000..0c3e2fe378ab --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_xyb.cc @@ -0,0 +1,446 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/enc_xyb.h" + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_xyb.cc" +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/transfer_functions-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::ShiftRight; + +// Returns cbrt(x) + add with 6 ulp max error. +// Modified from vectormath_exp.h, Apache 2 license. +// https://www.agner.org/optimize/vectorclass.zip +template +V CubeRootAndAdd(const V x, const V add) { + const HWY_FULL(float) df; + const HWY_FULL(int32_t) di; + + const auto kExpBias = Set(di, 0x54800000); // cast(1.) + cast(1.) / 3 + const auto kExpMul = Set(di, 0x002AAAAA); // shifted 1/3 + const auto k1_3 = Set(df, 1.0f / 3); + const auto k4_3 = Set(df, 4.0f / 3); + + const auto xa = x; // assume inputs never negative + const auto xa_3 = k1_3 * xa; + + // Multiply exponent by -1/3 + const auto m1 = BitCast(di, xa); + // Special case for 0. 0 is represented with an exponent of 0, so the + // "kExpBias - 1/3 * exp" below gives the wrong result. The IfThenZeroElse() + // sets those values as 0, which prevents having NaNs in the computations + // below. + const auto m2 = + IfThenZeroElse(m1 == Zero(di), kExpBias - (ShiftRight<23>(m1)) * kExpMul); + auto r = BitCast(df, m2); + + // Newton-Raphson iterations + for (int i = 0; i < 3; i++) { + const auto r2 = r * r; + r = NegMulAdd(xa_3, r2 * r2, k4_3 * r); + } + // Final iteration + auto r2 = r * r; + r = MulAdd(k1_3, NegMulAdd(xa, r2 * r2, r), r); + r2 = r * r; + r = MulAdd(r2, x, add); + + return r; +} + +// Ensures infinity norm is bounded. +void TestCubeRoot() { + const HWY_FULL(float) d; + float max_err = 0.0f; + for (uint64_t x5 = 0; x5 < 2000000; x5++) { + const float x = x5 * 1E-5f; + const float expected = cbrtf(x); + HWY_ALIGN float approx[MaxLanes(d)]; + Store(CubeRootAndAdd(Set(d, x), Zero(d)), d, approx); + + // All lanes are same + for (size_t i = 1; i < Lanes(d); ++i) { + JXL_ASSERT(std::abs(approx[0] - approx[i]) <= 1.2E-7f); + } + + const float err = std::abs(approx[0] - expected); + max_err = std::max(max_err, err); + } + // printf("max err %e\n", max_err); + JXL_ASSERT(max_err < 8E-7f); +} + +// 4x3 matrix * 3x1 SIMD vectors +template +JXL_INLINE void OpsinAbsorbance(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + V* JXL_RESTRICT mixed0, V* JXL_RESTRICT mixed1, + V* JXL_RESTRICT mixed2) { + const float* bias = &kOpsinAbsorbanceBias[0]; + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const auto m0 = Load(d, premul_absorb + 0 * N); + const auto m1 = Load(d, premul_absorb + 1 * N); + const auto m2 = Load(d, premul_absorb + 2 * N); + const auto m3 = Load(d, premul_absorb + 3 * N); + const auto m4 = Load(d, premul_absorb + 4 * N); + const auto m5 = Load(d, premul_absorb + 5 * N); + const auto m6 = Load(d, premul_absorb + 6 * N); + const auto m7 = Load(d, premul_absorb + 7 * N); + const auto m8 = Load(d, premul_absorb + 8 * N); + *mixed0 = MulAdd(m0, r, MulAdd(m1, g, MulAdd(m2, b, Set(d, bias[0])))); + *mixed1 = MulAdd(m3, r, MulAdd(m4, g, MulAdd(m5, b, Set(d, bias[1])))); + *mixed2 = MulAdd(m6, r, MulAdd(m7, g, MulAdd(m8, b, Set(d, bias[2])))); +} + +template +void StoreXYB(const V r, V g, const V b, float* JXL_RESTRICT valx, + float* JXL_RESTRICT valy, float* JXL_RESTRICT valz) { + const HWY_FULL(float) d; + const V half = Set(d, 0.5f); + Store(half * (r - g), d, valx); + Store(half * (r + g), d, valy); + Store(b, d, valz); +} + +// Converts one RGB vector to XYB. +template +void LinearRGBToXYB(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + float* JXL_RESTRICT valx, float* JXL_RESTRICT valy, + float* JXL_RESTRICT valz) { + V mixed0, mixed1, mixed2; + OpsinAbsorbance(r, g, b, premul_absorb, &mixed0, &mixed1, &mixed2); + + // mixed* should be non-negative even for wide-gamut, so clamp to zero. + mixed0 = ZeroIfNegative(mixed0); + mixed1 = ZeroIfNegative(mixed1); + mixed2 = ZeroIfNegative(mixed2); + + const HWY_FULL(float) d; + const size_t N = Lanes(d); + mixed0 = CubeRootAndAdd(mixed0, Load(d, premul_absorb + 9 * N)); + mixed1 = CubeRootAndAdd(mixed1, Load(d, premul_absorb + 10 * N)); + mixed2 = CubeRootAndAdd(mixed2, Load(d, premul_absorb + 11 * N)); + StoreXYB(mixed0, mixed1, mixed2, valx, valy, valz); + + // For wide-gamut inputs, r/g/b and valx (but not y/z) are often negative. +} + +// Input/output uses the codec.h scaling: nominally 0-1 if in-gamut. +template +V LinearFromSRGB(V encoded) { + return TF_SRGB().DisplayFromEncoded(encoded); +} + +void LinearSRGBToXYB(const Image3F& linear, + const float* JXL_RESTRICT premul_absorb, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb) { + const size_t xsize = linear.xsize(); + + const HWY_FULL(float) d; + RunOnPool( + pool, 0, static_cast(linear.ysize()), ThreadPool::SkipInit(), + [&](const int task, const int /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_in0 = linear.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_in1 = linear.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_in2 = linear.ConstPlaneRow(2, y); + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = Load(d, row_in0 + x); + const auto in_g = Load(d, row_in1 + x); + const auto in_b = Load(d, row_in2 + x); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "LinearToXYB"); +} + +void SRGBToXYB(const Image3F& srgb, const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT xyb) { + const size_t xsize = srgb.xsize(); + + const HWY_FULL(float) d; + RunOnPool( + pool, 0, static_cast(srgb.ysize()), ThreadPool::SkipInit(), + [&](const int task, const int /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_srgb0 = srgb.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_srgb1 = srgb.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_srgb2 = srgb.ConstPlaneRow(2, y); + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row_srgb0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row_srgb1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row_srgb2 + x)); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "SRGBToXYB"); +} + +void SRGBToXYBAndLinear(const Image3F& srgb, + const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT xyb, + Image3F* JXL_RESTRICT linear) { + const size_t xsize = srgb.xsize(); + + const HWY_FULL(float) d; + RunOnPool( + pool, 0, static_cast(srgb.ysize()), ThreadPool::SkipInit(), + [&](const int task, const int /*thread*/) { + const size_t y = static_cast(task); + const float* JXL_RESTRICT row_srgb0 = srgb.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_srgb1 = srgb.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_srgb2 = srgb.ConstPlaneRow(2, y); + + float* JXL_RESTRICT row_linear0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear2 = linear->PlaneRow(2, y); + + float* JXL_RESTRICT row_xyb0 = xyb->PlaneRow(0, y); + float* JXL_RESTRICT row_xyb1 = xyb->PlaneRow(1, y); + float* JXL_RESTRICT row_xyb2 = xyb->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row_srgb0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row_srgb1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row_srgb2 + x)); + + Store(in_r, d, row_linear0 + x); + Store(in_g, d, row_linear1 + x); + Store(in_b, d, row_linear2 + x); + + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_xyb0 + x, + row_xyb1 + x, row_xyb2 + x); + } + }, + "SRGBToXYBAndLinear"); +} + +// This is different from Butteraugli's OpsinDynamicsImage() in the sense that +// it does not contain a sensitivity multiplier based on the blurred image. +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, + ImageBundle* const JXL_RESTRICT linear) { + PROFILER_FUNC; + + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + JXL_ASSERT(SameSize(in, *xyb)); + + const HWY_FULL(float) d; + // Pre-broadcasted constants + HWY_ALIGN float premul_absorb[MaxLanes(d) * 12]; + const size_t N = Lanes(d); + for (size_t i = 0; i < 9; ++i) { + const auto absorb = Set(d, kOpsinAbsorbanceMatrix[i] * + (in.metadata()->IntensityTarget() / 255.0f)); + Store(absorb, d, premul_absorb + i * N); + } + for (size_t i = 0; i < 3; ++i) { + const auto neg_bias_cbrt = Set(d, -cbrtf(kOpsinAbsorbanceBias[i])); + Store(neg_bias_cbrt, d, premul_absorb + (9 + i) * N); + } + + const bool want_linear = linear != nullptr; + + const ColorEncoding& c_linear_srgb = ColorEncoding::LinearSRGB(in.IsGray()); + // Linear sRGB inputs are rare but can be useful for the fastest encoders, for + // which undoing the sRGB transfer function would be a large part of the cost. + if (c_linear_srgb.SameColorEncoding(in.c_current())) { + LinearSRGBToXYB(in.color(), premul_absorb, pool, xyb); + // This only happens if kitten or slower, moving ImageBundle might be + // possible but the encoder is much slower than this copy. + if (want_linear) { + *linear = in.Copy(); + return linear; + } + return ∈ + } + + // Common case: already sRGB, can avoid the color transform + if (in.IsSRGB()) { + // Common case: can avoid allocating/copying + if (!want_linear) { + SRGBToXYB(in.color(), premul_absorb, pool, xyb); + return ∈ + } + + // Slow encoder also wants linear sRGB. + linear->SetFromImage(Image3F(xsize, ysize), c_linear_srgb); + SRGBToXYBAndLinear(in.color(), premul_absorb, pool, xyb, linear->color()); + return linear; + } + + // General case: not sRGB, need color transform. + ImageBundle linear_storage; // Local storage only used if !want_linear. + + ImageBundle* linear_storage_ptr; + if (want_linear) { + // Caller asked for linear, use that storage directly. + linear_storage_ptr = linear; + } else { + // Caller didn't ask for linear, create our own local storage + // OK to reuse metadata, it will not be changed. + linear_storage = ImageBundle(const_cast(in.metadata())); + linear_storage_ptr = &linear_storage; + } + + const ImageBundle* ptr; + JXL_CHECK( + TransformIfNeeded(in, c_linear_srgb, pool, linear_storage_ptr, &ptr)); + // If no transform was necessary, should have taken the above codepath. + JXL_ASSERT(ptr == linear_storage_ptr); + + LinearSRGBToXYB(*linear_storage_ptr->color(), premul_absorb, pool, xyb); + return want_linear ? linear : ∈ +} + +// Transform RGB to YCbCr. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +void RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + const HWY_FULL(float) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = r_plane.xsize(); + const size_t ysize = r_plane.ysize(); + if ((xsize == 0) || (ysize == 0)) return; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto k128 = Set(df, 128.0f / 255); + const auto kR = Set(df, 0.299f); // NTSC luma + const auto kG = Set(df, 0.587f); + const auto kB = Set(df, 0.114f); + const auto kAmpR = Set(df, 0.701f); + const auto kAmpB = Set(df, 0.886f); + const auto kDiffR = kAmpR + kR; + const auto kDiffB = kAmpB + kB; + const auto kNormR = Set(df, 1.0f) / (kAmpR + kG + kB); + const auto kNormB = Set(df, 1.0f) / (kR + kG + kAmpB); + + constexpr size_t kGroupArea = kGroupDim * kGroupDim; + const size_t lines_per_group = DivCeil(kGroupArea, xsize); + const size_t num_stripes = DivCeil(ysize, lines_per_group); + const auto transform = [&](int idx, int /* thread*/) { + const size_t y0 = idx * lines_per_group; + const size_t y1 = std::min(y0 + lines_per_group, ysize); + for (size_t y = y0; y < y1; ++y) { + const float* r_row = r_plane.ConstRow(y); + const float* g_row = g_plane.ConstRow(y); + const float* b_row = b_plane.ConstRow(y); + float* y_row = y_plane->Row(y); + float* cb_row = cb_plane->Row(y); + float* cr_row = cr_plane->Row(y); + for (size_t x = 0; x < xsize; x += S) { + const auto r = Load(df, r_row + x); + const auto g = Load(df, g_row + x); + const auto b = Load(df, b_row + x); + const auto r_base = r * kR; + const auto r_diff = r * kDiffR; + const auto g_base = g * kG; + const auto b_base = b * kB; + const auto b_diff = b * kDiffB; + const auto y_base = r_base + g_base + b_base; + const auto y_vec = y_base - k128; + const auto cb_vec = (b_diff - y_base) * kNormB; + const auto cr_vec = (r_diff - y_base) * kNormR; + Store(y_vec, df, y_row + x); + Store(cb_vec, df, cb_row + x); + Store(cr_vec, df, cr_row + x); + } + } + }; + RunOnPool(pool, 0, static_cast(num_stripes), ThreadPool::SkipInit(), + transform, "RgbToYcbCr"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ToXYB); +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, + ImageBundle* JXL_RESTRICT linear_storage) { + return HWY_DYNAMIC_DISPATCH(ToXYB)(in, pool, xyb, linear_storage); +} + +HWY_EXPORT(RgbToYcbcr); +void RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(RgbToYcbcr)(r_plane, g_plane, b_plane, y_plane, + cb_plane, cr_plane, pool); +} + +HWY_EXPORT(TestCubeRoot); +void TestCubeRoot() { return HWY_DYNAMIC_DISPATCH(TestCubeRoot)(); } + +// DEPRECATED +Image3F OpsinDynamicsImage(const Image3B& srgb8) { + ImageMetadata metadata; + metadata.SetUintSamples(8); + metadata.color_encoding = ColorEncoding::SRGB(); + ImageBundle ib(&metadata); + ib.SetFromImage(ConvertToFloat(srgb8), metadata.color_encoding); + JXL_CHECK(ib.TransformTo(ColorEncoding::LinearSRGB(ib.IsGray()))); + ThreadPool* null_pool = nullptr; + Image3F xyb(srgb8.xsize(), srgb8.ysize()); + + ImageBundle linear_storage(&metadata); + (void)ToXYB(ib, null_pool, &xyb, &linear_storage); + return xyb; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_xyb.h b/third_party/jpeg-xl/lib/jxl/enc_xyb.h new file mode 100644 index 000000000000..711119ea6c44 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_xyb.h @@ -0,0 +1,54 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENC_XYB_H_ +#define LIB_JXL_ENC_XYB_H_ + +// Converts to XYB color space. + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Converts any color space to XYB. If `linear` is not null, returns `linear` +// after filling it with a linear sRGB copy of `in`. Otherwise, returns `&in`. +// +// NOTE this return value can avoid an extra color conversion if `in` would +// later be passed to JxlButteraugliComparator. +const ImageBundle* ToXYB(const ImageBundle& in, ThreadPool* pool, + Image3F* JXL_RESTRICT xyb, + ImageBundle* JXL_RESTRICT linear = nullptr); + +// Bt.601 to match JPEG/JFIF. Outputs _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +void RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool); + +// DEPRECATED, used by opsin_image_wrapper. +Image3F OpsinDynamicsImage(const Image3B& srgb8); + +// For opsin_image_test. +void TestCubeRoot(); + +} // namespace jxl + +#endif // LIB_JXL_ENC_XYB_H_ diff --git a/third_party/jpeg-xl/lib/jxl/encode.cc b/third_party/jpeg-xl/lib/jxl/encode.cc new file mode 100644 index 000000000000..857ea32c012c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode.cc @@ -0,0 +1,497 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/encode.h" + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_ENC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_ENC_ERROR) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JXL_ENC_ERROR) +#endif // JXL_CRASH_ON_ERROR + +namespace jxl { + +Status ConvertExternalToInternalColorEncoding(const JxlColorEncoding& external, + ColorEncoding* internal) { + internal->SetColorSpace(static_cast(external.color_space)); + + CIExy wp; + wp.x = external.white_point_xy[0]; + wp.y = external.white_point_xy[1]; + JXL_RETURN_IF_ERROR(internal->SetWhitePoint(wp)); + + if (external.color_space == JXL_COLOR_SPACE_RGB || + external.color_space == JXL_COLOR_SPACE_UNKNOWN) { + internal->primaries = static_cast(external.primaries); + PrimariesCIExy primaries; + primaries.r.x = external.primaries_red_xy[0]; + primaries.r.y = external.primaries_red_xy[1]; + primaries.g.x = external.primaries_green_xy[0]; + primaries.g.y = external.primaries_green_xy[1]; + primaries.b.x = external.primaries_blue_xy[0]; + primaries.b.y = external.primaries_blue_xy[1]; + JXL_RETURN_IF_ERROR(internal->SetPrimaries(primaries)); + } + CustomTransferFunction tf; + if (external.transfer_function == JXL_TRANSFER_FUNCTION_GAMMA) { + JXL_RETURN_IF_ERROR(tf.SetGamma(external.gamma)); + } else { + tf.SetTransferFunction( + static_cast(external.transfer_function)); + } + internal->tf = tf; + + internal->rendering_intent = + static_cast(external.rendering_intent); + + return true; +} + +} // namespace jxl + +uint32_t JxlEncoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +JxlEncoderStatus JxlEncoderStruct::RefillOutputByteQueue() { + jxl::MemoryManagerUniquePtr input_frame = + std::move(input_frame_queue[0]); + input_frame_queue.erase(input_frame_queue.begin()); + + // TODO(zond): If the frame queue is empty and the input_closed is true, + // then mark this frame as the last. + + jxl::BitWriter writer; + + if (!wrote_bytes) { + if (use_container) { + output_byte_queue.insert( + output_byte_queue.end(), jxl::kContainerHeader, + jxl::kContainerHeader + sizeof(jxl::kContainerHeader)); + if (store_jpeg_metadata && jpeg_metadata.size() > 0) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_metadata.size(), + false, &output_byte_queue); + output_byte_queue.insert(output_byte_queue.end(), jpeg_metadata.begin(), + jpeg_metadata.end()); + } + } + if (!WriteHeaders(&metadata, &writer, nullptr)) { + return JXL_ENC_ERROR; + } + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata.m.color_encoding.WantICC()) { + if (!jxl::WriteICC(metadata.m.color_encoding.ICC(), &writer, + jxl::kLayerHeader, nullptr)) { + return JXL_ENC_ERROR; + } + } + + // TODO(lode): preview should be added here if a preview image is added + + // Each frame should start on byte boundaries. + writer.ZeroPadToByte(); + } + + // TODO(zond): Handle progressive mode like EncodeFile does it. + // TODO(zond): Handle animation like EncodeFile does it, by checking if + // JxlEncoderCloseInput has been called and if the frame queue is + // empty (to see if it's the last animation frame). + + if (metadata.m.xyb_encoded) { + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kXYB; + } else { + // TODO(zond): Figure out when to use kYCbCr instead. + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kNone; + } + + jxl::PassesEncoderState enc_state; + if (!jxl::EncodeFrame(input_frame->option_values.cparams, jxl::FrameInfo{}, + &metadata, input_frame->frame, &enc_state, + thread_pool.get(), &writer, + /*aux_out=*/nullptr)) { + return JXL_ENC_ERROR; + } + + jxl::PaddedBytes bytes = std::move(writer).TakeBytes(); + + if (use_container && !wrote_bytes) { + if (input_closed && input_frame_queue.empty()) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), bytes.size(), + /*unbounded=*/false, &output_byte_queue); + } else { + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), 0, /*unbounded=*/true, + &output_byte_queue); + } + } + + output_byte_queue.insert(output_byte_queue.end(), bytes.data(), + bytes.data() + bytes.size()); + wrote_bytes = true; + + last_used_cparams = input_frame->option_values.cparams; + + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetColorEncoding(JxlEncoder* enc, + const JxlColorEncoding* color) { + if (enc->color_encoding_set) { + // Already set + return JXL_ENC_ERROR; + } + if (!jxl::ConvertExternalToInternalColorEncoding( + *color, &enc->metadata.m.color_encoding)) { + return JXL_ENC_ERROR; + } + enc->color_encoding_set = true; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetICCProfile(JxlEncoder* enc, + const uint8_t* icc_profile, + size_t size) { + if (enc->color_encoding_set) { + // Already set + return JXL_ENC_ERROR; + } + jxl::PaddedBytes icc; + icc.assign(icc_profile, icc_profile + size); + if (!enc->metadata.m.color_encoding.SetICCRaw(std::move(icc))) { + return JXL_ENC_ERROR; + } + enc->color_encoding_set = true; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetBasicInfo(JxlEncoder* enc, + const JxlBasicInfo* info) { + if (!enc->metadata.size.Set(info->xsize, info->ysize)) { + return JXL_ENC_ERROR; + } + if (info->exponent_bits_per_sample) { + if (info->exponent_bits_per_sample != 8) return JXL_ENC_NOT_SUPPORTED; + if (info->bits_per_sample == 32) { + enc->metadata.m.SetFloat32Samples(); + } else { + return JXL_ENC_NOT_SUPPORTED; + } + } else { + switch (info->bits_per_sample) { + case 32: + case 16: + case 8: + enc->metadata.m.SetUintSamples(info->bits_per_sample); + break; + default: + return JXL_ENC_ERROR; + break; + } + } + if (info->alpha_bits > 0 && info->alpha_exponent_bits > 0) { + return JXL_ENC_NOT_SUPPORTED; + } + switch (info->alpha_bits) { + case 0: + break; + case 32: + case 16: + enc->metadata.m.SetAlphaBits(16); + break; + case 8: + enc->metadata.m.SetAlphaBits(info->alpha_bits); + break; + default: + return JXL_ENC_ERROR; + break; + } + enc->metadata.m.xyb_encoded = !info->uses_original_profile; + enc->basic_info_set = true; + return JXL_ENC_SUCCESS; +} + +JxlEncoderOptions* JxlEncoderOptionsCreate(JxlEncoder* enc, + const JxlEncoderOptions* source) { + auto opts = + jxl::MemoryManagerMakeUnique(&enc->memory_manager); + if (!opts) return nullptr; + opts->enc = enc; + if (source != nullptr) { + opts->values = source->values; + } else { + opts->values.lossless = false; + } + JxlEncoderOptions* ret = opts.get(); + enc->encoder_options.emplace_back(std::move(opts)); + return ret; +} + +JxlEncoderStatus JxlEncoderOptionsSetLossless(JxlEncoderOptions* options, + const JXL_BOOL lossless) { + options->values.lossless = lossless; + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderOptionsSetEffort(JxlEncoderOptions* options, + const int effort) { + if (effort < 3 || effort > 9) { + return JXL_ENC_ERROR; + } + options->values.cparams.speed_tier = static_cast(10 - effort); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderOptionsSetDistance(JxlEncoderOptions* options, + float distance) { + if (distance < 0 || distance > 15) { + return JXL_ENC_ERROR; + } + options->values.cparams.butteraugli_distance = distance; + return JXL_ENC_SUCCESS; +} + +JxlEncoder* JxlEncoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) { + return nullptr; + } + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlEncoder)); + if (!alloc) return nullptr; + JxlEncoder* enc = new (alloc) JxlEncoder(); + enc->memory_manager = local_memory_manager; + + return enc; +} + +void JxlEncoderReset(JxlEncoder* enc) { + enc->thread_pool.reset(); + enc->input_frame_queue.clear(); + enc->encoder_options.clear(); + enc->output_byte_queue.clear(); + enc->wrote_bytes = false; + enc->metadata = jxl::CodecMetadata(); + enc->last_used_cparams = jxl::CompressParams(); + enc->input_closed = false; + enc->basic_info_set = false; + enc->color_encoding_set = false; +} + +void JxlEncoderDestroy(JxlEncoder* enc) { + if (enc) { + // Call destructor directly since custom free function is used. + enc->~JxlEncoder(); + jxl::MemoryManagerFree(&enc->memory_manager, enc); + } +} + +JxlEncoderStatus JxlEncoderUseContainer(JxlEncoder* enc, + JXL_BOOL use_container) { + enc->use_container = static_cast(use_container); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderStoreJPEGMetadata(JxlEncoder* enc, + JXL_BOOL store_jpeg_metadata) { + enc->store_jpeg_metadata = static_cast(store_jpeg_metadata); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderSetParallelRunner(JxlEncoder* enc, + JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (enc->thread_pool) return JXL_API_ERROR("parallel runner already set"); + enc->thread_pool = jxl::MemoryManagerMakeUnique( + &enc->memory_manager, parallel_runner, parallel_runner_opaque); + if (!enc->thread_pool) { + return JXL_ENC_ERROR; + } + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderAddJPEGFrame(const JxlEncoderOptions* options, + const uint8_t* buffer, size_t size) { + if (!options->enc->basic_info_set || !options->enc->color_encoding_set) { + return JXL_ENC_ERROR; + } + + if (options->enc->input_closed) { + return JXL_ENC_ERROR; + } + + if (options->enc->metadata.m.xyb_encoded) { + // Can't XYB encode a lossless JPEG. + return JXL_ENC_ERROR; + } + + jxl::CodecInOut io; + if (!jxl::jpeg::DecodeImageJPG(jxl::Span(buffer, size), &io)) { + return JXL_ENC_ERROR; + } + + if (options->enc->store_jpeg_metadata) { + jxl::jpeg::JPEGData data_in = *io.Main().jpeg_data; + jxl::PaddedBytes jpeg_data; + if (!EncodeJPEGData(data_in, &jpeg_data)) { + return JXL_ENC_ERROR; + } + options->enc->jpeg_metadata = std::vector( + jpeg_data.data(), jpeg_data.data() + jpeg_data.size()); + } + + auto queued_frame = jxl::MemoryManagerMakeUnique( + &options->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{options->values, + jxl::ImageBundle(&options->enc->metadata.m)}); + if (!queued_frame) { + return JXL_ENC_ERROR; + } + queued_frame->frame.SetFromImage(std::move(*io.Main().color()), + io.Main().c_current()); + queued_frame->frame.jpeg_data = std::move(io.Main().jpeg_data); + queued_frame->frame.color_transform = io.Main().color_transform; + queued_frame->frame.chroma_subsampling = io.Main().chroma_subsampling; + + if (options->values.lossless) { + queued_frame->option_values.cparams.SetLossless(); + } + + options->enc->input_frame_queue.emplace_back(std::move(queued_frame)); + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderAddImageFrame(const JxlEncoderOptions* options, + const JxlPixelFormat* pixel_format, + const void* buffer, size_t size) { + if (!options->enc->basic_info_set || !options->enc->color_encoding_set) { + return JXL_ENC_ERROR; + } + + if (options->enc->input_closed) { + return JXL_ENC_ERROR; + } + + auto queued_frame = jxl::MemoryManagerMakeUnique( + &options->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{options->values, + jxl::ImageBundle(&options->enc->metadata.m)}); + if (!queued_frame) { + return JXL_ENC_ERROR; + } + + if (pixel_format->data_type == JXL_TYPE_FLOAT16) { + // float16 is currently only supported in the decoder + return JXL_ENC_ERROR; + } + + jxl::ColorEncoding c_current; + if (options->enc->metadata.m.xyb_encoded) { + if (pixel_format->data_type == JXL_TYPE_FLOAT) { + c_current = + jxl::ColorEncoding::LinearSRGB(pixel_format->num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(pixel_format->num_channels < 3); + } + } else { + c_current = options->enc->metadata.m.color_encoding; + } + + if (!jxl::BufferToImageBundle(*pixel_format, options->enc->metadata.xsize(), + options->enc->metadata.ysize(), buffer, size, + options->enc->thread_pool.get(), c_current, + &(queued_frame->frame))) { + return JXL_ENC_ERROR; + } + + if (options->values.lossless) { + queued_frame->option_values.cparams.SetLossless(); + } + + options->enc->input_frame_queue.emplace_back(std::move(queued_frame)); + return JXL_ENC_SUCCESS; +} + +void JxlEncoderCloseInput(JxlEncoder* enc) { enc->input_closed = true; } + +JxlEncoderStatus JxlEncoderProcessOutput(JxlEncoder* enc, uint8_t** next_out, + size_t* avail_out) { + while (*avail_out > 0 && + (!enc->output_byte_queue.empty() || !enc->input_frame_queue.empty())) { + if (!enc->output_byte_queue.empty()) { + size_t to_copy = std::min(*avail_out, enc->output_byte_queue.size()); + memcpy(static_cast(*next_out), enc->output_byte_queue.data(), + to_copy); + *next_out += to_copy; + *avail_out -= to_copy; + enc->output_byte_queue.erase(enc->output_byte_queue.begin(), + enc->output_byte_queue.begin() + to_copy); + } else if (!enc->input_frame_queue.empty()) { + if (enc->RefillOutputByteQueue() != JXL_ENC_SUCCESS) { + return JXL_ENC_ERROR; + } + } + } + + if (!enc->output_byte_queue.empty() || !enc->input_frame_queue.empty()) { + return JXL_ENC_NEED_MORE_OUTPUT; + } + return JXL_ENC_SUCCESS; +} + +JxlEncoderStatus JxlEncoderOptionsSetDecodingSpeed(JxlEncoderOptions* options, + int tier) { + if (tier < 0 || tier > 4) { + return JXL_ENC_ERROR; + } + options->values.cparams.decoding_speed_tier = tier; + return JXL_ENC_SUCCESS; +} + +void JxlColorEncodingSetToSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + ConvertInternalToExternalColorEncoding(jxl::ColorEncoding::SRGB(is_gray), + color_encoding); +} + +void JxlColorEncodingSetToLinearSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + ConvertInternalToExternalColorEncoding( + jxl::ColorEncoding::LinearSRGB(is_gray), color_encoding); +} diff --git a/third_party/jpeg-xl/lib/jxl/encode_internal.h b/third_party/jpeg-xl/lib/jxl/encode_internal.h new file mode 100644 index 000000000000..ec9543a0cf0f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode_internal.h @@ -0,0 +1,132 @@ +/* Copyright (c) the JPEG XL Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIB_JXL_ENCODE_INTERNAL_H_ +#define LIB_JXL_ENCODE_INTERNAL_H_ + +#include + +#include "jxl/encode.h" +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" +#include "jxl/types.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/memory_manager_internal.h" + +namespace jxl { + +typedef struct JxlEncoderOptionsValuesStruct { + // lossless is a separate setting from cparams because it is a combination + // setting that overrides multiple settings inside of cparams. + bool lossless; + jxl::CompressParams cparams; +} JxlEncoderOptionsValues; + +typedef struct JxlEncoderQueuedFrame { + JxlEncoderOptionsValues option_values; + jxl::ImageBundle frame; +} JxlEncoderQueuedFrame; + +Status ConvertExternalToInternalColorEncoding(const JxlColorEncoding& external, + jxl::ColorEncoding* internal); + +typedef std::array BoxType; + +// Utility function that makes a BoxType from a null terminated string literal. +constexpr BoxType MakeBoxType(const char (&type)[5]) { + return BoxType({static_cast(type[0]), static_cast(type[1]), + static_cast(type[2]), + static_cast(type[3])}); +} + +constexpr unsigned char kContainerHeader[] = { + 0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xd, 0xa, 0x87, + 0xa, 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + +namespace { +template +uint8_t* Extend(T* vec, size_t size) { + vec->resize(vec->size() + size, 0); + return vec->data() + vec->size() - size; +} +} // namespace + +// Appends a JXL container box header with given type, size, and unbounded +// properties to output. +template +void AppendBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded, + T* output) { + uint64_t box_size = 0; + bool large_size = false; + if (!unbounded) { + box_size = size + 8; + if (box_size >= 0x100000000ull) { + large_size = true; + } + } + + StoreBE32(large_size ? 1 : box_size, Extend(output, 4)); + + for (size_t i = 0; i < 4; i++) { + output->push_back(*(type.data() + i)); + } + + if (large_size) { + StoreBE64(box_size, Extend(output, 8)); + } +} + +} // namespace jxl + +struct JxlEncoderStruct { + JxlMemoryManager memory_manager; + jxl::MemoryManagerUniquePtr thread_pool{ + nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)}; + std::vector> encoder_options; + + std::vector> + input_frame_queue; + std::vector output_byte_queue; + + bool use_container = false; + bool store_jpeg_metadata = false; + jxl::CodecMetadata metadata; + std::vector jpeg_metadata; + + bool wrote_bytes = false; + jxl::CompressParams last_used_cparams; + + bool input_closed = false; + bool basic_info_set = false; + bool color_encoding_set = false; + + // Takes the first frame in the input_frame_queue, encodes it, and appends the + // bytes to the output_byte_queue. + JxlEncoderStatus RefillOutputByteQueue(); + + // Appends the bytes of a JXL box header with the provided type and size to + // the end of the output_byte_queue. If unbounded is true, the size won't be + // added to the header and the box will be assumed to continue until EOF. + void AppendBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded); +}; + +struct JxlEncoderOptionsStruct { + JxlEncoder* enc; + jxl::JxlEncoderOptionsValues values; +}; + +#endif // LIB_JXL_ENCODE_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/encode_test.cc b/third_party/jpeg-xl/lib/jxl/encode_test.cc new file mode 100644 index 000000000000..fc431521f4b5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode_test.cc @@ -0,0 +1,620 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/encode.h" + +#include "gtest/gtest.h" +#include "jxl/encode_cxx.h" +#include "lib/extras/codec.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +TEST(EncodeTest, AddFrameAfterCloseInputTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(options, &pixel_format, pixels.data(), + pixels.size())); +} + +TEST(EncodeTest, AddJPEGAfterCloseTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + const std::string jpeg_path = + "imagecompression.info/flower_foveon.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + SetFromBytes(jxl::Span(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + + JxlBasicInfo basic_info; + basic_info.exponent_bits_per_sample = 0; + basic_info.bits_per_sample = 8; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddJPEGFrame(options, orig.data(), orig.size())); +} + +TEST(EncodeTest, AddFrameBeforeColorEncodingTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(options, &pixel_format, pixels.data(), + pixels.size())); +} + +TEST(EncodeTest, AddFrameBeforeBasicInfoTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(options, &pixel_format, pixels.data(), + pixels.size())); +} + +TEST(EncodeTest, DefaultAllocTest) { + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + JxlEncoderDestroy(enc); +} + +TEST(EncodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast(opaque)->frees++; + free(address); + }; + + { + JxlEncoderPtr enc = JxlEncoderMake(&mm); + EXPECT_NE(nullptr, enc.get()); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + } + EXPECT_LE(1, counters.frees); +} + +TEST(EncodeTest, DefaultParallelRunnerTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetParallelRunner(enc.get(), nullptr, nullptr)); +} + +void VerifyFrameEncoding(size_t xsize, size_t ysize, JxlEncoder* enc, + const JxlEncoderOptions* options) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + if (options->values.lossless) { + basic_info.uses_original_profile = true; + } else { + basic_info.uses_original_profile = false; + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(options, &pixel_format, pixels.data(), + pixels.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::DecompressParams dparams; + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::DecodeFile( + dparams, jxl::Span(compressed.data(), compressed.size()), + &decoded_io, /*pool=*/nullptr)); + + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(input_io, decoded_io, ba, + /*distmap=*/nullptr, nullptr), + 3.0f); +} + +void VerifyFrameEncoding(JxlEncoder* enc, const JxlEncoderOptions* options) { + VerifyFrameEncoding(63, 129, enc, options); +} + +TEST(EncodeTest, FrameEncodingTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(enc.get(), JxlEncoderOptionsCreate(enc.get(), nullptr)); +} + +TEST(EncodeTest, EncoderResetTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(50, 200, enc.get(), + JxlEncoderOptionsCreate(enc.get(), nullptr)); + // Encoder should become reusable for a new image from scratch after using + // reset. + JxlEncoderReset(enc.get()); + VerifyFrameEncoding(157, 77, enc.get(), + JxlEncoderOptionsCreate(enc.get(), nullptr)); +} + +TEST(EncodeTest, OptionsTest) { + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderOptionsSetEffort(options, 5)); + VerifyFrameEncoding(enc.get(), options); + EXPECT_EQ(jxl::SpeedTier::kHare, enc->last_used_cparams.speed_tier); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + // Lower than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderOptionsSetEffort(options, 2)); + // Higher than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderOptionsSetEffort(options, 10)); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderOptionsSetLossless(options, JXL_TRUE)); + VerifyFrameEncoding(enc.get(), options); + EXPECT_EQ(true, enc->last_used_cparams.IsLossless()); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderOptionsSetDistance(options, 0.5)); + VerifyFrameEncoding(enc.get(), options); + EXPECT_EQ(0.5, enc->last_used_cparams.butteraugli_distance); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + // Disallowed negative distance + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderOptionsSetDistance(options, -1)); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderOptionsSetDecodingSpeed(options, 2)); + VerifyFrameEncoding(enc.get(), options); + EXPECT_EQ(2, enc->last_used_cparams.decoding_speed_tier); + } +} + +namespace { +// Returns a copy of buf from offset to offset+size, or a new zeroed vector if +// the result would have been out of bounds taking integer overflow into +// account. +const std::vector SliceSpan(const jxl::Span& buf, + size_t offset, size_t size) { + if (offset + size >= buf.size()) { + return std::vector(size, 0); + } + if (offset + size < offset) { + return std::vector(size, 0); + } + return std::vector(buf.data() + offset, buf.data() + offset + size); +} + +struct Box { + // The type of the box. + // If "uuid", use extended_type instead + char type[4] = {0, 0, 0, 0}; + + // The extended_type is only used when type == "uuid". + // Extended types are not used in JXL. However, the box format itself + // supports this so they are handled correctly. + char extended_type[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Box data. + jxl::Span data = jxl::Span(nullptr, 0); + + // If the size is not given, the datasize extends to the end of the file. + // If this field is false, the size field is not encoded when the box is + // serialized. + bool data_size_given = true; + + // If successfull, returns true and sets `in` to be the rest data (if any). + // If `in` contains a box with a size larger than `in.size()`, will not + // modify `in`, and will return true but the data `Span` will + // remain set to nullptr. + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span* in) { + // Total box_size including this header itself. + uint64_t box_size = LoadBE32(SliceSpan(*in, 0, 4).data()); + size_t pos = 4; + + memcpy(type, SliceSpan(*in, pos, 4).data(), 4); + pos += 4; + + if (box_size == 1) { + // If the size is 1, it indicates extended size read from 64-bit integer. + box_size = LoadBE64(SliceSpan(*in, pos, 8).data()); + pos += 8; + } + + if (!memcmp("uuid", type, 4)) { + memcpy(extended_type, SliceSpan(*in, pos, 16).data(), 16); + pos += 16; + } + + // This is the end of the box header, the box data begins here. Handle + // the data size now. + const size_t header_size = pos; + + if (box_size != 0) { + if (box_size < header_size) { + return JXL_FAILURE("Invalid box size"); + } + if (box_size > in->size()) { + // The box is fine, but the input is too short. + return true; + } + data_size_given = true; + data = jxl::Span(in->data() + header_size, + box_size - header_size); + } else { + data_size_given = false; + data = jxl::Span(in->data() + header_size, + in->size() - header_size); + } + + *in = jxl::Span(in->data() + header_size + data.size(), + in->size() - header_size - data.size()); + return true; + } +}; + +struct Container { + std::vector boxes; + + // If successful, returns true and sets `in` to be the rest data (if any). + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span* in) { + boxes.clear(); + + Box signature_box; + JXL_RETURN_IF_ERROR(signature_box.Decode(in)); + if (memcmp("JXL ", signature_box.type, 4) != 0) { + return JXL_FAILURE("Invalid magic signature"); + } + if (signature_box.data.size() != 4) + return JXL_FAILURE("Invalid magic signature"); + if (signature_box.data[0] != 0xd || signature_box.data[1] != 0xa || + signature_box.data[2] != 0x87 || signature_box.data[3] != 0xa) { + return JXL_FAILURE("Invalid magic signature"); + } + + Box ftyp_box; + JXL_RETURN_IF_ERROR(ftyp_box.Decode(in)); + if (memcmp("ftyp", ftyp_box.type, 4) != 0) { + return JXL_FAILURE("Invalid ftyp"); + } + if (ftyp_box.data.size() != 12) return JXL_FAILURE("Invalid ftyp"); + const char* expected = "jxl \0\0\0\0jxl "; + if (memcmp(expected, ftyp_box.data.data(), 12) != 0) + return JXL_FAILURE("Invalid ftyp"); + + while (in->size() > 0) { + Box box = {}; + JXL_RETURN_IF_ERROR(box.Decode(in)); + if (box.data.data() == nullptr) { + // The decoding encountered a box, but not enough data yet. + return true; + } + boxes.emplace_back(box); + } + + return true; + } +}; + +} // namespace + +TEST(EncodeTest, SingleFrameBoundedJXLCTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), + true)); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + + size_t xsize = 71; + size_t ysize = 23; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(options, &pixel_format, pixels.data(), + pixels.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span encoded_span = + jxl::Span(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0, encoded_span.size()); + EXPECT_EQ(0, memcmp("jxlc", container.boxes[0].type, 4)); + EXPECT_EQ(true, container.boxes[0].data_size_given); +} + +TEST(EncodeTest, JPEGReconstructionTest) { + const std::string jpeg_path = + "imagecompression.info/flower_foveon.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + SetFromBytes(jxl::Span(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + + JxlBasicInfo basic_info; + basic_info.exponent_bits_per_sample = 0; + basic_info.bits_per_sample = 8; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(options, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span encoded_span = + jxl::Span(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0, encoded_span.size()); + EXPECT_EQ(0, memcmp("jbrd", container.boxes[0].type, 4)); + EXPECT_EQ(0, memcmp("jxlc", container.boxes[1].type, 4)); + + jxl::CodecInOut decoded_io; + decoded_io.Main().jpeg_data = jxl::make_unique(); + EXPECT_TRUE(jxl::jpeg::DecodeJPEGData(container.boxes[0].data, + decoded_io.Main().jpeg_data.get())); + + jxl::DecompressParams dparams; + dparams.keep_dct = true; + EXPECT_TRUE( + jxl::DecodeFile(dparams, container.boxes[1].data, &decoded_io, nullptr)); + + std::vector decoded_jpeg_bytes; + auto write = [&decoded_jpeg_bytes](const uint8_t* buf, size_t len) { + decoded_jpeg_bytes.insert(decoded_jpeg_bytes.end(), buf, buf + len); + return len; + }; + EXPECT_TRUE(jxl::jpeg::WriteJpeg(*decoded_io.Main().jpeg_data, write)); + + EXPECT_EQ(decoded_jpeg_bytes.size(), orig.size()); + EXPECT_EQ(0, memcmp(decoded_jpeg_bytes.data(), orig.data(), orig.size())); +} + +TEST(EncodeTest, JPEGFrameTest) { + const std::string jpeg_path = + "imagecompression.info/flower_foveon.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + SetFromBytes(jxl::Span(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + + JxlBasicInfo basic_info; + basic_info.exponent_bits_per_sample = 0; + basic_info.bits_per_sample = 8; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(options, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed = std::vector(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::DecompressParams dparams; + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::DecodeFile( + dparams, jxl::Span(compressed.data(), compressed.size()), + &decoded_io, /*pool=*/nullptr)); + + jxl::ButteraugliParams ba; + EXPECT_LE(ButteraugliDistance(orig_io, decoded_io, ba, + /*distmap=*/nullptr, nullptr), + 2.5f); +} diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder.cc b/third_party/jpeg-xl/lib/jxl/entropy_coder.cc new file mode 100644 index 000000000000..89eb526b5c2e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder.cc @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/entropy_coder.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map) { + auto& dct = block_ctx_map->dc_thresholds; + auto& qft = block_ctx_map->qf_thresholds; + auto& ctx_map = block_ctx_map->ctx_map; + bool is_default = br->ReadFixedBits<1>(); + if (is_default) { + *block_ctx_map = BlockCtxMap(); + return true; + } + block_ctx_map->num_dc_ctxs = 1; + for (int j : {0, 1, 2}) { + dct[j].resize(br->ReadFixedBits<4>()); + block_ctx_map->num_dc_ctxs *= dct[j].size() + 1; + for (int& i : dct[j]) { + i = UnpackSigned(U32Coder::Read(kDCThresholdDist, br)); + } + } + qft.resize(br->ReadFixedBits<4>()); + for (uint32_t& i : qft) { + i = U32Coder::Read(kQFThresholdDist, br) + 1; + } + + if (block_ctx_map->num_dc_ctxs * (qft.size() + 1) > 64) { + return JXL_FAILURE("Invalid block context map: too big"); + } + + ctx_map.resize(3 * kNumOrders * block_ctx_map->num_dc_ctxs * + (qft.size() + 1)); + JXL_RETURN_IF_ERROR(DecodeContextMap(&ctx_map, &block_ctx_map->num_ctxs, br)); + if (block_ctx_map->num_ctxs > 16) { + return JXL_FAILURE("Invalid block context map: too many distinct contexts"); + } + return true; +} + +constexpr uint8_t BlockCtxMap::kDefaultCtxMap[]; // from ac_context.h + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder.h b/third_party/jpeg-xl/lib/jxl/entropy_coder.h new file mode 100644 index 000000000000..d1893464f3ae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder.h @@ -0,0 +1,54 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ENTROPY_CODER_H_ +#define LIB_JXL_ENTROPY_CODER_H_ + +#include +#include + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +static JXL_INLINE int32_t PredictFromTopAndLeft( + const int32_t* const JXL_RESTRICT row_top, + const int32_t* const JXL_RESTRICT row, size_t x, int32_t default_val) { + if (x == 0) { + return row_top == nullptr ? default_val : row_top[x]; + } + if (row_top == nullptr) { + return row[x - 1]; + } + return (row_top[x] + row[x - 1] + 1) / 2; +} + +static constexpr U32Enc kDCThresholdDist(Bits(4), BitsOffset(8, 16), + BitsOffset(16, 272), + BitsOffset(32, 65808)); + +static constexpr U32Enc kQFThresholdDist(Bits(2), BitsOffset(3, 4), + BitsOffset(5, 12), BitsOffset(8, 44)); + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENTROPY_CODER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc b/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc new file mode 100644 index 000000000000..f0b688f77163 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO(deymo): Move these tests to dec_ans.h and common.h + +#include + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" + +namespace jxl { +namespace { + +TEST(EntropyCoderTest, PackUnpack) { + for (int32_t i = -31; i < 32; ++i) { + uint32_t packed = PackSigned(i); + EXPECT_LT(packed, 63); + int32_t unpacked = UnpackSigned(packed); + EXPECT_EQ(i, unpacked); + } +} + +struct DummyBitReader { + uint32_t nbits, bits; + void Consume(uint32_t nbits) {} + uint32_t PeekBits(uint32_t n) { + EXPECT_EQ(n, nbits); + return bits; + } +}; + +void HybridUintRoundtrip(HybridUintConfig config, size_t limit = 1 << 24) { + std::mt19937 rng(0); + std::uniform_int_distribution dist(0, limit); + constexpr size_t kNumIntegers = 1 << 20; + std::vector integers(kNumIntegers); + std::vector token(kNumIntegers); + std::vector nbits(kNumIntegers); + std::vector bits(kNumIntegers); + for (size_t i = 0; i < kNumIntegers; i++) { + integers[i] = dist(rng); + config.Encode(integers[i], &token[i], &nbits[i], &bits[i]); + } + for (size_t i = 0; i < kNumIntegers; i++) { + DummyBitReader br{nbits[i], bits[i]}; + EXPECT_EQ(integers[i], + ANSSymbolReader::ReadHybridUintConfig(config, token[i], &br)); + } +} + +TEST(HybridUintTest, Test000) { + HybridUintRoundtrip(HybridUintConfig{0, 0, 0}); +} +TEST(HybridUintTest, Test411) { + HybridUintRoundtrip(HybridUintConfig{4, 1, 1}); +} +TEST(HybridUintTest, Test420) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 0}); +} +TEST(HybridUintTest, Test421) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 1}, 256); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/epf.cc b/third_party/jpeg-xl/lib/jxl/epf.cc new file mode 100644 index 000000000000..dae478b37afe --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/epf.cc @@ -0,0 +1,685 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Edge-preserving smoothing: weighted average based on L1 patch similarity. + +#include "lib/jxl/epf.h" + +#include +#include +#include +#include +#include + +#include +#include +#include // std::accumulate +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/epf.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/filters.h" +#include "lib/jxl/filters_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +// The EPF logic treats 8x8 blocks as one unit, each with their own sigma. +// It should be possible to do two blocks at a time in AVX3 vectors, at some +// increase in complexity (broadcasting sigma0/1 to lanes 0..7 and 8..15). +using DF = HWY_CAPPED(float, GroupBorderAssigner::kPaddingXRound); +using DU = HWY_CAPPED(uint32_t, GroupBorderAssigner::kPaddingXRound); + +// kInvSigmaNum / 0.3 +constexpr float kMinSigma = -3.90524291751269967465540850526868f; + +DF df; + +JXL_INLINE Vec Weight(Vec sad, Vec inv_sigma, Vec thres) { + auto v = MulAdd(sad, inv_sigma, Set(DF(), 1.0f)); + auto v2 = v * v; + return IfThenZeroElse(v <= thres, v2); +} + +template +JXL_INLINE void AddPixelStep1(int row, const FilterRows& rows, size_t x, + Vec sad, Vec inv_sigma, + const LoopFilter& lf, Vec* JXL_RESTRICT X, + Vec* JXL_RESTRICT Y, Vec* JXL_RESTRICT B, + Vec* JXL_RESTRICT w) { + auto cx = aligned ? Load(DF(), rows.GetInputRow(row, 0) + x) + : LoadU(DF(), rows.GetInputRow(row, 0) + x); + auto cy = aligned ? Load(DF(), rows.GetInputRow(row, 1) + x) + : LoadU(DF(), rows.GetInputRow(row, 1) + x); + auto cb = aligned ? Load(DF(), rows.GetInputRow(row, 2) + x) + : LoadU(DF(), rows.GetInputRow(row, 2) + x); + + auto weight = Weight(sad, inv_sigma, Set(df, lf.epf_pass1_zeroflush)); + *w += weight; + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); +} + +template +JXL_INLINE void AddPixelStep2(int row, const FilterRows& rows, size_t x, + Vec rx, Vec ry, Vec rb, + Vec inv_sigma, const LoopFilter& lf, + Vec* JXL_RESTRICT X, Vec* JXL_RESTRICT Y, + Vec* JXL_RESTRICT B, + Vec* JXL_RESTRICT w) { + auto cx = aligned ? Load(DF(), rows.GetInputRow(row, 0) + x) + : LoadU(DF(), rows.GetInputRow(row, 0) + x); + auto cy = aligned ? Load(DF(), rows.GetInputRow(row, 1) + x) + : LoadU(DF(), rows.GetInputRow(row, 1) + x); + auto cb = aligned ? Load(DF(), rows.GetInputRow(row, 2) + x) + : LoadU(DF(), rows.GetInputRow(row, 2) + x); + + auto sad = AbsDiff(cx, rx) * Set(df, lf.epf_channel_scale[0]); + sad = MulAdd(AbsDiff(cy, ry), Set(df, lf.epf_channel_scale[1]), sad); + sad = MulAdd(AbsDiff(cb, rb), Set(df, lf.epf_channel_scale[2]), sad); + + auto weight = Weight(sad, inv_sigma, Set(df, lf.epf_pass2_zeroflush)); + + *w += weight; + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); +} + +template +void GaborishVector(const D df, const float* JXL_RESTRICT row_t, + const float* JXL_RESTRICT row_m, + const float* JXL_RESTRICT row_b, const V w0, const V w1, + const V w2, float* JXL_RESTRICT row_out) { +// Filter x0 is only aligned to blocks (8 floats = 32 bytes). For larger +// vectors, treat loads as unaligned (we manually align the Store). +#undef LoadMaybeU +#if HWY_CAP_GE512 +#define LoadMaybeU LoadU +#else +#define LoadMaybeU Load +#endif + + const auto t = LoadMaybeU(df, row_t); + const auto tl = LoadU(df, row_t - 1); + const auto tr = LoadU(df, row_t + 1); + const auto m = LoadMaybeU(df, row_m); + const auto l = LoadU(df, row_m - 1); + const auto r = LoadU(df, row_m + 1); + const auto b = LoadMaybeU(df, row_b); + const auto bl = LoadU(df, row_b - 1); + const auto br = LoadU(df, row_b + 1); + const auto sum0 = m; + const auto sum1 = (l + r) + (t + b); + const auto sum2 = (tl + tr) + (bl + br); + auto pixels = MulAdd(sum2, w2, MulAdd(sum1, w1, sum0 * w0)); + Store(pixels, df, row_out); +} + +void GaborishRow(const FilterRows& rows, const LoopFilter& /* lf */, + const FilterWeights& filter_weights, size_t x0, size_t x1, + size_t /*image_x_mod_8*/, size_t /* image_y_mod_8 */) { + JXL_DASSERT(x0 % Lanes(df) == 0); + + const float* JXL_RESTRICT gab_weights = filter_weights.gab_weights; + for (size_t c = 0; c < 3; c++) { + const float* JXL_RESTRICT row_t = rows.GetInputRow(-1, c); + const float* JXL_RESTRICT row_m = rows.GetInputRow(0, c); + const float* JXL_RESTRICT row_b = rows.GetInputRow(1, c); + float* JXL_RESTRICT row_out = rows.GetOutputRow(c); + + size_t ix = x0; + +#if HWY_CAP_GE512 + const HWY_FULL(float) dfull; // Gaborish is not block-dependent. + + // For AVX3, x0 might only be aligned to 8, not 16; if so, do a capped + // vector first to ensure full (Store-only!) alignment, then full vectors. + const uintptr_t addr = reinterpret_cast(row_out + ix); + if ((addr % 64) != 0 && ix < x1) { + const auto w0 = Set(df, gab_weights[3 * c + 0]); + const auto w1 = Set(df, gab_weights[3 * c + 1]); + const auto w2 = Set(df, gab_weights[3 * c + 2]); + GaborishVector(df, row_t + ix, row_m + ix, row_b + ix, w0, w1, w2, + row_out + ix); + ix += Lanes(df); + } + + const auto wfull0 = Set(dfull, gab_weights[3 * c + 0]); + const auto wfull1 = Set(dfull, gab_weights[3 * c + 1]); + const auto wfull2 = Set(dfull, gab_weights[3 * c + 2]); + for (; ix + Lanes(dfull) <= x1; ix += Lanes(dfull)) { + GaborishVector(dfull, row_t + ix, row_m + ix, row_b + ix, wfull0, wfull1, + wfull2, row_out + ix); + } +#endif + + // Non-AVX3 loop, or last capped vector for AVX3, if necessary + const auto w0 = Set(df, gab_weights[3 * c + 0]); + const auto w1 = Set(df, gab_weights[3 * c + 1]); + const auto w2 = Set(df, gab_weights[3 * c + 2]); + for (; ix < x1; ix += Lanes(df)) { + GaborishVector(df, row_t + ix, row_m + ix, row_b + ix, w0, w1, w2, + row_out + ix); + } + } +} + +// Step 0: 5x5 plus-shaped kernel with 5 SADs per pixel (3x3 +// plus-shaped). So this makes this filter a 7x7 filter. +void Epf0Row(const FilterRows& rows, const LoopFilter& lf, + const FilterWeights& filter_weights, size_t x0, size_t x1, + size_t image_x_mod_8, size_t image_y_mod_8) { + JXL_DASSERT(x0 % Lanes(df) == 0); + const float* JXL_RESTRICT row_sigma = rows.GetSigmaRow(); + + float sm = lf.epf_pass0_sigma_scale; + float bsm = sm * lf.epf_border_sad_mul; + + HWY_ALIGN float sad_mul[kBlockDim] = {bsm, sm, sm, sm, sm, sm, sm, bsm}; + + if (image_y_mod_8 == 0 || image_y_mod_8 == kBlockDim - 1) { + for (size_t i = 0; i < kBlockDim; i += Lanes(df)) { + Store(Set(df, bsm), df, sad_mul + i); + } + } + + for (size_t x = x0; x < x1; x += Lanes(df)) { + size_t bx = (x + image_x_mod_8) / kBlockDim; + size_t ix = (x + image_x_mod_8) % kBlockDim; + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows.GetInputRow(0, c) + x); + Store(px, df, rows.GetOutputRow(c) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Set(DF(), row_sigma[bx]) * sm; + + decltype(Zero(df)) sads[12]; + for (size_t i = 0; i < 12; i++) sads[i] = Zero(df); + constexpr std::array sads_off[12] = { + {-2, 0}, {-1, -1}, {-1, 0}, {-1, 1}, {0, -2}, {0, -1}, + {0, 1}, {0, 2}, {1, -1}, {1, 0}, {1, 1}, {2, 0}, + }; + + // compute sads + // TODO(veluca): consider unrolling and optimizing this. + for (size_t c = 0; c < 3; c++) { + auto scale = Set(df, lf.epf_channel_scale[c]); + for (size_t i = 0; i < 12; i++) { + auto sad = Zero(df); + constexpr std::array plus_off[] = { + {0, 0}, {-1, 0}, {0, -1}, {1, 0}, {0, 1}}; + for (size_t j = 0; j < 5; j++) { + const auto r11 = LoadU( + df, rows.GetInputRow(plus_off[j][0], c) + x + plus_off[j][1]); + const auto c11 = + LoadU(df, rows.GetInputRow(sads_off[i][0] + plus_off[j][0], c) + + x + sads_off[i][1] + plus_off[j][1]); + sad += AbsDiff(r11, c11); + } + sads[i] = MulAdd(sad, scale, sads[i]); + } + } + const auto x_cc = LoadU(df, rows.GetInputRow(0, 0) + x); + const auto y_cc = LoadU(df, rows.GetInputRow(0, 1) + x); + const auto b_cc = LoadU(df, rows.GetInputRow(0, 2) + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + for (size_t i = 0; i < 12; i++) { + AddPixelStep1(/*row=*/sads_off[i][0], rows, + x + sads_off[i][1], sads[i], inv_sigma, + lf, &X, &Y, &B, &w); + } + +#if JXL_HIGH_PRECISION + auto inv_w = Set(df, 1.0f) / w; +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(X * inv_w, df, rows.GetOutputRow(0) + x); + Store(Y * inv_w, df, rows.GetOutputRow(1) + x); + Store(B * inv_w, df, rows.GetOutputRow(2) + x); + } +} + +// Step 1: 3x3 plus-shaped kernel with 5 SADs per pixel (also 3x3 +// plus-shaped). So this makes this filter a 5x5 filter. +void Epf1Row(const FilterRows& rows, const LoopFilter& lf, + const FilterWeights& filter_weights, size_t x0, size_t x1, + size_t image_x_mod_8, size_t image_y_mod_8) { + JXL_DASSERT(x0 % Lanes(df) == 0); + const float* JXL_RESTRICT row_sigma = rows.GetSigmaRow(); + + float sm = 1.0f; + float bsm = sm * lf.epf_border_sad_mul; + + HWY_ALIGN float sad_mul[kBlockDim] = {bsm, sm, sm, sm, sm, sm, sm, bsm}; + + if (image_y_mod_8 == 0 || image_y_mod_8 == kBlockDim - 1) { + for (size_t i = 0; i < kBlockDim; i += Lanes(df)) { + Store(Set(df, bsm), df, sad_mul + i); + } + } + + for (size_t x = x0; x < x1; x += Lanes(df)) { + size_t bx = (x + image_x_mod_8) / kBlockDim; + size_t ix = (x + image_x_mod_8) % kBlockDim; + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows.GetInputRow(0, c) + x); + Store(px, df, rows.GetOutputRow(c) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Set(DF(), row_sigma[bx]) * sm; + auto sad0 = Zero(df); + auto sad1 = Zero(df); + auto sad2 = Zero(df); + auto sad3 = Zero(df); + + // compute sads + for (size_t c = 0; c < 3; c++) { + // center px = 22, px above = 21 + auto t = Undefined(df); + + const auto p20 = Load(df, rows.GetInputRow(-2, c) + x); + const auto p21 = Load(df, rows.GetInputRow(-1, c) + x); + auto sad0c = AbsDiff(p20, p21); // SAD 2, 1 + + const auto p11 = LoadU(df, rows.GetInputRow(-1, c) + x - 1); + auto sad1c = AbsDiff(p11, p21); // SAD 1, 2 + + const auto p31 = LoadU(df, rows.GetInputRow(-1, c) + x + 1); + auto sad2c = AbsDiff(p31, p21); // SAD 3, 2 + + const auto p02 = LoadU(df, rows.GetInputRow(0, c) + x - 2); + const auto p12 = LoadU(df, rows.GetInputRow(0, c) + x - 1); + sad1c += AbsDiff(p02, p12); // SAD 1, 2 + sad0c += AbsDiff(p11, p12); // SAD 2, 1 + + const auto p22 = LoadU(df, rows.GetInputRow(0, c) + x); + t = AbsDiff(p12, p22); + sad1c += t; // SAD 1, 2 + sad2c += t; // SAD 3, 2 + t = AbsDiff(p22, p21); + auto sad3c = t; // SAD 2, 3 + sad0c += t; // SAD 2, 1 + + const auto p32 = LoadU(df, rows.GetInputRow(0, c) + x + 1); + sad0c += AbsDiff(p31, p32); // SAD 2, 1 + t = AbsDiff(p22, p32); + sad1c += t; // SAD 1, 2 + sad2c += t; // SAD 3, 2 + + const auto p42 = LoadU(df, rows.GetInputRow(0, c) + x + 2); + sad2c += AbsDiff(p42, p32); // SAD 3, 2 + + const auto p13 = LoadU(df, rows.GetInputRow(1, c) + x - 1); + sad3c += AbsDiff(p13, p12); // SAD 2, 3 + + const auto p23 = Load(df, rows.GetInputRow(1, c) + x); + t = AbsDiff(p22, p23); + sad0c += t; // SAD 2, 1 + sad3c += t; // SAD 2, 3 + sad1c += AbsDiff(p13, p23); // SAD 1, 2 + + const auto p33 = LoadU(df, rows.GetInputRow(1, c) + x + 1); + sad2c += AbsDiff(p33, p23); // SAD 3, 2 + sad3c += AbsDiff(p33, p32); // SAD 2, 3 + + const auto p24 = Load(df, rows.GetInputRow(2, c) + x); + sad3c += AbsDiff(p24, p23); // SAD 2, 3 + + auto scale = Set(df, lf.epf_channel_scale[c]); + sad0 = MulAdd(sad0c, scale, sad0); + sad1 = MulAdd(sad1c, scale, sad1); + sad2 = MulAdd(sad2c, scale, sad2); + sad3 = MulAdd(sad3c, scale, sad3); + } + const auto x_cc = Load(df, rows.GetInputRow(0, 0) + x); + const auto y_cc = Load(df, rows.GetInputRow(0, 1) + x); + const auto b_cc = Load(df, rows.GetInputRow(0, 2) + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixelStep1(/*row=*/-1, rows, x, sad0, inv_sigma, lf, + &X, &Y, &B, &w); + // Center + AddPixelStep1(/*row=*/0, rows, x - 1, sad1, inv_sigma, + lf, &X, &Y, &B, &w); + AddPixelStep1(/*row=*/0, rows, x + 1, sad2, inv_sigma, + lf, &X, &Y, &B, &w); + // Bottom + AddPixelStep1(/*row=*/1, rows, x, sad3, inv_sigma, lf, &X, + &Y, &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Set(df, 1.0f) / w; +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(X * inv_w, df, rows.GetOutputRow(0) + x); + Store(Y * inv_w, df, rows.GetOutputRow(1) + x); + Store(B * inv_w, df, rows.GetOutputRow(2) + x); + } +} + +// Step 2: 3x3 plus-shaped kernel with a single reference pixel, ran on +// the output of the previous step. +void Epf2Row(const FilterRows& rows, const LoopFilter& lf, + const FilterWeights& filter_weights, size_t x0, size_t x1, + size_t image_x_mod_8, size_t image_y_mod_8) { + JXL_DASSERT(x0 % Lanes(df) == 0); + const float* JXL_RESTRICT row_sigma = rows.GetSigmaRow(); + + float sm = lf.epf_pass2_sigma_scale; + float bsm = sm * lf.epf_border_sad_mul; + + HWY_ALIGN float sad_mul[kBlockDim] = {bsm, sm, sm, sm, sm, sm, sm, bsm}; + + if (image_y_mod_8 == 0 || image_y_mod_8 == kBlockDim - 1) { + for (size_t i = 0; i < kBlockDim; i += Lanes(df)) { + Store(Set(df, bsm), df, sad_mul + i); + } + } + + for (size_t x = x0; x < x1; x += Lanes(df)) { + size_t bx = (x + image_x_mod_8) / kBlockDim; + size_t ix = (x + image_x_mod_8) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows.GetInputRow(0, c) + x); + Store(px, df, rows.GetOutputRow(c) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Set(DF(), row_sigma[bx]) * sm; + + const auto x_cc = Load(df, rows.GetInputRow(0, 0) + x); + const auto y_cc = Load(df, rows.GetInputRow(0, 1) + x); + const auto b_cc = Load(df, rows.GetInputRow(0, 2) + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixelStep2(/*row=*/-1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, lf, &X, &Y, &B, &w); + // Center + AddPixelStep2(/*row=*/0, rows, x - 1, x_cc, y_cc, b_cc, + inv_sigma, lf, &X, &Y, &B, &w); + AddPixelStep2(/*row=*/0, rows, x + 1, x_cc, y_cc, b_cc, + inv_sigma, lf, &X, &Y, &B, &w); + // Bottom + AddPixelStep2(/*row=*/1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, lf, &X, &Y, &B, &w); + +#if JXL_HIGH_PRECISION + auto inv_w = Set(df, 1.0f) / w; +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(X * inv_w, df, rows.GetOutputRow(0) + x); + Store(Y * inv_w, df, rows.GetOutputRow(1) + x); + Store(B * inv_w, df, rows.GetOutputRow(2) + x); + } +} + +constexpr FilterDefinition kGaborishFilter{&GaborishRow, 1}; +constexpr FilterDefinition kEpf0Filter{&Epf0Row, 3}; +constexpr FilterDefinition kEpf1Filter{&Epf1Row, 2}; +constexpr FilterDefinition kEpf2Filter{&Epf2Row, 1}; + +void FilterPipelineInit(FilterPipeline* fp, const LoopFilter& lf, + const Image3F& in, const Rect& in_rect, + const Rect& image_rect, size_t image_ysize, + Image3F* out, const Rect& out_rect) { + JXL_DASSERT(lf.gab || lf.epf_iters > 0); + // All EPF filters use sigma so we need to compute it. + fp->compute_sigma = lf.epf_iters > 0; + + fp->num_filters = 0; + fp->storage_rows_used = 0; + // First filter always uses the input image. + fp->filters[0].SetInput(&in, in_rect, image_rect, image_ysize); + + if (lf.gab) { + fp->AddStep(kGaborishFilter); + } + + if (lf.epf_iters == 1) { + fp->AddStep(kEpf1Filter); + } else if (lf.epf_iters == 2) { + fp->AddStep(kEpf1Filter); + fp->AddStep(kEpf2Filter); + } else if (lf.epf_iters == 3) { + fp->AddStep(kEpf0Filter); + fp->AddStep(kEpf1Filter); + fp->AddStep(kEpf2Filter); + } + + // At least one of the filters was enabled so "num_filters" must be non-zero. + JXL_DASSERT(fp->num_filters > 0); + + // Set the output of the last filter as the output image. + fp->filters[fp->num_filters - 1].SetOutput(out, out_rect); + + // Walk the list of filters backwards to compute how many rows are needed. + size_t col_border = 0; + for (int i = fp->num_filters - 1; i >= 0; i--) { + // The extra border needed for future filtering should be a multiple of + // Lanes(df). Rounding up in each step but not storing the rounded up + // value in col_border means that in a 3-step filter the first two filters + // may have the same output_col_border value but the second one would use + // uninitialized values from the previous one. It is fine to have this + // situation for pixels outside the col_border but inside the rounded up + // col_border. + fp->filters[i].output_col_border = RoundUpTo(col_border, Lanes(df)); + col_border += fp->filters[i].filter_def.border; + } + fp->total_border = col_border; + JXL_ASSERT(fp->total_border == lf.Padding()); + JXL_ASSERT(fp->total_border <= kMaxFilterBorder); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FilterPipelineInit); // Local function + +// Mirror n floats starting at *p and store them before p. +JXL_INLINE void LeftMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + *(p - 1 - i) = p[i]; + } +} + +// Mirror n floats starting at *(p - n) and store them at *p. +JXL_INLINE void RightMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + p[i] = *(p - 1 - i); + } +} + +void ComputeSigma(const Rect& block_rect, PassesDecoderState* state) { + const LoopFilter& lf = state->shared->frame_header.loop_filter; + JXL_CHECK(lf.epf_iters > 0); + const AcStrategyImage& ac_strategy = state->shared->ac_strategy; + const float quant_scale = state->shared->quantizer.Scale(); + + const size_t sigma_stride = state->filter_weights.sigma.PixelsPerRow(); + const size_t sharpness_stride = state->shared->epf_sharpness.PixelsPerRow(); + + for (size_t by = 0; by < block_rect.ysize(); ++by) { + float* JXL_RESTRICT sigma_row = + block_rect.Row(&state->filter_weights.sigma, by); + const uint8_t* JXL_RESTRICT sharpness_row = + block_rect.ConstRow(state->shared->epf_sharpness, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + const int* const JXL_RESTRICT row_quant = + block_rect.ConstRow(state->shared->raw_quant_field, by); + + for (size_t bx = 0; bx < block_rect.xsize(); bx++) { + AcStrategy acs = acs_row[bx]; + size_t llf_x = acs.covered_blocks_x(); + if (!acs.IsFirstBlock()) continue; + // quant_scale is smaller for low quality. + // quant_scale is roughly 0.08 / butteraugli score. + // + // row_quant is smaller for low quality. + // row_quant is a quantization multiplier of form 1.0 / + // row_quant[bx] + // + // lf.epf_quant_mul is a parameter in the format + // kInvSigmaNum is a constant + float sigma_quant = + lf.epf_quant_mul / (quant_scale * row_quant[bx] * kInvSigmaNum); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float sigma = + sigma_quant * + lf.epf_sharp_lut[sharpness_row[bx + ix + iy * sharpness_stride]]; + // Avoid infinities. + sigma = std::min(-1e-4f, sigma); // TODO(veluca): remove this. + sigma_row[bx + ix + kSigmaPadding + + (iy + kSigmaPadding) * sigma_stride] = 1.0f / sigma; + } + } + // TODO(veluca): remove this padding. + // Left padding with mirroring. + if (bx + block_rect.x0() == 0) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + LeftMirror( + sigma_row + kSigmaPadding + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Right padding with mirroring. + if (bx + block_rect.x0() + llf_x == + state->shared->frame_dim.xsize_blocks) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + RightMirror(sigma_row + kSigmaPadding + bx + llf_x + + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Offsets for row copying, in blocks. + size_t offset_before = bx + block_rect.x0() == 0 ? 1 : bx + kSigmaPadding; + size_t offset_after = + bx + block_rect.x0() + llf_x == state->shared->frame_dim.xsize_blocks + ? kSigmaPadding + llf_x + bx + kSigmaBorder + : kSigmaPadding + llf_x + bx; + size_t num = offset_after - offset_before; + // Above + if (by + block_rect.y0() == 0) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + (kSigmaPadding - 1 - iy) * sigma_stride, + sigma_row + offset_before + (kSigmaPadding + iy) * sigma_stride, + num * sizeof(*sigma_row)); + } + } + // Below + if (by + block_rect.y0() + acs.covered_blocks_y() == + state->shared->frame_dim.ysize_blocks) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + sigma_stride * (acs.covered_blocks_y() + kSigmaPadding + iy), + sigma_row + offset_before + + sigma_stride * + (acs.covered_blocks_y() + kSigmaPadding - 1 - iy), + num * sizeof(*sigma_row)); + } + } + } + } +} + +FilterPipeline* PrepareFilterPipeline( + PassesDecoderState* dec_state, const Rect& image_rect, const Image3F& input, + const Rect& input_rect, size_t image_ysize, size_t thread, + Image3F* JXL_RESTRICT out, const Rect& output_rect) { + const LoopFilter& lf = dec_state->shared->frame_header.loop_filter; + JXL_DASSERT(image_rect.x0() % GroupBorderAssigner::kPaddingXRound == 0); + JXL_DASSERT(input_rect.x0() % GroupBorderAssigner::kPaddingXRound == 0); + JXL_DASSERT(output_rect.x0() % GroupBorderAssigner::kPaddingXRound == 0); + JXL_DASSERT(input_rect.x0() >= lf.Padding()); + JXL_DASSERT(image_rect.xsize() == input_rect.xsize()); + JXL_DASSERT(image_rect.xsize() == output_rect.xsize()); + FilterPipeline* fp = &(dec_state->filter_pipelines[thread]); + HWY_DYNAMIC_DISPATCH(FilterPipelineInit) + (fp, lf, input, input_rect, image_rect, image_ysize, out, output_rect); + return fp; +} + +void ApplyFilters(PassesDecoderState* dec_state, const Rect& image_rect, + const Image3F& input, const Rect& input_rect, size_t thread, + Image3F* JXL_RESTRICT out, const Rect& output_rect) { + auto fp = PrepareFilterPipeline(dec_state, image_rect, input, input_rect, + input_rect.ysize(), thread, out, output_rect); + const LoopFilter& lf = dec_state->shared->frame_header.loop_filter; + for (ssize_t y = -lf.Padding(); + y < static_cast(lf.Padding() + image_rect.ysize()); y++) { + fp->ApplyFiltersRow(lf, dec_state->filter_weights, image_rect, y); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/epf.h b/third_party/jpeg-xl/lib/jxl/epf.h new file mode 100644 index 000000000000..009bf3e8bd44 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/epf.h @@ -0,0 +1,62 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_EPF_H_ +#define LIB_JXL_EPF_H_ + +// Fast SIMD "in-loop" edge preserving filter (adaptive, nonlinear). + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/filters.h" +#include "lib/jxl/passes_state.h" + +namespace jxl { + +// 4 * (sqrt(0.5)-1), so that Weight(sigma) = 0.5. +static constexpr float kInvSigmaNum = -1.1715728752538099024f; + +// Fills the `state->filter_weights.sigma` image with the precomputed sigma +// values in the area inside `block_rect`. Accesses the AC strategy, quant field +// and epf_sharpness fields in the corresponding positions. +void ComputeSigma(const Rect& block_rect, PassesDecoderState* state); + +// Applies Gaborish + EPF to the given `image_rect` part of the image (used to +// select the sigma values). Input pixels are taken from `input:input_rect`, and +// the filtering result is written to `out:output_rect`. `dec_state->sigma` must +// be padded with `kMaxFilterPadding/kBlockDim` values along the x axis. +// All rects must be aligned to a multiple of `kBlockDim` pixels. +// `input_rect`, `output_rect` and `image_rect` must all have the same size. +// At least `lf.Padding()` pixels must be accessible (and contain valid values) +// outside of `image_rect` in `input`. +// This function should only ever be called on full images. To do partial +// processing, use PrepareFilterPipeline directly. +void ApplyFilters(PassesDecoderState* dec_state, const Rect& image_rect, + const Image3F& input, const Rect& input_rect, size_t thread, + Image3F* JXL_RESTRICT out, const Rect& output_rect); + +// Same as ApplyFilters, but only prepares the pipeline (which is returned and +// must be run by the caller on -lf.Padding() to image_rect.ysize() + +// lf.Padding()). +FilterPipeline* PrepareFilterPipeline( + PassesDecoderState* dec_state, const Rect& image_rect, const Image3F& input, + const Rect& input_rect, size_t image_ysize, size_t thread, + Image3F* JXL_RESTRICT out, const Rect& output_rect); + +} // namespace jxl + +#endif // LIB_JXL_EPF_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fast_math-inl.h b/third_party/jpeg-xl/lib/jxl/fast_math-inl.h new file mode 100644 index 000000000000..edb10766c015 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_math-inl.h @@ -0,0 +1,184 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Fast SIMD math ops (log2, encoder only, cos, erf for splines) + +#if defined(LIB_JXL_FAST_MATH_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_FAST_MATH_INL_H_ +#undef LIB_JXL_FAST_MATH_INL_H_ +#else +#define LIB_JXL_FAST_MATH_INL_H_ +#endif + +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/rational_polynomial-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; + +// Computes base-2 logarithm like std::log2. Undefined if negative / NaN. +// L1 error ~3.9E-6 +template +V FastLog2f(const DF df, V x) { + // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2). + HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f), + HWY_REP4(1.4287160470083755E+00f), + HWY_REP4(7.4245873327820566E-01f)}; + HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f), + HWY_REP4(1.0096718572241148E+00f), + HWY_REP4(1.7409343003366853E-01f)}; + + const Rebind di; + const auto x_bits = BitCast(di, x); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = x_bits - Set(di, 0x3f2aaaab); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(df, x_bits - ShiftLeft<23>(exp_shifted)); + const auto exp_val = ConvertTo(df, exp_shifted); + return EvalRationalPolynomial(df, mantissa - Set(df, 1.0f), p, q) + exp_val; +} + +// max relative error ~3e-7 +template +V FastPow2f(const DF df, V x) { + const Rebind di; + auto floorx = Floor(x); + auto exp = BitCast(df, ShiftLeft<23>(ConvertTo(di, floorx) + Set(di, 127))); + auto frac = x - floorx; + auto num = frac + Set(df, 1.01749063e+01); + num = MulAdd(num, frac, Set(df, 4.88687798e+01)); + num = MulAdd(num, frac, Set(df, 9.85506591e+01)); + num = num * exp; + auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02)); + den = MulAdd(den, frac, Set(df, -1.94414990e+01)); + den = MulAdd(den, frac, Set(df, 9.85506633e+01)); + return num / den; +} + +// max relative error ~3e-5 +template +V FastPowf(const DF df, V base, V exponent) { + return FastPow2f(df, FastLog2f(df, base) * exponent); +} + +// Computes cosine like std::cos. +// L1 error 7e-5. +template +V FastCosf(const DF df, V x) { + // Step 1: range reduction to [0, 2pi) + const auto pi2 = Set(df, kPi * 2.0f); + const auto pi2_inv = Set(df, 0.5f / kPi); + const auto npi2 = Floor(x * pi2_inv) * pi2; + const auto xmodpi2 = x - npi2; + // Step 2: range reduction to [0, pi] + const auto x_pi = Min(xmodpi2, pi2 - xmodpi2); + // Step 3: range reduction to [0, pi/2] + const auto above_pihalf = x_pi >= Set(df, kPi / 2.0f); + const auto x_pihalf = IfThenElse(above_pihalf, Set(df, kPi) - x_pi, x_pi); + // Step 4: Taylor-like approximation, scaled by 2**0.75 to make angle + // duplication steps faster, on x/4. + const auto xs = x_pihalf * Set(df, 0.25f); + const auto x2 = xs * xs; + const auto x4 = x2 * x2; + const auto cosx_prescaling = + MulAdd(x4, Set(df, 0.06960438), + MulAdd(x2, Set(df, -0.84087373), Set(df, 1.68179268))); + // Step 5: angle duplication. + const auto cosx_scale1 = + MulAdd(cosx_prescaling, cosx_prescaling, Set(df, -1.414213562)); + const auto cosx_scale2 = MulAdd(cosx_scale1, cosx_scale1, Set(df, -1)); + // Step 6: change sign if needed. + const Rebind du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(above_pihalf))); + return BitCast(df, signbit ^ BitCast(du, cosx_scale2)); +} + +// Computes the error function like std::erf. +// L1 error 7e-4. +template +V FastErff(const DF df, V x) { + // Formula from + // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations + // but constants have been recomputed. + const auto xle0 = x <= Zero(df); + const auto absx = Abs(x); + // Compute 1 - 1 / ((((x * a + b) * x + c) * x + d) * x + 1)**4 + const auto denom1 = + MulAdd(absx, Set(df, 7.77394369e-02), Set(df, 2.05260015e-04)); + const auto denom2 = MulAdd(denom1, absx, Set(df, 2.32120216e-01)); + const auto denom3 = MulAdd(denom2, absx, Set(df, 2.77820801e-01)); + const auto denom4 = MulAdd(denom3, absx, Set(df, 1.0f)); + const auto denom5 = denom4 * denom4; + const auto inv_denom5 = Set(df, 1.0f) / denom5; + const auto result = NegMulAdd(inv_denom5, inv_denom5, Set(df, 1.0f)); + // Change sign if needed. + const Rebind du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(xle0))); + return BitCast(df, signbit ^ BitCast(du, result)); +} + +inline float FastLog2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastLog2f(D, Set(D, f))); +} + +inline float FastPow2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPow2f(D, Set(D, f))); +} + +inline float FastPowf(float b, float e) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPowf(D, Set(D, b), Set(D, e))); +} + +inline float FastCosf(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastCosf(D, Set(D, f))); +} + +inline float FastErff(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastErff(D, Set(D, f))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_FAST_MATH_INL_H_ + +#if HWY_ONCE + +namespace jxl { +inline float FastLog2f(float f) { return HWY_STATIC_DISPATCH(FastLog2f)(f); } +inline float FastPow2f(float f) { return HWY_STATIC_DISPATCH(FastPow2f)(f); } +inline float FastPowf(float b, float e) { + return HWY_STATIC_DISPATCH(FastPowf)(b, e); +} +inline float FastCosf(float f) { return HWY_STATIC_DISPATCH(FastCosf)(f); } +inline float FastErff(float f) { return HWY_STATIC_DISPATCH(FastErff)(f); } +} // namespace jxl + +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/fast_math_test.cc b/third_party/jpeg-xl/lib/jxl/fast_math_test.cc new file mode 100644 index 000000000000..31866f56bd47 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_math_test.cc @@ -0,0 +1,252 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_math_test.cc" +#include + +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/transfer_functions-inl.h" + +// Test utils +#include +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestFastLog2() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(1e-7f, 1e3f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const auto actual_v = FastLog2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::log2(f) - actual); + EXPECT_LT(abs_err, 2.9E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPow2() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(-100, 100); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const auto actual_v = FastPow2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = std::pow(2, f); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3.1E-7) << "f = " << f; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast(max_rel_err)); +} + +HWY_NOINLINE void TestFastPow() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution distb(1e-3f, 1e3f); + std::uniform_real_distribution diste(-10, 10); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float b = distb(rng); + const float e = diste(rng); + const auto actual_v = FastPowf(d, Set(d, b), Set(d, e)); + const float actual = GetLane(actual_v); + const float expected = std::pow(b, e); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3E-5) << "b = " << b << " e = " << e; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast(max_rel_err)); +} + +HWY_NOINLINE void TestFastCos() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(-1e3f, 1e3f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const auto actual_v = FastCosf(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::cos(f) - actual); + EXPECT_LT(abs_err, 7E-5) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastErf() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(-5.f, 5.f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const auto actual_v = FastErff(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::erf(f) - actual); + EXPECT_LT(abs_err, 7E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastSRGB() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(0.0f, 1.0f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const auto actual_v = FastLinearToSRGB(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = GetLane(TF_SRGB().EncodedFromDisplay(d, Set(d, f))); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 1.2E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPQEFD() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(0.0f, 1.0f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const float actual = GetLane(TF_PQ().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_PQ().EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 7e-7) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastPQDFE() { + constexpr size_t kNumTrials = 1 << 23; + std::mt19937 rng(1); + std::uniform_real_distribution dist(0.0f, 1.0f); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = dist(rng); + const float actual = GetLane(TF_PQ().DisplayFromEncoded(d, Set(d, f))); + const float expected = TF_PQ().DisplayFromEncoded(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 3E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast(max_abs_err)); +} + +HWY_NOINLINE void TestFastXYB() { + if (!HasFastXYBTosRGB8()) return; + ImageMetadata metadata; + ImageBundle ib(&metadata); + int scaling = 1; + int n = 256 * scaling; + float inv_scaling = 1.0f / scaling; + int kChunk = 32; + // The image is divided in chunks to reduce total memory usage. + for (int cr = 0; cr < n; cr += kChunk) { + for (int cg = 0; cg < n; cg += kChunk) { + for (int cb = 0; cb < n; cb += kChunk) { + Image3F chunk(kChunk * kChunk, kChunk); + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + chunk.PlaneRow(0, ir)[ig * kChunk + ib] = r * (1.0f / 255); + chunk.PlaneRow(1, ir)[ig * kChunk + ib] = g * (1.0f / 255); + chunk.PlaneRow(2, ir)[ig * kChunk + ib] = b * (1.0f / 255); + } + } + } + ib.SetFromImage(std::move(chunk), ColorEncoding::SRGB()); + Image3F xyb(kChunk * kChunk, kChunk); + std::vector roundtrip(kChunk * kChunk * kChunk * 3); + ToXYB(ib, nullptr, &xyb); + jxl::HWY_NAMESPACE::FastXYBTosRGB8(xyb, Rect(xyb), Rect(xyb), + roundtrip.data(), xyb.xsize()); + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + size_t idx = ir * kChunk * kChunk + ig * kChunk + ib; + int rr = roundtrip[3 * idx]; + int rg = roundtrip[3 * idx + 1]; + int rb = roundtrip[3 * idx + 2]; + EXPECT_LT(abs(r - rr), 2) << "expected " << r << " got " << rr; + EXPECT_LT(abs(g - rg), 2) << "expected " << g << " got " << rg; + EXPECT_LT(abs(b - rb), 2) << "expected " << b << " got " << rb; + } + } + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class FastMathTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastMathTargetTest); + +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastLog2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastCos); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastErf); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastSRGB); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPQDFE); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPQEFD); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastXYB); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/field_encodings.h b/third_party/jpeg-xl/lib/jxl/field_encodings.h new file mode 100644 index 000000000000..0fcb8fa8b9a0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/field_encodings.h @@ -0,0 +1,132 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_FIELD_ENCODINGS_H_ +#define LIB_JXL_FIELD_ENCODINGS_H_ + +// Constants needed to encode/decode fields; avoids including the full fields.h. + +#include +#include + +#include + +#include "hwy/base.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +class Visitor; +class Fields { + public: + virtual ~Fields() = default; + virtual const char* Name() const = 0; + virtual Status VisitFields(Visitor* JXL_RESTRICT visitor) = 0; +}; + +// Distribution of U32 values for one particular selector. Represents either a +// power of two-sized range, or a single value. A separate type ensures this is +// only passed to the U32Enc ctor. +struct U32Distr { + // No need to validate - all `d` are legitimate. + constexpr explicit U32Distr(uint32_t d) : d(d) {} + + static constexpr uint32_t kDirect = 0x80000000u; + + constexpr bool IsDirect() const { return (d & kDirect) != 0; } + + // Only call if IsDirect(). + constexpr uint32_t Direct() const { return d & (kDirect - 1); } + + // Only call if !IsDirect(). + constexpr size_t ExtraBits() const { return (d & 0x1F) + 1; } + uint32_t Offset() const { return (d >> 5) & 0x3FFFFFF; } + + uint32_t d; +}; + +// A direct-coded 31-bit value occupying 2 bits in the bitstream. +constexpr U32Distr Val(uint32_t value) { + return U32Distr(value | U32Distr::kDirect); +} + +// Value - `offset` will be signaled in `bits` extra bits. +constexpr U32Distr BitsOffset(uint32_t bits, uint32_t offset) { + return U32Distr(((bits - 1) & 0x1F) + ((offset & 0x3FFFFFF) << 5)); +} + +// Value will be signaled in `bits` extra bits. +constexpr U32Distr Bits(uint32_t bits) { return BitsOffset(bits, 0); } + +// See U32Coder documentation in fields.h. +class U32Enc { + public: + constexpr U32Enc(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3) + : d_{d0, d1, d2, d3} {} + + // Returns the U32Distr at `selector` = 0..3, least-significant first. + U32Distr GetDistr(const uint32_t selector) const { + JXL_ASSERT(selector < 4); + return d_[selector]; + } + + private: + U32Distr d_[4]; +}; + +// Returns bit with the given `index` (0 = least significant). +template +static inline constexpr uint64_t MakeBit(T index) { + return 1ULL << static_cast(index); +} + +// Returns vector of all possible values of an Enum type. Relies on each Enum +// providing an overload of EnumBits() that returns a bit array of its values, +// which implies values must be in [0, 64). +template +std::vector Values() { + uint64_t bits = EnumBits(Enum()); + + std::vector values; + values.reserve(hwy::PopCount(bits)); + + // For each 1-bit in bits: add its index as value + while (bits != 0) { + const int index = Num0BitsBelowLS1Bit_Nonzero(bits); + values.push_back(static_cast(index)); + bits &= bits - 1; // clear least-significant bit + } + return values; +} + +// Returns true if value is one of Values(). +template +Status EnumValid(const Enum value) { + if (static_cast(value) >= 64) { + return JXL_FAILURE("Value %u too large for %s\n", + static_cast(value), EnumName(Enum())); + } + const uint64_t bit = MakeBit(value); + if ((EnumBits(Enum()) & bit) == 0) { + return JXL_FAILURE("Invalid value %u for %s\n", + static_cast(value), EnumName(Enum())); + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_FIELD_ENCODINGS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fields.cc b/third_party/jpeg-xl/lib/jxl/fields.cc new file mode 100644 index 000000000000..361db39eaf5d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields.cc @@ -0,0 +1,994 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/fields.h" + +#include + +#include +#include + +#include "hwy/base.h" +#include "lib/jxl/base/bits.h" + +namespace jxl { + +namespace { + +// A bundle can be in one of three states concerning extensions: not-begun, +// active, ended. Bundles may be nested, so we need a stack of states. +class ExtensionStates { + public: + void Push() { + // Initial state = not-begun. + begun_ <<= 1; + ended_ <<= 1; + } + + // Clears current state; caller must check IsEnded beforehand. + void Pop() { + begun_ >>= 1; + ended_ >>= 1; + } + + // Returns true if state == active || state == ended. + Status IsBegun() const { return (begun_ & 1) != 0; } + // Returns true if state != not-begun && state != active. + Status IsEnded() const { return (ended_ & 1) != 0; } + + void Begin() { + JXL_ASSERT(!IsBegun()); + JXL_ASSERT(!IsEnded()); + begun_ += 1; + } + + void End() { + JXL_ASSERT(IsBegun()); + JXL_ASSERT(!IsEnded()); + ended_ += 1; + } + + private: + // Current state := least-significant bit of begun_ and ended_. + uint64_t begun_ = 0; + uint64_t ended_ = 0; +}; + +// Visitors generate Init/AllDefault/Read/Write logic for all fields. Each +// bundle's VisitFields member function calls visitor->U32 etc. We do not +// overload operator() because a function name is easier to search for. + +class VisitorBase : public Visitor { + public: + explicit VisitorBase(bool print_bundles = false) + : print_bundles_(print_bundles) {} + ~VisitorBase() override { JXL_ASSERT(depth_ == 0); } + + // This is the only call site of Fields::VisitFields. Adds tracing and + // ensures EndExtensions was called. + Status Visit(Fields* fields, const char* visitor_name) override { + fputs(visitor_name, stdout); // No newline; no effect if empty + if (print_bundles_) { + Trace("%s\n", print_bundles_ ? fields->Name() : ""); + } + + depth_ += 1; + JXL_ASSERT(depth_ <= Bundle::kMaxExtensions); + extension_states_.Push(); + + const Status ok = fields->VisitFields(this); + + if (ok) { + // If VisitFields called BeginExtensions, must also call + // EndExtensions. + JXL_ASSERT(!extension_states_.IsBegun() || extension_states_.IsEnded()); + } else { + // Failed, undefined state: don't care whether EndExtensions was + // called. + } + + extension_states_.Pop(); + JXL_ASSERT(depth_ != 0); + depth_ -= 1; + + return ok; + } + + // For visitors accepting a const Visitor, need to const-cast so we can call + // the non-const Visitor::VisitFields. NOTE: C is not modified except the + // `all_default` field by CanEncodeVisitor. + Status VisitConst(const Fields& t, const char* message) { + return Visit(const_cast(&t), message); + } + + // Derived types (overridden by InitVisitor because it is unsafe to read + // from *value there) + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + uint32_t bits = *value ? 1 : 0; + JXL_RETURN_IF_ERROR(Bits(1, static_cast(default_value), &bits)); + JXL_DASSERT(bits <= 1); + *value = bits == 1; + return true; + } + + // Overridden by ReadVisitor and WriteVisitor. + // Called before any conditional visit based on "extensions". + // Overridden by ReadVisitor, CanEncodeVisitor and WriteVisitor. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_RETURN_IF_ERROR(U64(0, extensions)); + + extension_states_.Begin(); + return true; + } + + // Called after all extension fields (if any). Although non-extension + // fields could be visited afterward, we prefer the convention that + // extension fields are always the last to be visited. Overridden by + // ReadVisitor. + Status EndExtensions() override { + extension_states_.End(); + return true; + } + + protected: + // Prints indentation, . + JXL_FORMAT(2, 3) // 1-based plus one because member function + void Trace(const char* format, ...) const { + // Indentation. + printf("%*s", static_cast(2 * depth_), ""); + + va_list args; + va_start(args, format); + vfprintf(stdout, format, args); + va_end(args); + } + + private: + size_t depth_ = 0; // for indentation. + ExtensionStates extension_states_; + const bool print_bundles_; +}; + +struct InitVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } + + Status VisitNested(Fields* /*fields*/) override { + // Avoid re-initializing nested bundles (their ctors already called + // Bundle::Init for their fields). + return true; + } + + const char* VisitorName() override { return "InitVisitor"; } +}; + +// Similar to InitVisitor, but also initializes nested fields. +struct SetDefaultVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } + + const char* VisitorName() override { return "SetDefaultVisitor"; } +}; + +class AllDefaultVisitor : public VisitorBase { + public: + explicit AllDefaultVisitor(bool print_all_default) + : VisitorBase(print_all_default), print_all_default_(print_all_default) {} + + Status Bits(const size_t bits, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + if (print_all_default_) { + Trace(" u(%zu) = %u, default %u\n", bits, *value, default_value); + } + + all_default_ &= *value == default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + if (print_all_default_) { + Trace(" U32 = %u, default %u\n", *value, default_value); + } + + all_default_ &= *value == default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + if (print_all_default_) { + Trace(" U64 = %" PRIu64 ", default %" PRIu64 "\n", *value, + default_value); + } + + all_default_ &= *value == default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + if (print_all_default_) { + Trace(" F16 = %.6f, default %.6f\n", static_cast(*value), + static_cast(default_value)); + } + all_default_ &= std::abs(*value - default_value) < 1E-6f; + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT /*all_default*/) override { + // Visit all fields so we can compute the actual all_default_ value. + return false; + } + + bool AllDefault() const { return all_default_; } + + const char* VisitorName() override { return "AllDefaultVisitor"; } + + private: + const bool print_all_default_; + bool all_default_ = true; +}; + +class ReadVisitor : public VisitorBase { + public: + ReadVisitor(BitReader* reader, bool print_read) + : VisitorBase(print_read), print_read_(print_read), reader_(reader) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = BitsCoder::Read(bits, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + if (print_read_) Trace(" u(%zu) = %u\n", bits, *value); + return true; + } + + Status U32(const U32Enc dist, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = U32Coder::Read(dist, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + if (print_read_) Trace(" U32 = %u\n", *value); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + *value = U64Coder::Read(reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + if (print_read_) Trace(" U64 = %" PRIu64 "\n", *value); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Read(reader_, value); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + if (print_read_) Trace(" F16 = %f\n", static_cast(*value)); + return true; + } + + void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); } + + bool IsReading() const override { return true; } + + // This never fails because visitors are expected to keep reading until + // EndExtensions, see comment there. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) return true; + + // For each nonzero bit, i.e. extension that is present: + for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + const size_t idx_extension = + Num0BitsBelowLS1Bit_Nonzero(remaining_extensions); + // Read additional U64 (one per extension) indicating the number of bits + // (allows skipping individual extensions). + JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension])); + if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension], + total_extension_bits_)) { + return JXL_FAILURE("Extension bits overflowed, invalid codestream"); + } + } + // Used by EndExtensions to skip past any _remaining_ extensions. + pos_after_ext_size_ = reader_->TotalBitsConsumed(); + JXL_ASSERT(pos_after_ext_size_ != 0); + return true; + } + + Status EndExtensions() override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions()); + // Happens if extensions == 0: don't read size, done. + if (pos_after_ext_size_ == 0) return true; + + // Not enough bytes as set by BeginExtensions or earlier. Do not return + // this as an JXL_FAILURE or false (which can also propagate to error + // through e.g. JXL_RETURN_IF_ERROR), since this may be used while + // silently checking whether there are enough bytes. If this case must be + // treated as an error, reader_>Close() will do this, just like is already + // done for non-extension fields. + if (!enough_bytes_) return true; + + // Skip new fields this (old?) decoder didn't know about, if any. + const size_t bits_read = reader_->TotalBitsConsumed(); + uint64_t end; + if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) { + return JXL_FAILURE("Invalid extension size, caused overflow"); + } + if (bits_read > end) { + return JXL_FAILURE("Read more extension bits than budgeted"); + } + const size_t remaining_bits = end - bits_read; + if (remaining_bits != 0) { + JXL_WARNING("Skipping %zu-bit extension(s)", remaining_bits); + reader_->SkipBits(remaining_bits); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + } + return true; + } + + Status OK() const { return ok_; } + + const char* VisitorName() override { return "ReadVisitor"; } + + private: + const bool print_read_; + + // Whether any error other than not enough bytes occurred. + bool ok_ = true; + + // Whether there are enough input bytes to read from. + bool enough_bytes_ = true; + BitReader* const reader_; + // May be 0 even if the corresponding extension is present. + uint64_t extension_bits_[Bundle::kMaxExtensions] = {0}; + uint64_t total_extension_bits_ = 0; + size_t pos_after_ext_size_ = 0; // 0 iff extensions == 0. +}; + +class MaxBitsVisitor : public VisitorBase { + public: + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += BitsCoder::MaxEncodedBits(bits); + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U32Coder::MaxEncodedBits(enc); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U64Coder::MaxEncodedBits(); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT /*value*/) override { + max_bits_ += F16Coder::MaxEncodedBits(); + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; // For max bits, assume nothing is default + } + + // Always visit conditional fields to get a (loose) upper bound. + Status Conditional(bool /*condition*/) override { return true; } + + Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override { + // Skip - extensions are not included in "MaxBits" because their length + // is potentially unbounded. + return true; + } + + Status EndExtensions() override { return true; } + + size_t MaxBits() const { return max_bits_; } + + const char* VisitorName() override { return "MaxBitsVisitor"; } + + private: + size_t max_bits_ = 0; +}; + +class CanEncodeVisitor : public VisitorBase { + public: + explicit CanEncodeVisitor(bool print_sizes) + : VisitorBase(print_sizes), print_sizes_(print_sizes) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits); + if (print_sizes_) Trace("u(%zu) = %u\n", bits, *value); + encoded_bits_ += encoded_bits; + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits); + if (print_sizes_) Trace("U32(%zu) = %u\n", encoded_bits, *value); + encoded_bits_ += encoded_bits; + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*value, &encoded_bits); + if (print_sizes_) { + Trace("U64(%zu) = %" PRIu64 "\n", encoded_bits, *value); + } + encoded_bits_ += encoded_bits; + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= F16Coder::CanEncode(*value, &encoded_bits); + if (print_sizes_) { + Trace("F16(%zu) = %.6f\n", encoded_bits, static_cast(*value)); + } + encoded_bits_ += encoded_bits; + return true; + } + + Status AllDefault(const Fields& fields, + bool* JXL_RESTRICT all_default) override { + *all_default = Bundle::AllDefault(fields); + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + extensions_ = *extensions; + if (*extensions != 0) { + JXL_ASSERT(pos_after_ext_ == 0); + pos_after_ext_ = encoded_bits_; + JXL_ASSERT(pos_after_ext_ != 0); // visited "extensions" + } + return true; + } + // EndExtensions = default. + + Status GetSizes(size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits) { + JXL_RETURN_IF_ERROR(ok_); + *extension_bits = 0; + *total_bits = encoded_bits_; + // Only if extension field was nonzero will we encode their sizes. + if (pos_after_ext_ != 0) { + JXL_ASSERT(encoded_bits_ >= pos_after_ext_); + *extension_bits = encoded_bits_ - pos_after_ext_; + // Also need to encode *extension_bits and bill it to *total_bits. + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits); + *total_bits += encoded_bits; + + // TODO(janwas): support encoding individual extension sizes. We + // currently ascribe all bits to the first and send zeros for the + // others. + for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) { + encoded_bits = 0; + ok_ &= U64Coder::CanEncode(0, &encoded_bits); + *total_bits += encoded_bits; + } + } + return true; + } + + const char* VisitorName() override { return "CanEncodeVisitor"; } + + private: + const bool print_sizes_; + bool ok_ = true; + size_t encoded_bits_ = 0; + uint64_t extensions_ = 0; + // Snapshot of encoded_bits_ after visiting the extension field, but NOT + // including the hidden extension sizes. + uint64_t pos_after_ext_ = 0; +}; + +class WriteVisitor : public VisitorBase { + public: + WriteVisitor(const size_t extension_bits, BitWriter* JXL_RESTRICT writer) + : extension_bits_(extension_bits), writer_(writer) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= BitsCoder::Write(bits, *value, writer_); + return true; + } + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= U32Coder::Write(enc, *value, writer_); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + ok_ &= U64Coder::Write(*value, writer_); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Write(*value, writer_); + return true; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) { + JXL_ASSERT(extension_bits_ == 0); + return true; + } + // TODO(janwas): extend API to pass in array of extension_bits, one per + // extension. We currently ascribe all bits to the first extension, but + // this is only an encoder limitation. NOTE: extension_bits_ can be zero + // if an extension does not require any additional fields. + ok_ &= U64Coder::Write(extension_bits_, writer_); + // For each nonzero bit except the lowest/first (already written): + for (uint64_t remaining_extensions = *extensions & (*extensions - 1); + remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + ok_ &= U64Coder::Write(0, writer_); + } + return true; + } + // EndExtensions = default. + + Status OK() const { return ok_; } + + const char* VisitorName() override { return "WriteVisitor"; } + + private: + const size_t extension_bits_; + BitWriter* JXL_RESTRICT writer_; + bool ok_ = true; +}; + +} // namespace + +void Bundle::Init(Fields* fields) { + InitVisitor visitor; + if (!visitor.Visit(fields, PrintVisitors() ? "-- Init\n" : "")) { + JXL_ABORT("Init should never fail"); + } +} +void Bundle::SetDefault(Fields* fields) { + SetDefaultVisitor visitor; + if (!visitor.Visit(fields, PrintVisitors() ? "-- SetDefault\n" : "")) { + JXL_ABORT("SetDefault should never fail"); + } +} +bool Bundle::AllDefault(const Fields& fields) { + AllDefaultVisitor visitor(/*print_all_default=*/PrintAllDefault()); + const char* name = + (PrintVisitors() || PrintAllDefault()) ? "[[AllDefault\n" : ""; + if (!visitor.VisitConst(fields, name)) { + JXL_ABORT("AllDefault should never fail"); + } + + if (PrintAllDefault()) printf(" %d]]\n", visitor.AllDefault()); + return visitor.AllDefault(); +} +size_t Bundle::MaxBits(const Fields& fields) { + MaxBitsVisitor visitor; +#if JXL_ENABLE_ASSERT + Status ret = +#else + (void) +#endif // JXL_ENABLE_ASSERT + visitor.VisitConst(fields, PrintVisitors() ? "-- MaxBits\n" : ""); + JXL_ASSERT(ret); + return visitor.MaxBits(); +} +Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits, + size_t* total_bits) { + CanEncodeVisitor visitor(/*print_sizes=*/PrintSizes()); + const char* name = (PrintVisitors() || PrintSizes()) ? "[[CanEncode\n" : ""; + JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields, name)); + JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits)); + if (PrintSizes()) printf(" %zu]]\n", *total_bits); + return true; +} +Status Bundle::Read(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader, /*print_read=*/PrintRead()); + JXL_RETURN_IF_ERROR( + visitor.Visit(fields, PrintVisitors() ? "-- Read\n" : "")); + return visitor.OK(); +} +bool Bundle::CanRead(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader, /*print_read=*/PrintRead()); + Status status = visitor.Visit(fields, PrintVisitors() ? "-- Read\n" : ""); + // We are only checking here whether there are enough bytes. We still return + // true for other errors because it means there are enough bytes to determine + // there's an error. Use Read() to determine which error it is. + return status.code() != StatusCode::kNotEnoughBytes; +} +Status Bundle::Write(const Fields& fields, BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t extension_bits, total_bits; + JXL_RETURN_IF_ERROR(CanEncode(fields, &extension_bits, &total_bits)); + + BitWriter::Allotment allotment(writer, total_bits); + WriteVisitor visitor(extension_bits, writer); + JXL_RETURN_IF_ERROR( + visitor.VisitConst(fields, PrintVisitors() ? "-- Write\n" : "")); + JXL_RETURN_IF_ERROR(visitor.OK()); + ReclaimAndCharge(writer, &allotment, layer, aux_out); + return true; +} + +size_t U32Coder::MaxEncodedBits(const U32Enc enc) { + size_t extra_bits = 0; + for (uint32_t selector = 0; selector < 4; ++selector) { + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + continue; + } else { + extra_bits = std::max(extra_bits, d.ExtraBits()); + } + } + return 2 + extra_bits; +} + +Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + uint32_t selector; + size_t total_bits; + const Status ok = ChooseSelector(enc, value, &selector, &total_bits); + *encoded_bits = ok ? total_bits : 0; + return ok; +} + +uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) { + const uint32_t selector = reader->ReadFixedBits<2>(); + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + return d.Direct(); + } else { + return reader->ReadBits(d.ExtraBits()) + d.Offset(); + } +} + +// Returns false if the value is too large to encode. +Status U32Coder::Write(const U32Enc enc, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + uint32_t selector; + size_t total_bits; + JXL_RETURN_IF_ERROR(ChooseSelector(enc, value, &selector, &total_bits)); + + writer->Write(2, selector); + + const U32Distr d = enc.GetDistr(selector); + if (!d.IsDirect()) { // Nothing more to write for direct encoding + const uint32_t offset = d.Offset(); + JXL_ASSERT(value >= offset); + writer->Write(total_bits - 2, value - offset); + } + + return true; +} + +Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits) { +#if JXL_ENABLE_ASSERT + const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(bits_required <= 32); + + *selector = 0; + *total_bits = 0; + + // It is difficult to verify whether Dist32Byte are sorted, so check all + // selectors and keep the one with the fewest total_bits. + *total_bits = 64; // more than any valid encoding + for (uint32_t s = 0; s < 4; ++s) { + const U32Distr d = enc.GetDistr(s); + if (d.IsDirect()) { + if (d.Direct() == value) { + *selector = s; + *total_bits = 2; + return true; // Done, direct is always the best possible. + } + continue; + } + const size_t extra_bits = d.ExtraBits(); + const uint32_t offset = d.Offset(); + if (value < offset || value >= offset + (1ULL << extra_bits)) continue; + + // Better than prior encoding, remember it: + if (2 + extra_bits < *total_bits) { + *selector = s; + *total_bits = 2 + extra_bits; + } + } + + if (*total_bits == 64) { + return JXL_FAILURE("No feasible selector for %u", value); + } + + return true; +} + +uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) { + uint64_t selector = reader->ReadFixedBits<2>(); + if (selector == 0) { + return 0; + } + if (selector == 1) { + return 1 + reader->ReadFixedBits<4>(); + } + if (selector == 2) { + return 17 + reader->ReadFixedBits<8>(); + } + + // selector 3, varint, groups have first 12, then 8, and last 4 bits. + uint64_t result = reader->ReadFixedBits<12>(); + + uint64_t shift = 12; + while (reader->ReadFixedBits<1>()) { + if (shift == 60) { + result |= static_cast(reader->ReadFixedBits<4>()) << shift; + break; + } + result |= static_cast(reader->ReadFixedBits<8>()) << shift; + shift += 8; + } + + return result; +} + +// Returns false if the value is too large to encode. +Status U64Coder::Write(uint64_t value, BitWriter* JXL_RESTRICT writer) { + if (value == 0) { + // Selector: use 0 bits, value 0 + writer->Write(2, 0); + } else if (value <= 16) { + // Selector: use 4 bits, value 1..16 + writer->Write(2, 1); + writer->Write(4, value - 1); + } else if (value <= 272) { + // Selector: use 8 bits, value 17..272 + writer->Write(2, 2); + writer->Write(8, value - 17); + } else { + // Selector: varint, first a 12-bit group, after that per 8-bit group. + writer->Write(2, 3); + writer->Write(12, value & 4095); + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + // Indicate varint not done + writer->Write(1, 1); + writer->Write(8, value & 255); + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + writer->Write(1, 1); + writer->Write(4, value & 15); + // Implicitly closed sequence, no extra stop bit is required. + } else { + // Indicate end of varint + writer->Write(1, 0); + } + } + + return true; +} + +// Can always encode, but useful because it also returns bit size. +Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) { + if (value == 0) { + *encoded_bits = 2; // 2 selector bits + } else if (value <= 16) { + *encoded_bits = 2 + 4; // 2 selector bits + 4 payload bits + } else if (value <= 272) { + *encoded_bits = 2 + 8; // 2 selector bits + 8 payload bits + } else { + *encoded_bits = 2 + 12; // 2 selector bits + 12 payload bits + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + *encoded_bits += 1 + 8; // 1 continuation bit + 8 payload bits + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + *encoded_bits += 1 + 4; // 1 continuation bit + 4 payload bits + } else { + *encoded_bits += 1; // 1 stop bit + } + } + + return true; +} + +Status F16Coder::Read(BitReader* JXL_RESTRICT reader, + float* JXL_RESTRICT value) { + const uint32_t bits16 = reader->ReadFixedBits<16>(); + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + if (JXL_UNLIKELY(biased_exp == 31)) { + return JXL_FAILURE("F16 infinity or NaN are not supported"); + } + + // Subnormal or zero + if (JXL_UNLIKELY(biased_exp == 0)) { + *value = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + if (sign) *value = -*value; + return true; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + memcpy(value, &bits32, sizeof(bits32)); + return true; +} + +Status F16Coder::Write(float value, BitWriter* JXL_RESTRICT writer) { + uint32_t bits32; + memcpy(&bits32, &value, sizeof(bits32)); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = static_cast(biased_exp32) - 127; + if (JXL_UNLIKELY(exp > 15)) { + return JXL_FAILURE("Too big to encode, CanEncode should return false"); + } + + // Tiny or zero => zero. + if (exp < -24) { + writer->Write(16, 0); + return true; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (JXL_UNLIKELY(exp < -14)) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast(-14 - exp); + JXL_ASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = (1 << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp)); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast(exp + 15); + JXL_ASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + JXL_ASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + JXL_ASSERT(bits16 < 0x10000); + writer->Write(16, bits16); + return true; +} + +Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = MaxEncodedBits(); + if (std::isnan(value) || std::isinf(value)) { + return JXL_FAILURE("Should not attempt to store NaN and infinity"); + } + return std::abs(value) <= 65504.0f; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/fields.h b/third_party/jpeg-xl/lib/jxl/fields.h new file mode 100644 index 000000000000..80e4671afb61 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields.h @@ -0,0 +1,309 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_FIELDS_H_ +#define LIB_JXL_FIELDS_H_ + +// Forward/backward-compatible 'bundles' with auto-serialized 'fields'. + +#include +#include +#include +#include +#include + +#include +#include // abs +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// Integer coders: BitsCoder (raw), U32Coder (table), U64Coder (varint). + +// Reads/writes a given (fixed) number of bits <= 32. +class BitsCoder { + public: + static size_t MaxEncodedBits(const size_t bits) { return bits; } + + static Status CanEncode(const size_t bits, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = bits; + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %u too large for %zu bits", value, bits); + } + return true; + } + + static uint32_t Read(const size_t bits, BitReader* JXL_RESTRICT reader) { + return reader->ReadBits(bits); + } + + // Returns false if the value is too large to encode. + static Status Write(const size_t bits, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %d too large to encode in %zu bits", value, + bits); + } + writer->Write(bits, value); + return true; + } +}; + +// Encodes u32 using a lookup table and/or extra bits, governed by a per-field +// encoding `enc` which consists of four distributions `d` chosen via a 2-bit +// selector (least significant = 0). Each d may have two modes: +// - direct: if d.IsDirect(), the value is d.Direct(); +// - offset: the value is derived from d.ExtraBits() extra bits plus d.Offset(); +// This encoding is denser than Exp-Golomb or Gamma codes when both small and +// large values occur. +// +// Examples: +// Direct: U32Enc(Val(8), Val(16), Val(32), Bits(6)), value 32 => 10b. +// Offset: U32Enc(Val(0), BitsOffset(1, 1), BitsOffset(2, 3), BitsOffset(8, 8)) +// defines the following prefix code: +// 00 -> 0 +// 01x -> 1..2 +// 10xx -> 3..7 +// 11xxxxxxxx -> 8..263 +class U32Coder { + public: + static size_t MaxEncodedBits(U32Enc enc); + static Status CanEncode(U32Enc enc, uint32_t value, + size_t* JXL_RESTRICT encoded_bits); + static uint32_t Read(U32Enc enc, BitReader* JXL_RESTRICT reader); + + // Returns false if the value is too large to encode. + static Status Write(U32Enc enc, uint32_t value, + BitWriter* JXL_RESTRICT writer); + + private: + static Status ChooseSelector(U32Enc enc, uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits); +}; + +// Encodes 64-bit unsigned integers with a fixed distribution, taking 2 bits +// to encode 0, 6 bits to encode 1 to 16, 10 bits to encode 17 to 272, 15 bits +// to encode up to 4095, and on the order of log2(value) * 1.125 bits for +// larger values. +class U64Coder { + public: + static constexpr size_t MaxEncodedBits() { + return 2 + 12 + 6 * (8 + 1) + (4 + 1); + } + + static uint64_t Read(BitReader* JXL_RESTRICT reader); + + // Returns false if the value is too large to encode. + static Status Write(uint64_t value, BitWriter* JXL_RESTRICT writer); + + // Can always encode, but useful because it also returns bit size. + static Status CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits); +}; + +// IEEE 754 half-precision (binary16). Refuses to read/write NaN/Inf. +class F16Coder { + public: + static constexpr size_t MaxEncodedBits() { return 16; } + + // Returns false if the bit representation is NaN or infinity + static Status Read(BitReader* JXL_RESTRICT reader, float* JXL_RESTRICT value); + + // Returns false if the value is too large to encode. + static Status Write(float value, BitWriter* JXL_RESTRICT writer); + static Status CanEncode(float value, size_t* JXL_RESTRICT encoded_bits); +}; + +// A "bundle" is a forward- and backward compatible collection of fields. +// They are used for SizeHeader/FrameHeader/GroupHeader. Bundles can be +// extended by appending(!) fields. Optional fields may be omitted from the +// bitstream by conditionally visiting them. When reading new bitstreams with +// old code, we skip unknown fields at the end of the bundle. This requires +// storing the amount of extra appended bits, and that fields are visited in +// chronological order of being added to the format, because old decoders +// cannot skip some future fields and resume reading old fields. Similarly, +// new readers query bits in an "extensions" field to skip (groups of) fields +// not present in old bitstreams. Note that each bundle must include an +// "extensions" field prior to freezing the format, otherwise it cannot be +// extended. +// +// To ensure interoperability, there will be no opaque fields. +// +// HOWTO: +// - basic usage: define a struct with member variables ("fields") and a +// VisitFields(v) member function that calls v->U32/Bool etc. for each +// field, specifying their default values. The ctor must call +// Bundle::Init(this). +// +// - print a trace of visitors: ensure each bundle has a static Name() member +// function, and change Bundle::Print* to return true. +// +// - optional fields: in VisitFields, add if (v->Conditional(your_condition)) +// { v->Bool(default, &field); }. This prevents reading/writing field +// if !your_condition, which is typically computed from a prior field. +// WARNING: to ensure all fields are initialized, do not add an else branch; +// instead add another if (v->Conditional(!your_condition)). +// +// - repeated fields: for dynamic sizes, use e.g. std::vector and in +// VisitFields, if (v->IsReading()) field.resize(size) before accessing field. +// For static or bounded sizes, use an array or std::array. In all cases, +// simply visit each array element as if it were a normal field. +// +// - nested bundles: add a bundle as a normal field and in VisitFields call +// JXL_RETURN_IF_ERROR(v->VisitNested(&nested)); +// +// - allow future extensions: define a "uint64_t extensions" field and call +// v->BeginExtensions(&extensions) after visiting all non-extension fields, +// and `return v->EndExtensions();` after the last extension field. +// +// - encode an entire bundle in one bit if ALL its fields equal their default +// values: add a "mutable bool all_default" field and as the first visitor: +// if (v->AllDefault(*this, &all_default)) { +// // Overwrite all serialized fields, but not any nonserialized_*. +// v->SetDefault(this); +// return true; +// } +// Note: if extensions are present, AllDefault() == false. + +class Bundle { + public: + static constexpr size_t kMaxExtensions = 64; // bits in u64 + + // Print the type of each visitor called. + static constexpr bool PrintVisitors() { return false; } + // Print default value for each field and AllDefault result. + static constexpr bool PrintAllDefault() { return false; } + // Print values decoded for each field in Read. + static constexpr bool PrintRead() { return false; } + // Print size for each field and CanEncode total_bits. + static constexpr bool PrintSizes() { return false; } + + // Initializes fields to the default values. It is not recursive to nested + // fields, this function is intended to be called in the constructors so + // each nested field will already Init itself. + static void Init(Fields* JXL_RESTRICT fields); + + // Similar to Init, but recursive to nested fields. + static void SetDefault(Fields* JXL_RESTRICT fields); + + // Returns whether ALL fields (including `extensions`, if present) are equal + // to their default value. + static bool AllDefault(const Fields& fields); + + // Returns max number of bits required to encode a T. + static size_t MaxBits(const Fields& fields); + + // Returns whether a header's fields can all be encoded, i.e. they have a + // valid representation. If so, "*total_bits" is the exact number of bits + // required. Called by Write. + static Status CanEncode(const Fields& fields, + size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits); + + static Status Read(BitReader* reader, Fields* JXL_RESTRICT fields); + + // Returns whether enough bits are available to fully read this bundle using + // Read. Also returns true in case of a codestream error (other than not being + // large enough): that means enough bits are available to determine there's an + // error, use Read to get such error status. + // NOTE: this advances the BitReader, a different one pointing back at the + // original bit position in the codestream must be created to use Read after + // this. + static bool CanRead(BitReader* reader, Fields* JXL_RESTRICT fields); + + static Status Write(const Fields& fields, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out); + + private: +}; + +// Different subclasses of Visitor are passed to implementations of Fields +// throughout their lifetime. Templates used to be used for this but dynamic +// polymorphism produces more compact executables than template reification did. +class Visitor { + public: + virtual ~Visitor() = default; + virtual Status Visit(Fields* fields, const char* visitor_name) = 0; + + virtual Status Bool(bool default_value, bool* JXL_RESTRICT value) = 0; + virtual Status U32(U32Enc, uint32_t, uint32_t*) = 0; + + // Helper to construct U32Enc from U32Distr. + Status U32(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) { + return U32(U32Enc(d0, d1, d2, d3), default_value, value); + } + + template + Status Enum(const EnumT default_value, EnumT* JXL_RESTRICT value) { + uint32_t u32 = static_cast(*value); + // 00 -> 0 + // 01 -> 1 + // 10xxxx -> 2..17 + // 11yyyyyy -> 18..81 + JXL_RETURN_IF_ERROR(U32(Val(0), Val(1), BitsOffset(4, 2), BitsOffset(6, 18), + static_cast(default_value), &u32)); + *value = static_cast(u32); + return EnumValid(*value); + } + + virtual Status Bits(size_t bits, uint32_t default_value, + uint32_t* JXL_RESTRICT value) = 0; + virtual Status U64(uint64_t default_value, uint64_t* JXL_RESTRICT value) = 0; + virtual Status F16(float default_value, float* JXL_RESTRICT value) = 0; + + // Returns whether VisitFields should visit some subsequent fields. + // "condition" is typically from prior fields, e.g. flags. + // Overridden by InitVisitor and MaxBitsVisitor. + virtual Status Conditional(bool condition) { return condition; } + + // Overridden by InitVisitor, AllDefaultVisitor and CanEncodeVisitor. + virtual Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + virtual void SetDefault(Fields* /*fields*/) { + // Do nothing by default, this is overridden by ReadVisitor. + } + + // Returns the result of visiting a nested Bundle. + // Overridden by InitVisitor. + virtual Status VisitNested(Fields* fields) { return Visit(fields, ""); } + + // Overridden by ReadVisitor. Enables dynamically-sized fields. + virtual bool IsReading() const { return false; } + + virtual Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) = 0; + virtual Status EndExtensions() = 0; + + // For debugging + virtual const char* VisitorName() = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_FIELDS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fields_test.cc b/third_party/jpeg-xl/lib/jxl/fields_test.cc new file mode 100644 index 000000000000..88500a25c1bb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields_test.cc @@ -0,0 +1,443 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/fields.h" + +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" + +namespace jxl { +namespace { + +// Ensures `value` round-trips and in exactly `expected_bits_written`. +void TestU32Coder(const uint32_t value, const size_t expected_bits_written) { + U32Coder coder; + const U32Enc enc(Val(0), Bits(4), Val(0x7FFFFFFF), Bits(32)); + + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U32Coder::MaxEncodedBits(enc))); + + size_t precheck_pos; + EXPECT_TRUE(coder.CanEncode(enc, value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(coder.Write(enc, value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint32_t decoded_value = coder.Read(enc, &reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U32CoderTest) { + TestU32Coder(0, 2); + TestU32Coder(1, 6); + TestU32Coder(15, 6); + TestU32Coder(0x7FFFFFFF, 2); + TestU32Coder(128, 34); + TestU32Coder(0x7FFFFFFEu, 34); + TestU32Coder(0x80000000u, 34); + TestU32Coder(0xFFFFFFFFu, 34); +} + +void TestU64Coder(const uint64_t value, const size_t expected_bits_written) { + U64Coder coder; + + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U64Coder::MaxEncodedBits())); + + size_t precheck_pos; + EXPECT_TRUE(coder.CanEncode(value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(coder.Write(value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint64_t decoded_value = coder.Read(&reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U64CoderTest) { + // Values that should take 2 bits (selector 00): 0 + TestU64Coder(0, 2); + + // Values that should take 6 bits (2 for selector, 4 for value): 1..16 + TestU64Coder(1, 6); + TestU64Coder(2, 6); + TestU64Coder(8, 6); + TestU64Coder(15, 6); + TestU64Coder(16, 6); + + // Values that should take 10 bits (2 for selector, 8 for value): 17..272 + TestU64Coder(17, 10); + TestU64Coder(18, 10); + TestU64Coder(100, 10); + TestU64Coder(271, 10); + TestU64Coder(272, 10); + + // Values that should take 15 bits (2 for selector, 12 for value, 1 for varint + // end): (0)..273..4095 + TestU64Coder(273, 15); + TestU64Coder(274, 15); + TestU64Coder(1000, 15); + TestU64Coder(4094, 15); + TestU64Coder(4095, 15); + + // Take 24 bits (of which 20 actual value): (0)..4096..1048575 + TestU64Coder(4096, 24); + TestU64Coder(4097, 24); + TestU64Coder(10000, 24); + TestU64Coder(1048574, 24); + TestU64Coder(1048575, 24); + + // Take 33 bits (of which 28 actual value): (0)..1048576..268435455 + TestU64Coder(1048576, 33); + TestU64Coder(1048577, 33); + TestU64Coder(10000000, 33); + TestU64Coder(268435454, 33); + TestU64Coder(268435455, 33); + + // Take 42 bits (of which 36 actual value): (0)..268435456..68719476735 + TestU64Coder(268435456ull, 42); + TestU64Coder(268435457ull, 42); + TestU64Coder(1000000000ull, 42); + TestU64Coder(68719476734ull, 42); + TestU64Coder(68719476735ull, 42); + + // Take 51 bits (of which 44 actual value): (0)..68719476736..17592186044415 + TestU64Coder(68719476736ull, 51); + TestU64Coder(68719476737ull, 51); + TestU64Coder(1000000000000ull, 51); + TestU64Coder(17592186044414ull, 51); + TestU64Coder(17592186044415ull, 51); + + // Take 60 bits (of which 52 actual value): + // (0)..17592186044416..4503599627370495 + TestU64Coder(17592186044416ull, 60); + TestU64Coder(17592186044417ull, 60); + TestU64Coder(100000000000000ull, 60); + TestU64Coder(4503599627370494ull, 60); + TestU64Coder(4503599627370495ull, 60); + + // Take 69 bits (of which 60 actual value): + // (0)..4503599627370496..1152921504606846975 + TestU64Coder(4503599627370496ull, 69); + TestU64Coder(4503599627370497ull, 69); + TestU64Coder(10000000000000000ull, 69); + TestU64Coder(1152921504606846974ull, 69); + TestU64Coder(1152921504606846975ull, 69); + + // Take 73 bits (of which 64 actual value): + // (0)..1152921504606846976..18446744073709551615 + TestU64Coder(1152921504606846976ull, 73); + TestU64Coder(1152921504606846977ull, 73); + TestU64Coder(10000000000000000000ull, 73); + TestU64Coder(18446744073709551614ull, 73); + TestU64Coder(18446744073709551615ull, 73); +} + +Status TestF16Coder(const float value) { + F16Coder coder; + + size_t max_encoded_bits; + // It is not a fatal error if it can't be encoded. + if (!coder.CanEncode(value, &max_encoded_bits)) return false; + EXPECT_EQ(F16Coder::MaxEncodedBits(), max_encoded_bits); + + BitWriter writer; + BitWriter::Allotment allotment(&writer, + RoundUpBitsToByteMultiple(max_encoded_bits)); + + EXPECT_TRUE(coder.Write(value, &writer)); + EXPECT_EQ(F16Coder::MaxEncodedBits(), writer.BitsWritten()); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, 0, nullptr); + + BitReader reader(writer.GetSpan()); + float decoded_value; + EXPECT_TRUE(coder.Read(&reader, &decoded_value)); + // All values we test can be represented exactly. + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); + return true; +} + +TEST(FieldsTest, F16CoderTest) { + for (float sign : {-1.0f, 1.0f}) { + // (anything less than 1E-3 are subnormals) + for (float mag : {0.0f, 0.5f, 1.0f, 2.0f, 2.5f, 16.015625f, 1.0f / 4096, + 1.0f / 16384, 65504.0f}) { + EXPECT_TRUE(TestF16Coder(sign * mag)); + } + } + + // Out of range + EXPECT_FALSE(TestF16Coder(65504.01f)); + EXPECT_FALSE(TestF16Coder(-65505.0f)); +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripSize) { + for (int i = 0; i < 8; i++) { + SizeHeader size; + ASSERT_TRUE(size.Set(123 + 77 * i, 7 + i)); + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(size, &extension_bits, &total_bits)); + EXPECT_EQ(0, extension_bits); + + BitWriter writer; + ASSERT_TRUE(WriteSizeHeader(size, &writer, 0, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + SizeHeader size2; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadSizeHeader(&reader, &size2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(size.xsize(), size2.xsize()); + EXPECT_EQ(size.ysize(), size2.ysize()); + } +} + +// Ensure all values can be reached by the encoding. +TEST(FieldsTest, TestCropRect) { + CodecMetadata metadata; + for (int32_t i = -1000; i < 19000; ++i) { + FrameHeader f(&metadata); + f.custom_size_or_origin = true; + f.frame_origin.x0 = i; + f.frame_origin.y0 = i; + f.frame_size.xsize = 1000 + i; + f.frame_size.ysize = 1000 + i; + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(f, &extension_bits, &total_bits)); + EXPECT_EQ(0, extension_bits); + EXPECT_GE(total_bits, 9); + } +} +TEST(FieldsTest, TestPreview) { + // (div8 cannot represent 4360, but !div8 can go a little higher) + for (uint32_t i = 1; i < 4360; ++i) { + PreviewHeader p; + ASSERT_TRUE(p.Set(i, i)); + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0, extension_bits); + EXPECT_GE(total_bits, 6); + } +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripFrame) { + CodecMetadata metadata; + FrameHeader h(&metadata); + h.extensions = 0x800; + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(h, &extension_bits, &total_bits)); + EXPECT_EQ(0, extension_bits); + BitWriter writer; + ASSERT_TRUE(WriteFrameHeader(h, &writer, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + FrameHeader h2(&metadata); + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadFrameHeader(&reader, &h2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(h.extensions, h2.extensions); + EXPECT_EQ(h.flags, h2.flags); +} + +#ifndef JXL_CRASH_ON_ERROR +// Ensure out-of-bounds values cause an error. +TEST(FieldsTest, TestOutOfRange) { + SizeHeader h; + ASSERT_TRUE(h.Set(0xFFFFFFFFull, 0xFFFFFFFFull)); + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_FALSE(Bundle::CanEncode(h, &extension_bits, &total_bits)); +} +#endif + +struct OldBundle : public Fields { + OldBundle() { Bundle::Init(this); } + const char* Name() const override { return "OldBundle"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; +}; + +struct NewBundle : public Fields { + NewBundle() { Bundle::Init(this); } + const char* Name() const override { return "NewBundle"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + if (visitor->Conditional(extensions & 1)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(2), Bits(2), Bits(3), Bits(4), 2, &new_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(-2.0f, &new_f)); + } + if (visitor->Conditional(extensions & 2)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(9), Bits(12), Bits(16), Bits(32), 0, &new_large)); + } + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; + + // If extensions & 1 + uint32_t new_small = 2; + float new_f = -2.0f; + // If extensions & 2 + uint32_t new_large = 0; +}; + +TEST(FieldsTest, TestNewDecoderOldData) { + OldBundle old_bundle; + old_bundle.old_large = 123; + old_bundle.old_f = 3.75f; + old_bundle.extensions = 0; + + // Write to bit stream + const size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(old_bundle, &extension_bits, &total_bits)); + ASSERT_LE(total_bits, kMaxOutBytes * kBitsPerByte); + EXPECT_EQ(0, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(old_bundle, &writer, kLayerHeader, &aux_out)); + + BitWriter::Allotment allotment(&writer, + kMaxOutBytes * kBitsPerByte - total_bits); + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, kLayerHeader, nullptr); + + ASSERT_LE(writer.GetSpan().size(), kMaxOutBytes); + BitReader reader(writer.GetSpan()); + NewBundle new_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &new_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55A); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(old_bundle.extensions, new_bundle.extensions); + EXPECT_EQ(old_bundle.old_small, new_bundle.old_small); + EXPECT_EQ(old_bundle.old_f, new_bundle.old_f); + EXPECT_EQ(old_bundle.old_large, new_bundle.old_large); + // New fields match their defaults + EXPECT_EQ(2, new_bundle.new_small); + EXPECT_EQ(-2.0f, new_bundle.new_f); + EXPECT_EQ(0, new_bundle.new_large); +} + +TEST(FieldsTest, TestOldDecoderNewData) { + NewBundle new_bundle; + new_bundle.old_large = 123; + new_bundle.extensions = 3; + new_bundle.new_f = 999.0f; + new_bundle.new_large = 456; + + // Write to bit stream + constexpr size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(new_bundle, &extension_bits, &total_bits)); + EXPECT_NE(0, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(new_bundle, &writer, kLayerHeader, &aux_out)); + ASSERT_LE(aux_out.layers[kLayerHeader].total_bits, + kMaxOutBytes * kBitsPerByte); + + BitWriter::Allotment allotment( + &writer, + kMaxOutBytes * kBitsPerByte - aux_out.layers[kLayerHeader].total_bits); + // Ensure Read skips the additional fields + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, kLayerHeader, nullptr); + + BitReader reader(writer.GetSpan()); + OldBundle old_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &old_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55A); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(new_bundle.extensions, old_bundle.extensions); + EXPECT_EQ(new_bundle.old_small, old_bundle.old_small); + EXPECT_EQ(new_bundle.old_f, old_bundle.old_f); + EXPECT_EQ(new_bundle.old_large, old_bundle.old_large); + // (Can't check new fields because old decoder doesn't know about them) +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/filters.cc b/third_party/jpeg-xl/lib/jxl/filters.cc new file mode 100644 index 000000000000..a948d8c7890c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/filters.cc @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/filters.h" + +#include "lib/jxl/base/profiler.h" + +namespace jxl { + +void FilterWeights::Init(const LoopFilter& lf, + const FrameDimensions& frame_dim) { + if (lf.epf_iters > 0) { + sigma = ImageF(frame_dim.xsize_blocks + 2 * kSigmaPadding, + frame_dim.ysize_blocks + 2 * kSigmaPadding); + } + if (lf.gab) { + GaborishWeights(lf); + } +} + +void FilterWeights::GaborishWeights(const LoopFilter& lf) { + gab_weights[0] = 1; + gab_weights[1] = lf.gab_x_weight1; + gab_weights[2] = lf.gab_x_weight2; + gab_weights[3] = 1; + gab_weights[4] = lf.gab_y_weight1; + gab_weights[5] = lf.gab_y_weight2; + gab_weights[6] = 1; + gab_weights[7] = lf.gab_b_weight1; + gab_weights[8] = lf.gab_b_weight2; + // Normalize + for (size_t c = 0; c < 3; c++) { + const float mul = + 1.0f / (gab_weights[3 * c] + + 4 * (gab_weights[3 * c + 1] + gab_weights[3 * c + 2])); + gab_weights[3 * c] *= mul; + gab_weights[3 * c + 1] *= mul; + gab_weights[3 * c + 2] *= mul; + } +} + +void FilterPipeline::ApplyFiltersRow(const LoopFilter& lf, + const FilterWeights& filter_weights, + const Rect& rect, ssize_t y) { + PROFILER_ZONE("Gaborish+EPF"); + JXL_DASSERT(num_filters != 0); // Must be initialized. + + JXL_ASSERT(y < static_cast(rect.ysize() + lf.Padding())); + + // The minimum value of the center row "y" needed to process the current + // filter. + ssize_t rows_needed = -static_cast(lf.Padding()); + + for (size_t i = 0; i < num_filters; i++) { + const FilterStep& filter = filters[i]; + + rows_needed += filter.filter_def.border; + + // After this "y" points to the rect row for the center of the filter. + y -= filter.filter_def.border; + if (y < rows_needed) return; + + // Compute the region where we need to apply this filter. Depending on the + // step we might need to compute a larger portion than the original rect. + const size_t filter_x0 = kMaxFilterPadding - filter.output_col_border; + const size_t filter_x1 = + filter_x0 + rect.xsize() + 2 * filter.output_col_border; + + // Apply filter to the given region. + FilterRows rows(filter.filter_def.border); + filter.set_input_rows(filter, &rows, y); + filter.set_output_rows(filter, &rows, y); + + // The "y" coordinate used for the sigma image in EPF1. Sigma is padded + // with kMaxFilterPadding (or kMaxFilterPadding/kBlockDim rows in sigma) + // above and below. + const size_t sigma_y = kMaxFilterPadding + rect.y0() + y; + if (compute_sigma) { + rows.SetSigma(filter_weights.sigma, sigma_y, rect.x0()); + } + + filter.filter_def.apply(rows, lf, filter_weights, filter_x0, filter_x1, + rect.x0() % kBlockDim, sigma_y % kBlockDim); + } + JXL_DASSERT(rows_needed == 0); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/filters.h b/third_party/jpeg-xl/lib/jxl/filters.h new file mode 100644 index 000000000000..98215484ab3d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/filters.h @@ -0,0 +1,329 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_FILTERS_H_ +#define LIB_JXL_FILTERS_H_ + +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/filters_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/loop_filter.h" + +namespace jxl { + +struct FilterWeights { + // Initialize the FilterWeights for the passed LoopFilter and FrameDimensions. + void Init(const LoopFilter& lf, const FrameDimensions& frame_dim); + + // Normalized weights for gaborish, in XYB order, each weight for Manhattan + // distance of 0, 1 and 2 respectively. + float gab_weights[9]; + + // Sigma values for EPF, if enabled. + // Note that, for speed reasons, this is actually kInvSigmaNum / sigma. + ImageF sigma; + + private: + void GaborishWeights(const LoopFilter& lf); +}; + +static constexpr size_t kMaxFinalizeRectPadding = 9; + +// Line-based EPF only needs to keep in cache 21 lines of the image, so 256 is +// sufficient for everything to fit in the L2 cache. We add +// 2*RoundUpTo(kMaxFinalizeRectPadding, kBlockDim) pixels as we might have up to +// two extra borders on each side. +constexpr size_t kApplyImageFeaturesTileDim = + 256 + 2 * RoundUpToBlockDim(kMaxFinalizeRectPadding); + +// The maximum row storage needed by the filtering pipeline. This is the sum of +// the number of input rows needed by each step. +constexpr size_t kTotalStorageRows = 7 + 5 + 3; // max is EPF0 + EPF1 + EPF2. + +// The maximum sum of all the borders in a chain of filters. +constexpr size_t kMaxFilterBorder = 1 * kBlockDim; + +// The maximum horizontal filter padding ever needed to apply a chain of +// filters. Intermediate storage must have at least as much padding on each +// left and right sides. This value must be a multiple of kBlockDim. +constexpr size_t kMaxFilterPadding = kMaxFilterBorder + kBlockDim; +static_assert(kMaxFilterPadding % kBlockDim == 0, + "kMaxFilterPadding must be a multiple of block size."); + +// Same as FilterBorder and FilterPadding but for Sigma. +constexpr size_t kSigmaBorder = kMaxFilterBorder / kBlockDim; +constexpr size_t kSigmaPadding = kMaxFilterPadding / kBlockDim; + +// Utility struct to define input/output rows of row-based loop filters. +constexpr size_t kMaxBorderSize = 3; +struct FilterRows { + explicit FilterRows(int border_size) : border_size_(border_size) { + JXL_DASSERT(border_size <= static_cast(kMaxBorderSize)); + } + + JXL_INLINE const float* GetInputRow(int row, size_t c) const { + // Check that row is within range. + JXL_DASSERT(-border_size_ <= row && row <= border_size_); + return rows_in_[c] + offsets_in_[kMaxBorderSize + row]; + } + + float* GetOutputRow(size_t c) const { return rows_out_[c]; } + + const float* GetSigmaRow() const { + JXL_DASSERT(row_sigma_ != nullptr); + return row_sigma_; + } + + template + void SetInput(const Image3F& in, size_t y_offset, ssize_t y0, ssize_t x0, + ssize_t full_image_y_offset = 0, ssize_t image_ysize = 0) { + RowMap row_map(full_image_y_offset, image_ysize); + for (size_t c = 0; c < 3; c++) { + rows_in_[c] = in.ConstPlaneRow(c, 0); + } + for (int32_t i = -border_size_; i <= border_size_; i++) { + size_t y = row_map(y0 + i); + offsets_in_[i + kMaxBorderSize] = + static_cast((y + y_offset) * in.PixelsPerRow()) + x0; + } + } + + template + void SetOutput(Image3F* out, size_t y_offset, ssize_t y0, ssize_t x0) { + size_t y = RowMap()(y0); + for (size_t c = 0; c < 3; c++) { + rows_out_[c] = out->PlaneRow(c, y + y_offset) + x0; + } + } + + // Sets the sigma row for the given y0, x0 input image position. Sigma images + // have one pixel per input image block, although they are padded with two + // blocks (pixels in sigma) on each one of the four sides. The (x0, y0) values + // should include this padding. + void SetSigma(const ImageF& sigma, size_t y0, size_t x0) { + JXL_DASSERT(x0 % GroupBorderAssigner::kPaddingXRound == 0); + row_sigma_ = sigma.ConstRow(y0 / kBlockDim) + x0 / kBlockDim; + } + + private: + // Base pointer to each one of the planes. + const float* JXL_RESTRICT rows_in_[3]; + + // Offset to the pixel x0 at the different rows. offsets_in_[kMaxBorderSize] + // references the center row, regardless of the border_size_. Only the center + // row, border_size_ before and border_size_ after are initialized. The offset + // is relative to the base pointer in rows_in_. + ssize_t offsets_in_[2 * kMaxBorderSize + 1]; + + float* JXL_RESTRICT rows_out_[3]; + + const float* JXL_RESTRICT row_sigma_{nullptr}; + + const int border_size_; +}; + +// Definition of a filter. This specifies the function to be used to apply the +// filter and its row and column padding requirements. +struct FilterDefinition { + // Function to apply the filter to a given row. The filter constant parameters + // are passed in LoopFilter lf and filter_weights. `xoff` is needed to offset + // the `x0` value so that it will cause correct accesses to + // rows.GetSigmaRow(): there is just one sigma value per 8 pixels, and if the + // image rectangle is not aligned to multiples of 8 pixels, we need to + // compensate for the difference between x0 and the image position modulo 8. + void (*apply)(const FilterRows& rows, const LoopFilter& lf, + const FilterWeights& filter_weights, size_t x0, size_t x1, + size_t image_y_mod_8, size_t image_x_mod_8); + + // Number of source image rows and cols before and after an input pixel needed + // to compute the output of the filter. For a 3x3 convolution this border will + // be only 1. + size_t border; +}; + +// A chain of filters to be applied to a source image. This instance must be +// initialized by the FilterPipelineInit() function before it can be used. +class FilterPipeline { + public: + FilterPipeline() : FilterPipeline(kApplyImageFeaturesTileDim) {} + explicit FilterPipeline(size_t max_rect_xsize) + : storage{max_rect_xsize + 2 * kMaxFilterPadding, kTotalStorageRows} { +#if MEMORY_SANITIZER + // The padding of the storage may be used uninitialized since we process + // multiple SIMD lanes at a time, aligned to a multiple of lanes. + // For example, in a hypothetical 3-step filter process where all filters + // use 1 pixel border the first filter needs to process 2 pixels more on + // each side than the requested rect.x0(), rect.xsize(), while the second + // filter needs to process 1 more pixel on each side, however for + // performance reasons both will process Lanes(df) more pixels on each + // side assuming this Lanes(df) value is more than one. In that case the + // second filter will be using one pixel of uninitialized data to generate + // an output pixel that won't affect the final output but may cause msan + // failures. For this reason we initialize the padding region. + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < storage.ysize(); y++) { + float* row = storage.PlaneRow(c, y); + memset(row, 0x77, sizeof(float) * kMaxFilterPadding); + memset(row + storage.xsize() - kMaxFilterPadding, 0x77, + sizeof(float) * kMaxFilterPadding); + } + } +#endif // MEMORY_SANITIZER + } + + FilterPipeline(const FilterPipeline&) = delete; + FilterPipeline(FilterPipeline&&) = default; + + // Apply the filter chain to a given row. To apply the filter chain to a whole + // image this must be called for `rect.ysize() + 2 * total_border` + // values of `y`, in increasing order, starting from `y = -total_border`. + void ApplyFiltersRow(const LoopFilter& lf, + const FilterWeights& filter_weights, const Rect& rect, + ssize_t y); + + struct FilterStep { + // Sets the input of the filter step as an image region. + void SetInput(const Image3F* im_input, const Rect& input_rect, + const Rect& image_rect, size_t image_ysize) { + input = im_input; + this->input_rect = input_rect; + this->image_rect = image_rect; + this->image_ysize = image_ysize; + JXL_DASSERT(SameSize(input_rect, image_rect)); + set_input_rows = [](const FilterStep& self, FilterRows* rows, + ssize_t y0) { + ssize_t full_image_y_offset = + static_cast(self.image_rect.y0()) - + static_cast(self.input_rect.y0()); + rows->SetInput(*(self.input), 0, + self.input_rect.y0() + y0, + self.input_rect.x0() - kMaxFilterPadding, + full_image_y_offset, self.image_ysize); + }; + } + + // Sets the input of the filter step as the temporary cyclic storage with + // num_rows rows. The value rect.x0() during application will be mapped to + // kMaxFilterPadding regardless of the rect being processed. + template + void SetInputCyclicStorage(const Image3F* storage, size_t offset_rows) { + input = storage; + input_y_offset = offset_rows; + set_input_rows = [](const FilterStep& self, FilterRows* rows, + ssize_t y0) { + rows->SetInput>(*(self.input), self.input_y_offset, + y0, 0); + }; + } + + // Sets the output of the filter step as the temporary cyclic storage with + // num_rows rows. The value rect.x0() during application will be mapped to + // kMaxFilterPadding regardless of the rect being processed. + template + void SetOutputCyclicStorage(Image3F* storage, size_t offset_rows) { + output = storage; + output_y_offset = offset_rows; + set_output_rows = [](const FilterStep& self, FilterRows* rows, + ssize_t y0) { + rows->SetOutput>(self.output, self.output_y_offset, + y0, 0); + }; + } + + // Set the output of the filter step as the output image. The value + // rect.x0() will be mapped to the same value in the output image. + void SetOutput(Image3F* im_output, const Rect& output_rect) { + output = im_output; + this->output_rect = output_rect; + set_output_rows = [](const FilterStep& self, FilterRows* rows, + ssize_t y0) { + rows->SetOutput( + self.output, 0, self.output_rect.y0() + y0, + static_cast(self.output_rect.x0()) - kMaxFilterPadding); + }; + } + + // The input and output image buffers for the current filter step. Note that + // the rows used from these images depends on the module used in + // set_input_rows and set_output_rows functions. + const Image3F* input; + size_t input_y_offset = 0; + Image3F* output; + size_t output_y_offset = 0; + + // Input/output rect for the first/last steps of the filter. + Rect input_rect; + Rect output_rect; + + // Information to properly do RowMapMirror(). + Rect image_rect; + size_t image_ysize; + + // Functions that compute the list of rows needed to process a region for + // the given row and starting column. + void (*set_input_rows)(const FilterStep&, FilterRows* rows, ssize_t y0); + void (*set_output_rows)(const FilterStep&, FilterRows* rows, ssize_t y0); + + // Actual filter descriptor. + FilterDefinition filter_def; + + // Number of extra horizontal pixels needed on each side of the output of + // this filter to produce the requested rect at the end of the chain. This + // value is always 0 for the last filter of the chain but it depends on the + // actual filter chain used in other cases. + size_t output_col_border; + }; + + template + void AddStep(const FilterDefinition& filter_def) { + JXL_DASSERT(num_filters < kMaxFilters); + filters[num_filters].filter_def = filter_def; + + if (num_filters > 0) { + // If it is not the first step we need to set the previous step output to + // a portion of the cyclic storage. We only need as many rows as the + // input of the current stage. + constexpr size_t num_rows = 2 * border + 1; + filters[num_filters - 1].SetOutputCyclicStorage( + &storage, storage_rows_used); + filters[num_filters].SetInputCyclicStorage(&storage, + storage_rows_used); + storage_rows_used += num_rows; + JXL_DASSERT(storage_rows_used <= kTotalStorageRows); + } + num_filters++; + } + + // Tile storage for ApplyImageFeatures steps. Different groups of rows of this + // image are used for the intermediate steps. + Image3F storage; + size_t storage_rows_used = 0; + + static const size_t kMaxFilters = 4; + FilterStep filters[kMaxFilters]; + size_t num_filters = 0; + + // Whether we need to compute the sigma_row_ during application. + bool compute_sigma = false; + + // The total border needed to process this pipeline. + size_t total_border = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_FILTERS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/filters_internal.h b/third_party/jpeg-xl/lib/jxl/filters_internal.h new file mode 100644 index 000000000000..ece032c73c48 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/filters_internal.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_FILTERS_INTERNAL_H_ +#define LIB_JXL_FILTERS_INTERNAL_H_ + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +// Maps a row to the range [0, image_ysize) mirroring it when outside the [0, +// image_ysize) range. The input row is offset by `full_image_y_offset`, i.e. +// row `y` corresponds to row `y + full_image_y_offset` in the full frame. +struct RowMapMirror { + RowMapMirror(ssize_t full_image_y_offset, size_t image_ysize) + : full_image_y_offset_(full_image_y_offset), image_ysize_(image_ysize) {} + size_t operator()(ssize_t y) { + return Mirror(y + full_image_y_offset_, image_ysize_) - + full_image_y_offset_; + } + ssize_t full_image_y_offset_; + size_t image_ysize_; +}; + +// Maps a row in the range [-16, \inf) to a row number in the range [0, m) using +// the modulo operation. +template +struct RowMapMod { + RowMapMod() = default; + RowMapMod(ssize_t /*full_image_y_offset*/, size_t /*image_ysize*/) {} + size_t operator()(ssize_t y) { + JXL_DASSERT(y >= -16); + // The `m > 16 ? m : 16 * m` is evaluated at compile time and is a multiple + // of m of at least 16. This is to make sure that the left operand is + // positive. + return static_cast(y + (m > 16 ? m : 16 * m)) % m; + } +}; + +// Identity mapping. Maps a row in the range [0, ysize) to the same value. +struct RowMapId { + size_t operator()(ssize_t y) { + JXL_DASSERT(y >= 0); + return y; + } +}; + +} // namespace jxl + +#endif // LIB_JXL_FILTERS_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/filters_internal_test.cc b/third_party/jpeg-xl/lib/jxl/filters_internal_test.cc new file mode 100644 index 000000000000..0c61c39637fb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/filters_internal_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/filters_internal.h" + +#include "gtest/gtest.h" + +namespace jxl { + +class FiltersInternalTest : public ::testing::Test {}; + +// Test the mping of rows using RowMapMod. +TEST(FiltersInternalTest, RowMapModTest) { + RowMapMod<5> m; + // Identity part: + EXPECT_EQ(0, m(0)); + EXPECT_EQ(4, m(4)); + + // Larger than the module work. + EXPECT_EQ(0, m(5)); + EXPECT_EQ(1, m(11)); + + // Smaller than 0 up to a block. + EXPECT_EQ(4, m(-1)); + EXPECT_EQ(2, m(-8)); +} + +// Test the implementation for mirroring of rows. +TEST(FiltersInternalTest, RowMapMirrorTest) { + RowMapMirror m(0, 10); // Image size of 10 rows. + + EXPECT_EQ(2, m(-3)); + EXPECT_EQ(1, m(-2)); + EXPECT_EQ(0, m(-1)); + + EXPECT_EQ(0, m(0)); + EXPECT_EQ(9, m(9)); + + EXPECT_EQ(9, m(10)); + EXPECT_EQ(8, m(11)); + EXPECT_EQ(7, m(12)); + + // It mirrors the rows to infinity. + EXPECT_EQ(1, m(21)); + EXPECT_EQ(1, m(41)); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/frame_header.cc b/third_party/jpeg-xl/lib/jxl/frame_header.cc new file mode 100644 index 000000000000..bc63346e3822 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/frame_header.cc @@ -0,0 +1,375 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/frame_header.h" + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +constexpr uint8_t YCbCrChromaSubsampling::kHShift[]; +constexpr uint8_t YCbCrChromaSubsampling::kVShift[]; + +static Status VisitBlendMode(Visitor* JXL_RESTRICT visitor, + BlendMode default_value, BlendMode* blend_mode) { + uint32_t encoded = static_cast(*blend_mode); + + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(static_cast(BlendMode::kReplace)), + Val(static_cast(BlendMode::kAdd)), + Val(static_cast(BlendMode::kBlend)), BitsOffset(2, 3), + static_cast(default_value), &encoded)); + if (encoded > 4) { + return JXL_FAILURE("Invalid blend_mode"); + } + *blend_mode = static_cast(encoded); + return true; +} + +static Status VisitFrameType(Visitor* JXL_RESTRICT visitor, + FrameType default_value, FrameType* frame_type) { + uint32_t encoded = static_cast(*frame_type); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(static_cast(FrameType::kRegularFrame)), + Val(static_cast(FrameType::kDCFrame)), + Val(static_cast(FrameType::kReferenceOnly)), + Val(static_cast(FrameType::kSkipProgressive)), + static_cast(default_value), &encoded)); + *frame_type = static_cast(encoded); + return true; +} + +BlendingInfo::BlendingInfo() { Bundle::Init(this); } + +Status BlendingInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + VisitBlendMode(visitor, BlendMode::kReplace, &mode)); + if (visitor->Conditional(nonserialized_has_multiple_extra_channels && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd))) { + // Up to 11 alpha channels for blending. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(3, 3), 0, &alpha_channel)); + } + if (visitor->Conditional((nonserialized_has_multiple_extra_channels && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd)) || + mode == BlendMode::kMul)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &clamp)); + } + // 'old' frame for blending. Only necessary if this is not a full frame, or + // blending is not kReplace. + if (visitor->Conditional(mode != BlendMode::kReplace || + nonserialized_is_partial_frame)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &source)); + } + return true; +} + +AnimationFrame::AnimationFrame(const CodecMetadata* metadata) + : nonserialized_metadata(metadata) { + Bundle::Init(this); +} +Status AnimationFrame::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Bits(8), Bits(32), 0, &duration)); + } + + if (visitor->Conditional( + nonserialized_metadata != nullptr && + nonserialized_metadata->m.animation.have_timecodes)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(32, 0, &timecode)); + } + return true; +} + +YCbCrChromaSubsampling::YCbCrChromaSubsampling() { Bundle::Init(this); } +Passes::Passes() { Bundle::Init(this); } +Status Passes::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(3, 4), 1, &num_passes)); + JXL_ASSERT(num_passes <= kMaxNumPasses); // Cannot happen when reading + + if (visitor->Conditional(num_passes != 1)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(1, 3), 0, &num_downsample)); + JXL_ASSERT(num_downsample <= 4); // 1,2,4,8 + if (num_downsample > num_passes) { + return JXL_FAILURE("num_downsample %u > num_passes %u", num_downsample, + num_passes); + } + + for (uint32_t i = 0; i < num_passes - 1; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &shift[i])); + } + shift[num_passes - 1] = 0; + + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &downsample[i])); + } + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Bits(3), 0, &last_pass[i])); + if (last_pass[i] >= num_passes) { + return JXL_FAILURE("last_pass %u >= num_passes %u", last_pass[i], + num_passes); + } + } + } + + return true; +} +FrameHeader::FrameHeader(const CodecMetadata* metadata) + : animation_frame(metadata), nonserialized_metadata(metadata) { + Bundle::Init(this); +} + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame) { + return Bundle::Read(reader, frame); +} + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + return Bundle::Write(frame, writer, kLayerHeader, aux_out); +} + +Status FrameHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + VisitFrameType(visitor, FrameType::kRegularFrame, &frame_type)); + if (visitor->IsReading() && nonserialized_is_preview && + frame_type != kRegularFrame) { + return JXL_FAILURE("Only regular frame could be a preview"); + } + + // FrameEncoding. + bool is_modular = (encoding == FrameEncoding::kModular); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &is_modular)); + encoding = (is_modular ? FrameEncoding::kModular : FrameEncoding::kVarDCT); + + // Flags + JXL_QUIET_RETURN_IF_ERROR(visitor->U64(0, &flags)); + + // Color transform + bool xyb_encoded = nonserialized_metadata == nullptr || + nonserialized_metadata->m.xyb_encoded; + + bool fp = nonserialized_metadata != nullptr && + nonserialized_metadata->m.bit_depth.floating_point_sample; + + if (xyb_encoded) { + if (is_modular && fp) { + return JXL_FAILURE( + "Floating point samples is not supported with XYB color encoding"); + } + color_transform = ColorTransform::kXYB; + } else { + // Alternate if kYCbCr. + bool alternate = color_transform == ColorTransform::kYCbCr; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alternate)); + color_transform = + (alternate ? ColorTransform::kYCbCr : ColorTransform::kNone); + } + + // Chroma subsampling for YCbCr, if no DC frame is used. + if (visitor->Conditional(color_transform == ColorTransform::kYCbCr && + ((flags & kUseDcFrame) == 0))) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&chroma_subsampling)); + } + if (is_modular && !chroma_subsampling.Is444()) { + return JXL_FAILURE( + "Chroma subsampling is not supported yet in modular mode"); + } + + size_t num_extra_channels = + nonserialized_metadata != nullptr + ? nonserialized_metadata->m.extra_channel_info.size() + : 0; + + // Upsampling + if (visitor->Conditional((flags & kUseDcFrame) == 0)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &upsampling)); + if (nonserialized_metadata != nullptr && + visitor->Conditional(num_extra_channels != 0)) { + const std::vector& extra_channels = + nonserialized_metadata->m.extra_channel_info; + extra_channel_upsampling.resize(extra_channels.size(), 1); + bool foundAlpha = false; + for (size_t i = 0; i < extra_channels.size(); ++i) { + uint32_t& ec_upsampling = extra_channel_upsampling[i]; + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &ec_upsampling)); + if (ec_upsampling != 1) { + return JXL_FAILURE( + "Upsampling for extra channels not yet implemented"); + } + if (!foundAlpha && extra_channels[i].type == ExtraChannel::kAlpha) { + foundAlpha = true; + if (ec_upsampling != upsampling) { + return JXL_FAILURE("Alpha upsampling != color upsampling"); + } + } + } + } else { + extra_channel_upsampling.clear(); + } + } + + // Modular- or VarDCT-specific data. + if (visitor->Conditional(encoding == FrameEncoding::kModular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 1, &group_size_shift)); + } + if (visitor->Conditional(encoding == FrameEncoding::kVarDCT && + color_transform == ColorTransform::kXYB)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 3, &x_qm_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 2, &b_qm_scale)); + } else { + x_qm_scale = b_qm_scale = 2; // noop + } + + // Not useful for kPatchSource + if (visitor->Conditional(frame_type != FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&passes)); + } + + if (visitor->Conditional(frame_type == FrameType::kDCFrame)) { + // Up to 4 pyramid levels - for up to 16384x downsampling. + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &dc_level)); + } + if (frame_type != FrameType::kDCFrame) { + dc_level = 0; + } + + bool is_partial_frame = false; + if (visitor->Conditional(frame_type != FrameType::kDCFrame)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &custom_size_or_origin)); + if (visitor->Conditional(custom_size_or_origin)) { + const U32Enc enc(Bits(8), BitsOffset(11, 256), BitsOffset(14, 2304), + BitsOffset(30, 18688)); + // Frame offset, only if kRegularFrame or kSkipProgressive. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + uint32_t ux0 = PackSigned(frame_origin.x0); + uint32_t uy0 = PackSigned(frame_origin.y0); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &ux0)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &uy0)); + frame_origin.x0 = UnpackSigned(ux0); + frame_origin.y0 = UnpackSigned(uy0); + } + // Frame size + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.xsize)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.ysize)); + int32_t image_xsize = default_xsize(); + int32_t image_ysize = default_ysize(); + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + is_partial_frame |= frame_origin.x0 > 0; + is_partial_frame |= frame_origin.y0 > 0; + is_partial_frame |= (static_cast(frame_size.xsize) + + frame_origin.x0) < image_xsize; + is_partial_frame |= (static_cast(frame_size.ysize) + + frame_origin.y0) < image_ysize; + } + } + } + + // Blending info, animation info and whether this is the last frame or not. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + blending_info.nonserialized_has_multiple_extra_channels = + num_extra_channels > 0; + blending_info.nonserialized_is_partial_frame = is_partial_frame; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blending_info)); + bool replace_all = (blending_info.mode == BlendMode::kReplace); + extra_channel_blending_info.resize(num_extra_channels); + for (size_t i = 0; i < num_extra_channels; i++) { + auto& ec_blending_info = extra_channel_blending_info[i]; + ec_blending_info.nonserialized_is_partial_frame = is_partial_frame; + ec_blending_info.nonserialized_has_multiple_extra_channels = + num_extra_channels > 0; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&ec_blending_info)); + replace_all &= (ec_blending_info.mode == BlendMode::kReplace); + } + if (visitor->IsReading() && nonserialized_is_preview) { + if (!replace_all || custom_size_or_origin) { + return JXL_FAILURE("Preview is not compatible with blending"); + } + } + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + animation_frame.nonserialized_metadata = nonserialized_metadata; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation_frame)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &is_last)); + } + if (frame_type != FrameType::kRegularFrame) { + is_last = false; + } + + // ID of that can be used to refer to this frame. 0 for a non-zero-duration + // frame means that it will not be referenced. Not necessary for the last + // frame. + if (visitor->Conditional(frame_type != kDCFrame && !is_last)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &save_as_reference)); + } + + // If this frame is not blended on another frame post-color-transform, it may + // be stored for being referenced either before or after the color transform. + // If it is blended post-color-transform, it must be blended after. It must + // also be blended after if this is a kRegular frame that does not cover the + // full frame, as samples outside the partial region are from a + // post-color-transform frame. + if (frame_type != FrameType::kDCFrame) { + if (visitor->Conditional(CanBeReferenced() && + blending_info.mode == BlendMode::kReplace && + !is_partial_frame && + (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive))) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(false, &save_before_color_transform)); + } else if (visitor->Conditional(frame_type == FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &save_before_color_transform)); + } + } else { + save_before_color_transform = true; + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + loop_filter.nonserialized_is_modular = is_modular; + JXL_RETURN_IF_ERROR(visitor->VisitNested(&loop_filter)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/frame_header.h b/third_party/jpeg-xl/lib/jxl/frame_header.h new file mode 100644 index 000000000000..c58aefe03702 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/frame_header.h @@ -0,0 +1,501 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_FRAME_HEADER_H_ +#define LIB_JXL_FRAME_HEADER_H_ + +// Frame header with backward and forward-compatible extension capability and +// compressed integer fields. + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/loop_filter.h" + +namespace jxl { + +// Also used by extra channel names. +static inline Status VisitNameString(Visitor* JXL_RESTRICT visitor, + std::string* name) { + uint32_t name_length = static_cast(name->length()); + // Allows layer name lengths up to 1071 bytes + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Bits(4), BitsOffset(5, 16), + BitsOffset(10, 48), 0, &name_length)); + if (visitor->IsReading()) { + name->resize(name_length); + } + for (size_t i = 0; i < name_length; i++) { + uint32_t c = (*name)[i]; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(8, 0, &c)); + (*name)[i] = static_cast(c); + } + return true; +} + +enum class FrameEncoding : uint32_t { + kVarDCT, + kModular, +}; + +enum class ColorTransform : uint32_t { + kXYB, // Values are encoded with XYB. May only be used if + // ImageBundle::xyb_encoded. + kNone, // Values are encoded according to the attached color profile. May + // only be used if !ImageBundle::xyb_encoded. + kYCbCr, // Values are encoded according to the attached color profile, but + // transformed to YCbCr. May only be used if + // !ImageBundle::xyb_encoded. +}; + +inline std::array JpegOrder(ColorTransform ct, bool is_gray) { + if (is_gray) { + return {0, 0, 0}; + } + JXL_ASSERT(ct != ColorTransform::kXYB); + if (ct == ColorTransform::kYCbCr) { + return {1, 0, 2}; + } else { + return {0, 1, 2}; + } +} + +struct YCbCrChromaSubsampling : public Fields { + YCbCrChromaSubsampling(); + const char* Name() const override { return "YCbCrChromaSubsampling"; } + size_t HShift(size_t c) const { return maxhs_ - kHShift[channel_mode_[c]]; } + size_t VShift(size_t c) const { return maxvs_ - kVShift[channel_mode_[c]]; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + // TODO(veluca): consider allowing 4x downsamples + for (size_t i = 0; i < 3; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &channel_mode_[i])); + } + Recompute(); + return true; + } + + uint8_t MaxHShift() const { return maxhs_; } + uint8_t MaxVShift() const { return maxvs_; } + + uint8_t RawHShift(size_t c) { return kHShift[channel_mode_[c]]; } + uint8_t RawVShift(size_t c) { return kVShift[channel_mode_[c]]; } + + // Uses JPEG channel order (Y, Cb, Cr). + Status Set(const uint8_t* hsample, const uint8_t* vsample) { + for (size_t c = 0; c < 3; c++) { + size_t cjpeg = c < 2 ? c ^ 1 : c; + size_t i = 0; + for (; i < 4; i++) { + if (1 << kHShift[i] == hsample[cjpeg] && + 1 << kVShift[i] == vsample[cjpeg]) { + channel_mode_[c] = i; + break; + } + } + if (i == 4) { + return JXL_FAILURE("Invalid subsample mode"); + } + } + Recompute(); + return true; + } + + bool Is444() const { + for (size_t c : {0, 2}) { + if (channel_mode_[c] != channel_mode_[1]) { + return false; + } + } + return true; + } + + bool Is420() const { + return channel_mode_[0] == 1 && channel_mode_[1] == 0 && + channel_mode_[2] == 1; + } + + bool Is422() const { + for (size_t c : {0, 2}) { + if (kHShift[channel_mode_[c]] == kHShift[channel_mode_[1]] + 1 && + kVShift[channel_mode_[c]] == kVShift[channel_mode_[1]]) { + return false; + } + } + return true; + } + + bool Is440() const { + for (size_t c : {0, 2}) { + if (kHShift[channel_mode_[c]] == kHShift[channel_mode_[1]] && + kVShift[channel_mode_[c]] == kVShift[channel_mode_[1]] + 1) { + return false; + } + } + return true; + } + + private: + void Recompute() { + maxhs_ = 0; + maxvs_ = 0; + for (size_t i = 0; i < 3; i++) { + maxhs_ = std::max(maxhs_, kHShift[channel_mode_[i]]); + maxvs_ = std::max(maxvs_, kVShift[channel_mode_[i]]); + } + } + static constexpr uint8_t kHShift[4] = {0, 1, 1, 0}; + static constexpr uint8_t kVShift[4] = {0, 1, 0, 1}; + uint32_t channel_mode_[3]; + uint8_t maxhs_; + uint8_t maxvs_; +}; + +// Indicates how to combine the current frame with a previously-saved one. Can +// be independently controlled for color and extra channels. Formulas are +// indicative and treat alpha as if it is in range 0.0-1.0. In descriptions +// below, alpha channel is the extra channel of type alpha used for blending +// according to the blend_channel, or fully opaque if there is no alpha channel. +// The blending specified here is used for performing blending *after* color +// transforms - in linear sRGB if blending a XYB-encoded frame on another +// XYB-encoded frame, in sRGB if blending a frame with kColorSpace == kSRGB, or +// in the original colorspace otherwise. Blending in XYB or YCbCr is done by +// using patches. +enum class BlendMode { + // The new values (in the crop) replace the old ones: sample = new + kReplace = 0, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 1, + // The new values (in the crop) replace the old ones if alpha>0: + // For the alpha channel that is used as source: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + kBlend = 2, + // The new values (in the crop) are added to the old ones if alpha>0: + // For the alpha channel that is used as source: + // sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAdd = 3, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // The range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + // If using kMul as a blend mode for color channels, no color transform is + // performed on the current frame. + kMul = 4, +}; + +struct BlendingInfo : public Fields { + BlendingInfo(); + const char* Name() const override { return "BlendingInfo"; } + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + BlendMode mode; + // Which extra channel to use as alpha channel for blending, only encoded + // for blend modes that involve alpha and if there are more than 1 extra + // channels. + uint32_t alpha_channel; + // Clamp alpha or channel values to 0-1 range. + bool clamp; + // Frame ID to copy from (0-3). Only encoded if blend_mode is not kReplace. + uint32_t source; + + bool nonserialized_has_multiple_extra_channels = false; + bool nonserialized_is_partial_frame = false; +}; + +// Origin of the current frame. Not present for frames of type +// kOnlyPatches. +struct FrameOrigin { + int32_t x0, y0; // can be negative. +}; + +// Size of the current frame. +struct FrameSize { + uint32_t xsize, ysize; +}; + +// AnimationFrame defines duration of animation frames. +struct AnimationFrame : public Fields { + explicit AnimationFrame(const CodecMetadata* metadata); + const char* Name() const override { return "AnimationFrame"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // How long to wait [in ticks, see Animation{}] after rendering. + // May be 0 if the current frame serves as a foundation for another frame. + uint32_t duration; + + uint32_t timecode; // 0xHHMMSSFF + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; +}; + +// For decoding to lower resolutions. Only used for kRegular frames. +struct Passes : public Fields { + Passes(); + const char* Name() const override { return "Passes"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + void GetDownsamplingBracket(size_t pass, int& minShift, int& maxShift) const { + maxShift = 2; + minShift = 0; + for (size_t i = 0;; i++) { + for (uint32_t j = 0; j < num_downsample; ++j) { + if (i <= last_pass[j]) { + if (downsample[j] == 8) minShift = 3; + if (downsample[j] == 4) minShift = 2; + if (downsample[j] == 2) minShift = 1; + if (downsample[j] == 1) minShift = 0; + } + } + if (i == num_passes - 1) minShift = 0; + if (i == pass) return; + maxShift = minShift - 1; + minShift = 0; + } + } + + uint32_t num_passes; // <= kMaxNumPasses + uint32_t num_downsample; // <= num_passes + + // Array of num_downsample pairs. downsample=1/last_pass=num_passes-1 and + // downsample=8/last_pass=0 need not be specified; they are implicit. + uint32_t downsample[kMaxNumPasses]; + uint32_t last_pass[kMaxNumPasses]; + // Array of shift values for each pass. It is implicitly assumed to be 0 for + // the last pass. + uint32_t shift[kMaxNumPasses]; +}; + +enum FrameType { + // A "regular" frame: might be a crop, and will be blended on a previous + // frame, if any, and displayed or blended in future frames. + kRegularFrame = 0, + // A DC frame: this frame is downsampled and will be *only* used as the DC of + // a future frame and, possibly, for previews. Cannot be cropped, blended, or + // referenced by patches or blending modes. Frames that *use* a DC frame + // cannot have non-default sizes either. + kDCFrame = 1, + // A PatchesSource frame: this frame will be only used as a source frame for + // taking patches. Can be cropped, but cannot have non-(0, 0) x0 and y0. + kReferenceOnly = 2, + // Same as kRegularFrame, but not used for progressive rendering. This also + // implies no early display of DC. + kSkipProgressive = 3, +}; + +// Image/frame := one of more of these, where the last has is_last = true. +// Starts at a byte-aligned address "a"; the next pass starts at "a + size". +struct FrameHeader : public Fields { + // Optional postprocessing steps. These flags are the source of truth; + // Override must set/clear them rather than change their meaning. Values + // chosen such that typical flags == 0 (encoded in only two bits). + enum Flags { + // Often but not always off => low bit value: + + // Inject noise into decoded output. + kNoise = 1, + + // Overlay patches. + kPatches = 2, + + // 4, 8 = reserved for future sometimes-off + + // Overlay splines. + kSplines = 16, + + kUseDcFrame = 32, // Implies kSkipAdaptiveDCSmoothing. + + // 64 = reserved for future often-off + + // Almost always on => negated: + + kSkipAdaptiveDCSmoothing = 128, + }; + + explicit FrameHeader(const CodecMetadata* metadata); + const char* Name() const override { return "FrameHeader"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Sets/clears `flag` based upon `condition`. + void UpdateFlag(const bool condition, const uint64_t flag) { + if (condition) { + flags |= flag; + } else { + flags &= ~flag; + } + } + + // Returns true if this frame is supposed to be saved for future usage by + // other frames. + bool CanBeReferenced() const { + // DC frames cannot be referenced. The last frame cannot be referenced. A + // duration 0 frame makes little sense if it is not referenced. A + // non-duration 0 frame may or may not be referenced. + return !is_last && frame_type != FrameType::kDCFrame && + (animation_frame.duration == 0 || save_as_reference != 0); + } + + mutable bool all_default; + + // Always present + FrameEncoding encoding; + // Some versions of UBSAN complain in VisitFrameType if not initialized. + FrameType frame_type = FrameType::kRegularFrame; + + uint64_t flags; + + ColorTransform color_transform; + YCbCrChromaSubsampling chroma_subsampling; + + uint32_t group_size_shift; // only if encoding == kModular; + + uint32_t x_qm_scale; // only if VarDCT and color_transform == kXYB + uint32_t b_qm_scale; // only if VarDCT and color_transform == kXYB + + std::string name; + + // Skipped for kReferenceOnly. + Passes passes; + + // Skipped for kDCFrame + bool custom_size_or_origin; + FrameSize frame_size; + + // upsampling factors for color and extra channels. + // Upsampling is always performed before applying any inverse color transform. + // Skipped (1) if kUseDCFrame + uint32_t upsampling; + std::vector extra_channel_upsampling; + + // Only for kRegular frames. + FrameOrigin frame_origin; + + BlendingInfo blending_info; + std::vector extra_channel_blending_info; + + // Animation info for this frame. + AnimationFrame animation_frame; + + // This is the last frame. + bool is_last; + + // ID to refer to this frame with. 0-3, not present if kDCFrame. + // 0 has a special meaning for kRegular frames of nonzero duration: it defines + // a frame that will not be referenced in the future. + uint32_t save_as_reference; + + // Whether to save this frame before or after the color transform. A frame + // that is saved before the color tansform can only be used for blending + // through patches. On the contrary, a frame that is saved after the color + // transform can only be used for blending through blending modes. + // Irrelevant for extra channel blending. Can only be true if + // blending_info.mode == kReplace and this is not a partial kRegularFrame; if + // this is a DC frame, it is always true. + bool save_before_color_transform; + + uint32_t dc_level; // 1-4 if kDCFrame (0 otherwise). + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; + + // NOTE: This is ignored by AllDefault. + LoopFilter loop_filter; + + bool nonserialized_is_preview = false; + + size_t default_xsize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.xsize(); + } + return nonserialized_metadata->xsize(); + } + + size_t default_ysize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.ysize(); + } + return nonserialized_metadata->ysize(); + } + + FrameDimensions ToFrameDimensions() const { + size_t xsize = default_xsize(); + size_t ysize = default_ysize(); + + xsize = frame_size.xsize ? frame_size.xsize : xsize; + ysize = frame_size.ysize ? frame_size.ysize : ysize; + + if (dc_level != 0) { + xsize = DivCeil(xsize, 1 << (3 * dc_level)); + ysize = DivCeil(ysize, 1 << (3 * dc_level)); + } + + FrameDimensions frame_dim; + frame_dim.Set(xsize, ysize, group_size_shift, + chroma_subsampling.MaxHShift(), + chroma_subsampling.MaxVShift(), + encoding == FrameEncoding::kModular, upsampling); + return frame_dim; + } + + // True if a color transform should be applied to this frame. + bool needs_color_transform() const { + return !save_before_color_transform || + frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive; + } + + uint64_t extensions; +}; + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame); + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +// Shared by enc/dec. 5F and 13 are by far the most common for d1/2/4/8, 0 +// ensures low overhead for small images. +static constexpr U32Enc kOrderEnc = + U32Enc(Val(0x5F), Val(0x13), Val(0), Bits(kNumOrders)); + +} // namespace jxl + +#endif // LIB_JXL_FRAME_HEADER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/gaborish.cc b/third_party/jpeg-xl/lib/jxl/gaborish.cc new file mode 100644 index 000000000000..6124681b3268 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gaborish.cc @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/gaborish.h" + +#include + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +void GaborishInverse(Image3F* in_out, float mul, ThreadPool* pool) { + JXL_ASSERT(mul >= 0.0f); + + // Only an approximation. One or even two 3x3, and rank-1 (separable) 5x5 + // are insufficient. + constexpr float kGaborish[5] = { + -0.092359145662814029f, -0.039253623634014627f, 0.016176494530216929f, + 0.00083458437774987476f, 0.004512465323949319f, + }; + /* + better would be: + 1.0 - mul * (4 * (kGaborish[0] + kGaborish[1] + + kGaborish[2] + kGaborish[4]) + + 8 * (kGaborish[3])); + */ + WeightsSymmetric5 weights = {{HWY_REP4(1.0f)}, + {HWY_REP4(mul * kGaborish[0])}, + {HWY_REP4(mul * kGaborish[2])}, + {HWY_REP4(mul * kGaborish[1])}, + {HWY_REP4(mul * kGaborish[4])}, + {HWY_REP4(mul * kGaborish[3])}}; + double sum = static_cast(weights.c[0]); + sum += 4 * weights.r[0]; + sum += 4 * weights.R[0]; + sum += 4 * weights.d[0]; + sum += 4 * weights.D[0]; + sum += 8 * weights.L[0]; + const float normalize = static_cast(1.0 / sum); + for (size_t i = 0; i < 4; ++i) { + weights.c[i] *= normalize; + weights.r[i] *= normalize; + weights.R[i] *= normalize; + weights.d[i] *= normalize; + weights.D[i] *= normalize; + weights.L[i] *= normalize; + } + + // Reduce memory footprint by only allocating a single plane and swapping it + // into the output Image3F. Better still would be tiling. + // Note that we cannot *allocate* a plane, as doing so might cause Image3F to + // have planes of different stride. Instead, we copy one plane in a temporary + // image and reuse the existing planes of the in/out image. + ImageF temp = CopyImage(in_out->Plane(2)); + Symmetric5(in_out->Plane(0), Rect(*in_out), weights, pool, &in_out->Plane(2)); + Symmetric5(in_out->Plane(1), Rect(*in_out), weights, pool, &in_out->Plane(0)); + Symmetric5(temp, Rect(*in_out), weights, pool, &in_out->Plane(1)); + // Now planes are 1, 2, 0. + in_out->Plane(0).Swap(in_out->Plane(1)); + // 2 1 0 + in_out->Plane(0).Swap(in_out->Plane(2)); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/gaborish.h b/third_party/jpeg-xl/lib/jxl/gaborish.h new file mode 100644 index 000000000000..349bd1b8b10e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gaborish.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_GABORISH_H_ +#define LIB_JXL_GABORISH_H_ + +// Linear smoothing (3x3 convolution) for deblocking without too much blur. + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Used in encoder to reduce the impact of the decoder's smoothing. +// This is not exact. Works in-place to reduce memory use. +// The input is typically in XYB space. +void GaborishInverse(Image3F* in_out, float mul, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_GABORISH_H_ diff --git a/third_party/jpeg-xl/lib/jxl/gaborish_test.cc b/third_party/jpeg-xl/lib/jxl/gaborish_test.cc new file mode 100644 index 000000000000..dcb767477828 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gaborish_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/gaborish.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +// weight1,2 need not be normalized. +WeightsSymmetric3 GaborishKernel(float weight1, float weight2) { + constexpr float weight0 = 1.0f; + + // Normalize + const float mul = 1.0f / (weight0 + 4 * (weight1 + weight2)); + const float w0 = weight0 * mul; + const float w1 = weight1 * mul; + const float w2 = weight2 * mul; + + const WeightsSymmetric3 w = {{HWY_REP4(w0)}, {HWY_REP4(w1)}, {HWY_REP4(w2)}}; + return w; +} + +void ConvolveGaborish(const ImageF& in, float weight1, float weight2, + ThreadPool* pool, ImageF* JXL_RESTRICT out) { + JXL_CHECK(SameSize(in, *out)); + Symmetric3(in, Rect(in), GaborishKernel(weight1, weight2), pool, out); +} + +void TestRoundTrip(const Image3F& in, float max_l1) { + Image3F fwd(in.xsize(), in.ysize()); + ThreadPool* null_pool = nullptr; + ConvolveGaborish(in.Plane(0), 0, 0, null_pool, &fwd.Plane(0)); + ConvolveGaborish(in.Plane(1), 0, 0, null_pool, &fwd.Plane(1)); + ConvolveGaborish(in.Plane(2), 0, 0, null_pool, &fwd.Plane(2)); + GaborishInverse(&fwd, 0.92718927264540152f, null_pool); + VerifyRelativeError(in, fwd, max_l1, 1E-4f); +} + +TEST(GaborishTest, TestZero) { + Image3F in(20, 20); + ZeroFillImage(&in); + TestRoundTrip(in, 0.0f); +} + +// Disabled: large difference. +#if 0 +TEST(GaborishTest, TestDirac) { + Image3F in(20, 20); + ZeroFillImage(&in); + in.PlaneRow(1, 10)[10] = 10.0f; + TestRoundTrip(in, 0.26f); +} +#endif + +TEST(GaborishTest, TestFlat) { + Image3F in(20, 20); + FillImage(1.0f, &in); + TestRoundTrip(in, 1E-5f); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc b/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc new file mode 100644 index 000000000000..d49eb0879210 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/enc_gamma_correct.h" + +namespace jxl { +namespace { + +TEST(GammaCorrectTest, TestLinearToSrgbEdgeCases) { + EXPECT_EQ(0, LinearToSrgb8Direct(0.0)); + EXPECT_NEAR(0, LinearToSrgb8Direct(1E-6f), 2E-5); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E-6f)); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E6)); + EXPECT_NEAR(1, LinearToSrgb8Direct(1 - 1E-6f), 1E-5); + EXPECT_EQ(1, LinearToSrgb8Direct(1 + 1E-6f)); + EXPECT_EQ(1, LinearToSrgb8Direct(1E6)); +} + +TEST(GammaCorrectTest, TestRoundTrip) { + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double linear = 0.0; linear <= 1.0; linear += 1E-7) { + const double srgb = LinearToSrgb8Direct(linear); + const double linear2 = Srgb8ToLinearDirect(srgb); + ASSERT_LT(std::abs(linear - linear2), 2E-13) + << "linear = " << linear << ", linear2 = " << linear2; + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/gauss_blur.cc b/third_party/jpeg-xl/lib/jxl/gauss_blur.cc new file mode 100644 index 000000000000..7e3a9f4e9144 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gauss_blur.cc @@ -0,0 +1,625 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/gauss_blur.h" + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/gauss_blur.cc" +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/linalg.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::ShiftLeftLanes; +#endif +using hwy::HWY_NAMESPACE::Vec; + +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out) { + // Although the current output depends on the previous output, we can unroll + // up to 4x by precomputing up to fourth powers of the constants. Beyond that, + // numerical precision might become a problem. Macro because this is tested + // in #if alongside HWY_TARGET. +#define JXL_GAUSS_MAX_LANES 4 + using D = HWY_CAPPED(float, JXL_GAUSS_MAX_LANES); + using V = Vec; + const D d; + const V mul_in_1 = Load(d, rg->mul_in + 0 * 4); + const V mul_in_3 = Load(d, rg->mul_in + 1 * 4); + const V mul_in_5 = Load(d, rg->mul_in + 2 * 4); + const V mul_prev_1 = Load(d, rg->mul_prev + 0 * 4); + const V mul_prev_3 = Load(d, rg->mul_prev + 1 * 4); + const V mul_prev_5 = Load(d, rg->mul_prev + 2 * 4); + const V mul_prev2_1 = Load(d, rg->mul_prev2 + 0 * 4); + const V mul_prev2_3 = Load(d, rg->mul_prev2 + 1 * 4); + const V mul_prev2_5 = Load(d, rg->mul_prev2 + 2 * 4); + V prev_1 = Zero(d); + V prev_3 = Zero(d); + V prev_5 = Zero(d); + V prev2_1 = Zero(d); + V prev2_3 = Zero(d); + V prev2_5 = Zero(d); + + const intptr_t N = rg->radius; + + intptr_t n = -N + 1; + // Left side with bounds checks and only write output after n >= 0. + const intptr_t first_aligned = RoundUpTo(N + 1, Lanes(d)); + for (; n < std::min(first_aligned, width); ++n) { + const intptr_t left = n - N - 1; + const intptr_t right = n + N - 1; + const float left_val = left >= 0 ? in[left] : 0.0f; + const float right_val = right < width ? in[right] : 0.0f; + const V sum = Set(d, left_val + right_val); + + // (Only processing a single lane here, no need to broadcast) + V out_1 = sum * mul_in_1; + V out_3 = sum * mul_in_3; + V out_5 = sum * mul_in_5; + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; + + if (n >= 0) { + out[n] = GetLane(out_1 + out_3 + out_5); + } + } + + // The above loop is effectively scalar but it is convenient to use the same + // prev/prev2 variables, so broadcast to each lane before the unrolled loop. +#if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES > 1 + prev2_1 = Broadcast<0>(prev2_1); + prev2_3 = Broadcast<0>(prev2_3); + prev2_5 = Broadcast<0>(prev2_5); + prev_1 = Broadcast<0>(prev_1); + prev_3 = Broadcast<0>(prev_3); + prev_5 = Broadcast<0>(prev_5); +#endif + + // Unrolled, no bounds checking needed. + for (; n < width - N + 1 - (JXL_GAUSS_MAX_LANES - 1); n += Lanes(d)) { + const V sum = LoadU(d, in + n - N - 1) + LoadU(d, in + n + N - 1); + + // To get a vector of output(s), we multiply broadcasted vectors (of each + // input plus the two previous outputs) and add them all together. + // Incremental broadcasting and shifting is expected to be cheaper than + // horizontal adds or transposing 4x4 values because they run on a different + // port, concurrently with the FMA. + const V in0 = Broadcast<0>(sum); + V out_1 = in0 * mul_in_1; + V out_3 = in0 * mul_in_3; + V out_5 = in0 * mul_in_5; + +#if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES >= 2 + const V in1 = Broadcast<1>(sum); + out_1 = MulAdd(ShiftLeftLanes<1>(mul_in_1), in1, out_1); + out_3 = MulAdd(ShiftLeftLanes<1>(mul_in_3), in1, out_3); + out_5 = MulAdd(ShiftLeftLanes<1>(mul_in_5), in1, out_5); + +#if JXL_GAUSS_MAX_LANES >= 4 + const V in2 = Broadcast<2>(sum); + out_1 = MulAdd(ShiftLeftLanes<2>(mul_in_1), in2, out_1); + out_3 = MulAdd(ShiftLeftLanes<2>(mul_in_3), in2, out_3); + out_5 = MulAdd(ShiftLeftLanes<2>(mul_in_5), in2, out_5); + + const V in3 = Broadcast<3>(sum); + out_1 = MulAdd(ShiftLeftLanes<3>(mul_in_1), in3, out_1); + out_3 = MulAdd(ShiftLeftLanes<3>(mul_in_3), in3, out_3); + out_5 = MulAdd(ShiftLeftLanes<3>(mul_in_5), in3, out_5); +#endif +#endif + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); +#if HWY_TARGET == HWY_SCALAR || JXL_GAUSS_MAX_LANES == 1 + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; +#else + prev2_1 = Broadcast(out_1); + prev2_3 = Broadcast(out_3); + prev2_5 = Broadcast(out_5); + prev_1 = Broadcast(out_1); + prev_3 = Broadcast(out_3); + prev_5 = Broadcast(out_5); +#endif + + Store(out_1 + out_3 + out_5, d, out + n); + } + + // Remainder handling with bounds checks + for (; n < width; ++n) { + const intptr_t left = n - N - 1; + const intptr_t right = n + N - 1; + const float left_val = left >= 0 ? in[left] : 0.0f; + const float right_val = right < width ? in[right] : 0.0f; + const V sum = Set(d, left_val + right_val); + + // (Only processing a single lane here, no need to broadcast) + V out_1 = sum * mul_in_1; + V out_3 = sum * mul_in_3; + V out_5 = sum * mul_in_5; + + out_1 = MulAdd(mul_prev2_1, prev2_1, out_1); + out_3 = MulAdd(mul_prev2_3, prev2_3, out_3); + out_5 = MulAdd(mul_prev2_5, prev2_5, out_5); + prev2_1 = prev_1; + prev2_3 = prev_3; + prev2_5 = prev_5; + + out_1 = MulAdd(mul_prev_1, prev_1, out_1); + out_3 = MulAdd(mul_prev_3, prev_3, out_3); + out_5 = MulAdd(mul_prev_5, prev_5, out_5); + prev_1 = out_1; + prev_3 = out_3; + prev_5 = out_5; + + out[n] = GetLane(out_1 + out_3 + out_5); + } +} + +// Ring buffer is for n, n-1, n-2; round up to 4 for faster modulo. +constexpr size_t kMod = 4; + +// Avoids an unnecessary store during warmup. +struct OutputNone { + template + void operator()(const V& /*unused*/, float* JXL_RESTRICT /*pos*/, + ptrdiff_t /*offset*/) const {} +}; + +// Common case: write output vectors in all VerticalBlock except warmup. +struct OutputStore { + template + void operator()(const V& out, float* JXL_RESTRICT pos, + ptrdiff_t offset) const { + // Stream helps for large images but is slower for images that fit in cache. + Store(out, HWY_FULL(float)(), pos + offset); + } +}; + +// At top/bottom borders, we don't have two inputs to load, so avoid addition. +// pos may even point to all zeros if the row is outside the input image. +class SingleInput { + public: + explicit SingleInput(const float* pos) : pos_(pos) {} + Vec operator()(const size_t offset) const { + return Load(HWY_FULL(float)(), pos_ + offset); + } + const float* pos_; +}; + +// In the middle of the image, we need to load from a row above and below, and +// return the sum. +class TwoInputs { + public: + TwoInputs(const float* pos1, const float* pos2) : pos1_(pos1), pos2_(pos2) {} + Vec operator()(const size_t offset) const { + const auto in1 = Load(HWY_FULL(float)(), pos1_ + offset); + const auto in2 = Load(HWY_FULL(float)(), pos2_ + offset); + return in1 + in2; + } + + private: + const float* pos1_; + const float* pos2_; +}; + +// Block := kVectors consecutive full vectors (one cache line except on the +// right boundary, where we can only rely on having one vector). Unrolling to +// the cache line size improves cache utilization. +template +void VerticalBlock(const V& d1_1, const V& d1_3, const V& d1_5, const V& n2_1, + const V& n2_3, const V& n2_5, const Input& input, + size_t& ctr, float* ring_buffer, const Output output, + float* JXL_RESTRICT out_pos) { + const HWY_FULL(float) d; + constexpr size_t kVN = MaxLanes(d); + // More cache-friendly to process an entirely cache line at a time + constexpr size_t kLanes = kVectors * kVN; + + float* JXL_RESTRICT y_1 = ring_buffer + 0 * kLanes * kMod; + float* JXL_RESTRICT y_3 = ring_buffer + 1 * kLanes * kMod; + float* JXL_RESTRICT y_5 = ring_buffer + 2 * kLanes * kMod; + + const size_t n_0 = (++ctr) % kMod; + const size_t n_1 = (ctr - 1) % kMod; + const size_t n_2 = (ctr - 2) % kMod; + + for (size_t idx_vec = 0; idx_vec < kVectors; ++idx_vec) { + const V sum = input(idx_vec * kVN); + + const V y_n1_1 = Load(d, y_1 + kLanes * n_1 + idx_vec * kVN); + const V y_n1_3 = Load(d, y_3 + kLanes * n_1 + idx_vec * kVN); + const V y_n1_5 = Load(d, y_5 + kLanes * n_1 + idx_vec * kVN); + const V y_n2_1 = Load(d, y_1 + kLanes * n_2 + idx_vec * kVN); + const V y_n2_3 = Load(d, y_3 + kLanes * n_2 + idx_vec * kVN); + const V y_n2_5 = Load(d, y_5 + kLanes * n_2 + idx_vec * kVN); + // (35) + const V y1 = MulAdd(n2_1, sum, NegMulSub(d1_1, y_n1_1, y_n2_1)); + const V y3 = MulAdd(n2_3, sum, NegMulSub(d1_3, y_n1_3, y_n2_3)); + const V y5 = MulAdd(n2_5, sum, NegMulSub(d1_5, y_n1_5, y_n2_5)); + Store(y1, d, y_1 + kLanes * n_0 + idx_vec * kVN); + Store(y3, d, y_3 + kLanes * n_0 + idx_vec * kVN); + Store(y5, d, y_5 + kLanes * n_0 + idx_vec * kVN); + output(y1 + y3 + y5, out_pos, idx_vec * kVN); + } + // NOTE: flushing cache line out_pos hurts performance - less so with + // clflushopt than clflush but still a significant slowdown. +} + +// Reads/writes one block (kVectors full vectors) in each row. +template +void VerticalStrip(const hwy::AlignedUniquePtr& rg, + const ImageF& in, const size_t x, ImageF* JXL_RESTRICT out) { + // We're iterating vertically, so use multiple full-length vectors (each lane + // is one column of row n). + using D = HWY_FULL(float); + using V = Vec; + const D d; + constexpr size_t kVN = MaxLanes(d); + // More cache-friendly to process an entirely cache line at a time + constexpr size_t kLanes = kVectors * kVN; +#if HWY_TARGET == HWY_SCALAR + const V d1_1 = Set(d, rg->d1[0 * 4]); + const V d1_3 = Set(d, rg->d1[1 * 4]); + const V d1_5 = Set(d, rg->d1[2 * 4]); + const V n2_1 = Set(d, rg->n2[0 * 4]); + const V n2_3 = Set(d, rg->n2[1 * 4]); + const V n2_5 = Set(d, rg->n2[2 * 4]); +#else + const V d1_1 = LoadDup128(d, rg->d1 + 0 * 4); + const V d1_3 = LoadDup128(d, rg->d1 + 1 * 4); + const V d1_5 = LoadDup128(d, rg->d1 + 2 * 4); + const V n2_1 = LoadDup128(d, rg->n2 + 0 * 4); + const V n2_3 = LoadDup128(d, rg->n2 + 1 * 4); + const V n2_5 = LoadDup128(d, rg->n2 + 2 * 4); +#endif + + const size_t N = rg->radius; + const size_t ysize = in.ysize(); + + size_t ctr = 0; + HWY_ALIGN float ring_buffer[3 * kLanes * kMod] = {0}; + HWY_ALIGN static constexpr float zero[kLanes] = {0}; + + // Warmup: top is out of bounds (zero padded), bottom is usually in-bounds. + ssize_t n = -static_cast(N) + 1; + for (; n < 0; ++n) { + // bottom is always non-negative since n is initialized in -N + 1. + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + SingleInput(bottom < ysize ? in.ConstRow(bottom) + x : zero), ctr, + ring_buffer, OutputNone(), nullptr); + } + JXL_DASSERT(n >= 0); + + // Start producing output; top is still out of bounds. + for (; static_cast(n) < std::min(N + 1, ysize); ++n) { + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + SingleInput(bottom < ysize ? in.ConstRow(bottom) + x : zero), ctr, + ring_buffer, OutputStore(), out->Row(n) + x); + } + + // Interior outputs with prefetching and without bounds checks. + constexpr size_t kPrefetchRows = 8; + for (; n < static_cast(ysize - N + 1 - kPrefetchRows); ++n) { + const size_t top = n - N - 1; + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + TwoInputs(in.ConstRow(top) + x, in.ConstRow(bottom) + x), ctr, + ring_buffer, OutputStore(), out->Row(n) + x); + hwy::Prefetch(in.ConstRow(top + kPrefetchRows) + x); + hwy::Prefetch(in.ConstRow(bottom + kPrefetchRows) + x); + } + + // Bottom border without prefetching and with bounds checks. + for (; static_cast(n) < ysize; ++n) { + const size_t top = n - N - 1; + const size_t bottom = n + N - 1; + VerticalBlock( + d1_1, d1_3, d1_5, n2_1, n2_3, n2_5, + TwoInputs(in.ConstRow(top) + x, + bottom < ysize ? in.ConstRow(bottom) + x : zero), + ctr, ring_buffer, OutputStore(), out->Row(n) + x); + } +} + +// Apply 1D vertical scan to multiple columns (one per vector lane). +// Not yet parallelized. +void FastGaussianVertical(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* /*pool*/, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(SameSize(in, *out)); + + constexpr size_t kCacheLineLanes = 64 / sizeof(float); + constexpr size_t kVN = MaxLanes(HWY_FULL(float)()); + constexpr size_t kCacheLineVectors = kCacheLineLanes / kVN; + + size_t x = 0; + for (; x + kCacheLineLanes <= in.xsize(); x += kCacheLineLanes) { + VerticalStrip(rg, in, x, out); + } + for (; x < in.xsize(); x += kVN) { + VerticalStrip<1>(rg, in, x, out); + } +} + +// TODO(veluca): consider replacing with FastGaussian. +ImageF ConvolveXSampleAndTranspose(const ImageF& in, + const std::vector& kernel, + const size_t res) { + JXL_ASSERT(kernel.size() % 2 == 1); + JXL_ASSERT(in.xsize() % res == 0); + const size_t offset = res / 2; + const size_t out_xsize = in.xsize() / res; + ImageF out(in.ysize(), out_xsize); + const int r = kernel.size() / 2; + HWY_FULL(float) df; + std::vector row_tmp(in.xsize() + 2 * r + Lanes(df)); + float* const JXL_RESTRICT rowp = &row_tmp[r]; + std::vector padded_k = kernel; + padded_k.resize(padded_k.size() + Lanes(df)); + const float* const kernelp = &padded_k[r]; + for (size_t y = 0; y < in.ysize(); ++y) { + ExtrapolateBorders(in.Row(y), rowp, in.xsize(), r); + size_t x = offset, ox = 0; + for (; x < static_cast(r) && x < in.xsize(); x += res, ++ox) { + float sum = 0.0f; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = sum; + } + for (; x + r < in.xsize(); x += res, ++ox) { + auto sum = Zero(df); + for (int i = -r; i <= r; i += Lanes(df)) { + sum = MulAdd(LoadU(df, rowp + x + i), LoadU(df, kernelp + i), sum); + } + out.Row(ox)[y] = GetLane(SumOfLanes(sum)); + } + for (; x < in.xsize(); x += res, ++ox) { + float sum = 0.0f; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = sum; + } + } + return out; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FastGaussian1D); +HWY_EXPORT(ConvolveXSampleAndTranspose); +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out) { + return HWY_DYNAMIC_DISPATCH(FastGaussian1D)(rg, in, width, out); +} + +HWY_EXPORT(FastGaussianVertical); // Local function. + +void ExtrapolateBorders(const float* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const int xsize, + const int radius) { + const int lastcol = xsize - 1; + for (int x = 1; x <= radius; ++x) { + row_out[-x] = row_in[std::min(x, xsize - 1)]; + } + memcpy(row_out, row_in, xsize * sizeof(row_out[0])); + for (int x = 1; x <= radius; ++x) { + row_out[lastcol + x] = row_in[std::max(0, lastcol - x)]; + } +} + +ImageF ConvolveXSampleAndTranspose(const ImageF& in, + const std::vector& kernel, + const size_t res) { + return HWY_DYNAMIC_DISPATCH(ConvolveXSampleAndTranspose)(in, kernel, res); +} + +Image3F ConvolveXSampleAndTranspose(const Image3F& in, + const std::vector& kernel, + const size_t res) { + return Image3F(ConvolveXSampleAndTranspose(in.Plane(0), kernel, res), + ConvolveXSampleAndTranspose(in.Plane(1), kernel, res), + ConvolveXSampleAndTranspose(in.Plane(2), kernel, res)); +} + +ImageF ConvolveAndSample(const ImageF& in, const std::vector& kernel, + const size_t res) { + ImageF tmp = ConvolveXSampleAndTranspose(in, kernel, res); + return ConvolveXSampleAndTranspose(tmp, kernel, res); +} + +// Implements "Recursive Implementation of the Gaussian Filter Using Truncated +// Cosine Functions" by Charalampidis [2016]. +hwy::AlignedUniquePtr CreateRecursiveGaussian(double sigma) { + PROFILER_FUNC; + auto rg = hwy::MakeUniqueAligned(); + constexpr double kPi = 3.141592653589793238; + + const double radius = roundf(3.2795 * sigma + 0.2546); // (57), "N" + + // Table I, first row + const double pi_div_2r = kPi / (2.0 * radius); + const double omega[3] = {pi_div_2r, 3.0 * pi_div_2r, 5.0 * pi_div_2r}; + + // (37), k={1,3,5} + const double p_1 = +1.0 / std::tan(0.5 * omega[0]); + const double p_3 = -1.0 / std::tan(0.5 * omega[1]); + const double p_5 = +1.0 / std::tan(0.5 * omega[2]); + + // (44), k={1,3,5} + const double r_1 = +p_1 * p_1 / std::sin(omega[0]); + const double r_3 = -p_3 * p_3 / std::sin(omega[1]); + const double r_5 = +p_5 * p_5 / std::sin(omega[2]); + + // (50), k={1,3,5} + const double neg_half_sigma2 = -0.5 * sigma * sigma; + const double recip_radius = 1.0 / radius; + double rho[3]; + for (size_t i = 0; i < 3; ++i) { + rho[i] = std::exp(neg_half_sigma2 * omega[i] * omega[i]) * recip_radius; + } + + // second part of (52), k1,k2 = 1,3; 3,5; 5,1 + const double D_13 = p_1 * r_3 - r_1 * p_3; + const double D_35 = p_3 * r_5 - r_3 * p_5; + const double D_51 = p_5 * r_1 - r_5 * p_1; + + // (52), k=5 + const double recip_d13 = 1.0 / D_13; + const double zeta_15 = D_35 * recip_d13; + const double zeta_35 = D_51 * recip_d13; + + double A[9] = {p_1, p_3, p_5, // + r_1, r_3, r_5, // (56) + zeta_15, zeta_35, 1}; + Inv3x3Matrix(A); + const double gamma[3] = {1, radius * radius - sigma * sigma, // (55) + zeta_15 * rho[0] + zeta_35 * rho[1] + rho[2]}; + double beta[3]; + MatMul(A, gamma, 3, 3, 1, beta); // (53) + + // Sanity check: correctly solved for beta (IIR filter weights are normalized) + const double sum = beta[0] * p_1 + beta[1] * p_3 + beta[2] * p_5; // (39) + JXL_ASSERT(std::abs(sum - 1) < 1E-12); + (void)sum; + + rg->radius = static_cast(radius); + + double n2[3]; + double d1[3]; + for (size_t i = 0; i < 3; ++i) { + n2[i] = -beta[i] * std::cos(omega[i] * (radius + 1.0)); // (33) + d1[i] = -2.0 * std::cos(omega[i]); // (33) + + for (size_t lane = 0; lane < 4; ++lane) { + rg->n2[4 * i + lane] = static_cast(n2[i]); + rg->d1[4 * i + lane] = static_cast(d1[i]); + } + + const double d_2 = d1[i] * d1[i]; + + // Obtained by expanding (35) for four consecutive outputs via sympy: + // n, d, p, pp = symbols('n d p pp') + // i0, i1, i2, i3 = symbols('i0 i1 i2 i3') + // o0, o1, o2, o3 = symbols('o0 o1 o2 o3') + // o0 = n*i0 - d*p - pp + // o1 = n*i1 - d*o0 - p + // o2 = n*i2 - d*o1 - o0 + // o3 = n*i3 - d*o2 - o1 + // Then expand(o3) and gather terms for p(prev), pp(prev2) etc. + rg->mul_prev[4 * i + 0] = -d1[i]; + rg->mul_prev[4 * i + 1] = d_2 - 1.0; + rg->mul_prev[4 * i + 2] = -d_2 * d1[i] + 2.0 * d1[i]; + rg->mul_prev[4 * i + 3] = d_2 * d_2 - 3.0 * d_2 + 1.0; + rg->mul_prev2[4 * i + 0] = -1.0; + rg->mul_prev2[4 * i + 1] = d1[i]; + rg->mul_prev2[4 * i + 2] = -d_2 + 1.0; + rg->mul_prev2[4 * i + 3] = d_2 * d1[i] - 2.0 * d1[i]; + rg->mul_in[4 * i + 0] = n2[i]; + rg->mul_in[4 * i + 1] = -d1[i] * n2[i]; + rg->mul_in[4 * i + 2] = d_2 * n2[i] - n2[i]; + rg->mul_in[4 * i + 3] = -d_2 * d1[i] * n2[i] + 2.0 * d1[i] * n2[i]; + } + return rg; +} + +namespace { + +// Apply 1D horizontal scan to each row. +void FastGaussianHorizontal(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + PROFILER_FUNC; + JXL_CHECK(SameSize(in, *out)); + + const intptr_t xsize = in.xsize(); + RunOnPool( + pool, 0, in.ysize(), ThreadPool::SkipInit(), + [&](const int task, const int /*thread*/) { + const size_t y = task; + const float* row_in = in.ConstRow(y); + float* JXL_RESTRICT row_out = out->Row(y); + FastGaussian1D(rg, row_in, xsize, row_out); + }, + "FastGaussianHorizontal"); +} + +} // namespace + +void FastGaussian(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, ImageF* JXL_RESTRICT temp, + ImageF* JXL_RESTRICT out) { + FastGaussianHorizontal(rg, in, pool, temp); + HWY_DYNAMIC_DISPATCH(FastGaussianVertical)(rg, *temp, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/gauss_blur.h b/third_party/jpeg-xl/lib/jxl/gauss_blur.h new file mode 100644 index 000000000000..d5d111cc6423 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gauss_blur.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_GAUSS_BLUR_H_ +#define LIB_JXL_GAUSS_BLUR_H_ + +#include + +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +std::vector GaussianKernel(int radius, T sigma) { + JXL_ASSERT(sigma > 0.0); + std::vector kernel(2 * radius + 1); + const T scaler = -1.0 / (2 * sigma * sigma); + double sum = 0.0; + for (int i = -radius; i <= radius; ++i) { + const T val = std::exp(scaler * i * i); + kernel[i + radius] = val; + sum += val; + } + for (size_t i = 0; i < kernel.size(); ++i) { + kernel[i] /= sum; + } + return kernel; +} + +// All convolution functions below apply mirroring of the input on the borders +// in the following way: +// +// input: [a0 a1 a2 ... aN] +// mirrored input: [aR ... a1 | a0 a1 a2 .... aN | aN-1 ... aN-R] +// +// where R is the radius of the kernel (i.e. kernel size is 2*R+1). + +// REQUIRES: in.xsize() and in.ysize() are integer multiples of res. +ImageF ConvolveAndSample(const ImageF& in, const std::vector& kernel, + const size_t res); + +// Private, used by test. +void ExtrapolateBorders(const float* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const int xsize, + const int radius); + +// Only for use by CreateRecursiveGaussian and FastGaussian*. +#pragma pack(push, 1) +struct RecursiveGaussian { + // For k={1,3,5} in that order, each broadcasted 4x for LoadDup128. Used only + // for vertical passes. + float n2[3 * 4]; + float d1[3 * 4]; + + // We unroll horizontal passes 4x - one output per lane. These are each lane's + // multiplier for the previous output (relative to the first of the four + // outputs). Indexing: 4 * 0..2 (for {1,3,5}) + 0..3 for the lane index. + float mul_prev[3 * 4]; + // Ditto for the second to last output. + float mul_prev2[3 * 4]; + + // We multiply a vector of inputs 0..3 by a vector shifted from this array. + // in=0 uses all 4 (nonzero) terms; for in=3, the lower three lanes are 0. + float mul_in[3 * 4]; + + size_t radius; +}; +#pragma pack(pop) + +// Precomputation for FastGaussian*; users may use the same pointer/storage in +// subsequent calls to FastGaussian* with the same sigma. +hwy::AlignedUniquePtr CreateRecursiveGaussian(double sigma); + +// 1D Gaussian with zero-pad boundary handling and runtime independent of sigma. +void FastGaussian1D(const hwy::AlignedUniquePtr& rg, + const float* JXL_RESTRICT in, intptr_t width, + float* JXL_RESTRICT out); + +// 2D Gaussian with zero-pad boundary handling and runtime independent of sigma. +void FastGaussian(const hwy::AlignedUniquePtr& rg, + const ImageF& in, ThreadPool* pool, ImageF* JXL_RESTRICT temp, + ImageF* JXL_RESTRICT out); + +} // namespace jxl + +#endif // LIB_JXL_GAUSS_BLUR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/gauss_blur_test.cc b/third_party/jpeg-xl/lib/jxl/gauss_blur_test.cc new file mode 100644 index 000000000000..64eda7ecd39c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gauss_blur_test.cc @@ -0,0 +1,619 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/gauss_blur.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/robust_statistics.h" +#include "lib/jxl/base/time.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { + +bool NearEdge(const int64_t width, const int64_t peak) { + // When around 3*sigma from the edge, there is negligible truncation. + return peak < 10 || peak > width - 10; +} + +// Follow the curve downwards by scanning right from `peak` and verifying +// identical values at the same offset to the left. +void VerifySymmetric(const int64_t width, const int64_t peak, + const float* out) { + const double tolerance = NearEdge(width, peak) ? 0.015 : 6E-7; + for (int64_t i = 1;; ++i) { + // Stop if we passed either end of the array + if (peak - i < 0 || peak + i >= width) break; + EXPECT_GT(out[peak + i - 1] + tolerance, out[peak + i]); // descending + EXPECT_NEAR(out[peak - i], out[peak + i], tolerance); // symmetric + } +} + +void TestImpulseResponse(size_t width, size_t peak) { + const auto rg3 = CreateRecursiveGaussian(3.0); + const auto rg4 = CreateRecursiveGaussian(4.0); + const auto rg5 = CreateRecursiveGaussian(5.0); + + // Extra padding for 4x unrolling + auto in = hwy::AllocateAligned(width + 3); + memset(in.get(), 0, sizeof(float) * (width + 3)); + in[peak] = 1.0f; + + auto out3 = hwy::AllocateAligned(width + 3); + auto out4 = hwy::AllocateAligned(width + 3); + auto out5 = hwy::AllocateAligned(width + 3); + FastGaussian1D(rg3, in.get(), width, out3.get()); + FastGaussian1D(rg4, out3.get(), width, out4.get()); + FastGaussian1D(rg5, in.get(), width, out5.get()); + + VerifySymmetric(width, peak, out3.get()); + VerifySymmetric(width, peak, out4.get()); + VerifySymmetric(width, peak, out5.get()); + + // Wider kernel has flatter peak + EXPECT_LT(out5[peak] + 0.05, out3[peak]); + + // Gauss3 o Gauss4 ~= Gauss5 + const double tolerance = NearEdge(width, peak) ? 0.04 : 0.01; + for (size_t i = 0; i < width; ++i) { + EXPECT_NEAR(out4[i], out5[i], tolerance); + } +} + +void TestImpulseResponseForWidth(size_t width) { + for (size_t i = 0; i < width; ++i) { + TestImpulseResponse(width, i); + } +} + +TEST(GaussBlurTest, ImpulseResponse) { + TestImpulseResponseForWidth(10); // tiny even + TestImpulseResponseForWidth(15); // small odd + TestImpulseResponseForWidth(32); // power of two + TestImpulseResponseForWidth(31); // power of two - 1 + TestImpulseResponseForWidth(33); // power of two + 1 +} + +ImageF Convolve(const ImageF& in, const std::vector& kernel) { + return ConvolveAndSample(in, kernel, 1); +} + +// Higher-precision version for accuracy test. +ImageF ConvolveAndTransposeF64(const ImageF& in, + const std::vector& kernel) { + JXL_ASSERT(kernel.size() % 2 == 1); + ImageF out(in.ysize(), in.xsize()); + const int r = kernel.size() / 2; + std::vector row_tmp(in.xsize() + 2 * r); + float* const JXL_RESTRICT rowp = &row_tmp[r]; + const double* const kernelp = &kernel[r]; + for (size_t y = 0; y < in.ysize(); ++y) { + ExtrapolateBorders(in.Row(y), rowp, in.xsize(), r); + for (size_t x = 0, ox = 0; x < in.xsize(); ++x, ++ox) { + double sum = 0.0; + for (int i = -r; i <= r; ++i) { + sum += rowp[std::max( + 0, std::min(static_cast(x) + i, in.xsize()))] * + kernelp[i]; + } + out.Row(ox)[y] = static_cast(sum); + } + } + return out; +} + +ImageF ConvolveF64(const ImageF& in, const std::vector& kernel) { + ImageF tmp = ConvolveAndTransposeF64(in, kernel); + return ConvolveAndTransposeF64(tmp, kernel); +} + +void TestDirac2D(size_t xsize, size_t ysize, double sigma) { + ImageF in(xsize, ysize); + ZeroFillImage(&in); + // We anyway ignore the border below, so might as well choose the middle. + in.Row(ysize / 2)[xsize / 2] = 1.0f; + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out); + + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + const ImageF expected = Convolve(in, kernel); + + const double max_l1 = sigma < 1.5 ? 5E-3 : 6E-4; + const size_t border = 2 * sigma; + VerifyRelativeError(expected, out, max_l1, 1E-8, border); +} + +TEST(GaussBlurTest, Test2D) { + const std::vector dimensions{6, 15, 17, 64, 50, 49}; + for (int xsize : dimensions) { + for (int ysize : dimensions) { + for (double sigma : {1.0, 2.5, 3.6, 7.0}) { + TestDirac2D(static_cast(xsize), static_cast(ysize), + sigma); + } + } + } +} + +// Slow (44 sec). To run, remove the disabled prefix. +TEST(GaussBlurTest, DISABLED_SlowTestDirac1D) { + const double sigma = 7.0; + const auto rg = CreateRecursiveGaussian(sigma); + + // IPOL accuracy test uses 10^-15 tolerance, this is 2*10^-11. + const size_t radius = static_cast(7 * sigma); + const std::vector kernel = GaussianKernel(radius, sigma); + + const size_t length = 16384; + ImageF inputs(length, 1); + ZeroFillImage(&inputs); + + auto outputs = hwy::AllocateAligned(length); + + // One per center position + auto sum_abs_err = hwy::AllocateAligned(length); + std::fill(sum_abs_err.get(), sum_abs_err.get() + length, 0.0); + + for (size_t center = radius; center < length - radius; ++center) { + inputs.Row(0)[center - 1] = 0.0f; // reset last peak, entire array now 0 + inputs.Row(0)[center] = 1.0f; + FastGaussian1D(rg, inputs.Row(0), length, outputs.get()); + + const ImageF outputs_fir = ConvolveF64(inputs, kernel); + + for (size_t i = 0; i < length; ++i) { + const float abs_err = std::abs(outputs[i] - outputs_fir.Row(0)[i]); + sum_abs_err[i] += static_cast(abs_err); + } + } + + const double max_abs_err = + *std::max_element(sum_abs_err.get(), sum_abs_err.get() + length); + printf("Max abs err: %.8e\n", max_abs_err); +} + +void TestRandom(size_t xsize, size_t ysize, float min, float max, double sigma, + double max_l1, double max_rel) { + printf("%4zu x %4zu %4.1f %4.1f sigma %.1f\n", xsize, ysize, min, max, sigma); + ImageF in(xsize, ysize); + RandomFillImage(&in, min, max, 65537 + xsize * 129 + ysize); + // FastGaussian/Convolve handle borders differently, so keep those pixels 0. + const size_t border = 4 * sigma; + SetBorder(border, 0.0f, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out); + + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + const ImageF expected = Convolve(in, kernel); + + VerifyRelativeError(expected, out, max_l1, max_rel, border); +} + +void TestRandomForSizes(float min, float max, double sigma) { + double max_l1 = 5E-3; + double max_rel = 3E-3; + TestRandom(128, 1, min, max, sigma, max_l1, max_rel); + TestRandom(1, 128, min, max, sigma, max_l1, max_rel); + TestRandom(30, 201, min, max, sigma, max_l1 * 1.6, max_rel * 1.2); + TestRandom(201, 30, min, max, sigma, max_l1 * 1.6, max_rel * 1.2); + TestRandom(201, 201, min, max, sigma, max_l1 * 2.0, max_rel * 1.2); +} + +TEST(GaussBlurTest, TestRandom) { + // small non-negative + TestRandomForSizes(0.0f, 10.0f, 3.0f); + TestRandomForSizes(0.0f, 10.0f, 7.0f); + + // small negative + TestRandomForSizes(-4.0f, -1.0f, 3.0f); + TestRandomForSizes(-4.0f, -1.0f, 7.0f); + + // mixed positive/negative + TestRandomForSizes(-6.0f, 6.0f, 3.0f); + TestRandomForSizes(-6.0f, 6.0f, 7.0f); +} + +TEST(GaussBlurTest, TestSign) { + const size_t xsize = 500; + const size_t ysize = 606; + ImageF in(xsize, ysize); + + ZeroFillImage(&in); + const float center[33 * 33] = { + -0.128445f, -0.098473f, -0.121883f, -0.093601f, 0.095665f, -0.271332f, + -0.705475f, -1.324005f, -2.020741f, -1.329464f, 1.834064f, 4.787300f, + 5.834560f, 5.272720f, 3.967960f, 3.547935f, 3.432732f, 3.383015f, + 3.239326f, 3.290806f, 3.298954f, 3.397808f, 3.359730f, 3.533844f, + 3.511856f, 3.436787f, 3.428310f, 3.460209f, 3.550011f, 3.590942f, + 3.593109f, 3.560005f, 3.443165f, 0.089741f, 0.179230f, -0.032997f, + -0.182610f, 0.005669f, -0.244759f, -0.395123f, -0.514961f, -1.003529f, + -1.798656f, -2.377975f, 0.222191f, 3.957664f, 5.946804f, 5.543129f, + 4.290096f, 3.621010f, 3.407257f, 3.392494f, 3.345367f, 3.391903f, + 3.441605f, 3.429260f, 3.444969f, 3.507130f, 3.518612f, 3.443111f, + 3.475948f, 3.536148f, 3.470333f, 3.628311f, 3.600243f, 3.292892f, + -0.226730f, -0.573616f, -0.762165f, -0.398739f, -0.189842f, -0.275921f, + -0.446739f, -0.550037f, -0.461033f, -0.724792f, -1.448349f, -1.814064f, + -0.491032f, 2.817703f, 5.213242f, 5.675629f, 4.864548f, 3.876324f, + 3.535587f, 3.530312f, 3.413765f, 3.386261f, 3.404854f, 3.383472f, + 3.420830f, 3.326496f, 3.257877f, 3.362152f, 3.489609f, 3.619587f, + 3.555805f, 3.423164f, 3.309708f, -0.483940f, -0.502926f, -0.592983f, + -0.492527f, -0.413616f, -0.482555f, -0.475506f, -0.447990f, -0.338120f, + -0.189072f, -0.376427f, -0.910828f, -1.878044f, -1.937927f, 1.423218f, + 4.871609f, 5.767548f, 5.103741f, 3.983868f, 3.633003f, 3.458263f, + 3.507309f, 3.247021f, 3.220612f, 3.326061f, 3.352814f, 3.291061f, + 3.322739f, 3.444302f, 3.506207f, 3.556839f, 3.529575f, 3.457024f, + -0.408161f, -0.431343f, -0.454369f, -0.356419f, -0.380924f, -0.399452f, + -0.439476f, -0.412189f, -0.306816f, -0.008213f, -0.325813f, -0.537842f, + -0.984100f, -1.805332f, -2.028198f, 0.773205f, 4.423046f, 5.604839f, + 5.231617f, 4.080299f, 3.603008f, 3.498741f, 3.517010f, 3.333897f, + 3.381336f, 3.342617f, 3.369686f, 3.434155f, 3.490452f, 3.607029f, + 3.555298f, 3.702297f, 3.618679f, -0.503609f, -0.578564f, -0.419014f, + -0.239883f, 0.269836f, 0.022984f, -0.455067f, -0.621777f, -0.304176f, + -0.163792f, -0.490250f, -0.466637f, -0.391792f, -0.657940f, -1.498035f, + -1.895836f, 0.036537f, 3.462456f, 5.586445f, 5.658791f, 4.434784f, + 3.423435f, 3.318848f, 3.202328f, 3.532764f, 3.436687f, 3.354881f, + 3.356941f, 3.382645f, 3.503902f, 3.512867f, 3.632366f, 3.537312f, + -0.274734f, -0.658829f, -0.726532f, -0.281254f, 0.053196f, -0.064991f, + -0.608517f, -0.720966f, -0.070602f, -0.111320f, -0.440956f, -0.492180f, + -0.488762f, -0.569283f, -1.012741f, -1.582779f, -2.101479f, -1.392380f, + 2.451153f, 5.555855f, 6.096313f, 5.230045f, 4.068172f, 3.404274f, + 3.392586f, 3.326065f, 3.156670f, 3.284828f, 3.347012f, 3.319252f, + 3.352310f, 3.610790f, 3.499847f, -0.150600f, -0.314445f, -0.093575f, + -0.057384f, 0.053688f, -0.189255f, -0.263515f, -0.318653f, 0.053246f, + 0.080627f, -0.119553f, -0.152454f, -0.305420f, -0.404869f, -0.385944f, + -0.689949f, -1.204914f, -1.985748f, -1.711361f, 1.260658f, 4.626896f, + 5.888351f, 5.450989f, 4.070587f, 3.539200f, 3.383492f, 3.296318f, + 3.267334f, 3.436028f, 3.463005f, 3.502625f, 3.522282f, 3.403763f, + -0.348049f, -0.302303f, -0.137016f, -0.041737f, -0.164001f, -0.358849f, + -0.469627f, -0.428291f, -0.375797f, -0.246346f, -0.118950f, -0.084229f, + -0.205681f, -0.241199f, -0.391796f, -0.323151f, -0.241211f, -0.834137f, + -1.684219f, -1.972137f, 0.448399f, 4.019985f, 5.648144f, 5.647846f, + 4.295094f, 3.641884f, 3.374790f, 3.197342f, 3.425545f, 3.507481f, + 3.478065f, 3.430889f, 3.341900f, -1.016304f, -0.959221f, -0.909466f, + -0.810715f, -0.590729f, -0.594467f, -0.646721f, -0.629364f, -0.528561f, + -0.551819f, -0.301086f, -0.149101f, -0.060146f, -0.162220f, -0.326210f, + -0.156548f, -0.036293f, -0.426098f, -1.145470f, -1.628998f, -2.003052f, + -1.142891f, 2.885162f, 5.652863f, 5.718426f, 4.911140f, 3.234222f, + 3.473373f, 3.577183f, 3.271603f, 3.410435f, 3.505489f, 3.434032f, + -0.508911f, -0.438797f, -0.437450f, -0.627426f, -0.511745f, -0.304874f, + -0.274246f, -0.261841f, -0.228466f, -0.342491f, -0.528206f, -0.490082f, + -0.516350f, -0.361694f, -0.398514f, -0.276020f, -0.210369f, -0.355938f, + -0.402622f, -0.538864f, -1.249573f, -2.100105f, -0.996178f, 1.886410f, + 4.929745f, 5.630871f, 5.444199f, 4.042740f, 3.739189f, 3.691399f, + 3.391956f, 3.469696f, 3.431232f, 0.204849f, 0.205433f, -0.131927f, + -0.367908f, -0.374378f, -0.126820f, -0.186951f, -0.228565f, -0.081776f, + -0.143143f, -0.379230f, -0.598701f, -0.458019f, -0.295586f, -0.407730f, + -0.245853f, -0.043140f, 0.024242f, -0.038998f, -0.044151f, -0.425991f, + -1.240753f, -1.943146f, -2.174755f, 0.523415f, 4.376751f, 5.956558f, + 5.850082f, 4.403152f, 3.517399f, 3.560753f, 3.554836f, 3.471985f, + -0.508503f, -0.109783f, 0.057747f, 0.190079f, -0.257153f, -0.591980f, + -0.666771f, -0.525391f, -0.293060f, -0.489731f, -0.304855f, -0.259644f, + -0.367825f, -0.346977f, -0.292889f, -0.215652f, -0.120705f, -0.176010f, + -0.422905f, -0.114647f, -0.289749f, -0.374203f, -0.606754f, -1.127949f, + -1.994583f, -0.588058f, 3.415840f, 5.603470f, 5.811581f, 4.959423f, + 3.721760f, 3.710499f, 3.785461f, -0.554588f, -0.565517f, -0.434578f, + -0.012482f, -0.284660f, -0.699795f, -0.957535f, -0.755135f, -0.382034f, + -0.321552f, -0.287571f, -0.279537f, -0.314972f, -0.256287f, -0.372818f, + -0.316017f, -0.287975f, -0.365639f, -0.512589f, -0.420692f, -0.436485f, + -0.295353f, -0.451958f, -0.755459f, -1.272358f, -2.301353f, -1.776161f, + 1.572483f, 4.826286f, 5.741898f, 5.162853f, 4.028049f, 3.686325f, + -0.495590f, -0.664413f, -0.760044f, -0.152634f, -0.286480f, -0.340462f, + 0.076477f, 0.187706f, -0.068787f, -0.293491f, -0.361145f, -0.292515f, + -0.140671f, -0.190723f, -0.333302f, -0.368168f, -0.192581f, -0.154499f, + -0.236544f, -0.124405f, -0.208321f, -0.465607f, -0.883080f, -1.104813f, + -1.210567f, -1.415665f, -1.924683f, -1.634758f, 0.601017f, 4.276672f, + 5.501350f, 5.331257f, 3.809288f, -0.727722f, -0.533619f, -0.511524f, + -0.470688f, -0.610710f, -0.575130f, -0.311115f, -0.090420f, -0.297676f, + -0.646118f, -0.742805f, -0.485050f, -0.330910f, -0.275417f, -0.357037f, + -0.425598f, -0.481876f, -0.488941f, -0.393551f, -0.051105f, -0.090755f, + -0.328674f, -0.536369f, -0.533684f, -0.336960f, -0.689194f, -1.187195f, + -1.860954f, -2.290253f, -0.424774f, 3.050060f, 5.083332f, 5.291920f, + -0.343605f, -0.190975f, -0.303692f, -0.456512f, -0.681820f, -0.690693f, + -0.416729f, -0.286446f, -0.442055f, -0.709148f, -0.569160f, -0.382423f, + -0.402321f, -0.383362f, -0.366413f, -0.290718f, -0.110069f, -0.220280f, + -0.279018f, -0.255424f, -0.262081f, -0.487556f, -0.444492f, -0.250500f, + -0.119583f, -0.291557f, -0.537781f, -1.104073f, -1.737091f, -1.697441f, + -0.323456f, 2.042049f, 4.605103f, -0.310631f, -0.279568f, -0.012695f, + -0.160130f, -0.358746f, -0.421101f, -0.559677f, -0.474136f, -0.416565f, + -0.561817f, -0.534672f, -0.519157f, -0.767197f, -0.605831f, -0.186523f, + 0.219872f, 0.264984f, -0.193432f, -0.363182f, -0.467472f, -0.462009f, + -0.571053f, -0.522476f, -0.315903f, -0.237427f, -0.147320f, -0.100201f, + -0.237568f, -0.763435f, -1.242043f, -2.135159f, -1.409485f, 1.236370f, + -0.474247f, -0.517906f, -0.410217f, -0.542244f, -0.795986f, -0.590004f, + -0.388863f, -0.462921f, -0.810627f, -0.778637f, -0.512486f, -0.718025f, + -0.710854f, -0.482513f, -0.318233f, -0.194962f, -0.220116f, -0.421673f, + -0.534233f, -0.403339f, -0.389332f, -0.407303f, -0.437355f, -0.469730f, + -0.359600f, -0.352745f, -0.466755f, -0.414585f, -0.430756f, -0.656822f, + -1.237038f, -2.046097f, -1.574898f, -0.593815f, -0.582165f, -0.336098f, + -0.372612f, -0.554386f, -0.410603f, -0.428276f, -0.647644f, -0.640720f, + -0.582207f, -0.414112f, -0.435547f, -0.435505f, -0.332561f, -0.248116f, + -0.340221f, -0.277855f, -0.352699f, -0.377319f, -0.230850f, -0.313267f, + -0.446270f, -0.346237f, -0.420422f, -0.530781f, -0.400341f, -0.463661f, + -0.209091f, -0.056705f, -0.011772f, -0.169388f, -0.736275f, -1.463017f, + -0.752701f, -0.668865f, -0.329765f, -0.299347f, -0.245667f, -0.286999f, + -0.520420f, -0.675438f, -0.255753f, 0.141357f, -0.079639f, -0.419476f, + -0.374069f, -0.046253f, 0.116116f, -0.145847f, -0.380371f, -0.563412f, + -0.638634f, -0.310116f, -0.260914f, -0.508404f, -0.465508f, -0.527824f, + -0.370979f, -0.305595f, -0.244694f, -0.254490f, 0.009968f, -0.050201f, + -0.331219f, -0.614960f, -0.788208f, -0.483242f, -0.367516f, -0.186951f, + -0.180031f, 0.129711f, -0.127811f, -0.384750f, -0.499542f, -0.418613f, + -0.121635f, 0.203197f, -0.167290f, -0.397270f, -0.355461f, -0.218746f, + -0.376785f, -0.521698f, -0.721581f, -0.845741f, -0.535439f, -0.220882f, + -0.309067f, -0.555248f, -0.690342f, -0.664948f, -0.390102f, 0.020355f, + -0.130447f, -0.173252f, -0.170059f, -0.633663f, -0.956001f, -0.621696f, + -0.388302f, -0.342262f, -0.244370f, -0.386948f, -0.401421f, -0.172979f, + -0.206163f, -0.450058f, -0.525789f, -0.549274f, -0.349251f, -0.474613f, + -0.667976f, -0.435600f, -0.175369f, -0.196877f, -0.202976f, -0.242481f, + -0.258369f, -0.189133f, -0.395397f, -0.765499f, -0.944016f, -0.850967f, + -0.631561f, -0.152493f, -0.046432f, -0.262066f, -0.195919f, 0.048218f, + 0.084972f, 0.039902f, 0.000618f, -0.404430f, -0.447456f, -0.418076f, + -0.631935f, -0.717415f, -0.502888f, -0.530514f, -0.747826f, -0.704041f, + -0.674969f, -0.516853f, -0.418446f, -0.327740f, -0.308815f, -0.481636f, + -0.440083f, -0.481720f, -0.341053f, -0.283897f, -0.324368f, -0.352829f, + -0.434349f, -0.545589f, -0.533104f, -0.472755f, -0.570496f, -0.557735f, + -0.708176f, -0.493332f, -0.194416f, -0.186249f, -0.256710f, -0.271835f, + -0.304752f, -0.431267f, -0.422398f, -0.646725f, -0.680801f, -0.249031f, + -0.058567f, -0.213890f, -0.383949f, -0.540291f, -0.549877f, -0.225567f, + -0.037174f, -0.499874f, -0.641010f, -0.628044f, -0.390549f, -0.311497f, + -0.542313f, -0.569565f, -0.473408f, -0.331245f, -0.357197f, -0.285599f, + -0.200157f, -0.201866f, -0.124428f, -0.346016f, -0.392311f, -0.264496f, + -0.285370f, -0.436974f, -0.523483f, -0.410461f, -0.267925f, -0.055016f, + -0.382458f, -0.319771f, -0.049927f, 0.124329f, 0.266102f, -0.106606f, + -0.773647f, -0.973053f, -0.708206f, -0.486137f, -0.319923f, -0.493900f, + -0.490860f, -0.324986f, -0.147346f, -0.146088f, -0.161758f, -0.084396f, + -0.379494f, 0.041626f, -0.113361f, -0.277767f, 0.083366f, 0.126476f, + 0.139057f, 0.038040f, 0.038162f, -0.242126f, -0.411736f, -0.370049f, + -0.455357f, -0.039257f, 0.264442f, -0.271492f, -0.425346f, -0.514847f, + -0.448650f, -0.580399f, -0.652603f, -0.774803f, -0.692524f, -0.579578f, + -0.465206f, -0.386265f, -0.458012f, -0.446594f, -0.284893f, -0.345448f, + -0.350876f, -0.440350f, -0.360378f, -0.270428f, 0.237213f, -0.063602f, + -0.364529f, -0.179867f, 0.078197f, 0.117947f, -0.093410f, -0.359119f, + -0.480961f, -0.540638f, -0.436287f, -0.598576f, -0.253735f, -0.060093f, + -0.549145f, -0.808327f, -0.698593f, -0.595764f, -0.582508f, -0.497353f, + -0.480892f, -0.584240f, -0.665791f, -0.690903f, -0.743446f, -0.796677f, + -0.782391f, -0.649010f, -0.628139f, -0.880848f, -0.829361f, -0.373272f, + -0.223667f, 0.174572f, -0.348743f, -0.798901f, -0.692307f, -0.607609f, + -0.401455f, -0.480919f, -0.450798f, -0.435413f, -0.322338f, -0.228382f, + -0.450466f, -0.504440f, -0.477402f, -0.662224f, -0.583397f, -0.217445f, + -0.157459f, -0.079584f, -0.226168f, -0.488720f, -0.669624f, -0.666878f, + -0.565311f, -0.549625f, -0.364601f, -0.497627f, -0.736897f, -0.763023f, + -0.741020f, -0.404503f, 0.184814f, -0.075315f, -0.281513f, -0.532906f, + -0.405800f, -0.313438f, -0.536652f, -0.403381f, 0.011967f, 0.103310f, + -0.269848f, -0.508656f, -0.445923f, -0.644859f, -0.617870f, -0.500927f, + -0.371559f, -0.125580f, 0.028625f, -0.154713f, -0.442024f, -0.492764f, + -0.199371f, 0.236305f, 0.225925f, 0.075577f, -0.285812f, -0.437145f, + -0.374260f, -0.156693f, -0.129635f, -0.243206f, -0.123058f, 0.162148f, + -0.313152f, -0.337982f, -0.358421f, 0.040070f, 0.038925f, -0.333313f, + -0.351662f, 0.023014f, 0.091362f, -0.282890f, -0.373253f, -0.389050f, + -0.532707f, -0.423347f, -0.349968f, -0.287045f, -0.202442f, -0.308430f, + -0.222801f, -0.106323f, -0.056358f, 0.027222f, 0.390732f, 0.033558f, + -0.160088f, -0.382217f, -0.535282f, -0.515900f, -0.022736f, 0.165665f, + -0.111408f, -0.233784f, -0.312357f, -0.541885f, -0.480022f, -0.482513f, + -0.246254f, 0.132244f, 0.090134f, 0.234634f, -0.089249f, -0.460854f, + -0.515457f, -0.450874f, -0.311031f, -0.387680f, -0.360554f, -0.179241f, + -0.283817f, -0.475815f, -0.246399f, -0.388958f, -0.551140f, -0.496239f, + -0.559879f, -0.379761f, -0.254288f, -0.395111f, -0.613018f, -0.459427f, + -0.263580f, -0.268929f, 0.080826f, 0.115616f, -0.097324f, -0.325310f, + -0.480450f, -0.313286f, -0.310371f, -0.517361f, -0.288288f, -0.112679f, + -0.173241f, -0.221664f, -0.039452f, -0.107578f, -0.089630f, -0.483768f, + -0.571087f, -0.497108f, -0.321533f, -0.375492f, -0.540363f, -0.406815f, + -0.388512f, -0.514561f, -0.540192f, -0.402412f, -0.232246f, -0.304749f, + -0.383724f, -0.679596f, -0.685463f, -0.694538f, -0.642937f, -0.425789f, + 0.103271f, -0.194862f, -0.487999f, -0.717281f, -0.681850f, -0.709286f, + -0.615398f, -0.554245f, -0.254681f, -0.049950f, -0.002914f, -0.095383f, + -0.370911f, -0.564224f, -0.242714f}; + const size_t xtest = xsize / 2; + const size_t ytest = ysize / 2; + + for (intptr_t dy = -16; dy <= 16; ++dy) { + float* row = in.Row(ytest + dy); + for (intptr_t dx = -16; dx <= 16; ++dx) + row[xtest + dx] = center[(dy + 16) * 33 + (dx + 16)]; + } + + const double sigma = 7.155933; + + ImageF temp(xsize, ysize); + ImageF out_rg(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + FastGaussian(rg, in, null_pool, &temp, &out_rg); + + ImageF out_old; + { + const std::vector kernel = + GaussianKernel(static_cast(4 * sigma), static_cast(sigma)); + printf("old kernel size %zu\n", kernel.size()); + out_old = Convolve(in, kernel); + } + + printf("rg %.4f old %.4f\n", out_rg.Row(ytest)[xtest], + out_old.Row(ytest)[xtest]); +} + +// Returns megapixels/sec. "div" is a divisor for the number of repetitions, +// used to reduce benchmark duration. Func returns elapsed time. +template +double Measure(const size_t xsize, const size_t ysize, int div, + const Func& func) { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + int reps = 10 / div; +#else + int reps = 2000 / div; +#endif + if (reps < 2) reps = 2; + std::vector elapsed; + for (int i = 0; i < reps; ++i) { + elapsed.push_back(func(xsize, ysize)); + } + + double mean_elapsed; + // Potential loss of precision, and also enough samples for mode. + if (reps > 50) { + std::sort(elapsed.begin(), elapsed.end()); + mean_elapsed = jxl::HalfSampleMode()(elapsed.data(), elapsed.size()); + } else { + // Skip first(noisier) + mean_elapsed = Geomean(elapsed.data() + 1, elapsed.size() - 1); + } + return (xsize * ysize * 1E-6) / mean_elapsed; +} + +void Benchmark1D() { + // Uncomment to disable SIMD and force and scalar implementation + // hwy::DisableTargets(~HWY_SCALAR); + + const size_t length = 16384; // (same value used for running IPOL benchmark) + const double sigma = 7.0; // (from Butteraugli application) + // NOTE: MSVC and clang disagree on the required captures, so use =. + const double mps_rg1 = + Measure(length, 1, 1, [=](size_t /*xsize*/, size_t /*ysize*/) { + ImageF in(length, 1); + const float expected = length; + FillImage(expected, &in); + + ImageF temp(length, 1); + ImageF out(length, 1); + const auto rg = CreateRecursiveGaussian(sigma); + const double t0 = Now(); + FastGaussian1D(rg, in.Row(0), length, out.Row(0)); + const double t1 = Now(); + // Prevent optimizing out + const float actual = out.ConstRow(0)[length / 2]; + const float rel_err = std::abs(actual - expected) / expected; + EXPECT_LT(rel_err, 9E-5); + return t1 - t0; + }); + // Report milliseconds for comparison with IPOL benchmark + const double milliseconds = (1E-6 * length) / mps_rg1 * 1E3; + printf("%5zu @%.1f: rg 1D %e\n", length, sigma, milliseconds); +} + +void Benchmark(size_t xsize, size_t ysize, double sigma) { + // Uncomment to run AVX2 + // hwy::DisableTargets(HWY_AVX3); + + const double mps_rg = + Measure(xsize, ysize, 1, [sigma](size_t xsize, size_t ysize) { + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + + ImageF temp(xsize, ysize); + ImageF out(xsize, ysize); + const auto rg = CreateRecursiveGaussian(sigma); + ThreadPool* null_pool = nullptr; + const double t0 = Now(); + FastGaussian(rg, in, null_pool, &temp, &out); + const double t1 = Now(); + // Prevent optimizing out + const float actual = out.ConstRow(ysize / 2)[xsize / 2]; + const float rel_err = std::abs(actual - expected) / expected; + EXPECT_LT(rel_err, 9E-5); + return t1 - t0; + }); + + const double mps_fir = + Measure(xsize, ysize, 100, [sigma](size_t xsize, size_t ysize) { + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + const std::vector kernel = GaussianKernel( + static_cast(4 * sigma), static_cast(sigma)); + const double t0 = Now(); + const ImageF out = Convolve(in, kernel); + const double t1 = Now(); + + // Prevent optimizing out + const float actual = out.ConstRow(ysize / 2)[xsize / 2]; + const float rel_err = std::abs(actual - expected) / expected; + EXPECT_LT(rel_err, 5E-6); + return t1 - t0; + }); + + const double mps_simd7 = + Measure(xsize, ysize, 10, [](size_t xsize, size_t ysize) { + ImageF in(xsize, ysize); + const float expected = xsize + ysize; + FillImage(expected, &in); + ImageF out(xsize, ysize); + // Gaussian with sigma 1 + const WeightsSeparable7 weights = { + {HWY_REP4(0.383103f), HWY_REP4(0.241843f), HWY_REP4(0.060626f), + HWY_REP4(0.00598f)}, + {HWY_REP4(0.383103f), HWY_REP4(0.241843f), HWY_REP4(0.060626f), + HWY_REP4(0.00598f)}}; + ThreadPool* null_pool = nullptr; + const double t0 = Now(); + Separable7(in, Rect(in), weights, null_pool, &out); + const double t1 = Now(); + + // Prevent optimizing out + const float actual = out.ConstRow(ysize / 2)[xsize / 2]; + const float rel_err = std::abs(actual - expected) / expected; + EXPECT_LT(rel_err, 5E-6); + return t1 - t0; + }); + + printf("%zu,%zu,%.1f,%.1f,%.1f\n", xsize, ysize, mps_fir, mps_simd7, mps_rg); +} + +TEST(GaussBlurTest, BenchmarkTest) { + Benchmark1D(); + Benchmark(77, 177, 7); +} + +TEST(GaussBlurTest, DISABLED_SlowBenchmark) { + Benchmark1D(); + + // Euler's gamma as a nothing-up-my-sleeve number, so sizes are unlikely to + // interact with cache properties + const float g = 0.57721566; + const size_t d0 = 128; + const size_t d1 = static_cast(d0 / g); + const size_t d2 = static_cast(d1 / g); + const size_t d3 = static_cast(d2 / g); + Benchmark(d0, d0, 7); + Benchmark(d0, d1, 7); + Benchmark(d1, d0, 7); + Benchmark(d1, d1, 7); + Benchmark(d1, d2, 7); + Benchmark(d2, d1, 7); + Benchmark(d2, d2, 7); + Benchmark(d2, d3, 7); + Benchmark(d3, d2, 7); + Benchmark(d3, d3, 7); + + Benchmark(1920, 1080, 7); + + PROFILER_PRINT_RESULTS(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/gradient_test.cc b/third_party/jpeg-xl/lib/jxl/gradient_test.cc new file mode 100644 index 000000000000..ed28cdff9d41 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gradient_test.cc @@ -0,0 +1,214 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Returns distance of point p to line p0..p1, the result is signed and is not +// normalized. +double PointLineDist(double x0, double y0, double x1, double y1, double x, + double y) { + return (y1 - y0) * x - (x1 - x0) * y + x1 * y0 - y1 * x0; +} + +// Generates a test image with a gradient from one color to another. +// Angle in degrees, colors can be given in hex as 0xRRGGBB. The angle is the +// angle in which the change direction happens. +Image3F GenerateTestGradient(uint32_t color0, uint32_t color1, double angle, + size_t xsize, size_t ysize) { + Image3F image(xsize, ysize); + + double x0 = xsize / 2; + double y0 = ysize / 2; + double x1 = x0 + std::sin(angle / 360.0 * 2.0 * kPi); + double y1 = y0 + std::cos(angle / 360.0 * 2.0 * kPi); + + double maxdist = + std::max(fabs(PointLineDist(x0, y0, x1, y1, 0, 0)), + fabs(PointLineDist(x0, y0, x1, y1, xsize, 0))); + + for (size_t c = 0; c < 3; ++c) { + float c0 = ((color0 >> (8 * (2 - c))) & 255); + float c1 = ((color1 >> (8 * (2 - c))) & 255); + for (size_t y = 0; y < ysize; ++y) { + float* row = image.PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + double dist = PointLineDist(x0, y0, x1, y1, x, y); + double v = ((dist / maxdist) + 1.0) / 2.0; + float color = c0 * (1.0 - v) + c1 * v; + row[x] = color; + } + } + } + + return image; +} + +// Computes the max of the horizontal and vertical second derivative for each +// pixel, where second derivative means absolute value of difference of left +// delta and right delta (top/bottom for vertical direction). +// The radius over which the derivative is computed is only 1 pixel and it only +// checks two angles (hor and ver), but this approximation works well enough. +static ImageF Gradient2(const ImageF& image) { + size_t xsize = image.xsize(); + size_t ysize = image.ysize(); + ImageF image2(image.xsize(), image.ysize()); + for (size_t y = 1; y + 1 < ysize; y++) { + const auto* JXL_RESTRICT row0 = image.Row(y - 1); + const auto* JXL_RESTRICT row1 = image.Row(y); + const auto* JXL_RESTRICT row2 = image.Row(y + 1); + auto* row_out = image2.Row(y); + for (size_t x = 1; x + 1 < xsize; x++) { + float ddx = (row1[x] - row1[x - 1]) - (row1[x + 1] - row1[x]); + float ddy = (row1[x] - row0[x]) - (row2[x] - row1[x]); + row_out[x] = std::max(fabsf(ddx), fabsf(ddy)); + } + } + // Copy to the borders + if (ysize > 2) { + auto* JXL_RESTRICT row0 = image2.Row(0); + const auto* JXL_RESTRICT row1 = image2.Row(1); + const auto* JXL_RESTRICT row2 = image2.Row(ysize - 2); + auto* JXL_RESTRICT row3 = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + row0[x] = row1[x]; + row3[x] = row2[x]; + } + } else { + const auto* row0_in = image.Row(0); + const auto* row1_in = image.Row(ysize - 1); + auto* row0_out = image2.Row(0); + auto* row1_out = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + // Image too narrow, take first derivative instead + row0_out[x] = row1_out[x] = fabsf(row0_in[x] - row1_in[x]); + } + } + if (xsize > 2) { + for (size_t y = 0; y < ysize; y++) { + auto* row = image2.Row(y); + row[0] = row[1]; + row[xsize - 1] = row[xsize - 2]; + } + } else { + for (size_t y = 0; y < ysize; y++) { + const auto* JXL_RESTRICT row_in = image.Row(y); + auto* row_out = image2.Row(y); + // Image too narrow, take first derivative instead + row_out[0] = row_out[xsize - 1] = fabsf(row_in[0] - row_in[xsize - 1]); + } + } + return image2; +} + +static Image3F Gradient2(const Image3F& image) { + return Image3F(Gradient2(image.Plane(0)), Gradient2(image.Plane(1)), + Gradient2(image.Plane(2))); +} + +/* +Tests if roundtrip with jxl on a gradient image doesn't cause banding. +Only tests if use_gradient is true. Set to false for debugging to see the +distance values. +Angle in degrees, colors can be given in hex as 0xRRGGBB. +*/ +void TestGradient(ThreadPool* pool, uint32_t color0, uint32_t color1, + size_t xsize, size_t ysize, float angle, bool fast_mode, + float butteraugli_distance, bool use_gradient = true) { + CompressParams cparams; + cparams.butteraugli_distance = butteraugli_distance; + if (fast_mode) { + cparams.speed_tier = SpeedTier::kSquirrel; + } + DecompressParams dparams; + + Image3F gradient = GenerateTestGradient(color0, color1, angle, xsize, ysize); + + CodecInOut io; + io.metadata.m.SetUintSamples(8); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(std::move(gradient), io.metadata.m.color_encoding); + + CodecInOut io2; + + PaddedBytes compressed; + AuxOut* aux_out = nullptr; + PassesEncoderState enc_state; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + EXPECT_TRUE(io2.Main().TransformTo(io2.metadata.m.color_encoding, pool)); + + if (use_gradient) { + // Test that the gradient map worked. For that, we take a second derivative + // of the image with Gradient2 to measure how linear the change is in x and + // y direction. For a well handled gradient, we expect max values around + // 0.1, while if there is noticeable banding, which means the gradient map + // failed, the values are around 0.5-1.0 (regardless of + // butteraugli_distance). + Image3F gradient2 = Gradient2(*io2.Main().color()); + + std::array image_max; + Image3Max(gradient2, &image_max); + + // TODO(jyrki): These values used to work with 0.2, 0.2, 0.2. + EXPECT_LE(image_max[0], 3.15); + EXPECT_LE(image_max[1], 1.72); + EXPECT_LE(image_max[2], 5.05); + } +} + +static constexpr bool fast_mode = true; + +TEST(GradientTest, SteepGradient) { + ThreadPoolInternal pool(8); + // Relatively steep gradients, colors from the sky of stp.png + TestGradient(&pool, 0xd99d58, 0x889ab1, 512, 512, 90, fast_mode, 3.0); +} + +TEST(GradientTest, SubtleGradient) { + ThreadPoolInternal pool(8); + // Very subtle gradient + TestGradient(&pool, 0xb89b7b, 0xa89b8d, 512, 512, 90, fast_mode, 4.0); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/headers.cc b/third_party/jpeg-xl/lib/jxl/headers.cc new file mode 100644 index 000000000000..217d7e25dce9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/headers.cc @@ -0,0 +1,221 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/headers.h" + +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +struct Rational { + constexpr explicit Rational(uint32_t num, uint32_t den) + : num(num), den(den) {} + + // Returns floor(multiplicand * rational). + constexpr uint32_t MulTruncate(uint32_t multiplicand) const { + return uint64_t(multiplicand) * num / den; + } + + uint32_t num; + uint32_t den; +}; + +Rational FixedAspectRatios(uint32_t ratio) { + JXL_ASSERT(0 != ratio && ratio < 8); + // Other candidates: 5/4, 7/5, 14/9, 16/10, 5/3, 21/9, 12/5 + constexpr Rational kRatios[7] = {Rational(1, 1), // square + Rational(12, 10), // + Rational(4, 3), // camera + Rational(3, 2), // mobile camera + Rational(16, 9), // camera/display + Rational(5, 4), // + Rational(2, 1)}; // + return kRatios[ratio - 1]; +} + +uint32_t FindAspectRatio(uint32_t xsize, uint32_t ysize) { + for (uint32_t r = 1; r < 8; ++r) { + if (xsize == FixedAspectRatios(r).MulTruncate(ysize)) { + return r; + } + } + return 0; // Must send xsize instead +} + +} // namespace + +size_t SizeHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast(ysize())); + } + return small_ ? ((xsize_div8_minus_1_ + 1) * 8) : xsize_; +} + +Status SizeHeader::Set(size_t xsize64, size_t ysize64) { + if (xsize64 > 0xFFFFFFFFull || ysize64 > 0xFFFFFFFFull) { + return JXL_FAILURE("Image too large"); + } + const uint32_t xsize32 = static_cast(xsize64); + const uint32_t ysize32 = static_cast(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty image"); + small_ = xsize64 <= 256 && ysize64 <= 256 && (xsize64 % kBlockDim) == 0 && + (ysize64 % kBlockDim) == 0; + if (small_) { + ysize_div8_minus_1_ = ysize32 / 8 - 1; + } else { + ysize_ = ysize32; + } + + ratio_ = FindAspectRatio(xsize32, ysize32); + if (ratio_ == 0) { + if (small_) { + xsize_div8_minus_1_ = xsize32 / 8 - 1; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +Status PreviewHeader::Set(size_t xsize64, size_t ysize64) { + const uint32_t xsize32 = static_cast(xsize64); + const uint32_t ysize32 = static_cast(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty preview"); + div8_ = (xsize64 % kBlockDim) == 0 && (ysize64 % kBlockDim) == 0; + if (div8_) { + ysize_div8_ = ysize32 / 8; + } else { + ysize_ = ysize32; + } + + ratio_ = FindAspectRatio(xsize32, ysize32); + if (ratio_ == 0) { + if (div8_) { + xsize_div8_ = xsize32 / 8; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +size_t PreviewHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast(ysize())); + } + return div8_ ? (xsize_div8_ * 8) : xsize_; +} + +SizeHeader::SizeHeader() { Bundle::Init(this); } +Status SizeHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &small_)); + + if (visitor->Conditional(small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &ysize_div8_minus_1_)); + } + if (visitor->Conditional(!small_)) { + // (Could still be small, but non-multiple of 8.) + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &xsize_div8_minus_1_)); + } + if (visitor->Conditional(ratio_ == 0 && !small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &xsize_)); + } + + return true; +} + +PreviewHeader::PreviewHeader() { Bundle::Init(this); } +Status PreviewHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &div8_)); + + if (visitor->Conditional(div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &ysize_div8_)); + } + if (visitor->Conditional(!div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &xsize_div8_)); + } + if (visitor->Conditional(ratio_ == 0 && !div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &xsize_)); + } + + return true; +} + +AnimationHeader::AnimationHeader() { Bundle::Init(this); } +Status AnimationHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(100), Val(1000), BitsOffset(10, 1), + BitsOffset(30, 1), 1, &tps_numerator)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Val(1001), BitsOffset(8, 1), + BitsOffset(10, 1), 1, + &tps_denominator)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Bits(3), Bits(16), Bits(32), 0, &num_loops)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_timecodes)); + return true; +} + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size) { + return Bundle::Read(reader, size); +} + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out) { + const size_t max_bits = Bundle::MaxBits(size); + if (max_bits != SizeHeader::kMaxBits) { + JXL_ABORT("Please update SizeHeader::kMaxBits from %zu to %zu\n", + SizeHeader::kMaxBits, max_bits); + } + + // Only check the number of non-extension bits (extensions are unbounded). + // (Bundle::Write will call CanEncode again, but it is fast because SizeHeader + // is tiny.) + size_t extension_bits, total_bits; + JXL_RETURN_IF_ERROR(Bundle::CanEncode(size, &extension_bits, &total_bits)); + JXL_ASSERT(total_bits - extension_bits < SizeHeader::kMaxBits); + + return Bundle::Write(size, writer, layer, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/headers.h b/third_party/jpeg-xl/lib/jxl/headers.h new file mode 100644 index 000000000000..13a44f82be90 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/headers.h @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_HEADERS_H_ +#define LIB_JXL_HEADERS_H_ + +// Codestream headers, also stored in CodecInOut. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// Reserved by ISO/IEC 10918-1. LF causes files opened in text mode to be +// rejected because the marker changes to 0x0D instead. The 0xFF prefix also +// ensures there were no 7-bit transmission limitations. +static constexpr uint8_t kCodestreamMarker = 0x0A; + +// Compact representation of image dimensions (best case: 9 bits) so decoders +// can preallocate early. +class SizeHeader : public Fields { + public: + // All fields are valid after reading at most this many bits. WriteSizeHeader + // verifies this matches Bundle::MaxBits(SizeHeader). + static constexpr size_t kMaxBits = 78; + + SizeHeader(); + const char* Name() const override { return "SizeHeader"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { + return small_ ? ((ysize_div8_minus_1_ + 1) * 8) : ysize_; + } + + private: + bool small_; // xsize and ysize <= 256 and divisible by 8. + + uint32_t ysize_div8_minus_1_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_minus_1_; + uint32_t xsize_; +}; + +// (Similar to SizeHeader but different encoding because previews are smaller) +class PreviewHeader : public Fields { + public: + PreviewHeader(); + const char* Name() const override { return "PreviewHeader"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { return div8_ ? (ysize_div8_ * 8) : ysize_; } + + private: + bool div8_; // xsize and ysize divisible by 8. + + uint32_t ysize_div8_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_; + uint32_t xsize_; +}; + +struct AnimationHeader : public Fields { + AnimationHeader(); + const char* Name() const override { return "AnimationHeader"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Ticks per second (expressed as rational number to support NTSC) + uint32_t tps_numerator; + uint32_t tps_denominator; + + uint32_t num_loops; // 0 means to repeat infinitely. + + bool have_timecodes; +}; + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size); + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_HEADERS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/huffman_table.cc b/third_party/jpeg-xl/lib/jxl/huffman_table.cc new file mode 100644 index 000000000000..7286e2b90b68 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_table.cc @@ -0,0 +1,170 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/huffman_table.h" + +#include /* for memcpy */ +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/dec_huffman.h" + +namespace jxl { + +/* Returns reverse(reverse(key, len) + 1, len), where reverse(key, len) is the + bit-wise reversal of the len least significant bits of key. */ +static inline int GetNextKey(int key, int len) { + int step = 1u << (len - 1); + while (key & step) { + step >>= 1; + } + return (key & (step - 1)) + step; +} + +/* Stores code in table[0], table[step], table[2*step], ..., table[end] */ +/* Assumes that end is an integer multiple of step */ +static inline void ReplicateValue(HuffmanCode* table, int step, int end, + HuffmanCode code) { + do { + end -= step; + table[end] = code; + } while (end > 0); +} + +/* Returns the table width of the next 2nd level table. count is the histogram + of bit lengths for the remaining symbols, len is the code length of the next + processed symbol */ +static inline size_t NextTableBitSize(const uint16_t* const count, size_t len, + int root_bits) { + size_t left = 1u << (len - root_bits); + while (len < PREFIX_MAX_BITS) { + if (left <= count[len]) break; + left -= count[len]; + ++len; + left <<= 1; + } + return len - root_bits; +} + +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* const code_lengths, + size_t code_lengths_size, uint16_t* count) { + HuffmanCode code; /* current table entry */ + HuffmanCode* table; /* next available space in table */ + size_t len; /* current code length */ + size_t symbol; /* symbol index in original or sorted table */ + int key; /* reversed prefix code */ + int step; /* step size to replicate values in current table */ + int low; /* low bits for current root entry */ + int mask; /* mask for low bits */ + size_t table_bits; /* key length of current table */ + int table_size; /* size of current table */ + int total_size; /* sum of root table size and 2nd level table sizes */ + /* offsets in sorted table for each length */ + uint16_t offset[PREFIX_MAX_BITS + 1]; + size_t max_length = 1; + + if (code_lengths_size > 1u << PREFIX_MAX_BITS) return 0; + + /* symbols sorted by code length */ + std::vector sorted_storage(code_lengths_size); + uint16_t* sorted = sorted_storage.data(); + + /* generate offsets into sorted symbol table by code length */ + { + uint16_t sum = 0; + for (len = 1; len <= PREFIX_MAX_BITS; len++) { + offset[len] = sum; + if (count[len]) { + sum = static_cast(sum + count[len]); + max_length = len; + } + } + } + + /* sort symbols by length, by symbol order within each length */ + for (symbol = 0; symbol < code_lengths_size; symbol++) { + if (code_lengths[symbol] != 0) { + sorted[offset[code_lengths[symbol]]++] = symbol; + } + } + + table = root_table; + table_bits = root_bits; + table_size = 1u << table_bits; + total_size = table_size; + + /* special case code with only one value */ + if (offset[PREFIX_MAX_BITS] == 1) { + code.bits = 0; + code.value = static_cast(sorted[0]); + for (key = 0; key < total_size; ++key) { + table[key] = code; + } + return total_size; + } + + /* fill in root table */ + /* let's reduce the table size to a smaller size if possible, and */ + /* create the repetitions by memcpy if possible in the coming loop */ + if (table_bits > max_length) { + table_bits = max_length; + table_size = 1u << table_bits; + } + key = 0; + symbol = 0; + code.bits = 1; + step = 2; + do { + for (; count[code.bits] != 0; --count[code.bits]) { + code.value = static_cast(sorted[symbol++]); + ReplicateValue(&table[key], step, table_size, code); + key = GetNextKey(key, code.bits); + } + step <<= 1; + } while (++code.bits <= table_bits); + + /* if root_bits != table_bits we only created one fraction of the */ + /* table, and we need to replicate it now. */ + while (total_size != table_size) { + memcpy(&table[table_size], &table[0], table_size * sizeof(table[0])); + table_size <<= 1; + } + + /* fill in 2nd level tables and add pointers to root table */ + mask = total_size - 1; + low = -1; + for (len = root_bits + 1, step = 2; len <= max_length; ++len, step <<= 1) { + for (; count[len] != 0; --count[len]) { + if ((key & mask) != low) { + table += table_size; + table_bits = NextTableBitSize(count, len, root_bits); + table_size = 1u << table_bits; + total_size += table_size; + low = key & mask; + root_table[low].bits = static_cast(table_bits + root_bits); + root_table[low].value = + static_cast((table - root_table) - low); + } + code.bits = static_cast(len - root_bits); + code.value = static_cast(sorted[symbol++]); + ReplicateValue(&table[key >> root_bits], step, table_size, code); + key = GetNextKey(key, len); + } + } + + return total_size; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/huffman_table.h b/third_party/jpeg-xl/lib/jxl/huffman_table.h new file mode 100644 index 000000000000..ff404e8ea296 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_table.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_HUFFMAN_TABLE_H_ +#define LIB_JXL_HUFFMAN_TABLE_H_ + +#include +#include + +namespace jxl { + +struct HuffmanCode { + uint8_t bits; /* number of bits used for this symbol */ + uint16_t value; /* symbol value or table offset */ +}; + +/* Builds Huffman lookup table assuming code lengths are in symbol order. */ +/* Returns 0 in case of error (invalid tree or memory error), otherwise + populated size of table. */ +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* code_lengths, + size_t code_lengths_size, uint16_t* count); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TABLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/huffman_tree.cc b/third_party/jpeg-xl/lib/jxl/huffman_tree.cc new file mode 100644 index 000000000000..ab88c19d3952 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_tree.cc @@ -0,0 +1,337 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/huffman_tree.h" + +#include +#include +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level) { + if (p.index_left >= 0) { + ++level; + SetDepth(pool[p.index_left], pool, depth, level); + SetDepth(pool[p.index_right_or_value], pool, depth, level); + } else { + depth[p.index_right_or_value] = level; + } +} + +// Sort the root nodes, least popular first. +static JXL_INLINE bool Compare(const HuffmanTree& v0, const HuffmanTree& v1) { + return v0.total_count < v1.total_count; +} + +// This function will create a Huffman tree. +// +// The catch here is that the tree cannot be arbitrarily deep. +// Brotli specifies a maximum depth of 15 bits for "code trees" +// and 7 bits for "code length code trees." +// +// count_limit is the value that is to be faked as the minimum value +// and this minimum value is raised until the tree matches the +// maximum length requirement. +// +// This algorithm is not of excellent performance for very long data blocks, +// especially when population counts are longer than 2**tree_limit, but +// we are not planning to use this with extremely long blocks. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, const size_t length, + const int tree_limit, uint8_t* depth) { + // For block sizes below 64 kB, we never need to do a second iteration + // of this loop. Probably all of our block sizes will be smaller than + // that, so this loop is mostly of academic interest. If we actually + // would need this, we would be better off with the Katajainen algorithm. + for (uint32_t count_limit = 1;; count_limit *= 2) { + std::vector tree; + tree.reserve(2 * length + 1); + + for (size_t i = length; i != 0;) { + --i; + if (data[i]) { + const uint32_t count = std::max(data[i], count_limit - 1); + tree.emplace_back(count, -1, static_cast(i)); + } + } + + const size_t n = tree.size(); + if (n == 1) { + // Fake value; will be fixed on upper level. + depth[tree[0].index_right_or_value] = 1; + break; + } + + std::stable_sort(tree.begin(), tree.end(), Compare); + + // The nodes are: + // [0, n): the sorted leaf nodes that we start with. + // [n]: we add a sentinel here. + // [n + 1, 2n): new parent nodes are added here, starting from + // (n+1). These are naturally in ascending order. + // [2n]: we add a sentinel at the end as well. + // There will be (2n+1) elements at the end. + const HuffmanTree sentinel(std::numeric_limits::max(), -1, -1); + tree.push_back(sentinel); + tree.push_back(sentinel); + + size_t i = 0; // Points to the next leaf node. + size_t j = n + 1; // Points to the next non-leaf node. + for (size_t k = n - 1; k != 0; --k) { + size_t left, right; + if (tree[i].total_count <= tree[j].total_count) { + left = i; + ++i; + } else { + left = j; + ++j; + } + if (tree[i].total_count <= tree[j].total_count) { + right = i; + ++i; + } else { + right = j; + ++j; + } + + // The sentinel node becomes the parent node. + size_t j_end = tree.size() - 1; + tree[j_end].total_count = + tree[left].total_count + tree[right].total_count; + tree[j_end].index_left = static_cast(left); + tree[j_end].index_right_or_value = static_cast(right); + + // Add back the last sentinel node. + tree.push_back(sentinel); + } + JXL_DASSERT(tree.size() == 2 * n + 1); + SetDepth(tree[2 * n - 1], &tree[0], depth, 0); + + // We need to pack the Huffman tree in tree_limit bits. + // If this was not successful, add fake entities to the lowest values + // and retry. + if (*std::max_element(&depth[0], &depth[length]) <= tree_limit) { + break; + } + } +} + +void Reverse(uint8_t* v, size_t start, size_t end) { + --end; + while (start < end) { + uint8_t tmp = v[start]; + v[start] = v[end]; + v[end] = tmp; + ++start; + --end; + } +} + +void WriteHuffmanTreeRepetitions(const uint8_t previous_value, + const uint8_t value, size_t repetitions, + size_t* tree_size, uint8_t* tree, + uint8_t* extra_bits_data) { + JXL_DASSERT(repetitions > 0); + if (previous_value != value) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions == 7) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 16; + extra_bits_data[*tree_size] = repetitions & 0x3; + ++(*tree_size); + repetitions >>= 2; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +void WriteHuffmanTreeRepetitionsZeros(size_t repetitions, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + if (repetitions == 11) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 17; + extra_bits_data[*tree_size] = repetitions & 0x7; + ++(*tree_size); + repetitions >>= 3; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +static void DecideOverRleUse(const uint8_t* depth, const size_t length, + bool* use_rle_for_non_zero, + bool* use_rle_for_zero) { + size_t total_reps_zero = 0; + size_t total_reps_non_zero = 0; + size_t count_reps_zero = 1; + size_t count_reps_non_zero = 1; + for (size_t i = 0; i < length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + for (size_t k = i + 1; k < length && depth[k] == value; ++k) { + ++reps; + } + if (reps >= 3 && value == 0) { + total_reps_zero += reps; + ++count_reps_zero; + } + if (reps >= 4 && value != 0) { + total_reps_non_zero += reps; + ++count_reps_non_zero; + } + i += reps; + } + *use_rle_for_non_zero = total_reps_non_zero > count_reps_non_zero * 2; + *use_rle_for_zero = total_reps_zero > count_reps_zero * 2; +} + +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + uint8_t previous_value = 8; + + // Throw away trailing zeros. + size_t new_length = length; + for (size_t i = 0; i < length; ++i) { + if (depth[length - i - 1] == 0) { + --new_length; + } else { + break; + } + } + + // First gather statistics on if it is a good idea to do rle. + bool use_rle_for_non_zero = false; + bool use_rle_for_zero = false; + if (length > 50) { + // Find rle coding for longer codes. + // Shorter codes seem not to benefit from rle. + DecideOverRleUse(depth, new_length, &use_rle_for_non_zero, + &use_rle_for_zero); + } + + // Actual rle coding. + for (size_t i = 0; i < new_length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + if ((value != 0 && use_rle_for_non_zero) || + (value == 0 && use_rle_for_zero)) { + for (size_t k = i + 1; k < new_length && depth[k] == value; ++k) { + ++reps; + } + } + if (value == 0) { + WriteHuffmanTreeRepetitionsZeros(reps, tree_size, tree, extra_bits_data); + } else { + WriteHuffmanTreeRepetitions(previous_value, value, reps, tree_size, tree, + extra_bits_data); + previous_value = value; + } + i += reps; + } +} + +namespace { + +uint16_t ReverseBits(int num_bits, uint16_t bits) { + static const size_t kLut[16] = {// Pre-reversed 4-bit values. + 0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe, + 0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf}; + size_t retval = kLut[bits & 0xf]; + for (int i = 4; i < num_bits; i += 4) { + retval <<= 4; + bits = static_cast(bits >> 4); + retval |= kLut[bits & 0xf]; + } + retval >>= (-num_bits & 0x3); + return static_cast(retval); +} + +} // namespace + +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits) { + // In Brotli, all bit depths are [1..15] + // 0 bit depth means that the symbol does not exist. + const int kMaxBits = 16; // 0..15 are values for bits + uint16_t bl_count[kMaxBits] = {0}; + { + for (size_t i = 0; i < len; ++i) { + ++bl_count[depth[i]]; + } + bl_count[0] = 0; + } + uint16_t next_code[kMaxBits]; + next_code[0] = 0; + { + int code = 0; + for (size_t i = 1; i < kMaxBits; ++i) { + code = (code + bl_count[i - 1]) << 1; + next_code[i] = static_cast(code); + } + } + for (size_t i = 0; i < len; ++i) { + if (depth[i]) { + bits[i] = ReverseBits(depth[i], next_code[depth[i]]++); + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/huffman_tree.h b/third_party/jpeg-xl/lib/jxl/huffman_tree.h new file mode 100644 index 000000000000..4550f7750b6f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_tree.h @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Library for creating Huffman codes from population counts. + +#ifndef LIB_JXL_HUFFMAN_TREE_H_ +#define LIB_JXL_HUFFMAN_TREE_H_ + +#include +#include + +namespace jxl { + +// A node of a Huffman tree. +struct HuffmanTree { + HuffmanTree(uint32_t count, int16_t left, int16_t right) + : total_count(count), index_left(left), index_right_or_value(right) {} + uint32_t total_count; + int16_t index_left; + int16_t index_right_or_value; +}; + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level); + +// This function will create a Huffman tree. +// +// The (data,length) contains the population counts. +// The tree_limit is the maximum bit depth of the Huffman codes. +// +// The depth contains the tree, i.e., how many bits are used for +// the symbol. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, const size_t length, + const int tree_limit, uint8_t* depth); + +// Write a Huffman tree from bit depths into the bitstream representation +// of a Huffman tree. The generated Huffman tree is to be compressed once +// more using a Huffman tree +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data); + +// Get the actual bit values for a tree of bit depths. +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TREE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/iaca_test.cc b/third_party/jpeg-xl/lib/jxl/iaca_test.cc new file mode 100644 index 000000000000..1a155a10b758 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/iaca_test.cc @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/iaca.h" + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(IacaTest, MarkersDefaultToDisabledAndDoNotCrash) { + BeginIACA(); + EndIACA(); +} + +TEST(IacaTest, ScopeDefaultToDisabledAndDoNotCrash) { ScopeIACA iaca; } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec.cc b/third_party/jpeg-xl/lib/jxl/icc_codec.cc new file mode 100644 index 000000000000..8b469265b846 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec.cc @@ -0,0 +1,383 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/icc_codec.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" + +namespace jxl { +namespace { + +uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, size_t* pos) { + size_t i; + uint64_t ret = 0; + for (i = 0; *pos + i < inputSize && i < 10; ++i) { + ret |= uint64_t(input[*pos + i] & 127) << uint64_t(7 * i); + // If the next-byte flag is not set, stop + if ((input[*pos + i] & 128) == 0) break; + } + // TODO: Return a decoding error if i == 10. + *pos += i + 1; + return ret; +} + +// Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd" +// into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and +// width rows. There are size elements, size may be < width * height, if so the +// last elements of the rightmost column are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the end of the bottom row. The input is the input matrix in +// scanline order but with missing elements skipped (which may occur in multiple +// locations), the output is the result matrix in scanline order (with +// no need to skip missing elements as they are past the end of the data). +void Shuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of output + PaddedBytes result(size); + // i = output index, j input index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[i] = data[j]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved; +// currently DecodeVarInt does not signal the errors, and marks +// 11 bytes as used even if only 10 are used (and 9 is enough for +// 63-bit values). +constexpr const size_t kPreambleSize = 22; // enough for reading 2 VarInts + +} // namespace + +// Mimics the beginning of UnpredictICC for quick validity check. +// At least kPreambleSize bytes of data should be valid at invocation time. +Status CheckPreamble(const PaddedBytes& data, size_t enc_size, + size_t output_limit) { + const uint8_t* enc = data.data(); + size_t size = data.size(); + size_t pos = 0; + uint64_t osize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + // We expect that UnpredictICC inflates input, not the other way round. + if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC"); + if (output_limit && osize > output_limit) { + return JXL_FAILURE("Decoded ICC is too large"); + } + return true; +} + +// Decodes the result of PredictICC back to a valid ICC profile. +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { + if (!result->empty()) return JXL_FAILURE("result must be empty initially"); + size_t pos = 0; + // TODO(lode): technically speaking we need to check that the entire varint + // decoding never goes out of bounds, not just the first byte. This requires + // a DecodeVarInt function that returns an error code. It is safe to use + // DecodeVarInt with out of bounds values, it silently returns, but the + // specification requires an error. Idem for all DecodeVarInt below. + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t osize = DecodeVarInt(enc, size, &pos); // Output size + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); // Commands size + // Every command is translated to at least on byte. + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + size_t cpos = pos; // pos in commands stream + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + size_t commands_end = cpos + csize; + pos = commands_end; // pos in data stream + + // Header + PaddedBytes header = ICCInitialHeaderPrediction(); + EncodeUint32(0, osize, &header); + for (size_t i = 0; i <= kICCHeaderSize; i++) { + if (result->size() == osize) { + if (cpos != commands_end) return JXL_FAILURE("Not all commands used"); + if (pos != size) return JXL_FAILURE("Not all data used"); + return true; // Valid end + } + if (i == kICCHeaderSize) break; // Done + ICCPredictHeader(result->data(), result->size(), header.data(), i); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + result->push_back(enc[pos++] + header[i]); + } + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + + // Tag list + uint64_t numtags = DecodeVarInt(enc, size, &cpos); + + if (numtags != 0) { + numtags--; + JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags)); + AppendUint32(numtags, result); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint64_t prevtagsize = 0; + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + uint8_t tagcode = command & 63; + Tag tag; + if (tagcode == 0) { + break; + } else if (tagcode == kCommandTagUnknown) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size)); + tag = DecodeKeyword(enc, size, pos); + pos += 4; + } else if (tagcode == kCommandTagTRC) { + tag = kRtrcTag; + } else if (tagcode == kCommandTagXYZ) { + tag = kRxyzTag; + } else { + if (tagcode - kCommandTagStringFirst >= kNumTagStrings) { + return JXL_FAILURE("Unknown tagcode"); + } + tag = *kTagStrings[tagcode - kCommandTagStringFirst]; + } + AppendKeyword(tag, result); + + uint64_t tagstart; + uint64_t tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + tagsize = 20; + } + + if (command & kFlagBitOffset) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagstart = DecodeVarInt(enc, size, &cpos); + } else { + JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart)); + tagstart = prevtagstart + prevtagsize; + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart)); + AppendUint32(tagstart, result); + if (command & kFlagBitSize) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagsize = DecodeVarInt(enc, size, &cpos); + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize)); + AppendUint32(tagsize, result); + prevtagstart = tagstart; + prevtagsize = tagsize; + + if (tagcode == kCommandTagTRC) { + AppendKeyword(kGtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + AppendKeyword(kBtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + } + + if (tagcode == kCommandTagXYZ) { + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2)); + AppendKeyword(kGxyzTag, result); + AppendUint32(tagstart + tagsize, result); + AppendUint32(tagsize, result); + AppendKeyword(kBxyzTag, result); + AppendUint32(tagstart + tagsize * 2, result); + AppendUint32(tagsize, result); + } + } + } + + // Main Content + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + if (command == kCommandInsert) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + for (size_t i = 0; i < num; i++) { + result->push_back(enc[pos++]); + } + } else if (command == kCommandShuffle2 || command == kCommandShuffle4) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (command == kCommandShuffle2) { + Shuffle(shuffled.data(), num, 2); + } else if (command == kCommandShuffle4) { + Shuffle(shuffled.data(), num, 4); + } + for (size_t i = 0; i < num; i++) { + result->push_back(shuffled[i]); + pos++; + } + } else if (command == kCommandPredict) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end)); + uint8_t flags = enc[cpos++]; + + size_t width = (flags & 3) + 1; + if (width == 3) return JXL_FAILURE("Invalid width"); + + int order = (flags & 12) >> 2; + if (order == 3) return JXL_FAILURE("Invalid order"); + + uint64_t stride = width; + if (flags & 16) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + stride = DecodeVarInt(enc, size, &cpos); + if (stride < width) { + return JXL_FAILURE("Invalid stride"); + } + } + // If stride * 4 >= result->size(), return failure. The check + // "size == 0 || ((size - 1) >> 2) < stride" corresponds to + // "stride * 4 >= size", but does not suffer from integer overflow. + // This check is more strict than necessary but follows the specification + // and the encoder should ensure this is followed. + if (result->empty() || ((result->size() - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); // in bytes + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (width > 1) Shuffle(shuffled.data(), num, width); + + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = LinearPredictICCValue(result->data(), start, i, + stride, width, order); + result->push_back(predicted + shuffled[i]); + } + pos += num; + } else if (command == kCommandXYZ) { + AppendKeyword(kXyz_Tag, result); + for (int i = 0; i < 4; i++) result->push_back(0); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size)); + for (size_t i = 0; i < 12; i++) { + result->push_back(enc[pos++]); + } + } else if (command >= kCommandTypeStartFirst && + command < kCommandTypeStartFirst + kNumTypeStrings) { + AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result); + for (size_t i = 0; i < 4; i++) { + result->push_back(0); + } + } else { + return JXL_FAILURE("Unknown command"); + } + } + + if (pos != size) return JXL_FAILURE("Not all data used"); + if (result->size() != osize) return JXL_FAILURE("Invalid result size"); + + return true; +} + +Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc, + size_t output_limit) { + icc->clear(); + const auto checkEndOfInput = [&]() -> Status { + if (reader->AllReadsWithinBounds()) return true; + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for reading ICC profile"); + }; + JXL_RETURN_IF_ERROR(checkEndOfInput()); + uint64_t enc_size = U64Coder::Read(reader); + if (enc_size > 268435456) { + // Avoid too large memory allocation for invalid file. + // TODO(lode): a more accurate limit would be the filesize of the JXL file, + // if we can have it available here. + return JXL_FAILURE("Too large encoded profile"); + } + PaddedBytes decompressed; + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, kNumICCContexts, &code, &context_map)); + ANSSymbolReader ans_reader(&code, reader); + size_t used_bits_base = reader->TotalBitsConsumed(); + size_t i = 0; + decompressed.resize(std::min(i + 0x400, enc_size)); + + for (; i < std::min(2, enc_size); i++) { + decompressed[i] = ans_reader.ReadHybridUint( + ICCANSContext(i, i > 0 ? decompressed[i - 1] : 0, + i > 1 ? decompressed[i - 2] : 0), + reader, context_map); + } + if (enc_size > kPreambleSize) { + for (; i < kPreambleSize; i++) { + decompressed[i] = ans_reader.ReadHybridUint( + ICCANSContext(i, decompressed[i - 1], decompressed[i - 2]), reader, + context_map); + } + JXL_RETURN_IF_ERROR(checkEndOfInput()); + JXL_RETURN_IF_ERROR(CheckPreamble(decompressed, enc_size, output_limit)); + } + for (; i < enc_size; i++) { + if ((i & 0x3FF) == 0) { + JXL_RETURN_IF_ERROR(checkEndOfInput()); + if ((i > 0) && (((i & 0xFFFF) == 0))) { + float used_bytes = + (reader->TotalBitsConsumed() - used_bits_base) / 8.0f; + if (i > used_bytes * 256) return JXL_FAILURE("Corrupted stream"); + } + decompressed.resize(std::min(i + 0x400, enc_size)); + } + JXL_DASSERT(i >= 2); + decompressed[i] = ans_reader.ReadHybridUint( + ICCANSContext(i, decompressed[i - 1], decompressed[i - 2]), reader, + context_map); + } + JXL_RETURN_IF_ERROR(checkEndOfInput()); + if (!ans_reader.CheckANSFinalState()) { + return JXL_FAILURE("Corrupted ICC profile"); + } + + JXL_RETURN_IF_ERROR( + UnpredictICC(decompressed.data(), decompressed.size(), icc)); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec.h b/third_party/jpeg-xl/lib/jxl/icc_codec.h new file mode 100644 index 000000000000..ec27eccfa36e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ICC_CODEC_H_ +#define LIB_JXL_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Should still be called if `icc.empty()` - if so, writes only 1 bit. +Status WriteICC(const PaddedBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +// `icc` may be empty afterwards - if so, call CreateProfile. Does not append, +// clears any original data that was in icc. +// If `output_limit` is not 0, then returns error if resulting profile would be +// longer than `output_limit` +Status ReadICC(BitReader* JXL_RESTRICT reader, PaddedBytes* JXL_RESTRICT icc, + size_t output_limit = 0); + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +// Exposed only for testing +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc b/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc new file mode 100644 index 000000000000..a1709de034a4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc @@ -0,0 +1,201 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/icc_codec_common.h" + +#include + +#include +#include +#include + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { +static uint8_t ByteKind1(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b == 0) return 2; + if (b == 1) return 3; + if (b < 16) return 4; + if (b == 255) return 6; + if (b > 240) return 5; + return 7; +} + +static uint8_t ByteKind2(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b < 16) return 2; + if (b > 240) return 3; + return 4; +} + +template +T PredictValue(T p1, T p2, T p3, int order) { + if (order == 0) return p1; + if (order == 1) return 2 * p1 - p2; + if (order == 2) return 3 * p1 - 3 * p2 + p3; + return 0; +} +} // namespace + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos) { + return pos + 4 > size ? 0 : LoadBE32(data + pos); +} + +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data) { + if (pos + 4 > data->size()) return; + StoreBE32(value, data->data() + pos); +} + +void AppendUint32(uint32_t value, PaddedBytes* data) { + data->resize(data->size() + 4); + EncodeUint32(data->size() - 4, value, data); +} + +typedef std::array Tag; + +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos) { + if (pos + 4 > size) return {' ', ' ', ' ', ' '}; + return {data[pos], data[pos + 1], data[pos + 2], data[pos + 3]}; +} + +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos) { + if (keyword.size() != 4 || pos + 3 >= size) return; + for (size_t i = 0; i < 4; ++i) data[pos + i] = keyword[i]; +} + +void AppendKeyword(const Tag& keyword, PaddedBytes* data) { + JXL_ASSERT(keyword.size() == 4); + data->append(keyword); +} + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return JXL_FAILURE("Out of bounds"); + if (pos < a) return JXL_FAILURE("Out of bounds"); // overflow happened + return true; +} + +Status CheckIs32Bit(uint64_t v) { + static constexpr const uint64_t kUpper32 = ~static_cast(0xFFFFFFFF); + if ((v & kUpper32) != 0) return JXL_FAILURE("32-bit value expected"); + return true; +} + +PaddedBytes ICCInitialHeaderPrediction() { + PaddedBytes result(kICCHeaderSize); + for (size_t i = 0; i < kICCHeaderSize; i++) { + result[i] = 0; + } + result[8] = 4; + EncodeKeyword(kMntrTag, result.data(), result.size(), 12); + EncodeKeyword(kRgb_Tag, result.data(), result.size(), 16); + EncodeKeyword(kXyz_Tag, result.data(), result.size(), 20); + EncodeKeyword(kAcspTag, result.data(), result.size(), 36); + result[68] = 0; + result[69] = 0; + result[70] = 246; + result[71] = 214; + result[72] = 0; + result[73] = 1; + result[74] = 0; + result[75] = 0; + result[76] = 0; + result[77] = 0; + result[78] = 211; + result[79] = 45; + return result; +} + +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos) { + if (pos == 8 && size >= 8) { + header[80] = icc[4]; + header[81] = icc[5]; + header[82] = icc[6]; + header[83] = icc[7]; + } + if (pos == 41 && size >= 41) { + if (icc[40] == 'A') { + header[41] = 'P'; + header[42] = 'P'; + header[43] = 'L'; + } + if (icc[40] == 'M') { + header[41] = 'S'; + header[42] = 'F'; + header[43] = 'T'; + } + } + if (pos == 42 && size >= 42) { + if (icc[40] == 'S' && icc[41] == 'G') { + header[42] = 'I'; + header[43] = ' '; + } + if (icc[40] == 'S' && icc[41] == 'U') { + header[42] = 'N'; + header[43] = 'W'; + } + } +} + +// Predicts a value with linear prediction of given order (0-2), for integers +// with width bytes and given stride in bytes between values. +// The start position is at start + i, and the relevant modulus of i describes +// which byte of the multi-byte integer is being handled. +// The value start + i must be at least stride * 4. +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order) { + size_t pos = start + i; + if (width == 1) { + uint8_t p1 = data[pos - stride]; + uint8_t p2 = data[pos - stride * 2]; + uint8_t p3 = data[pos - stride * 3]; + return PredictValue(p1, p2, p3, order); + } else if (width == 2) { + size_t p = start + (i & ~1); + uint16_t p1 = (data[p - stride * 1] << 8) + data[p - stride * 1 + 1]; + uint16_t p2 = (data[p - stride * 2] << 8) + data[p - stride * 2 + 1]; + uint16_t p3 = (data[p - stride * 3] << 8) + data[p - stride * 3 + 1]; + uint16_t pred = PredictValue(p1, p2, p3, order); + return (i & 1) ? (pred & 255) : ((pred >> 8) & 255); + } else { + size_t p = start + (i & ~3); + uint32_t p1 = DecodeUint32(data, pos, p - stride); + uint32_t p2 = DecodeUint32(data, pos, p - stride * 2); + uint32_t p3 = DecodeUint32(data, pos, p - stride * 3); + uint32_t pred = PredictValue(p1, p2, p3, order); + unsigned shiftbytes = 3 - (i & 3); + return (pred >> (shiftbytes * 8)) & 255; + } +} + +size_t ICCANSContext(size_t i, size_t b1, size_t b2) { + if (i <= 128) return 0; + return 1 + ByteKind1(b1) + ByteKind2(b2) * 8; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_common.h b/third_party/jpeg-xl/lib/jxl/icc_codec_common.h new file mode 100644 index 000000000000..b50246a353ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_common.h @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_ICC_CODEC_COMMON_H_ +#define LIB_JXL_ICC_CODEC_COMMON_H_ + +// Compressed representation of ICC profiles. + +#include +#include + +#include + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +static constexpr size_t kICCHeaderSize = 128; + +typedef std::array Tag; + +static const Tag kAcspTag = {'a', 'c', 's', 'p'}; +static const Tag kBkptTag = {'b', 'k', 'p', 't'}; +static const Tag kBtrcTag = {'b', 'T', 'R', 'C'}; +static const Tag kBxyzTag = {'b', 'X', 'Y', 'Z'}; +static const Tag kChadTag = {'c', 'h', 'a', 'd'}; +static const Tag kChrmTag = {'c', 'h', 'r', 'm'}; +static const Tag kCprtTag = {'c', 'p', 'r', 't'}; +static const Tag kCurvTag = {'c', 'u', 'r', 'v'}; +static const Tag kDescTag = {'d', 'e', 's', 'c'}; +static const Tag kDmddTag = {'d', 'm', 'd', 'd'}; +static const Tag kDmndTag = {'d', 'm', 'n', 'd'}; +static const Tag kGbd_Tag = {'g', 'b', 'd', ' '}; +static const Tag kGtrcTag = {'g', 'T', 'R', 'C'}; +static const Tag kGxyzTag = {'g', 'X', 'Y', 'Z'}; +static const Tag kKtrcTag = {'k', 'T', 'R', 'C'}; +static const Tag kKxyzTag = {'k', 'X', 'Y', 'Z'}; +static const Tag kLumiTag = {'l', 'u', 'm', 'i'}; +static const Tag kMab_Tag = {'m', 'A', 'B', ' '}; +static const Tag kMba_Tag = {'m', 'B', 'A', ' '}; +static const Tag kMlucTag = {'m', 'l', 'u', 'c'}; +static const Tag kMntrTag = {'m', 'n', 't', 'r'}; +static const Tag kParaTag = {'p', 'a', 'r', 'a'}; +static const Tag kRgb_Tag = {'R', 'G', 'B', ' '}; +static const Tag kRtrcTag = {'r', 'T', 'R', 'C'}; +static const Tag kRxyzTag = {'r', 'X', 'Y', 'Z'}; +static const Tag kSf32Tag = {'s', 'f', '3', '2'}; +static const Tag kTextTag = {'t', 'e', 'x', 't'}; +static const Tag kVcgtTag = {'v', 'c', 'g', 't'}; +static const Tag kWtptTag = {'w', 't', 'p', 't'}; +static const Tag kXyz_Tag = {'X', 'Y', 'Z', ' '}; + +// Tag names focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTagStrings = 17; +static constexpr const Tag* kTagStrings[kNumTagStrings] = { + &kCprtTag, &kWtptTag, &kBkptTag, &kRxyzTag, &kGxyzTag, &kBxyzTag, + &kKxyzTag, &kRtrcTag, &kGtrcTag, &kBtrcTag, &kKtrcTag, &kChadTag, + &kDescTag, &kChrmTag, &kDmndTag, &kDmddTag, &kLumiTag}; + +static constexpr size_t kCommandTagUnknown = 1; +static constexpr size_t kCommandTagTRC = 2; +static constexpr size_t kCommandTagXYZ = 3; +static constexpr size_t kCommandTagStringFirst = 4; + +// Tag types focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTypeStrings = 8; +static constexpr const Tag* kTypeStrings[kNumTypeStrings] = { + &kXyz_Tag, &kDescTag, &kTextTag, &kMlucTag, + &kParaTag, &kCurvTag, &kSf32Tag, &kGbd_Tag}; + +static constexpr size_t kCommandInsert = 1; +static constexpr size_t kCommandShuffle2 = 2; +static constexpr size_t kCommandShuffle4 = 3; +static constexpr size_t kCommandPredict = 4; +static constexpr size_t kCommandXYZ = 10; +static constexpr size_t kCommandTypeStartFirst = 16; + +static constexpr size_t kFlagBitOffset = 64; +static constexpr size_t kFlagBitSize = 128; + +static constexpr size_t kNumICCContexts = 41; + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos); +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data); +void AppendUint32(uint32_t value, PaddedBytes* data); +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos); +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos); +void AppendKeyword(const Tag& keyword, PaddedBytes* data); + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(size_t a, size_t b, size_t size); +Status CheckIs32Bit(uint64_t v); + +PaddedBytes ICCInitialHeaderPrediction(); +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos); +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order); +size_t ICCANSContext(size_t i, size_t b1, size_t b2); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc b/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc new file mode 100644 index 000000000000..36adf49e334a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc @@ -0,0 +1,216 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/icc_codec.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/enc_icc_codec.h" + +namespace jxl { +namespace { + +void TestProfile(const PaddedBytes& icc) { + BitWriter writer; + ASSERT_TRUE(WriteICC(icc, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + PaddedBytes dec; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(icc.size(), dec.size()); + if (icc.size() == dec.size()) { + for (size_t i = 0; i < icc.size(); i++) { + EXPECT_EQ(icc[i], dec[i]); + if (icc[i] != dec[i]) break; // One output is enough + } + } +} + +void TestProfile(const std::string& icc) { + PaddedBytes bytes(icc.size()); + for (size_t i = 0; i < icc.size(); i++) { + bytes[i] = icc[i]; + } + TestProfile(bytes); +} + +// Valid profile from one of the images output by the decoder. +static const unsigned char kTestProfile[] = { + 0x00, 0x00, 0x03, 0x80, 0x6c, 0x63, 0x6d, 0x73, 0x04, 0x30, 0x00, 0x00, + 0x6d, 0x6e, 0x74, 0x72, 0x52, 0x47, 0x42, 0x20, 0x58, 0x59, 0x5a, 0x20, + 0x07, 0xe3, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x0f, 0x00, 0x32, 0x00, 0x2e, + 0x61, 0x63, 0x73, 0x70, 0x41, 0x50, 0x50, 0x4c, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x6c, 0x63, 0x6d, 0x73, + 0x5f, 0x07, 0x0d, 0x3e, 0x4d, 0x32, 0xf2, 0x6e, 0x5d, 0x77, 0x26, 0xcc, + 0x23, 0xb0, 0x6a, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, + 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x01, 0x20, 0x00, 0x00, 0x00, 0x42, + 0x63, 0x70, 0x72, 0x74, 0x00, 0x00, 0x01, 0x64, 0x00, 0x00, 0x01, 0x00, + 0x77, 0x74, 0x70, 0x74, 0x00, 0x00, 0x02, 0x64, 0x00, 0x00, 0x00, 0x14, + 0x63, 0x68, 0x61, 0x64, 0x00, 0x00, 0x02, 0x78, 0x00, 0x00, 0x00, 0x2c, + 0x72, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xa4, 0x00, 0x00, 0x00, 0x14, + 0x62, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xb8, 0x00, 0x00, 0x00, 0x14, + 0x67, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xcc, 0x00, 0x00, 0x00, 0x14, + 0x72, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x67, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x62, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x24, + 0x64, 0x6d, 0x6e, 0x64, 0x00, 0x00, 0x03, 0x24, 0x00, 0x00, 0x00, 0x28, + 0x64, 0x6d, 0x64, 0x64, 0x00, 0x00, 0x03, 0x4c, 0x00, 0x00, 0x00, 0x32, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x26, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x52, 0x00, 0x47, 0x00, 0x42, 0x00, 0x5f, + 0x00, 0x44, 0x00, 0x36, 0x00, 0x35, 0x00, 0x5f, 0x00, 0x53, 0x00, 0x52, + 0x00, 0x47, 0x00, 0x5f, 0x00, 0x52, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x5f, + 0x00, 0x37, 0x00, 0x30, 0x00, 0x39, 0x00, 0x00, 0x6d, 0x6c, 0x75, 0x63, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, + 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x1c, + 0x00, 0x43, 0x00, 0x6f, 0x00, 0x70, 0x00, 0x79, 0x00, 0x72, 0x00, 0x69, + 0x00, 0x67, 0x00, 0x68, 0x00, 0x74, 0x00, 0x20, 0x00, 0x32, 0x00, 0x30, + 0x00, 0x31, 0x00, 0x38, 0x00, 0x20, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, + 0x00, 0x67, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x20, 0x00, 0x4c, 0x00, 0x4c, + 0x00, 0x43, 0x00, 0x2c, 0x00, 0x20, 0x00, 0x43, 0x00, 0x43, 0x00, 0x2d, + 0x00, 0x42, 0x00, 0x59, 0x00, 0x2d, 0x00, 0x53, 0x00, 0x41, 0x00, 0x20, + 0x00, 0x33, 0x00, 0x2e, 0x00, 0x30, 0x00, 0x20, 0x00, 0x55, 0x00, 0x6e, + 0x00, 0x70, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x74, 0x00, 0x65, 0x00, 0x64, + 0x00, 0x20, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x28, 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, + 0x00, 0x70, 0x00, 0x73, 0x00, 0x3a, 0x00, 0x2f, 0x00, 0x2f, 0x00, 0x63, + 0x00, 0x72, 0x00, 0x65, 0x00, 0x61, 0x00, 0x74, 0x00, 0x69, 0x00, 0x76, + 0x00, 0x65, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x6d, 0x00, 0x6f, + 0x00, 0x6e, 0x00, 0x73, 0x00, 0x2e, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x67, + 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x73, 0x00, 0x2f, 0x00, 0x62, 0x00, 0x79, + 0x00, 0x2d, 0x00, 0x73, 0x00, 0x61, 0x00, 0x2f, 0x00, 0x33, 0x00, 0x2e, + 0x00, 0x30, 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x67, 0x00, 0x61, + 0x00, 0x6c, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x64, 0x00, 0x65, 0x00, 0x29, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x73, 0x66, 0x33, 0x32, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x0c, 0x42, 0x00, 0x00, 0x05, 0xde, + 0xff, 0xff, 0xf3, 0x25, 0x00, 0x00, 0x07, 0x93, 0x00, 0x00, 0xfd, 0x90, + 0xff, 0xff, 0xfb, 0xa1, 0xff, 0xff, 0xfd, 0xa2, 0x00, 0x00, 0x03, 0xdc, + 0x00, 0x00, 0xc0, 0x6e, 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x6f, 0xa0, 0x00, 0x00, 0x38, 0xf5, 0x00, 0x00, 0x03, 0x90, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x24, 0x9f, + 0x00, 0x00, 0x0f, 0x84, 0x00, 0x00, 0xb6, 0xc4, 0x58, 0x59, 0x5a, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x97, 0x00, 0x00, 0xb7, 0x87, + 0x00, 0x00, 0x18, 0xd9, 0x70, 0x61, 0x72, 0x61, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x38, 0xe4, 0x00, 0x00, 0xe8, 0xf0, + 0x00, 0x00, 0x17, 0x10, 0x00, 0x00, 0x38, 0xe4, 0x00, 0x00, 0x14, 0xbc, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x00, 0x00, 0xa3, 0xd7, 0x00, 0x00, 0x54, 0x7c, 0x00, 0x00, 0x4c, 0xcd, + 0x00, 0x00, 0x99, 0x9a, 0x00, 0x00, 0x26, 0x67, 0x00, 0x00, 0x0f, 0x5c, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x0c, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, 0x00, 0x67, + 0x00, 0x6c, 0x00, 0x65, 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x49, 0x00, 0x6d, + 0x00, 0x61, 0x00, 0x67, 0x00, 0x65, 0x00, 0x20, 0x00, 0x63, 0x00, 0x6f, + 0x00, 0x64, 0x00, 0x65, 0x00, 0x63, 0x00, 0x00, +}; + +} // namespace + +TEST(IccCodecTest, Icc) { + // Empty string cannot be tested, encoder checks against writing it. + TestProfile("a"); + TestProfile("ab"); + TestProfile("aaaa"); + + { + // Exactly the ICC header size + PaddedBytes profile(128); + for (size_t i = 0; i < 128; i++) { + profile[i] = 0; + } + TestProfile(profile); + } + + { + PaddedBytes profile; + profile.append(kTestProfile, kTestProfile + sizeof(kTestProfile)); + TestProfile(profile); + } + + // Test substrings of full profile + { + PaddedBytes profile; + for (size_t i = 0; i <= 256; i++) { + profile.push_back(kTestProfile[i]); + TestProfile(profile); + } + } +} + +// kTestProfile after encoding with the ICC codec +static const unsigned char kEncodedTestProfile[] = { + 0x1f, 0x8b, 0x1, 0x13, 0x10, 0x0, 0x0, 0x0, 0x20, 0x4c, 0xcc, 0x3, + 0xe7, 0xa0, 0xa5, 0xa2, 0x90, 0xa4, 0x27, 0xe8, 0x79, 0x1d, 0xe3, 0x26, + 0x57, 0x54, 0xef, 0x0, 0xe8, 0x97, 0x2, 0xce, 0xa1, 0xd7, 0x85, 0x16, + 0xb4, 0x29, 0x94, 0x58, 0xf2, 0x56, 0xc0, 0x76, 0xea, 0x23, 0xec, 0x7c, + 0x73, 0x51, 0x41, 0x40, 0x23, 0x21, 0x95, 0x4, 0x75, 0x12, 0xc9, 0xcc, + 0x16, 0xbd, 0xb6, 0x99, 0xad, 0xf8, 0x75, 0x35, 0xb6, 0x42, 0xae, 0xae, + 0xae, 0x86, 0x56, 0xf8, 0xcc, 0x16, 0x30, 0xb3, 0x45, 0xad, 0xd, 0x40, + 0xd6, 0xd1, 0xd6, 0x99, 0x40, 0xbe, 0xe2, 0xdc, 0x31, 0x7, 0xa6, 0xb9, + 0x27, 0x92, 0x38, 0x0, 0x3, 0x5e, 0x2c, 0xbe, 0xe6, 0xfb, 0x19, 0xbf, + 0xf3, 0x6d, 0xbc, 0x4d, 0x64, 0xe5, 0xba, 0x76, 0xde, 0x31, 0x65, 0x66, + 0x14, 0xa6, 0x3a, 0xc5, 0x8f, 0xb1, 0xb4, 0xba, 0x1f, 0xb1, 0xb8, 0xd4, + 0x75, 0xba, 0x18, 0x86, 0x95, 0x3c, 0x26, 0xf6, 0x25, 0x62, 0x53, 0xfd, + 0x9c, 0x94, 0x76, 0xf6, 0x95, 0x2c, 0xb1, 0xfd, 0xdc, 0xc0, 0xe4, 0x3f, + 0xb3, 0xff, 0x67, 0xde, 0xd5, 0x94, 0xcc, 0xb0, 0x83, 0x2f, 0x28, 0x93, + 0x92, 0x3, 0xa1, 0x41, 0x64, 0x60, 0x62, 0x70, 0x80, 0x87, 0xaf, 0xe7, + 0x60, 0x4a, 0x20, 0x23, 0xb3, 0x11, 0x7, 0x38, 0x38, 0xd4, 0xa, 0x66, + 0xb5, 0x93, 0x41, 0x90, 0x19, 0x17, 0x18, 0x60, 0xa5, 0xb, 0x7a, 0x24, + 0xaa, 0x20, 0x81, 0xac, 0xa9, 0xa1, 0x70, 0xa6, 0x12, 0x8a, 0x4a, 0xa3, + 0xa0, 0xf9, 0x9a, 0x97, 0xe7, 0xa8, 0xac, 0x8, 0xa8, 0xc4, 0x2a, 0x86, + 0xa7, 0x69, 0x1e, 0x67, 0xe6, 0xbe, 0xa4, 0xd3, 0xff, 0x91, 0x61, 0xf6, + 0x8a, 0xe6, 0xb5, 0xb3, 0x61, 0x9f, 0x19, 0x17, 0x98, 0x27, 0x6b, 0xe9, + 0x8, 0x98, 0xe1, 0x21, 0x4a, 0x9, 0xb5, 0xd7, 0xca, 0xfa, 0x94, 0xd0, + 0x69, 0x1a, 0xeb, 0x52, 0x1, 0x4e, 0xf5, 0xf6, 0xdf, 0x7f, 0xe7, 0x29, + 0x70, 0xee, 0x4, 0xda, 0x2f, 0xa4, 0xff, 0xfe, 0xbb, 0x6f, 0xa8, 0xff, + 0xfe, 0xdb, 0xaf, 0x8, 0xf6, 0x72, 0xa1, 0x40, 0x5d, 0xf0, 0x2d, 0x8, + 0x82, 0x5b, 0x87, 0xbd, 0x10, 0x8, 0xe9, 0x7, 0xee, 0x4b, 0x80, 0xda, + 0x4a, 0x4, 0xc5, 0x5e, 0xa0, 0xb7, 0x1e, 0x60, 0xb0, 0x59, 0x76, 0x60, + 0xb, 0x2e, 0x19, 0x8a, 0x2e, 0x1c, 0xe6, 0x6, 0x20, 0xb8, 0x64, 0x18, + 0x2a, 0xcf, 0x51, 0x94, 0xd4, 0xee, 0xc3, 0xfe, 0x39, 0x74, 0xd4, 0x2b, + 0x48, 0xc9, 0x83, 0x4c, 0x9b, 0xd0, 0x4c, 0x35, 0x10, 0xe3, 0x9, 0xf7, + 0x72, 0xf0, 0x7a, 0xe, 0xbf, 0x7d, 0x36, 0x2e, 0x19, 0x7e, 0x3f, 0xc, + 0xf7, 0x93, 0xe7, 0xf4, 0x1d, 0x32, 0xc6, 0xb0, 0x89, 0xad, 0xe0, 0x28, + 0xc1, 0xa7, 0x59, 0xe3, 0x0, +}; + +// Tests that the decoded kEncodedTestProfile matches kTestProfile. +TEST(IccCodecTest, EncodedIccProfile) { + jxl::BitReader reader(jxl::Span(kEncodedTestProfile, + sizeof(kEncodedTestProfile))); + jxl::PaddedBytes dec; + ASSERT_TRUE(ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(sizeof(kTestProfile), dec.size()); + if (sizeof(kTestProfile) == dec.size()) { + for (size_t i = 0; i < dec.size(); i++) { + EXPECT_EQ(kTestProfile[i], dec[i]); + if (kTestProfile[i] != dec[i]) break; // One output is enough + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image.cc b/third_party/jpeg-xl/lib/jxl/image.cc new file mode 100644 index 000000000000..7a719a05a55f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image.cc @@ -0,0 +1,304 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/image.h" + +#include // swap + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/image.cc" +#include +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { + +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return HWY_LANES(uint8_t); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(GetVectorSize); // Local function. + +size_t VectorSize() { + static size_t bytes = HWY_DYNAMIC_DISPATCH(GetVectorSize)(); + return bytes; +} + +// Returns distance [bytes] between the start of two consecutive rows, a +// multiple of vector/cache line size but NOT CacheAligned::kAlias - see below. +size_t BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 0) { + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = std::max(vec_size, CacheAligned::kAlignment); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % CacheAligned::kAlias == 0) { + bytes_per_row += align; + } + + JXL_ASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +} // namespace + +PlaneBase::PlaneBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + orig_xsize_(static_cast(xsize)), + orig_ysize_(static_cast(ysize)) { + // (Can't profile CacheAligned itself because it is used by profiler.h) + PROFILER_FUNC; + + JXL_CHECK(xsize == xsize_); + JXL_CHECK(ysize == ysize_); + + JXL_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateArray(bytes_per_row_ * ysize); + JXL_CHECK(bytes_.get()); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +void PlaneBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if defined(MEMORY_SANITIZER) || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); + if (vec_size == 0) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* JXL_RESTRICT row = static_cast(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + memset(row, 0, initialize_size); +#else + memset(row + valid_size, 0, initialize_size - valid_size); +#endif // clang6 + } +#endif // MEMORY_SANITIZER +} + +void PlaneBase::Swap(PlaneBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(orig_xsize_, other.orig_xsize_); + std::swap(orig_ysize_, other.orig_ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +ImageB ImageFromPacked(const uint8_t* packed, const size_t xsize, + const size_t ysize, const size_t bytes_per_row) { + JXL_ASSERT(bytes_per_row >= xsize); + ImageB image(xsize, ysize); + PROFILER_FUNC; + for (size_t y = 0; y < ysize; ++y) { + uint8_t* const JXL_RESTRICT row = image.Row(y); + const uint8_t* const JXL_RESTRICT packed_row = packed + y * bytes_per_row; + memcpy(row, packed_row, xsize); + } + return image; +} + +// Note that using mirroring here gives slightly worse results. +ImageF PadImage(const ImageF& in, const size_t xsize, const size_t ysize) { + JXL_ASSERT(xsize >= in.xsize()); + JXL_ASSERT(ysize >= in.ysize()); + ImageF out(xsize, ysize); + size_t y = 0; + for (; y < in.ysize(); ++y) { + const float* JXL_RESTRICT row_in = in.ConstRow(y); + float* JXL_RESTRICT row_out = out.Row(y); + memcpy(row_out, row_in, in.xsize() * sizeof(row_in[0])); + const int lastcol = in.xsize() - 1; + const float lastval = row_out[lastcol]; + for (size_t x = in.xsize(); x < xsize; ++x) { + row_out[x] = lastval; + } + } + + // TODO(janwas): no need to copy if we can 'extend' image: if rows are + // pointers to any memory? Or allocate larger image before IO? + const int lastrow = in.ysize() - 1; + for (; y < ysize; ++y) { + const float* JXL_RESTRICT row_in = out.ConstRow(lastrow); + float* JXL_RESTRICT row_out = out.Row(y); + memcpy(row_out, row_in, xsize * sizeof(row_out[0])); + } + return out; +} + +Image3F PadImageMirror(const Image3F& in, const size_t xborder, + const size_t yborder) { + size_t xsize = in.xsize(); + size_t ysize = in.ysize(); + Image3F out(xsize + 2 * xborder, ysize + 2 * yborder); + if (xborder > xsize || yborder > ysize) { + for (size_t c = 0; c < 3; c++) { + for (int32_t y = 0; y < static_cast(out.ysize()); y++) { + float* row_out = out.PlaneRow(c, y); + const float* row_in = in.PlaneRow( + c, Mirror(y - static_cast(yborder), in.ysize())); + for (int32_t x = 0; x < static_cast(out.xsize()); x++) { + int32_t xin = Mirror(x - static_cast(xborder), in.xsize()); + row_out[x] = row_in[xin]; + } + } + } + return out; + } + CopyImageTo(in, Rect(xborder, yborder, xsize, ysize), &out); + for (size_t c = 0; c < 3; c++) { + // Horizontal pad. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xborder; x++) { + out.PlaneRow(c, y + yborder)[x] = + in.ConstPlaneRow(c, y)[xborder - x - 1]; + out.PlaneRow(c, y + yborder)[x + xsize + xborder] = + in.ConstPlaneRow(c, y)[xsize - 1 - x]; + } + } + // Vertical pad. + for (size_t y = 0; y < yborder; y++) { + memcpy(out.PlaneRow(c, y), out.ConstPlaneRow(c, 2 * yborder - 1 - y), + out.xsize() * sizeof(float)); + memcpy(out.PlaneRow(c, y + ysize + yborder), + out.ConstPlaneRow(c, ysize + yborder - 1 - y), + out.xsize() * sizeof(float)); + } + } + return out; +} + +Image3F PadImageToMultiple(const Image3F& in, const size_t N) { + PROFILER_FUNC; + const size_t xsize_blocks = DivCeil(in.xsize(), N); + const size_t ysize_blocks = DivCeil(in.ysize(), N); + const size_t xsize = N * xsize_blocks; + const size_t ysize = N * ysize_blocks; + ImageF out[3]; + for (size_t c = 0; c < 3; ++c) { + out[c] = PadImage(in.Plane(c), xsize, ysize); + } + return Image3F(std::move(out[0]), std::move(out[1]), std::move(out[2])); +} + +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in) { + PROFILER_FUNC; + const size_t xsize_orig = in->xsize(); + const size_t ysize_orig = in->ysize(); + const size_t xsize = RoundUpToBlockDim(xsize_orig); + const size_t ysize = RoundUpToBlockDim(ysize_orig); + // Expands image size to the originally-allocated size. + in->ShrinkTo(xsize, ysize); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize_orig; y++) { + float* JXL_RESTRICT row = in->PlaneRow(c, y); + for (size_t x = xsize_orig; x < xsize; x++) { + row[x] = row[xsize_orig - 1]; + } + } + const float* JXL_RESTRICT row_src = in->ConstPlaneRow(c, ysize_orig - 1); + for (size_t y = ysize_orig; y < ysize; y++) { + memcpy(in->PlaneRow(c, y), row_src, xsize * sizeof(float)); + } + } +} + +float DotProduct(const ImageF& a, const ImageF& b) { + double sum = 0.0; + for (size_t y = 0; y < a.ysize(); ++y) { + const float* const JXL_RESTRICT row_a = a.ConstRow(y); + const float* const JXL_RESTRICT row_b = b.ConstRow(y); + for (size_t x = 0; x < a.xsize(); ++x) { + sum += row_a[x] * row_b[x]; + } + } + return sum; +} + +void DownsampleImage(Image3F* opsin, size_t factor) { + JXL_ASSERT(factor != 1); + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), factor) + kBlockDim, + DivCeil(opsin->ysize(), factor) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + size_t in_stride = opsin->PixelsPerRow(); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < downsampled.ysize(); y++) { + float* row_out = downsampled.PlaneRow(c, y); + const float* row_in = opsin->PlaneRow(c, factor * y); + for (size_t x = 0; x < downsampled.xsize(); x++) { + size_t cnt = 0; + float sum = 0; + for (size_t iy = 0; iy < factor && iy + factor * y < opsin->ysize(); + iy++) { + for (size_t ix = 0; ix < factor && ix + factor * x < opsin->xsize(); + ix++) { + sum += row_in[iy * in_stride + x * factor + ix]; + cnt++; + } + } + row_out[x] = sum / cnt; + } + } + } + *opsin = std::move(downsampled); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/image.h b/third_party/jpeg-xl/lib/jxl/image.h new file mode 100644 index 000000000000..b2c64aae3c34 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image.h @@ -0,0 +1,437 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_IMAGE_H_ +#define LIB_JXL_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include +#include +#include + +#include +#include // std::move + +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Type-independent parts of Plane<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct PlaneBase { + PlaneBase() + : xsize_(0), + ysize_(0), + orig_xsize_(0), + orig_ysize_(0), + bytes_per_row_(0), + bytes_(nullptr) {} + PlaneBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + PlaneBase(const PlaneBase& other) = delete; + PlaneBase& operator=(const PlaneBase& other) = delete; + + // Move constructor (required for returning Image from function) + PlaneBase(PlaneBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + PlaneBase& operator=(PlaneBase&& other) noexcept = default; + + void Swap(PlaneBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + JXL_CHECK(xsize <= orig_xsize_); + JXL_CHECK(ysize <= orig_ysize_); + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + JXL_INLINE size_t xsize() const { return xsize_; } + JXL_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + JXL_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(JXL_ASSUME_ALIGNED(p, 64)); + } + JXL_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(JXL_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + JXL_INLINE void* VoidRow(const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (y >= ysize_) { + JXL_ABORT("Row(%zu) in (%u x %u) image\n", y, xsize_, ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return JXL_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x = xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + uint32_t orig_xsize_; + uint32_t orig_ysize_; + size_t bytes_per_row_; // Includes padding. + CacheAlignedUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Plane : public PlaneBase { + public: + using T = ComponentType; + static constexpr size_t kNumPlanes = 1; + + Plane() = default; + Plane(const size_t xsize, const size_t ysize) + : PlaneBase(xsize, ysize, sizeof(T)) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + JXL_INLINE T* Row(const size_t y) { return static_cast(VoidRow(y)); } + + // Returns pointer to const (see above). + JXL_INLINE const T* Row(const size_t y) const { + return static_cast(VoidRow(y)); + } + + // Documents that the access is const. + JXL_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageSB = Plane; +using ImageB = Plane; +using ImageS = Plane; // signed integer or half-float +using ImageU = Plane; +using ImageI = Plane; +using ImageF = Plane; +using ImageD = Plane; + +// Also works for Image3 and mixed argument types. +template +bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +template +class Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions (e.g. color transform and quantization field). +// Can compare using SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image/plane/ImageBundle etc. + template + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + // Construct a subrect that resides in an image/plane/ImageBundle etc. + template + Rect Crop(const Image& image) const { + return Rect(x0_, y0_, xsize_, ysize_, image.xsize(), image.ysize()); + } + + // Returns a rect that only contains `num` lines with offset `y` from `y0()`. + Rect Lines(size_t y, size_t num) const { + JXL_DASSERT(y + num <= ysize_); + return Rect(x0_, y0_ + y, xsize_, num); + } + + Rect Line(size_t y) const { return Lines(y, 1); } + + JXL_MUST_USE_RESULT Rect Intersection(const Rect& other) const { + return Rect(std::max(x0_, other.x0_), std::max(y0_, other.y0_), xsize_, + ysize_, std::min(x0_ + xsize_, other.x0_ + other.xsize_), + std::min(y0_ + ysize_, other.y0_ + other.ysize_)); + } + + template + T* Row(Plane* image, size_t y) const { + return image->Row(y + y0_) + x0_; + } + + template + const T* Row(const Plane* image, size_t y) const { + return image->Row(y + y0_) + x0_; + } + + template + T* PlaneRow(Image3* image, const size_t c, size_t y) const { + return image->PlaneRow(c, y + y0_) + x0_; + } + + template + const T* ConstRow(const Plane& image, size_t y) const { + return image.ConstRow(y + y0_) + x0_; + } + + template + const T* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Plane or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Currently, we abuse Image to either refer to an image that owns its storage +// or one that doesn't. In similar vein, we abuse Image* function parameters to +// either mean "assign to me" or "fill the provided image with data". +// Hopefully, the "assign to me" meaning will go away and most images in the +// codebase will not be backed by own storage. When this happens we can redesign +// Image to be a non-storage-holding view class and introduce BackedImage in +// those places that actually need it. + +// NOTE: we can't use Image as a view because invariants are violated +// (alignment and the presence of padding before/after each "row"). + +// A bundle of 3 same-sized images. Typically constructed by moving from three +// rvalue references to Image. To overwrite an existing Image3 using +// single-channel producers, we also need access to Image*. Constructing +// temporary non-owning Image pointing to one plane of an existing Image3 risks +// dangling references, especially if the wrapper is moved. Therefore, we +// store an array of Image (which are compact enough that size is not a concern) +// and provide Plane+Row accessors. +template +class Image3 { + public: + using T = ComponentType; + using PlaneT = jxl::Plane; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{PlaneT(), PlaneT(), PlaneT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{PlaneT(xsize, ysize), PlaneT(xsize, ysize), + PlaneT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(PlaneT&& plane0, PlaneT&& plane1, PlaneT&& plane2) { + JXL_CHECK(SameSize(plane0, plane1)); + JXL_CHECK(SameSize(plane0, plane2)); + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + // Returns row pointer; usage: PlaneRow(idx_plane, y)[x] = val. + JXL_INLINE T* PlaneRow(const size_t c, const size_t y) { + // Custom implementation instead of calling planes_[c].Row ensures only a + // single multiplication is needed for PlaneRow(0..2, y). + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + void* row = planes_[c].bytes() + row_offset; + return static_cast(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer; usage: val = PlaneRow(idx_plane, y)[x]. + JXL_INLINE const T* PlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer, even if called from a non-const Image3. + JXL_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + return PlaneRow(c, y); + } + + JXL_INLINE const PlaneT& Plane(size_t idx) const { return planes_[idx]; } + + JXL_INLINE PlaneT& Plane(size_t idx) { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (PlaneT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + JXL_INLINE size_t xsize() const { return planes_[0].xsize(); } + JXL_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + void PlaneRowBoundsCheck(const size_t c, const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (c >= kNumPlanes || y >= ysize()) { + JXL_ABORT("PlaneRow(%zu, %zu) in (%zu x %zu) image\n", c, y, xsize(), + ysize()); + } +#endif + } + + private: + PlaneT planes_[kNumPlanes]; +}; + +using Image3B = Image3; +using Image3S = Image3; +using Image3U = Image3; +using Image3I = Image3; +using Image3F = Image3; +using Image3D = Image3; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.cc b/third_party/jpeg-xl/lib/jxl/image_bundle.cc new file mode 100644 index 000000000000..dee1b46f0bcd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.cc @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/image_bundle.h" + +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +void ImageBundle::ShrinkTo(size_t xsize, size_t ysize) { + if (HasColor()) color_.ShrinkTo(xsize, ysize); + for (size_t i = 0; i < extra_channels_.size(); ++i) { + const auto& eci = metadata_->extra_channel_info[i]; + extra_channels_[i].ShrinkTo(eci.Size(xsize), eci.Size(ysize)); + } +} + +// Called by all other SetFrom*. +void ImageBundle::SetFromImage(Image3F&& color, + const ColorEncoding& c_current) { + JXL_CHECK(color.xsize() != 0 && color.ysize() != 0); + JXL_CHECK(metadata_->color_encoding.IsGray() == c_current.IsGray()); + color_ = std::move(color); + c_current_ = c_current; + VerifySizes(); +} + +void ImageBundle::VerifyMetadata() const { + JXL_CHECK(!c_current_.ICC().empty()); + JXL_CHECK(metadata_->color_encoding.IsGray() == IsGray()); + + if (metadata_->HasAlpha() && alpha().xsize() == 0) { + JXL_ABORT("MD alpha_bits %u IB alpha %zu x %zu\n", + metadata_->GetAlphaBits(), alpha().xsize(), alpha().ysize()); + } + const uint32_t alpha_bits = metadata_->GetAlphaBits(); + JXL_CHECK(alpha_bits <= 32); + + // metadata_->num_extra_channels may temporarily differ from + // extra_channels_.size(), e.g. after SetAlpha. They are synced by the next + // call to VisitFields. +} + +void ImageBundle::VerifySizes() const { + const size_t xs = xsize(); + const size_t ys = ysize(); + + if (HasExtraChannels()) { + JXL_CHECK(xs != 0 && ys != 0); + for (size_t ec = 0; ec < metadata_->extra_channel_info.size(); ++ec) { + const ExtraChannelInfo& eci = metadata_->extra_channel_info[ec]; + JXL_CHECK(extra_channels_[ec].xsize() == eci.Size(xs)); + JXL_CHECK(extra_channels_[ec].ysize() == eci.Size(ys)); + } + } +} + +size_t ImageBundle::DetectRealBitdepth() const { + return metadata_->bit_depth.bits_per_sample; + + // TODO(lode): let this function return lower bit depth if possible, e.g. + // return 8 bits in case the original image came from a 16-bit PNG that + // was in fact representable as 8-bit PNG. Ensure that the implementation + // returns 16 if e.g. two consecutive 16-bit values appeared in the original + // image (such as 32768 and 32769), take into account that e.g. the values + // 3-bit can represent is not a superset of the values 2-bit can represent, + // and there may be slight imprecisions in the floating point image. +} + +const ImageF& ImageBundle::alpha() const { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} +ImageF* ImageBundle::alpha() { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return &extra_channels_[ec]; +} + +const ImageF& ImageBundle::depth() const { + JXL_ASSERT(HasDepth()); + const size_t ec = metadata_->Find(ExtraChannel::kDepth) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} + +void ImageBundle::SetAlpha(ImageF&& alpha, bool alpha_is_premultiplied) { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + // Must call SetAlphaBits first, otherwise we don't know which channel index + JXL_CHECK(eci != nullptr); + JXL_CHECK(alpha.xsize() != 0 && alpha.ysize() != 0); + JXL_CHECK(eci->alpha_associated == alpha_is_premultiplied); + extra_channels_.insert( + extra_channels_.begin() + (eci - metadata_->extra_channel_info.data()), + std::move(alpha)); + // num_extra_channels is automatically set in visitor + VerifySizes(); +} + +void ImageBundle::SetExtraChannels(std::vector&& extra_channels) { + for (const ImageF& plane : extra_channels) { + JXL_CHECK(plane.xsize() != 0 && plane.ysize() != 0); + } + extra_channels_ = std::move(extra_channels); + VerifySizes(); +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.h b/third_party/jpeg-xl/lib/jxl/image_bundle.h new file mode 100644 index 000000000000..66d8a8d4147b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.h @@ -0,0 +1,261 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_IMAGE_BUNDLE_H_ +#define LIB_JXL_IMAGE_BUNDLE_H_ + +// The main image or frame consists of a bundle of associated images. + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// A bundle of color/alpha/depth/plane images. +class ImageBundle { + public: + // Uninitialized state for use as output parameter. + ImageBundle() : metadata_(nullptr) {} + // Caller is responsible for setting metadata before calling Set*. + explicit ImageBundle(const ImageMetadata* metadata) : metadata_(metadata) {} + + // Move-only (allows storing in std::vector). + ImageBundle(ImageBundle&&) = default; + ImageBundle& operator=(ImageBundle&&) = default; + + ImageBundle Copy() const { + ImageBundle copy(metadata_); + copy.color_ = CopyImage(color_); + copy.c_current_ = c_current_; + copy.extra_channels_.reserve(extra_channels_.size()); + for (const ImageF& plane : extra_channels_) { + copy.extra_channels_.emplace_back(CopyImage(plane)); + } + + copy.jpeg_data = + jpeg_data ? make_unique(*jpeg_data) : nullptr; + copy.color_transform = color_transform; + copy.chroma_subsampling = chroma_subsampling; + + return copy; + } + + // -- SIZE + + size_t xsize() const { + if (IsJPEG()) return jpeg_data->width; + if (color_.xsize() != 0) return color_.xsize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].xsize(); + } + size_t ysize() const { + if (IsJPEG()) return jpeg_data->height; + if (color_.ysize() != 0) return color_.ysize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].ysize(); + } + void ShrinkTo(size_t xsize, size_t ysize); + + // sizes taking orientation into account + size_t oriented_xsize() const { + if (static_cast(metadata_->GetOrientation()) > 4) { + return ysize(); + } else { + return xsize(); + } + } + size_t oriented_ysize() const { + if (static_cast(metadata_->GetOrientation()) > 4) { + return xsize(); + } else { + return ysize(); + } + } + + // -- COLOR + + // Whether color() is valid/usable. Returns true in most cases. Even images + // with spot colors (one example of when !planes().empty()) typically have a + // part that can be converted to RGB. + bool HasColor() const { return color_.xsize() != 0; } + + // For resetting the size when switching from a reference to main frame. + void RemoveColor() { color_ = Image3F(); } + + // Do not use if !HasColor(). + const Image3F& color() const { + // If this fails, Set* was not called - perhaps because decoding failed? + JXL_DASSERT(HasColor()); + return color_; + } + + // Do not use if !HasColor(). + Image3F* color() { + JXL_DASSERT(HasColor()); + return &color_; + } + + // If c_current.IsGray(), all planes must be identical. NOTE: c_current is + // independent of metadata()->color_encoding, which is the original, whereas + // a decoder might return pixels in a different c_current. + // This only sets the color channels, you must also make extra channels + // match the amount that is in the metadata. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current); + + // -- COLOR ENCODING + + const ColorEncoding& c_current() const { return c_current_; } + + // Returns whether the color image has identical planes. Once established by + // Set*, remains unchanged until a subsequent Set* or TransformTo. + bool IsGray() const { return c_current_.IsGray(); } + + bool IsSRGB() const { return c_current_.IsSRGB(); } + bool IsLinearSRGB() const { + return c_current_.white_point == WhitePoint::kD65 && + c_current_.primaries == Primaries::kSRGB && c_current_.tf.IsLinear(); + } + + // Set the c_current profile without doing any transformation, e.g. if the + // transformation was already applied. + void OverrideProfile(const ColorEncoding& new_c_current) { + c_current_ = new_c_current; + } + + // TODO(lode): TransformTo and CopyTo are implemented in enc_image_bundle.cc, + // move these functions out of this header file and class, to + // enc_image_bundle.h. + + // Transforms color to c_desired and sets c_current to c_desired. Alpha and + // metadata remains unchanged. + Status TransformTo(const ColorEncoding& c_desired, + ThreadPool* pool = nullptr); + // Copies this:rect, converts to c_desired, and allocates+fills out. + Status CopyTo(const Rect& rect, const ColorEncoding& c_desired, Image3B* out, + ThreadPool* pool = nullptr) const; + Status CopyTo(const Rect& rect, const ColorEncoding& c_desired, Image3F* out, + ThreadPool* pool = nullptr) const; + Status CopyToSRGB(const Rect& rect, Image3B* out, + ThreadPool* pool = nullptr) const; + + // Detect 'real' bit depth, which can be lower than nominal bit depth + // (this is common in PNG), returns 'real' bit depth + size_t DetectRealBitdepth() const; + + // -- ALPHA + + void SetAlpha(ImageF&& alpha, bool alpha_is_premultiplied); + bool HasAlpha() const { + return metadata_->Find(ExtraChannel::kAlpha) != nullptr; + } + bool AlphaIsPremultiplied() const { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + return (eci == nullptr) ? false : eci->alpha_associated; + } + const ImageF& alpha() const; + ImageF* alpha(); + + // -- DEPTH + bool HasDepth() const { + return metadata_->Find(ExtraChannel::kDepth) != nullptr; + } + const ImageF& depth() const; + + // -- EXTRA CHANNELS + + // Extra channels of unknown interpretation (e.g. spot colors). + void SetExtraChannels(std::vector&& extra_channels); + void ClearExtraChannels() { extra_channels_.clear(); } + bool HasExtraChannels() const { return !extra_channels_.empty(); } + const std::vector& extra_channels() const { return extra_channels_; } + std::vector& extra_channels() { return extra_channels_; } + + const ImageMetadata* metadata() const { return metadata_; } + + void VerifyMetadata() const; + + void SetDecodedBytes(size_t decoded_bytes) { decoded_bytes_ = decoded_bytes; } + size_t decoded_bytes() const { return decoded_bytes_; } + + // -- JPEG transcoding: + + // Returns true if image does or will represent quantized DCT-8 coefficients, + // stored in 8x8 pixel regions. + bool IsJPEG() const { return jpeg_data != nullptr; } + + std::unique_ptr jpeg_data; + // these fields are used to signal the input JPEG color space + // NOTE: JPEG doesn't actually provide a way to determine whether YCbCr was + // applied or not. + ColorTransform color_transform = ColorTransform::kNone; + YCbCrChromaSubsampling chroma_subsampling; + + FrameOrigin origin{0, 0}; + // Animation-related information. This assumes GIF- and APNG- like animation. + uint32_t duration = 0; + bool use_for_next_frame = false; + bool blend = false; + std::string name; + + private: + // Called after any Set* to ensure their sizes are compatible. + void VerifySizes() const; + + // Required for TransformTo so that an ImageBundle is self-sufficient. Always + // points to the same thing, but cannot be const-pointer because that prevents + // the compiler from generating a move ctor. + const ImageMetadata* metadata_; + + // Initialized by Set*: + Image3F color_; // If empty, planes_ is not; all planes equal if IsGray(). + ColorEncoding c_current_; // of color_ + + // Initialized by SetPlanes; size = ImageMetadata.num_extra_channels + std::vector extra_channels_; + + // How many bytes of the input were actually read. + size_t decoded_bytes_ = 0; +}; + +// Does color transformation from in.c_current() to c_desired if the color +// encodings are different, or nothing if they are already the same. +// If color transformation is done, stores the transformed values into store and +// sets the out pointer to store, else leaves store untouched and sets the out +// pointer to &in. +// Returns false if color transform fails. +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + ThreadPool* pool, ImageBundle* store, + const ImageBundle** out); + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_BUNDLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc b/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc new file mode 100644 index 000000000000..52bab8b511f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/image_bundle.h" + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out.h" + +namespace jxl { +namespace { + +TEST(ImageBundleTest, ExtraChannelName) { + AuxOut aux_out; + BitWriter writer; + BitWriter::Allotment allotment(&writer, 99); + + ImageMetadata metadata; + ExtraChannelInfo eci; + eci.type = ExtraChannel::kBlack; + eci.name = "testK"; + metadata.extra_channel_info.push_back(std::move(eci)); + ASSERT_TRUE(WriteImageMetadata(metadata, &writer, /*layer=*/0, &aux_out)); + writer.ZeroPadToByte(); + ReclaimAndCharge(&writer, &allotment, /*layer=*/0, &aux_out); + + BitReader reader(writer.GetSpan()); + ImageMetadata metadata_out; + ASSERT_TRUE(ReadImageMetadata(&reader, &metadata_out)); + EXPECT_TRUE(reader.Close()); + EXPECT_EQ("testK", metadata_out.Find(ExtraChannel::kBlack)->name); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_metadata.cc b/third_party/jpeg-xl/lib/jxl/image_metadata.cc new file mode 100644 index 000000000000..231a2f7a9377 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_metadata.cc @@ -0,0 +1,425 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/image_metadata.h" + +#include +#include + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/fields.h" + +namespace jxl { +BitDepth::BitDepth() { Bundle::Init(this); } +Status BitDepth::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &floating_point_sample)); + // The same fields (bits_per_sample and exponent_bits_per_sample) are read + // in a different way depending on floating_point_sample's value. It's still + // default-initialized correctly so using visitor->Conditional is not + // required. + if (!floating_point_sample) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(8), Val(10), Val(12), BitsOffset(6, 1), 8, &bits_per_sample)); + exponent_bits_per_sample = 0; + } else { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(32), Val(16), Val(24), BitsOffset(6, 1), 32, &bits_per_sample)); + // The encoded value is exponent_bits_per_sample - 1, encoded in 3 bits + // so the value can be in range [1, 8]. + const uint32_t offset = 1; + exponent_bits_per_sample -= offset; + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, 8 - offset, &exponent_bits_per_sample)); + exponent_bits_per_sample += offset; + } + + // Error-checking for floating point ranges. + if (floating_point_sample) { + if (exponent_bits_per_sample < 2 || exponent_bits_per_sample > 8) { + return JXL_FAILURE("Invalid exponent_bits_per_sample: %u", + exponent_bits_per_sample); + } + int mantissa_bits = + static_cast(bits_per_sample) - exponent_bits_per_sample - 1; + if (mantissa_bits < 2 || mantissa_bits > 23) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } else { + if (bits_per_sample > 31) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } + return true; +} + +CustomTransformData::CustomTransformData() { Bundle::Init(this); } +Status CustomTransformData::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + if (visitor->Conditional(nonserialized_xyb_encoded)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&opsin_inverse_matrix)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &custom_weights_mask)); + if (visitor->Conditional((custom_weights_mask & 0x1) != 0)) { + // 4 5x5 kernels, but all of them can be obtained by symmetry from one, + // which is symmetric along its main diagonal. The top-left kernel is + // defined by + // + // 0 1 2 3 4 + // 1 5 6 7 8 + // 2 6 9 10 11 + // 3 7 10 12 13 + // 4 8 11 13 14 + float constexpr kWeights2[15] = { + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, + 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, 0.56661550f, + 0.03777607f, -0.01986694f, -0.03144731f, -0.01185068f, -0.00213539f}; + for (size_t i = 0; i < 15; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights2[i], &upsampling2_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x2) != 0)) { + // 16 5x5 kernels, but all of them can be obtained by symmetry from + // three, two of which are symmetric along their main diagonals. The top + // left 4 kernels are defined by + // + // 0 1 2 3 4 5 6 7 8 9 + // 1 10 11 12 13 14 15 16 17 18 + // 2 11 19 20 21 22 23 24 25 26 + // 3 12 20 27 28 29 30 31 32 33 + // 4 13 21 28 34 35 36 37 38 39 + // + // 5 14 22 29 35 40 41 42 43 44 + // 6 15 23 30 36 41 45 46 47 48 + // 7 16 24 31 37 42 46 49 50 51 + // 8 17 25 32 38 43 47 50 52 53 + // 9 18 26 33 39 44 48 51 53 54 + constexpr float kWeights4[55] = { + -0.02419067f, -0.03491987f, -0.03693351f, -0.03094285f, -0.00529785f, + -0.01663432f, -0.03556863f, -0.03888905f, -0.03516850f, -0.00989469f, + 0.23651958f, 0.33392945f, -0.01073543f, -0.01313181f, -0.03556694f, + 0.13048175f, 0.40103025f, 0.03951150f, -0.02077584f, 0.46914198f, + -0.00209270f, -0.01484589f, -0.04064806f, 0.18942530f, 0.56279892f, + 0.06674400f, -0.02335494f, -0.03551682f, -0.00754830f, -0.02267919f, + -0.02363578f, 0.00315804f, -0.03399098f, -0.01359519f, -0.00091653f, + -0.00335467f, -0.01163294f, -0.01610294f, -0.00974088f, -0.00191622f, + -0.01095446f, -0.03198464f, -0.04455121f, -0.02799790f, -0.00645912f, + 0.06390599f, 0.22963888f, 0.00630981f, -0.01897349f, 0.67537268f, + 0.08483369f, -0.02534994f, -0.02205197f, -0.01667999f, -0.00384443f}; + for (size_t i = 0; i < 55; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights4[i], &upsampling4_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x4) != 0)) { + // 64 5x5 kernels, all of them can be obtained by symmetry from + // 10, 4 of which are symmetric along their main diagonals. The top + // left 16 kernels are defined by + // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 + // 1 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 21 22 23 24 25 26 + // 2 15 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 + // 3 16 28 39 3a 3b 3c 3d 3e 3f 40 41 42 43 44 45 46 47 48 49 + // 4 17 29 3a 4a 4b 4c 4d 4e 4f 50 51 52 53 54 55 56 57 58 59 + + // 5 18 2a 3b 4b 5a 5b 5c 5d 5e 5f 60 61 62 63 64 65 66 67 68 + // 6 19 2b 3c 4c 5b 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 + // 7 1a 2c 3d 4d 5c 6a 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 + // 8 1b 2d 3e 4e 5d 6b 78 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f + // 9 1c 2e 3f 4f 5e 6c 79 85 90 91 92 93 94 95 96 97 98 99 9a + + // a 1d 2f 40 50 5f 6d 7a 86 91 9b 9c 9d 9e 9f a0 a1 a2 a3 a4 + // b 1e 30 41 51 60 6e 7b 87 92 9c a5 a6 a7 a8 a9 aa ab ac ad + // c 1f 31 42 52 61 6f 7c 88 93 9d a6 ae af b0 b1 b2 b3 b4 b5 + // d 20 32 43 53 62 70 7d 89 94 9e a7 af b6 b7 b8 b9 ba bb bc + // e 21 33 44 54 63 71 7e 8a 95 9f a8 b0 b7 bd be bf c0 c1 c2 + + // f 22 34 45 55 64 72 7f 8b 96 a0 a9 b1 b8 be c3 c4 c5 c6 c7 + // 10 23 35 46 56 65 73 80 8c 97 a1 aa b2 b9 bf c4 c8 c9 ca cb + // 11 24 36 47 57 66 74 81 8d 98 a2 ab b3 ba c0 c5 c9 cc cd ce + // 12 25 37 48 58 67 75 82 8e 99 a3 ac b4 bb c1 c6 ca cd cf d0 + // 13 26 38 49 59 68 76 83 8f 9a a4 ad b5 bc c2 c7 cb ce d0 d1 + constexpr float kWeights8[210] = { + -0.02928613f, -0.03706353f, -0.03783812f, -0.03324558f, -0.00447632f, + -0.02519406f, -0.03752601f, -0.03901508f, -0.03663285f, -0.00646649f, + -0.02066407f, -0.03838633f, -0.04002101f, -0.03900035f, -0.00901973f, + -0.01626393f, -0.03954148f, -0.04046620f, -0.03979621f, -0.01224485f, + 0.29895328f, 0.35757708f, -0.02447552f, -0.01081748f, -0.04314594f, + 0.23903219f, 0.41119301f, -0.00573046f, -0.01450239f, -0.04246845f, + 0.17567618f, 0.45220643f, 0.02287757f, -0.01936783f, -0.03583255f, + 0.11572472f, 0.47416733f, 0.06284440f, -0.02685066f, 0.42720050f, + -0.02248939f, -0.01155273f, -0.04562755f, 0.28689496f, 0.49093869f, + -0.00007891f, -0.01545926f, -0.04562659f, 0.21238920f, 0.53980934f, + 0.03369474f, -0.02070211f, -0.03866988f, 0.14229550f, 0.56593398f, + 0.08045181f, -0.02888298f, -0.03680918f, -0.00542229f, -0.02920477f, + -0.02788574f, -0.02118180f, -0.03942402f, -0.00775547f, -0.02433614f, + -0.03193943f, -0.02030828f, -0.04044014f, -0.01074016f, -0.01930822f, + -0.03620399f, -0.01974125f, -0.03919545f, -0.01456093f, -0.00045072f, + -0.00360110f, -0.01020207f, -0.01231907f, -0.00638988f, -0.00071592f, + -0.00279122f, -0.00957115f, -0.01288327f, -0.00730937f, -0.00107783f, + -0.00210156f, -0.00890705f, -0.01317668f, -0.00813895f, -0.00153491f, + -0.02128481f, -0.04173044f, -0.04831487f, -0.03293190f, -0.00525260f, + -0.01720322f, -0.04052736f, -0.05045706f, -0.03607317f, -0.00738030f, + -0.01341764f, -0.03965629f, -0.05151616f, -0.03814886f, -0.01005819f, + 0.18968273f, 0.33063684f, -0.01300105f, -0.01372950f, -0.04017465f, + 0.13727832f, 0.36402234f, 0.01027890f, -0.01832107f, -0.03365072f, + 0.08734506f, 0.38194295f, 0.04338228f, -0.02525993f, 0.56408126f, + 0.00458352f, -0.01648227f, -0.04887868f, 0.24585519f, 0.62026135f, + 0.04314807f, -0.02213737f, -0.04158014f, 0.16637289f, 0.65027023f, + 0.09621636f, -0.03101388f, -0.04082742f, -0.00904519f, -0.02790922f, + -0.02117818f, 0.00798662f, -0.03995711f, -0.01243427f, -0.02231705f, + -0.02946266f, 0.00992055f, -0.03600283f, -0.01684920f, -0.00111684f, + -0.00411204f, -0.01297130f, -0.01723725f, -0.01022545f, -0.00165306f, + -0.00313110f, -0.01218016f, -0.01763266f, -0.01125620f, -0.00231663f, + -0.01374149f, -0.03797620f, -0.05142937f, -0.03117307f, -0.00581914f, + -0.01064003f, -0.03608089f, -0.05272168f, -0.03375670f, -0.00795586f, + 0.09628104f, 0.27129991f, -0.00353779f, -0.01734151f, -0.03153981f, + 0.05686230f, 0.28500998f, 0.02230594f, -0.02374955f, 0.68214326f, + 0.05018048f, -0.02320852f, -0.04383616f, 0.18459474f, 0.71517975f, + 0.10805613f, -0.03263677f, -0.03637639f, -0.01394373f, -0.02511203f, + -0.01728636f, 0.05407331f, -0.02867568f, -0.01893131f, -0.00240854f, + -0.00446511f, -0.01636187f, -0.02377053f, -0.01522848f, -0.00333334f, + -0.00819975f, -0.02964169f, -0.04499287f, -0.02745350f, -0.00612408f, + 0.02727416f, 0.19446600f, 0.00159832f, -0.02232473f, 0.74982506f, + 0.11452620f, -0.03348048f, -0.01605681f, -0.02070339f, -0.00458223f}; + for (size_t i = 0; i < 210; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights8[i], &upsampling8_weights[i])); + } + } + return true; +} + +ExtraChannelInfo::ExtraChannelInfo() { Bundle::Init(this); } +Status ExtraChannelInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // General + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(ExtraChannel::kAlpha, &type)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(3), Val(4), BitsOffset(3, 1), 0, &dim_shift)); + if ((1U << dim_shift) > kGroupDim) { + return JXL_FAILURE("dim_shift %u too large", dim_shift); + } + if (dim_shift != 0) { + return JXL_FAILURE("Non-zero dim_shift not yet implemented"); + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + // Conditional + if (visitor->Conditional(type == ExtraChannel::kAlpha)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alpha_associated)); + } + if (visitor->Conditional(type == ExtraChannel::kSpotColor)) { + for (float& c : spot_color) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0, &c)); + } + } + if (visitor->Conditional(type == ExtraChannel::kCFA)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Bits(2), BitsOffset(4, 3), + BitsOffset(8, 19), 1, &cfa_channel)); + } + return true; +} + +ImageMetadata::ImageMetadata() { Bundle::Init(this); } +Status ImageMetadata::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // Bundle::AllDefault does not allow usage when reading (it may abort the + // program when a codestream has invalid values), but when reading we + // overwrite the extra_fields value, so do not need to call AllDefault. + bool tone_mapping_default = + visitor->IsReading() ? false : Bundle::AllDefault(tone_mapping); + + bool extra_fields = (orientation != 1 || have_preview || have_animation || + have_intrinsic_size || !tone_mapping_default); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &extra_fields)); + if (visitor->Conditional(extra_fields)) { + orientation--; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &orientation)); + orientation++; + // (No need for bounds checking because we read exactly 3 bits) + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_intrinsic_size)); + if (visitor->Conditional(have_intrinsic_size)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&intrinsic_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_preview)); + if (visitor->Conditional(have_preview)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&preview_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_animation)); + if (visitor->Conditional(have_animation)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation)); + } + } else { + orientation = 1; // identity + have_intrinsic_size = false; + have_preview = false; + have_animation = false; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &modular_16_bit_buffer_sufficient)); + + num_extra_channels = extra_channel_info.size(); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(12, 1), 0, + &num_extra_channels)); + + if (visitor->Conditional(num_extra_channels != 0)) { + if (visitor->IsReading()) { + extra_channel_info.resize(num_extra_channels); + } + for (ExtraChannelInfo& eci : extra_channel_info) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&eci)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &xyb_encoded)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&color_encoding)); + if (visitor->Conditional(extra_fields)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tone_mapping)); + } + + // Treat as if only the fields up to extra channels exist. + if (visitor->IsReading() && nonserialized_only_parse_basic_info) { + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +OpsinInverseMatrix::OpsinInverseMatrix() { Bundle::Init(this); } +Status OpsinInverseMatrix::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + for (int i = 0; i < 9; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + DefaultInverseOpsinAbsorbanceMatrix()[i], &inverse_matrix[i])); + } + for (int i = 0; i < 3; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kNegOpsinAbsorbanceBiasRGB[i], &opsin_biases[i])); + } + for (int i = 0; i < 4; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultQuantBias[i], &quant_biases[i])); + } + return true; +} + +ToneMapping::ToneMapping() { Bundle::Init(this); } +Status ToneMapping::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultIntensityTarget, &intensity_target)); + if (intensity_target <= 0.f) { + return JXL_FAILURE("invalid intensity target"); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &min_nits)); + if (min_nits < 0.f || min_nits > intensity_target) { + return JXL_FAILURE("invalid min %f vs max %f", min_nits, intensity_target); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &relative_to_max_display)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &linear_below)); + if (linear_below < 0 || (relative_to_max_display && linear_below > 1.0f)) { + return JXL_FAILURE("invalid linear_below %f (%s)", linear_below, + relative_to_max_display ? "relative" : "absolute"); + } + + return true; +} + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata) { + return Bundle::Read(reader, metadata); +} + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out) { + return Bundle::Write(metadata, writer, layer, aux_out); +} + +void ImageMetadata::SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied) { + std::vector& eciv = extra_channel_info; + ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (bits == 0) { + if (alpha != nullptr) { + // Remove the alpha channel from the extra channel info. It's + // theoretically possible that there are multiple, remove all in that + // case. This ensure a next HasAlpha() will return false. + const auto is_alpha = [](const ExtraChannelInfo& eci) { + return eci.type == ExtraChannel::kAlpha; + }; + eciv.erase(std::remove_if(eciv.begin(), eciv.end(), is_alpha), + eciv.end()); + } + } else { + if (alpha == nullptr) { + ExtraChannelInfo info; + info.type = ExtraChannel::kAlpha; + info.bit_depth.bits_per_sample = bits; + info.dim_shift = 0; + info.alpha_associated = alpha_is_premultiplied; + // Prepend rather than append: in case there already are other extra + // channels, prefer alpha channel to be listed first. + eciv.insert(eciv.begin(), info); + } else { + // Ignores potential extra alpha channels, only sets to first one. + alpha->bit_depth.bits_per_sample = bits; + alpha->bit_depth.floating_point_sample = false; + alpha->bit_depth.exponent_bits_per_sample = 0; + alpha->alpha_associated = alpha_is_premultiplied; + } + } + num_extra_channels = extra_channel_info.size(); +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_metadata.h b/third_party/jpeg-xl/lib/jxl/image_metadata.h new file mode 100644 index 000000000000..ab49774d07f2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_metadata.h @@ -0,0 +1,390 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Main codestream header bundles, the metadata that applies to all frames. + +#ifndef LIB_JXL_IMAGE_METADATA_H_ +#define LIB_JXL_IMAGE_METADATA_H_ + +#include +#include + +#include + +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { + +// EXIF orientation of the image. This field overrides any field present in +// actual EXIF metadata. The value tells which transformation the decoder must +// apply after decoding to display the image with the correct orientation. +enum class Orientation : uint32_t { + // Values 1..8 match the EXIF definitions. + kIdentity = 1, + kFlipHorizontal, + kRotate180, + kFlipVertical, + kTranspose, + kRotate90, + kAntiTranspose, + kRotate270, +}; +// Don't need an EnumBits because Orientation is not read via Enum(). + +enum class ExtraChannel : uint32_t { + // First two enumerators (most common) are cheaper to encode + kAlpha, + kDepth, + + kSpotColor, + kSelectionMask, + kBlack, // for CMYK + kCFA, // Bayer channel + kThermal, + kReserved0, + kReserved1, + kReserved2, + kReserved3, + kReserved4, + kReserved5, + kReserved6, + kReserved7, + kUnknown, // disambiguated via name string, raise warning if unsupported + kOptional // like kUnknown but can silently be ignored +}; +static inline const char* EnumName(ExtraChannel /*unused*/) { + return "ExtraChannel"; +} +static inline constexpr uint64_t EnumBits(ExtraChannel /*unused*/) { + using EC = ExtraChannel; + return MakeBit(EC::kAlpha) | MakeBit(EC::kDepth) | MakeBit(EC::kSpotColor) | + MakeBit(EC::kSelectionMask) | MakeBit(EC::kBlack) | MakeBit(EC::kCFA) | + MakeBit(EC::kUnknown) | MakeBit(EC::kOptional); +} + +// Used in ImageMetadata and ExtraChannelInfo. +struct BitDepth : public Fields { + BitDepth(); + const char* Name() const override { return "BitDepth"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Whether the original (uncompressed) samples are floating point or + // unsigned integer. + bool floating_point_sample; + + // Bit depth of the original (uncompressed) image samples. Must be in the + // range [1, 32]. + uint32_t bits_per_sample; + + // Floating point exponent bits of the original (uncompressed) image samples, + // only used if floating_point_sample is true. + // If used, the samples are floating point with: + // - 1 sign bit + // - exponent_bits_per_sample exponent bits + // - (bits_per_sample - exponent_bits_per_sample - 1) mantissa bits + // If used, exponent_bits_per_sample must be in the range + // [2, 8] and amount of mantissa bits must be in the range [2, 23]. + // NOTE: exponent_bits_per_sample is 8 for single precision binary32 + // point, 5 for half precision binary16, 7 for fp24. + uint32_t exponent_bits_per_sample; +}; + +// Describes one extra channel. +struct ExtraChannelInfo : public Fields { + ExtraChannelInfo(); + const char* Name() const override { return "ExtraChannelInfo"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + size_t Size(size_t size) const { + const size_t mask = (1u << dim_shift) - 1; + return (size + mask) >> dim_shift; + } + + mutable bool all_default; + + ExtraChannel type; + BitDepth bit_depth; + uint32_t dim_shift; // downsampled by 2^dim_shift on each axis + + std::string name; // UTF-8 + + // Conditional: + bool alpha_associated; // i.e. premultiplied + float spot_color[4]; // spot color in linear RGBA + uint32_t cfa_channel; +}; + +struct OpsinInverseMatrix : public Fields { + OpsinInverseMatrix(); + const char* Name() const override { return "OpsinInverseMatrix"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + float inverse_matrix[9]; + float opsin_biases[3]; + float quant_biases[4]; +}; + +// Information useful for mapping HDR images to lower dynamic range displays. +struct ToneMapping : public Fields { + ToneMapping(); + const char* Name() const override { return "ToneMapping"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + // Upper bound on the intensity level present in the image. For unsigned + // integer pixel encodings, this is the brightness of the largest + // representable value. The image does not necessarily contain a pixel + // actually this bright. An encoder is allowed to set 255 for SDR images + // without computing a histogram. + float intensity_target; // [nits] + + // Lower bound on the intensity level present in the image. This may be + // loose, i.e. lower than the actual darkest pixel. When tone mapping, a + // decoder will map [min_nits, intensity_target] to the display range. + float min_nits; + + bool relative_to_max_display; // see below + // The tone mapping will leave unchanged (linear mapping) any pixels whose + // brightness is strictly below this. The interpretation depends on + // relative_to_max_display. If true, this is a ratio [0, 1] of the maximum + // display brightness [nits], otherwise an absolute brightness [nits]. + float linear_below; +}; + +// Contains weights to customize some trasnforms - in particular, XYB and +// upsampling. +struct CustomTransformData : public Fields { + CustomTransformData(); + const char* Name() const override { return "CustomTransformData"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields. Must equal xyb_encoded of + // ImageMetadata, should be set by ImageMetadata during VisitFields. + bool nonserialized_xyb_encoded = false; + + mutable bool all_default; + + OpsinInverseMatrix opsin_inverse_matrix; + + uint32_t custom_weights_mask; + float upsampling2_weights[15]; + float upsampling4_weights[55]; + float upsampling8_weights[210]; +}; + +// Properties of the original image bundle. This enables Encode(Decode()) to +// re-create an equivalent image without user input. +struct ImageMetadata : public Fields { + ImageMetadata(); + const char* Name() const override { return "ImageMetadata"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Returns bit depth of the JPEG XL compressed alpha channel, or 0 if no alpha + // channel present. In the theoretical case that there are multiple alpha + // channels, returns the bit depht of the first. + uint32_t GetAlphaBits() const { + const ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (alpha == nullptr) return 0; + JXL_ASSERT(alpha->bit_depth.bits_per_sample != 0); + return alpha->bit_depth.bits_per_sample; + } + + // Sets bit depth of alpha channel, adding extra channel if needed, or + // removing all alpha channels if bits is 0. + // Assumes integer alpha channel and not designed to support multiple + // alpha channels (it's possible to use those features by manipulating + // extra_channel_info directly). + // + // Callers must insert the actual channel image at the same index before any + // further modifications to extra_channel_info. + void SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied = false); + + bool HasAlpha() const { return GetAlphaBits() != 0; } + + // Sets the original bit depth fields to indicate unsigned integer of the + // given bit depth. + // TODO(lode): move function to BitDepth + void SetUintSamples(uint32_t bits) { + bit_depth.bits_per_sample = bits; + bit_depth.exponent_bits_per_sample = 0; + bit_depth.floating_point_sample = false; + } + // Sets the original bit depth fields to indicate single precision floating + // point. + // TODO(lode): move function to BitDepth + void SetFloat32Samples() { + bit_depth.bits_per_sample = 32; + bit_depth.exponent_bits_per_sample = 8; + bit_depth.floating_point_sample = true; + } + + void SetFloat16Samples() { + bit_depth.bits_per_sample = 16; + bit_depth.exponent_bits_per_sample = 5; + bit_depth.floating_point_sample = true; + } + + void SetIntensityTarget(float intensity_target) { + tone_mapping.intensity_target = intensity_target; + } + float IntensityTarget() const { + JXL_ASSERT(tone_mapping.intensity_target != 0); + return tone_mapping.intensity_target; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + const ExtraChannelInfo* Find(ExtraChannel type) const { + for (const ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + ExtraChannelInfo* Find(ExtraChannel type) { + for (ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + Orientation GetOrientation() const { + return static_cast(orientation); + } + + bool ExtraFieldsDefault() const; + + mutable bool all_default; + + BitDepth bit_depth; + bool modular_16_bit_buffer_sufficient; // otherwise 32 is. + + // Whether the colors values of the pixels of frames are encoded in the + // codestream using the absolute XYB color space, or the using values that + // follow the color space defined by the ColorEncoding or ICC profile. This + // determines when or whether a CMS (Color Management System) is needed to get + // the pixels in a desired color space. In one case, the pixels have one known + // color space and a CMS is needed to convert them to the original image's + // color space, in the other case the pixels have the color space of the + // original image and a CMS is required if a different display space, or a + // single known consistent color space for multiple decoded images, is + // desired. In all cases, the color space of all frames from a single image is + // the same, both VarDCT and modular frames. + // + // If true: then frames can be decoded to XYB (which can also be converted to + // linear and non-linear sRGB with the built in conversion without CMS). The + // attached ColorEncoding or ICC profile has no effect on the meaning of the + // pixel's color values, but instead indicates what the color profile of the + // original image was, and what color profile one should convert to when + // decoding to integers to prevent clipping and precision loss. To do that + // conversion requires a CMS. + // + // If false: then the color values of decoded frames are in the space defined + // by the attached ColorEncoding or ICC profile. To instead get the pixels in + // a chosen known color space, such as sRGB, requires a CMS, since the + // attached ColorEncoding or ICC profile could be any arbitrary color space. + // This mode is typically used for lossless images encoded as integers. + // Frames can also use YCbCr encoding, some frames may and some may not, but + // this is not a different color space but a certain encoding of the RGB + // values. + // + // Note: if !xyb_encoded, but the attached color profile indicates XYB (which + // can happen either if it's a ColorEncoding with color_space_ == + // ColorSpace::kXYB, or if it's an ICC Profile that has been crafted to + // represent XYB), then the frames still may not use ColorEncoding kXYB, they + // must still use kNone (or kYCbCr, which would mean applying the YCbCr + // transform to the 3-channel XYB data), since with !xyb_encoded, the 3 + // channels are stored as-is, no matter what meaning the color profile assigns + // to them. To use ColorEncoding::kXYB, xyb_encoded must be true. + // + // This value is defined in image metadata because this is the global + // codestream header. This value does not affect the image itself, so is not + // image metadata per se, it only affects the encoding, and what color space + // the decoder can receive the pixels in without needing a CMS. + bool xyb_encoded; + + ColorEncoding color_encoding; + + // These values are initialized to defaults such that the 'extra_fields' + // condition in VisitFields uses correctly initialized values. + uint32_t orientation = 1; + bool have_preview = false; + bool have_animation = false; + bool have_intrinsic_size = false; + + // If present, the stored image has the dimensions of the first SizeHeader, + // but decoders are advised to resample or display per `intrinsic_size`. + SizeHeader intrinsic_size; // only if have_intrinsic_size + + ToneMapping tone_mapping; + + // When reading: deserialized. When writing: automatically set from vector. + uint32_t num_extra_channels; + std::vector extra_channel_info; + + CustomTransformData transform_data; // often default + + // Only present if m.have_preview. + PreviewHeader preview_size; + // Only present if m.have_animation. + AnimationHeader animation; + + uint64_t extensions; + + // Option to stop parsing after basic info, and treat as if the later + // fields do not participate. Use to parse only basic image information + // excluding the final larger or variable sized data. + bool nonserialized_only_parse_basic_info = false; +}; + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata); + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out); + +// All metadata applicable to the entire codestream (dimensions, extra channels, +// ...) +struct CodecMetadata { + // TODO(lode): use the preview and animation fields too, in place of the + // nonserialized_ ones in ImageMetadata. + ImageMetadata m; + // The size of the codestream: this is the nominal size applicable to all + // frames, although some frames can have a different effective size through + // crop, dc_level or representing a the preview. + SizeHeader size; + // Often default. + CustomTransformData transform_data; + + size_t xsize() const { return size.xsize(); } + size_t ysize() const { return size.ysize(); } +}; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_METADATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_ops.h b/third_party/jpeg-xl/lib/jxl/image_ops.h new file mode 100644 index 000000000000..c572580056ed --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_ops.h @@ -0,0 +1,822 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_IMAGE_OPS_H_ +#define LIB_JXL_IMAGE_OPS_H_ + +// Operations on images. + +#include +#include +#include +#include + +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +void CopyImageTo(const Plane& from, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImage1"); + JXL_ASSERT(SameSize(from, *to)); + if (from.ysize() == 0 || from.xsize() == 0) return; + for (size_t y = 0; y < from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = from.ConstRow(y); + T* JXL_RESTRICT row_to = to->Row(y); + memcpy(row_to, row_from, from.xsize() * sizeof(T)); + } +} + +// DEPRECATED - prefer to preallocate result. +template +Plane CopyImage(const Plane& from) { + Plane to(from.xsize(), from.ysize()); + CopyImageTo(from, &to); + return to; +} + +// Copies `from:rect_from` to `to:rect_to`. +template +void CopyImageTo(const Rect& rect_from, const Plane& from, + const Rect& rect_to, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImageR"); + JXL_DASSERT(SameSize(rect_from, rect_to)); + JXL_DASSERT(rect_from.IsInside(from)); + JXL_DASSERT(rect_to.IsInside(*to)); + if (rect_from.xsize() == 0) return; + for (size_t y = 0; y < rect_from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + T* JXL_RESTRICT row_to = rect_to.Row(to, y); + memcpy(row_to, row_from, rect_from.xsize() * sizeof(T)); + } +} + +// DEPRECATED - Returns a copy of the "image" pixels that lie in "rect". +template +Plane CopyImage(const Rect& rect, const Plane& image) { + Plane copy(rect.xsize(), rect.ysize()); + CopyImageTo(rect, image, ©); + return copy; +} + +// Copies `from:rect_from` to `to:rect_to`. +template +void CopyImageTo(const Rect& rect_from, const Image3& from, + const Rect& rect_to, Image3* JXL_RESTRICT to) { + PROFILER_ZONE("CopyImageR"); + JXL_ASSERT(SameSize(rect_from, rect_to)); + for (size_t c = 0; c < 3; c++) { + CopyImageTo(rect_from, from.Plane(c), rect_to, &to->Plane(c)); + } +} + +template +void ConvertPlaneAndClamp(const Rect& rect_from, const Plane& from, + const Rect& rect_to, Plane* JXL_RESTRICT to) { + PROFILER_ZONE("ConvertPlane"); + JXL_ASSERT(SameSize(rect_from, rect_to)); + using M = decltype(T() + U()); + for (size_t y = 0; y < rect_to.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + U* JXL_RESTRICT row_to = rect_to.Row(to, y); + for (size_t x = 0; x < rect_to.xsize(); ++x) { + row_to[x] = + std::min(std::max(row_from[x], std::numeric_limits::min()), + std::numeric_limits::max()); + } + } +} + +// Copies `from` to `to`. +template +void CopyImageTo(const T& from, T* JXL_RESTRICT to) { + return CopyImageTo(Rect(from), from, Rect(*to), to); +} + +// Copies `from:rect_from` to `to`. +template +void CopyImageTo(const Rect& rect_from, const T& from, T* JXL_RESTRICT to) { + return CopyImageTo(rect_from, from, Rect(*to), to); +} + +// Copies `from` to `to:rect_to`. +template +void CopyImageTo(const T& from, const Rect& rect_to, T* JXL_RESTRICT to) { + return CopyImageTo(Rect(from), from, rect_to, to); +} + +// Copies `from:rect_from` to `to:rect_to`; also copies `padding` pixels of +// border around `from:rect_from`, in all directions, whenever they are inside +// the first image. +template +void CopyImageToWithPadding(const Rect& from_rect, const T& from, + size_t padding, const Rect& to_rect, T* to) { + size_t xextra0 = std::min(padding, from_rect.x0()); + size_t xextra1 = + std::min(padding, from.xsize() - from_rect.x0() - from_rect.xsize()); + size_t yextra0 = std::min(padding, from_rect.y0()); + size_t yextra1 = + std::min(padding, from.ysize() - from_rect.y0() - from_rect.ysize()); + JXL_DASSERT(to_rect.x0() >= xextra0); + JXL_DASSERT(to_rect.y0() >= yextra0); + + return CopyImageTo(Rect(from_rect.x0() - xextra0, from_rect.y0() - yextra0, + from_rect.xsize() + xextra0 + xextra1, + from_rect.ysize() + yextra0 + yextra1), + from, + Rect(to_rect.x0() - xextra0, to_rect.y0() - yextra0, + to_rect.xsize() + xextra0 + xextra1, + to_rect.ysize() + yextra0 + yextra1), + to); +} + +// DEPRECATED - prefer to preallocate result. +template +Image3 CopyImage(const Image3& from) { + Image3 copy(from.xsize(), from.ysize()); + CopyImageTo(from, ©); + return copy; +} + +// DEPRECATED - prefer to preallocate result. +template +Image3 CopyImage(const Rect& rect, const Image3& from) { + Image3 to(rect.xsize(), rect.ysize()); + CopyImageTo(rect, from.Plane(0), to.Plane(0)); + CopyImageTo(rect, from.Plane(1), to.Plane(1)); + CopyImageTo(rect, from.Plane(2), to.Plane(2)); + return to; +} + +// Sets "thickness" pixels on each border to "value". This is faster than +// initializing the entire image and overwriting valid/interior pixels. +template +void SetBorder(const size_t thickness, const T value, Image3* image) { + const size_t xsize = image->xsize(); + const size_t ysize = image->ysize(); + // Top: fill entire row + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < std::min(thickness, ysize); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + xsize, value); + } + + // Bottom: fill entire row + for (size_t y = ysize - thickness; y < ysize; ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + xsize, value); + } + + // Left/right: fill the 'columns' on either side, but only if the image is + // big enough that they don't already belong to the top/bottom rows. + if (ysize >= 2 * thickness) { + for (size_t y = thickness; y < ysize - thickness; ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + std::fill(row, row + thickness, value); + std::fill(row + xsize - thickness, row + xsize, value); + } + } + } +} + +template +void Subtract(const ImageIn& image1, const ImageIn& image2, ImageOut* out) { + using T = typename ImageIn::T; + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] - row2[x]; + } + } +} + +// In-place. +template +void SubtractFrom(const Plane& what, Plane* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] -= row_what[x]; + } + } +} + +// In-place. +template +void AddTo(const Plane& what, Plane* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +template +void AddTo(Rect rectFrom, const Plane& what, Rect rectTo, + Plane* to) { + JXL_ASSERT(SameSize(rectFrom, rectTo)); + const size_t xsize = rectTo.xsize(); + const size_t ysize = rectTo.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = rectFrom.ConstRow(what, y); + Tout* JXL_RESTRICT row_to = rectTo.Row(to, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +// Returns linear combination of two grayscale images. +template +Plane LinComb(const T lambda1, const Plane& image1, const T lambda2, + const Plane& image2) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + Plane out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = lambda1 * row1[x] + lambda2 * row2[x]; + } + } + return out; +} + +// Returns a pixel-by-pixel multiplication of image by lambda. +template +Plane ScaleImage(const T lambda, const Plane& image) { + Plane out(image.xsize(), image.ysize()); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* const JXL_RESTRICT row = image.Row(y); + T* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = lambda * row[x]; + } + } + return out; +} + +// Multiplies image by lambda in-place +template +void ScaleImage(const T lambda, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = lambda * row[x]; + } + } +} + +template +Plane Product(const Plane& a, const Plane& b) { + Plane c(a.xsize(), a.ysize()); + for (size_t y = 0; y < a.ysize(); ++y) { + const T* const JXL_RESTRICT row_a = a.Row(y); + const T* const JXL_RESTRICT row_b = b.Row(y); + T* const JXL_RESTRICT row_c = c.Row(y); + for (size_t x = 0; x < a.xsize(); ++x) { + row_c[x] = row_a[x] * row_b[x]; + } + } + return c; +} + +float DotProduct(const ImageF& a, const ImageF& b); + +template +void FillImage(const T value, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template +void ZeroFillImage(Plane* image) { + if (image->xsize() == 0) return; + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + memset(row, 0, image->xsize() * sizeof(T)); + } +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static inline int64_t Mirror(int64_t x, const int64_t xsize) { + JXL_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return x; +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + JXL_INLINE int64_t operator()(const int64_t coord, const int64_t size) const { + return Mirror(coord, size); + } +}; + +// Returns the same coordinate: required for TFNode with Border(), or useful +// when we know "coord" is already valid (e.g. interior of an image). +struct WrapUnchanged { + JXL_INLINE int64_t operator()(const int64_t coord, int64_t /*size*/) const { + return coord; + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const ImageOrView& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const JXL_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const JXL_RESTRICT first_row_; + const float* const JXL_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + JXL_INLINE const float* operator()(const float* const JXL_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +// Sets "thickness" pixels on each border to "value". This is faster than +// initializing the entire image and overwriting valid/interior pixels. +template +void SetBorder(const size_t thickness, const T value, Plane* image) { + const size_t xsize = image->xsize(); + const size_t ysize = image->ysize(); + // Top: fill entire row + for (size_t y = 0; y < std::min(thickness, ysize); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Bottom: fill entire row + for (size_t y = ysize - thickness; y < ysize; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Left/right: fill the 'columns' on either side, but only if the image is + // big enough that they don't already belong to the top/bottom rows. + if (ysize >= 2 * thickness) { + for (size_t y = thickness; y < ysize - thickness; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + thickness, value); + std::fill(row + xsize - thickness, row + xsize, value); + } + } +} + +// Computes the minimum and maximum pixel value. +template +void ImageMinMax(const Plane& image, T* const JXL_RESTRICT min, + T* const JXL_RESTRICT max) { + *min = std::numeric_limits::max(); + *max = std::numeric_limits::lowest(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* const JXL_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + *min = std::min(*min, row[x]); + *max = std::max(*max, row[x]); + } + } +} + +// Copies pixels, scaling their value relative to the "from" min/max by +// "to_range". Example: U8 [0, 255] := [0.0, 1.0], to_range = 1.0 => +// outputs [0.0, 1.0]. +template +void ImageConvert(const Plane& from, const float to_range, + Plane* const JXL_RESTRICT to) { + JXL_ASSERT(SameSize(from, *to)); + FromType min_from, max_from; + ImageMinMax(from, &min_from, &max_from); + const float scale = to_range / (max_from - min_from); + for (size_t y = 0; y < from.ysize(); ++y) { + const FromType* const JXL_RESTRICT row_from = from.Row(y); + ToType* const JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < from.xsize(); ++x) { + row_to[x] = static_cast((row_from[x] - min_from) * scale); + } + } +} + +template +Plane ConvertToFloat(const Plane& from) { + float factor = 1.0f / std::numeric_limits::max(); + if (std::is_same::value || std::is_same::value) { + factor = 1.0f; + } + Plane to(from.xsize(), from.ysize()); + for (size_t y = 0; y < from.ysize(); ++y) { + const From* const JXL_RESTRICT row_from = from.Row(y); + float* const JXL_RESTRICT row_to = to.Row(y); + for (size_t x = 0; x < from.xsize(); ++x) { + row_to[x] = row_from[x] * factor; + } + } + return to; +} + +template +Plane ImageFromPacked(const std::vector& packed, const size_t xsize, + const size_t ysize) { + Plane out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + T* const JXL_RESTRICT row = out.Row(y); + const T* const JXL_RESTRICT packed_row = &packed[y * xsize]; + memcpy(row, packed_row, xsize * sizeof(T)); + } + return out; +} + +// Computes independent minimum and maximum values for each plane. +template +void Image3MinMax(const Image3& image, const Rect& rect, + std::array* out_min, std::array* out_max) { + for (size_t c = 0; c < 3; ++c) { + T min = std::numeric_limits::max(); + T max = std::numeric_limits::min(); + for (size_t y = 0; y < rect.ysize(); ++y) { + const T* JXL_RESTRICT row = rect.ConstPlaneRow(image, c, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + min = std::min(min, row[x]); + max = std::max(max, row[x]); + } + } + (*out_min)[c] = min; + (*out_max)[c] = max; + } +} + +// Computes independent minimum and maximum values for each plane. +template +void Image3MinMax(const Image3& image, std::array* out_min, + std::array* out_max) { + Image3MinMax(image, Rect(image), out_min, out_max); +} + +template +void Image3Max(const Image3& image, std::array* out_max) { + for (size_t c = 0; c < 3; ++c) { + T max = std::numeric_limits::min(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row = image.ConstPlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + max = std::max(max, row[x]); + } + } + (*out_max)[c] = max; + } +} + +// Computes the sum of the pixels in `rect`. +template +T ImageSum(const Plane& image, const Rect& rect) { + T result = 0; + for (size_t y = 0; y < rect.ysize(); ++y) { + const T* JXL_RESTRICT row = rect.ConstRow(image, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + result += row[x]; + } + } + return result; +} + +template +T ImageSum(const Plane& image) { + return ImageSum(image, Rect(image)); +} + +template +std::array Image3Sum(const Image3& image, const Rect& rect) { + std::array out_sum = 0; + for (size_t c = 0; c < 3; ++c) { + (out_sum)[c] = ImageSum(image.Plane(c), rect); + } + return out_sum; +} + +template +std::array Image3Sum(const Image3& image) { + return Image3Sum(image, Rect(image)); +} + +template +std::vector PackedFromImage(const Plane& image, const Rect& rect) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + std::vector packed(xsize * ysize); + for (size_t y = 0; y < rect.ysize(); ++y) { + memcpy(&packed[y * xsize], rect.ConstRow(image, y), xsize * sizeof(T)); + } + return packed; +} + +template +std::vector PackedFromImage(const Plane& image) { + return PackedFromImage(image, Rect(image)); +} + +// Computes the median pixel value. +template +T ImageMedian(const Plane& image, const Rect& rect) { + std::vector pixels = PackedFromImage(image, rect); + return Median(&pixels); +} + +template +T ImageMedian(const Plane& image) { + return ImageMedian(image, Rect(image)); +} + +template +std::array Image3Median(const Image3& image, const Rect& rect) { + std::array out_median; + for (size_t c = 0; c < 3; ++c) { + (out_median)[c] = ImageMedian(image.Plane(c), rect); + } + return out_median; +} + +template +std::array Image3Median(const Image3& image) { + return Image3Median(image, Rect(image)); +} + +template +void Image3Convert(const Image3& from, const float to_range, + Image3* const JXL_RESTRICT to) { + JXL_ASSERT(SameSize(from, *to)); + std::array min_from, max_from; + Image3MinMax(from, &min_from, &max_from); + float scales[3]; + for (size_t c = 0; c < 3; ++c) { + scales[c] = to_range / (max_from[c] - min_from[c]); + } + float scale = std::min(scales[0], std::min(scales[1], scales[2])); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < from.ysize(); ++y) { + const FromType* JXL_RESTRICT row_from = from.ConstPlaneRow(c, y); + ToType* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < from.xsize(); ++x) { + const float to = (row_from[x] - min_from[c]) * scale; + row_to[x] = static_cast(to); + } + } + } +} + +template +Image3F ConvertToFloat(const Image3& from) { + return Image3F(ConvertToFloat(from.Plane(0)), ConvertToFloat(from.Plane(1)), + ConvertToFloat(from.Plane(2))); +} + +template +void Subtract(const Image3& image1, const Image3& image2, + Image3* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* const JXL_RESTRICT row1 = image1.ConstPlaneRow(c, y); + const Tin* const JXL_RESTRICT row2 = image2.ConstPlaneRow(c, y); + Tout* const JXL_RESTRICT row_out = out->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] - row2[x]; + } + } + } +} + +template +void SubtractFrom(const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] -= row_what[x]; + } + } + } +} + +template +void AddTo(const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = to->PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } + } +} + +// Adds `what` of the size of `rect` to `to` in the position of `rect`. +template +void AddTo(const Rect& rect, const Image3& what, Image3* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + JXL_ASSERT(xsize == rect.xsize()); + JXL_ASSERT(ysize == rect.ysize()); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstPlaneRow(c, y); + Tout* JXL_RESTRICT row_to = rect.PlaneRow(to, c, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } + } +} + +template +Image3 ScaleImage(const T lambda, const Image3& image) { + Image3 out(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row = image.ConstPlaneRow(c, y); + T* JXL_RESTRICT row_out = out.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = lambda * row[x]; + } + } + } + return out; +} + +// Multiplies image by lambda in-place +template +void ScaleImage(const T lambda, Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = lambda * row[x]; + } + } + } +} + +// Initializes all planes to the same "value". +template +void FillImage(const T value, Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } + } +} + +template +void FillPlane(const T value, Plane* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template +void FillImage(const T value, Image3* image, Rect rect) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.PlaneRow(image, c, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } + } +} + +template +void FillPlane(const T value, Plane* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } +} + +template +void ZeroFillImage(Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + memset(row, 0, image->xsize() * sizeof(T)); + } + } +} + +template +void ZeroFillPlane(Plane* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + memset(row, 0, rect.xsize() * sizeof(T)); + } +} + +// First, image is padded horizontally, with the rightmost value. +// Next, image is padded vertically, by repeating the last line. +ImageF PadImage(const ImageF& in, size_t xsize, size_t ysize); + +// Pad an image with xborder columns on each vertical side and yboder rows +// above and below, mirroring the image. +Image3F PadImageMirror(const Image3F& in, size_t xborder, size_t yborder); + +// First, image is padded horizontally, with the rightmost value. +// Next, image is padded vertically, by repeating the last line. +// Prefer PadImageToBlockMultipleInPlace if padding to kBlockDim. +Image3F PadImageToMultiple(const Image3F& in, size_t N); + +// Same as above, but operates in-place. Assumes that the `in` image was +// allocated large enough. +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in); + +// Downsamples an image by a given factor. +void DownsampleImage(Image3F* opsin, size_t factor); + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_OPS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_ops_test.cc b/third_party/jpeg-xl/lib/jxl/image_ops_test.cc new file mode 100644 index 000000000000..841dff7d3a43 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_ops_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/image_ops.h" + +#include +#include +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +template +void TestPacked(const size_t xsize, const size_t ysize) { + Plane image1(xsize, ysize); + RandomFillImage(&image1); + const std::vector& packed = PackedFromImage(image1); + const Plane& image2 = ImageFromPacked(packed, xsize, ysize); + EXPECT_TRUE(SamePixels(image1, image2)); +} + +TEST(ImageTest, TestPacked) { + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); + + TestPacked(1, 1); + TestPacked(7, 1); + TestPacked(1, 7); +} + +// Ensure entire payload is readable/writable for various size/offset combos. +TEST(ImageTest, TestAllocator) { + std::mt19937 rng(129); + const size_t k32 = 32; + const size_t kAlign = CacheAligned::kAlignment; + for (size_t size : {k32 * 1, k32 * 2, k32 * 3, k32 * 4, k32 * 5, + CacheAligned::kAlias, 2 * CacheAligned::kAlias + 4}) { + for (size_t offset = 0; offset <= CacheAligned::kAlias; offset += kAlign) { + uint8_t* bytes = + static_cast(CacheAligned::Allocate(size, offset)); + JXL_CHECK(reinterpret_cast(bytes) % kAlign == 0); + // Ensure we can write/read the last byte. Use RNG to fool the compiler + // into thinking the write is necessary. + memset(bytes, 0, size); + bytes[size - 1] = 1; // greatest element + std::uniform_int_distribution dist(0, size - 1); + uint32_t pos = dist(rng); // random but != greatest + while (pos == size - 1) { + pos = dist(rng); + } + JXL_CHECK(bytes[pos] < bytes[size - 1]); + + CacheAligned::Free(bytes); + } + } +} + +template +void TestFillImpl(Image3* img, const char* layout) { + FillImage(T(1), img); + for (size_t y = 0; y < img->ysize(); ++y) { + for (size_t c = 0; c < 3; ++c) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(1)) { + printf("Not 1 at c=%zu %zu, %zu (%zu x %zu) (%s)\n", c, x, y, + img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(2); + } + } + } + + // Same for ZeroFillImage and swapped c/y loop ordering. + ZeroFillImage(img); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < img->ysize(); ++y) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(0)) { + printf("Not 0 at c=%zu %zu, %zu (%zu x %zu) (%s)\n", c, x, y, + img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(3); + } + } + } +} + +template +void TestFillT() { + for (uint32_t xsize : {0, 1, 15, 16, 31, 32}) { + for (uint32_t ysize : {0, 1, 15, 16, 31, 32}) { + Image3 image(xsize, ysize); + TestFillImpl(&image, "size ctor"); + + Image3 planar(Plane(xsize, ysize), Plane(xsize, ysize), + Plane(xsize, ysize)); + TestFillImpl(&planar, "planar"); + } + } +} + +// Ensure y/c/x and c/y/x loops visit pixels no more than once. +TEST(ImageTest, TestFill) { + TestFillT(); + TestFillT(); + TestFillT(); + TestFillT(); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_test_utils.h b/third_party/jpeg-xl/lib/jxl/image_test_utils.h new file mode 100644 index 000000000000..80ab32264429 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_test_utils.h @@ -0,0 +1,322 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_IMAGE_TEST_UTILS_H_ +#define LIB_JXL_IMAGE_TEST_UTILS_H_ + +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +void VerifyEqual(const Plane& expected, const Plane& actual) { + JXL_CHECK(SameSize(expected, actual)); + for (size_t y = 0; y < expected.ysize(); ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (size_t x = 0; x < expected.xsize(); ++x) { + ASSERT_EQ(row_expected[x], row_actual[x]) << x << " " << y; + } + } +} + +template +void VerifyEqual(const Image3& expected, const Image3& actual) { + for (size_t c = 0; c < 3; ++c) { + VerifyEqual(expected.Plane(c), actual.Plane(c)); + } +} + +template +bool SamePixels(const Plane& image1, const Plane& image2, + const Rect rect) { + if (!rect.IsInside(image1) || !rect.IsInside(image2)) { + ADD_FAILURE() << "requsted rectangle is not fully inside the image"; + return false; + } + size_t mismatches = 0; + for (size_t y = rect.y0(); y < rect.ysize(); ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + for (size_t x = rect.x0(); x < rect.xsize(); ++x) { + if (row1[x] != row2[x]) { + ADD_FAILURE() << "pixel mismatch" << x << ", " << y << ": " + << double(row1[x]) << " != " << double(row2[x]); + if (++mismatches > 4) { + return false; + } + } + } + } + return mismatches == 0; +} + +template +bool SamePixels(const Plane& image1, const Plane& image2) { + JXL_CHECK(SameSize(image1, image2)); + return SamePixels(image1, image2, Rect(image1)); +} + +template +bool SamePixels(const Image3& image1, const Image3& image2) { + JXL_CHECK(SameSize(image1, image2)); + for (size_t c = 0; c < 3; ++c) { + if (!SamePixels(image1.Plane(c), image2.Plane(c))) { + return false; + } + } + return true; +} + +// Use for floating-point images with fairly large numbers; tolerates small +// absolute errors and/or small relative errors. Returns max_relative. +template +void VerifyRelativeError(const Plane& expected, const Plane& actual, + const double threshold_l1, + const double threshold_relative, + const intptr_t border = 0, const size_t c = 0) { + JXL_CHECK(SameSize(expected, actual)); + const intptr_t xsize = expected.xsize(); + const intptr_t ysize = expected.ysize(); + + // Max over current scanline to give a better idea whether there are + // systematic errors or just one outlier. Invalid if negative. + double max_l1 = -1; + double max_relative = -1; + bool any_bad = false; + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + // Cannot compute relative, only check/update L1. + if (std::abs(row_expected[x]) < 1E-10) { + if (l1 > threshold_l1) { + any_bad = true; + max_l1 = std::max(max_l1, l1); + } + } else { + const double relative = l1 / std::abs(double(row_expected[x])); + if (l1 > threshold_l1 && relative > threshold_relative) { + // Fails both tolerances => will exit below, update max_*. + any_bad = true; + max_l1 = std::max(max_l1, l1); + max_relative = std::max(max_relative, relative); + } + } + } + } + if (any_bad) { + // Never had a valid relative value, don't print it. + if (max_relative < 0) { + fprintf(stderr, "c=%zu: max +/- %E exceeds +/- %.2E\n", c, max_l1, + threshold_l1); + } else { + fprintf(stderr, "c=%zu: max +/- %E, x %E exceeds +/- %.2E, x %.2E\n", c, + max_l1, max_relative, threshold_l1, threshold_relative); + } + // Dump the expected image and actual image if the region is small enough. + const intptr_t kMaxTestDumpSize = 16; + if (xsize <= kMaxTestDumpSize + 2 * border && + ysize <= kMaxTestDumpSize + 2 * border) { + fprintf(stderr, "Expected image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + fprintf(stderr, "%10lf ", static_cast(row_expected[x])); + } + fprintf(stderr, "\n"); + } + + fprintf(stderr, "Actual image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + fprintf(stderr, "%10lf ", static_cast(row_actual[x])); + } else { + fprintf(stderr, "%10s ", "=="); + } + } + fprintf(stderr, "\n"); + } + } + + // Find first failing x for further debugging. + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + FAIL() << x << ", " << y << " (" << expected.xsize() << " x " + << expected.ysize() << ") expected " + << static_cast(row_expected[x]) << " actual " + << static_cast(row_actual[x]); + } + } + } + return; // if any_bad, we should have exited. + } +} + +template +void VerifyRelativeError(const Image3& expected, const Image3& actual, + const float threshold_l1, + const float threshold_relative, + const intptr_t border = 0) { + for (size_t c = 0; c < 3; ++c) { + VerifyRelativeError(expected.Plane(c), actual.Plane(c), threshold_l1, + threshold_relative, border, c); + } +} + +// Generator for independent, uniformly distributed integers [0, max]. +template +class GeneratorRandom { + public: + GeneratorRandom(Random* rng, const T max) : rng_(*rng), dist_(0, max) {} + + GeneratorRandom(Random* rng, const T min, const T max) + : rng_(*rng), dist_(min, max) {} + + T operator()(const size_t x, const size_t y, const int c) const { + return dist_(rng_); + } + + private: + Random& rng_; + mutable std::uniform_int_distribution<> dist_; +}; + +template +class GeneratorRandom { + public: + GeneratorRandom(Random* rng, const float max) + : rng_(*rng), dist_(0.0f, max) {} + + GeneratorRandom(Random* rng, const float min, const float max) + : rng_(*rng), dist_(min, max) {} + + float operator()(const size_t x, const size_t y, const int c) const { + return dist_(rng_); + } + + private: + Random& rng_; + mutable std::uniform_real_distribution dist_; +}; + +template +class GeneratorRandom { + public: + GeneratorRandom(Random* rng, const double max) + : rng_(*rng), dist_(0.0, max) {} + + GeneratorRandom(Random* rng, const double min, const double max) + : rng_(*rng), dist_(min, max) {} + + double operator()(const size_t x, const size_t y, const int c) const { + return dist_(rng_); + } + + private: + Random& rng_; + mutable std::uniform_real_distribution<> dist_; +}; + +// Assigns generator(x, y, 0) to each pixel (x, y). +template +void GenerateImage(const Generator& generator, Image* image) { + using T = typename Image::T; + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = generator(x, y, 0); + } + } +} + +template +void RandomFillImage(Plane* image, + const T max = std::numeric_limits::max()) { + std::mt19937_64 rng(129); + const GeneratorRandom generator(&rng, max); + GenerateImage(generator, image); +} + +template +void RandomFillImage(Plane* image, const T min, const T max, + const int seed) { + std::mt19937_64 rng(seed); + const GeneratorRandom generator(&rng, min, max); + GenerateImage(generator, image); +} + +// Assigns generator(x, y, c) to each pixel (x, y). +template +void GenerateImage(const Generator& generator, Image3* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = generator(x, y, c); + } + } + } +} + +template +void RandomFillImage(Image3* image, + const T max = std::numeric_limits::max()) { + std::mt19937_64 rng(129); + const GeneratorRandom generator(&rng, max); + GenerateImage(generator, image); +} + +template +void RandomFillImage(Image3* image, const T min, const T max, + const int seed) { + std::mt19937_64 rng(seed); + const GeneratorRandom generator(&rng, min, max); + GenerateImage(generator, image); +} + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_TEST_UTILS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc new file mode 100644 index 000000000000..98dca089ec54 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc @@ -0,0 +1,138 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/dec_jpeg_data.h" + +#include + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span encoded, JPEGData* jpeg_data) { + Status ret = true; + const uint8_t* in = encoded.data(); + size_t available_in = encoded.size(); + { + BitReader br(encoded); + BitReaderScopedCloser br_closer(&br, &ret); + JXL_RETURN_IF_ERROR(Bundle::Read(&br, jpeg_data)); + JXL_RETURN_IF_ERROR(br.JumpToByteBoundary()); + in += br.TotalBitsConsumed() / 8; + available_in -= br.TotalBitsConsumed() / 8; + } + JXL_RETURN_IF_ERROR(ret); + + BrotliDecoderState* brotli_dec = + BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + + struct BrotliDecDeleter { + BrotliDecoderState* brotli_dec; + ~BrotliDecDeleter() { BrotliDecoderDestroyInstance(brotli_dec); } + } brotli_dec_deleter{brotli_dec}; + + BrotliDecoderResult result = + BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS; + + auto br_read = [&](std::vector& data) -> Status { + size_t available_out = data.size(); + uint8_t* out = data.data(); + while (available_out != 0) { + if (BrotliDecoderIsFinished(brotli_dec)) { + return JXL_FAILURE("Not enough decompressed output"); + } + result = BrotliDecoderDecompressStream(brotli_dec, &available_in, &in, + &available_out, &out, nullptr); + if (result != + BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT && + result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE( + "Brotli decoding error: %s\n", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(brotli_dec))); + } + } + return true; + }; + size_t num_icc = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] != AppMarkerType::kUnknown) { + // Set the size of the marker. + size_t size_minus_1 = marker.size() - 1; + marker[1] = size_minus_1 >> 8; + marker[2] = size_minus_1 & 0xFF; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + if (marker.size() < 17) { + return JXL_FAILURE("ICC markers must be at least 17 bytes"); + } + marker[0] = 0xE2; + memcpy(&marker[3], kIccProfileTag, sizeof kIccProfileTag); + marker[15] = ++num_icc; + } + } else { + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + } + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + marker[16] = num_icc; + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kExif) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kExifTag) { + return JXL_FAILURE("Incorrect Exif marker size"); + } + memcpy(&marker[3], kExifTag, sizeof kExifTag); + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kXMP) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kXMPTag) { + return JXL_FAILURE("Incorrect XMP marker size"); + } + memcpy(&marker[3], kXMPTag, sizeof kXMPTag); + } + } + // TODO(eustas): actually inject ICC profile and check it fits perfectly. + for (size_t i = 0; i < jpeg_data->com_data.size(); i++) { + auto& marker = jpeg_data->com_data[i]; + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + for (size_t i = 0; i < jpeg_data->inter_marker_data.size(); i++) { + JXL_RETURN_IF_ERROR(br_read(jpeg_data->inter_marker_data[i])); + } + JXL_RETURN_IF_ERROR(br_read(jpeg_data->tail_data)); + if (result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE("Invalid brotli-compressed data"); + } + + if (!BrotliDecoderIsFinished(brotli_dec)) { + return JXL_FAILURE("Excess data in compressed stream"); + } + if (available_in != 0) { + return JXL_FAILURE("Unused data after brotli stream"); + } + + return true; +} +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h new file mode 100644 index 000000000000..18e6c66e761e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_H_ + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span encoded, JPEGData* jpeg_data); +} +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc new file mode 100644 index 000000000000..e99f3ea4f523 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc @@ -0,0 +1,992 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" + +#include +#include /* for memset, memcpy */ + +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/common.h" +#include "lib/jxl/jpeg/dec_jpeg_serialization_state.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +namespace { + +enum struct SerializationStatus { + NEEDS_MORE_INPUT, + NEEDS_MORE_OUTPUT, + ERROR, + DONE +}; + +const int kJpegPrecision = 8; + +// JpegBitWriter: buffer size +const size_t kJpegBitWriterChunkSize = 16384; + +// DCTCodingState: maximum number of correction bits to buffer +const int kJPEGMaxCorrectionBits = 1u << 16; + +// Returns non-zero if and only if x has a zero byte, i.e. one of +// x & 0xff, x & 0xff00, ..., x & 0xff00000000000000 is zero. +static JXL_INLINE uint64_t HasZeroByte(uint64_t x) { + return (x - 0x0101010101010101ULL) & ~x & 0x8080808080808080ULL; +} + +void JpegBitWriterInit(JpegBitWriter* bw, + std::deque* output_queue) { + bw->output = output_queue; + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->pos = 0; + bw->put_buffer = 0; + bw->put_bits = 64; + bw->healthy = true; + bw->data = bw->chunk.buffer->data(); +} + +static JXL_NOINLINE void SwapBuffer(JpegBitWriter* bw) { + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->data = bw->chunk.buffer->data(); + bw->pos = 0; +} + +static JXL_INLINE void Reserve(JpegBitWriter* bw, size_t n_bytes) { + if (JXL_UNLIKELY((bw->pos + n_bytes) > kJpegBitWriterChunkSize)) { + SwapBuffer(bw); + } +} + +/** + * Writes the given byte to the output, writes an extra zero if byte is 0xFF. + * + * This method is "careless" - caller must make sure that there is enough + * space in the output buffer. Emits up to 2 bytes to buffer. + */ +static JXL_INLINE void EmitByte(JpegBitWriter* bw, int byte) { + bw->data[bw->pos++] = byte; + if (byte == 0xFF) bw->data[bw->pos++] = 0; +} + +static JXL_INLINE void DischargeBitBuffer(JpegBitWriter* bw) { + // At this point we are ready to emit the most significant 6 bytes of + // put_buffer_ to the output. + // The JPEG format requires that after every 0xff byte in the entropy + // coded section, there is a zero byte, therefore we first check if any of + // the 6 most significant bytes of put_buffer_ is 0xFF. + Reserve(bw, 12); + if (HasZeroByte(~bw->put_buffer | 0xFFFF)) { + // We have a 0xFF byte somewhere, examine each byte and append a zero + // byte if necessary. + EmitByte(bw, (bw->put_buffer >> 56) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 48) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 40) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 32) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 24) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 16) & 0xFF); + } else { + // We don't have any 0xFF bytes, output all 6 bytes without checking. + bw->data[bw->pos] = (bw->put_buffer >> 56) & 0xFF; + bw->data[bw->pos + 1] = (bw->put_buffer >> 48) & 0xFF; + bw->data[bw->pos + 2] = (bw->put_buffer >> 40) & 0xFF; + bw->data[bw->pos + 3] = (bw->put_buffer >> 32) & 0xFF; + bw->data[bw->pos + 4] = (bw->put_buffer >> 24) & 0xFF; + bw->data[bw->pos + 5] = (bw->put_buffer >> 16) & 0xFF; + bw->pos += 6; + } + bw->put_buffer <<= 48; + bw->put_bits += 48; +} + +static JXL_INLINE void WriteBits(JpegBitWriter* bw, int nbits, uint64_t bits) { + // This is an optimization; if everything goes well, + // then |nbits| is positive; if non-existing Huffman symbol is going to be + // encoded, its length should be zero; later encoder could check the + // "health" of JpegBitWriter. + if (nbits == 0) { + bw->healthy = false; + return; + } + bw->put_bits -= nbits; + bw->put_buffer |= (bits << bw->put_bits); + if (bw->put_bits <= 16) DischargeBitBuffer(bw); +} + +void EmitMarker(JpegBitWriter* bw, int marker) { + Reserve(bw, 2); + JXL_DASSERT(marker != 0xFF); + bw->data[bw->pos++] = 0xFF; + bw->data[bw->pos++] = marker; +} + +bool JumpToByteBoundary(JpegBitWriter* bw, const uint8_t** pad_bits, + const uint8_t* pad_bits_end) { + size_t n_bits = bw->put_bits & 7u; + uint8_t pad_pattern; + if (*pad_bits == nullptr) { + pad_pattern = (1u << n_bits) - 1; + } else { + pad_pattern = 0; + const uint8_t* src = *pad_bits; + // TODO(eustas): bitwise reading looks insanely ineffective... + while (n_bits--) { + pad_pattern <<= 1; + if (src >= pad_bits_end) return false; + // TODO(eustas): DCHECK *src == {0, 1} + pad_pattern |= !!*(src++); + } + *pad_bits = src; + } + + Reserve(bw, 16); + + while (bw->put_bits <= 56) { + int c = (bw->put_buffer >> 56) & 0xFF; + EmitByte(bw, c); + bw->put_buffer <<= 8; + bw->put_bits += 8; + } + if (bw->put_bits < 64) { + int pad_mask = 0xFFu >> (64 - bw->put_bits); + int c = ((bw->put_buffer >> 56) & ~pad_mask) | pad_pattern; + EmitByte(bw, c); + } + bw->put_buffer = 0; + bw->put_bits = 64; + + return true; +} + +void JpegBitWriterFinish(JpegBitWriter* bw) { + if (bw->pos == 0) return; + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(nullptr, 0); + bw->data = nullptr; + bw->pos = 0; +} + +void DCTCodingStateInit(DCTCodingState* s) { + s->eob_run_ = 0; + s->cur_ac_huff_ = nullptr; + s->refinement_bits_.clear(); + s->refinement_bits_.reserve(kJPEGMaxCorrectionBits); +} + +// Emit all buffered data to the bit stream using the given Huffman code and +// bit writer. +static JXL_INLINE void Flush(DCTCodingState* s, JpegBitWriter* bw) { + if (s->eob_run_ > 0) { + int nbits = FloorLog2Nonzero(s->eob_run_); + int symbol = nbits << 4u; + WriteBits(bw, s->cur_ac_huff_->depth[symbol], + s->cur_ac_huff_->code[symbol]); + if (nbits > 0) { + WriteBits(bw, nbits, s->eob_run_ & ((1 << nbits) - 1)); + } + s->eob_run_ = 0; + } + for (size_t i = 0; i < s->refinement_bits_.size(); ++i) { + WriteBits(bw, 1, s->refinement_bits_[i]); + } + s->refinement_bits_.clear(); +} + +// Buffer some more data at the end-of-band (the last non-zero or newly +// non-zero coefficient within the [Ss, Se] spectral band). +static JXL_INLINE void BufferEndOfBand(DCTCodingState* s, + const HuffmanCodeTable* ac_huff, + const std::vector* new_bits, + JpegBitWriter* bw) { + if (s->eob_run_ == 0) { + s->cur_ac_huff_ = ac_huff; + } + ++s->eob_run_; + if (new_bits) { + s->refinement_bits_.insert(s->refinement_bits_.end(), new_bits->begin(), + new_bits->end()); + } + if (s->eob_run_ == 0x7FFF || + s->refinement_bits_.size() > kJPEGMaxCorrectionBits - kDCTBlockSize + 1) { + Flush(s, bw); + } +} + +bool BuildHuffmanCodeTable(const JPEGHuffmanCode& huff, + HuffmanCodeTable* table) { + int huff_code[kJpegHuffmanAlphabetSize]; + // +1 for a sentinel element. + uint32_t huff_size[kJpegHuffmanAlphabetSize + 1]; + int p = 0; + for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) { + int i = huff.counts[l]; + if (p + i > kJpegHuffmanAlphabetSize + 1) { + return false; + } + while (i--) huff_size[p++] = l; + } + + if (p == 0) { + return true; + } + + // Reuse sentinel element. + int last_p = p - 1; + huff_size[last_p] = 0; + + int code = 0; + uint32_t si = huff_size[0]; + p = 0; + while (huff_size[p]) { + while ((huff_size[p]) == si) { + huff_code[p++] = code; + code++; + } + code <<= 1; + si++; + } + for (p = 0; p < last_p; p++) { + int i = huff.values[p]; + table->depth[i] = huff_size[p]; + table->code[i] = huff_code[p]; + } + return true; +} + +bool EncodeSOI(SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD8})); + return true; +} + +bool EncodeEOI(const JPEGData& jpg, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD9})); + state->output_queue.emplace_back(jpg.tail_data); + return true; +} + +bool EncodeSOF(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + if (marker <= 0xC2) state->is_progressive = (marker == 0xC2); + + const size_t n_comps = jpg.components.size(); + const size_t marker_len = 8 + 3 * n_comps; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = marker; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = kJpegPrecision; + data[pos++] = jpg.height >> 8u; + data[pos++] = jpg.height & 0xFFu; + data[pos++] = jpg.width >> 8u; + data[pos++] = jpg.width & 0xFFu; + data[pos++] = n_comps; + for (size_t i = 0; i < n_comps; ++i) { + data[pos++] = jpg.components[i].id; + data[pos++] = ((jpg.components[i].h_samp_factor << 4u) | + (jpg.components[i].v_samp_factor)); + const size_t quant_idx = jpg.components[i].quant_idx; + if (quant_idx >= jpg.quant.size()) return false; + data[pos++] = jpg.quant[quant_idx].index; + } + return true; +} + +bool EncodeSOS(const JPEGData& jpg, const JPEGScanInfo& scan_info, + SerializationState* state) { + const size_t n_scans = scan_info.num_components; + const size_t marker_len = 6 + 2 * n_scans; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDA; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = n_scans; + for (size_t i = 0; i < n_scans; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + if (si.comp_idx >= jpg.components.size()) return false; + data[pos++] = jpg.components[si.comp_idx].id; + data[pos++] = (si.dc_tbl_idx << 4u) + si.ac_tbl_idx; + } + data[pos++] = scan_info.Ss; + data[pos++] = scan_info.Se; + data[pos++] = ((scan_info.Ah << 4u) | (scan_info.Al)); + return true; +} + +bool EncodeDHT(const JPEGData& jpg, SerializationState* state) { + const std::vector& huffman_code = jpg.huffman_code; + + size_t marker_len = 2; + for (size_t i = state->dht_index; i < huffman_code.size(); ++i) { + const JPEGHuffmanCode& huff = huffman_code[i]; + marker_len += kJpegHuffmanMaxBitLength; + for (size_t j = 0; j < huff.counts.size(); ++j) { + marker_len += huff.counts[j]; + } + if (huff.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xC4; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t huffman_code_index = state->dht_index++; + if (huffman_code_index >= huffman_code.size()) { + return false; + } + const JPEGHuffmanCode& huff = huffman_code[huffman_code_index]; + size_t index = huff.slot_id; + HuffmanCodeTable* huff_table; + if (index & 0x10) { + index -= 0x10; + huff_table = &state->ac_huff_table[index]; + } else { + huff_table = &state->dc_huff_table[index]; + } + // TODO(eustas): cache + // TODO(eustas): set up non-existing symbols + if (!BuildHuffmanCodeTable(huff, huff_table)) { + return false; + } + size_t total_count = 0; + size_t max_length = 0; + for (size_t i = 0; i < huff.counts.size(); ++i) { + if (huff.counts[i] != 0) { + max_length = i; + } + total_count += huff.counts[i]; + } + --total_count; + data[pos++] = huff.slot_id; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + data[pos++] = (i == max_length ? huff.counts[i] - 1 : huff.counts[i]); + } + for (size_t i = 0; i < total_count; ++i) { + data[pos++] = huff.values[i]; + } + if (huff.is_last) break; + } + return true; +} + +bool EncodeDQT(const JPEGData& jpg, SerializationState* state) { + int marker_len = 2; + for (size_t i = state->dqt_index; i < jpg.quant.size(); ++i) { + const JPEGQuantTable& table = jpg.quant[i]; + marker_len += 1 + (table.precision ? 2 : 1) * kDCTBlockSize; + if (table.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDB; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t idx = state->dqt_index++; + if (idx >= jpg.quant.size()) { + return false; // corrupt input + } + const JPEGQuantTable& table = jpg.quant[idx]; + data[pos++] = (table.precision << 4u) + table.index; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int val_idx = kJPEGNaturalOrder[i]; + int val = table.values[val_idx]; + if (table.precision) { + data[pos++] = val >> 8u; + } + data[pos++] = val & 0xFFu; + } + if (table.is_last) break; + } + return true; +} + +bool EncodeDRI(const JPEGData& jpg, SerializationState* state) { + state->seen_dri_marker = true; + OutputChunk dri_marker = {0xFF, + 0xDD, + 0, + 4, + static_cast(jpg.restart_interval >> 8), + static_cast(jpg.restart_interval & 0xFF)}; + state->output_queue.push_back(std::move(dri_marker)); + return true; +} + +bool EncodeRestart(uint8_t marker, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, marker})); + return true; +} + +bool EncodeAPP(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + // TODO(eustas): check that marker corresponds to payload? + (void)marker; + + size_t app_index = state->app_index++; + if (app_index >= jpg.app_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.app_data[app_index]); + return true; +} + +bool EncodeCOM(const JPEGData& jpg, SerializationState* state) { + size_t com_index = state->com_index++; + if (com_index >= jpg.com_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.com_data[com_index]); + return true; +} + +bool EncodeInterMarkerData(const JPEGData& jpg, SerializationState* state) { + size_t index = state->data_index++; + if (index >= jpg.inter_marker_data.size()) return false; + state->output_queue.emplace_back(jpg.inter_marker_data[index]); + return true; +} + +bool EncodeDCTBlockSequential(const coeff_t* coeffs, + const HuffmanCodeTable& dc_huff, + const HuffmanCodeTable& ac_huff, + int num_zero_runs, coeff_t* last_dc_coeff, + JpegBitWriter* bw) { + coeff_t temp2; + coeff_t temp; + temp2 = coeffs[0]; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp; + if (temp < 0) { + temp = -temp; + temp2--; + } + int dc_nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); + WriteBits(bw, dc_huff.depth[dc_nbits], dc_huff.code[dc_nbits]); + if (dc_nbits >= 12) return false; + if (dc_nbits > 0) { + WriteBits(bw, dc_nbits, temp2 & ((1u << dc_nbits) - 1)); + } + int r = 0; + for (int k = 1; k < 64; ++k) { + if ((temp = coeffs[kJPEGNaturalOrder[k]]) == 0) { + r++; + continue; + } + if (temp < 0) { + temp = -temp; + temp2 = ~temp; + } else { + temp2 = temp; + } + while (r > 15) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + int ac_nbits = FloorLog2Nonzero(temp) + 1; + if (ac_nbits >= 16) return false; + int symbol = (r << 4u) + ac_nbits; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, ac_nbits, temp2 & ((1 << ac_nbits) - 1)); + r = 0; + } + for (int i = 0; i < num_zero_runs; ++i) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + if (r > 0) { + WriteBits(bw, ac_huff.depth[0], ac_huff.code[0]); + } + return true; +} + +bool EncodeDCTBlockProgressive(const coeff_t* coeffs, + const HuffmanCodeTable& dc_huff, + const HuffmanCodeTable& ac_huff, int Ss, int Se, + int Al, int num_zero_runs, + DCTCodingState* coding_state, + coeff_t* last_dc_coeff, JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + coeff_t temp2; + coeff_t temp; + if (Ss == 0) { + temp2 = coeffs[0] >> Al; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp; + if (temp < 0) { + temp = -temp; + temp2--; + } + int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero(temp) + 1); + WriteBits(bw, dc_huff.depth[nbits], dc_huff.code[nbits]); + if (nbits > 0) { + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + } + ++Ss; + } + if (Ss > Se) { + return true; + } + int r = 0; + for (int k = Ss; k <= Se; ++k) { + if ((temp = coeffs[kJPEGNaturalOrder[k]]) == 0) { + r++; + continue; + } + if (temp < 0) { + temp = -temp; + temp >>= Al; + temp2 = ~temp; + } else { + temp >>= Al; + temp2 = temp; + } + if (temp == 0) { + r++; + continue; + } + Flush(coding_state, bw); + while (r > 15) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + int nbits = FloorLog2Nonzero(temp) + 1; + int symbol = (r << 4u) + nbits; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + r = 0; + } + if (num_zero_runs > 0) { + Flush(coding_state, bw); + for (int i = 0; i < num_zero_runs; ++i) { + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + } + } + if (r > 0) { + BufferEndOfBand(coding_state, &ac_huff, nullptr, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +bool EncodeRefinementBits(const coeff_t* coeffs, + const HuffmanCodeTable& ac_huff, int Ss, int Se, + int Al, DCTCodingState* coding_state, + JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + if (Ss == 0) { + // Emit next bit of DC component. + WriteBits(bw, 1, (coeffs[0] >> Al) & 1); + ++Ss; + } + if (Ss > Se) { + return true; + } + int abs_values[kDCTBlockSize]; + int eob = 0; + for (int k = Ss; k <= Se; k++) { + const coeff_t abs_val = std::abs(coeffs[kJPEGNaturalOrder[k]]); + abs_values[k] = abs_val >> Al; + if (abs_values[k] == 1) { + eob = k; + } + } + int r = 0; + std::vector refinement_bits; + refinement_bits.reserve(kDCTBlockSize); + for (int k = Ss; k <= Se; k++) { + if (abs_values[k] == 0) { + r++; + continue; + } + while (r > 15 && k <= eob) { + Flush(coding_state, bw); + WriteBits(bw, ac_huff.depth[0xf0], ac_huff.code[0xf0]); + r -= 16; + for (int bit : refinement_bits) { + WriteBits(bw, 1, bit); + } + refinement_bits.clear(); + } + if (abs_values[k] > 1) { + refinement_bits.push_back(abs_values[k] & 1u); + continue; + } + Flush(coding_state, bw); + int symbol = (r << 4u) + 1; + int new_non_zero_bit = (coeffs[kJPEGNaturalOrder[k]] < 0) ? 0 : 1; + WriteBits(bw, ac_huff.depth[symbol], ac_huff.code[symbol]); + WriteBits(bw, 1, new_non_zero_bit); + for (int bit : refinement_bits) { + WriteBits(bw, 1, bit); + } + refinement_bits.clear(); + r = 0; + } + if (r > 0 || !refinement_bits.empty()) { + BufferEndOfBand(coding_state, &ac_huff, &refinement_bits, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +template +SerializationStatus JXL_NOINLINE DoEncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + EncodeScanState& ss = state->scan_state; + + const int restart_interval = + state->seen_dri_marker ? jpg.restart_interval : 0; + + const auto get_next_extra_zero_run_index = [&ss, &scan_info]() -> int { + if (ss.extra_zero_runs_pos < scan_info.extra_zero_runs.size()) { + return scan_info.extra_zero_runs[ss.extra_zero_runs_pos].block_idx; + } else { + return -1; + } + }; + + const auto get_next_reset_point = [&ss, &scan_info]() -> int { + if (ss.next_reset_point_pos < scan_info.reset_points.size()) { + return scan_info.reset_points[ss.next_reset_point_pos++]; + } else { + return -1; + } + }; + + if (ss.stage == EncodeScanState::HEAD) { + if (!EncodeSOS(jpg, scan_info, state)) return SerializationStatus::ERROR; + JpegBitWriterInit(&ss.bw, &state->output_queue); + DCTCodingStateInit(&ss.coding_state); + ss.restarts_to_go = restart_interval; + ss.next_restart_marker = 0; + ss.block_scan_index = 0; + ss.extra_zero_runs_pos = 0; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + ss.next_reset_point_pos = 0; + ss.next_reset_point = get_next_reset_point(); + ss.mcu_y = 0; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + ss.stage = EncodeScanState::BODY; + } + JpegBitWriter* bw = &ss.bw; + DCTCodingState* coding_state = &ss.coding_state; + + JXL_DASSERT(ss.stage == EncodeScanState::BODY); + + // "Non-interleaved" means color data comes in separate scans, in other words + // each scan can contain only one color component. + const bool is_interleaved = (scan_info.num_components > 1); + int MCUs_per_row = 0; + int MCU_rows = 0; + jpg.CalculateMcuSize(scan_info, &MCUs_per_row, &MCU_rows); + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + + // DC-only is defined by [0..0] spectral range. + const bool want_ac = ((Ss != 0) || (Se != 0)); + // TODO: support streaming decoding again. + const bool complete_ac = true; + const bool has_ac = true; + if (want_ac && !has_ac) return SerializationStatus::NEEDS_MORE_INPUT; + + // |has_ac| implies |complete_dc| but not vice versa; for the sake of + // simplicity we pretend they are equal, because they are separated by just a + // few bytes of input. + const bool complete_dc = has_ac; + const bool complete = want_ac ? complete_ac : complete_dc; + // When "incomplete" |ac_dc| tracks information about current ("incomplete") + // band parsing progress. + + // FIXME: Is this always complete? + // const int last_mcu_y = + // complete ? MCU_rows : parsing_state.internal->ac_dc.next_mcu_y * + // v_group; + (void)complete; + const int last_mcu_y = complete ? MCU_rows : 0; + + for (; ss.mcu_y < last_mcu_y; ++ss.mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Possibly emit a restart marker. + if (restart_interval > 0 && ss.restarts_to_go == 0) { + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + EmitMarker(bw, 0xD0 + ss.next_restart_marker); + ss.next_restart_marker += 1; + ss.next_restart_marker &= 0x7; + ss.restarts_to_go = restart_interval; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + } + // Encode one MCU + for (size_t i = 0; i < scan_info.num_components; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + const JPEGComponent& c = jpg.components[si.comp_idx]; + const HuffmanCodeTable& dc_huff = state->dc_huff_table[si.dc_tbl_idx]; + const HuffmanCodeTable& ac_huff = state->ac_huff_table[si.ac_tbl_idx]; + int n_blocks_y = is_interleaved ? c.v_samp_factor : 1; + int n_blocks_x = is_interleaved ? c.h_samp_factor : 1; + for (int iy = 0; iy < n_blocks_y; ++iy) { + for (int ix = 0; ix < n_blocks_x; ++ix) { + int block_y = ss.mcu_y * n_blocks_y + iy; + int block_x = mcu_x * n_blocks_x + ix; + int block_idx = block_y * c.width_in_blocks + block_x; + if (ss.block_scan_index == ss.next_reset_point) { + Flush(coding_state, bw); + ss.next_reset_point = get_next_reset_point(); + } + int num_zero_runs = 0; + if (ss.block_scan_index == ss.next_extra_zero_run_index) { + num_zero_runs = scan_info.extra_zero_runs[ss.extra_zero_runs_pos] + .num_extra_zero_runs; + ++ss.extra_zero_runs_pos; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + } + const coeff_t* coeffs = &c.coeffs[block_idx << 6]; + bool ok; + if (kMode == 0) { + ok = EncodeDCTBlockSequential(coeffs, dc_huff, ac_huff, + num_zero_runs, + ss.last_dc_coeff + si.comp_idx, bw); + } else if (kMode == 1) { + ok = EncodeDCTBlockProgressive( + coeffs, dc_huff, ac_huff, Ss, Se, Al, num_zero_runs, + coding_state, ss.last_dc_coeff + si.comp_idx, bw); + } else { + ok = EncodeRefinementBits(coeffs, ac_huff, Ss, Se, Al, + coding_state, bw); + } + if (!ok) return SerializationStatus::ERROR; + ++ss.block_scan_index; + } + } + } + --ss.restarts_to_go; + } + } + if (ss.mcu_y < MCU_rows) { + if (!bw->healthy) return SerializationStatus::ERROR; + return SerializationStatus::NEEDS_MORE_INPUT; + } + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + JpegBitWriterFinish(bw); + ss.stage = EncodeScanState::HEAD; + state->scan_index++; + if (!bw->healthy) return SerializationStatus::ERROR; + + return SerializationStatus::DONE; +} + +static SerializationStatus JXL_INLINE EncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ah = is_progressive ? scan_info.Ah : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + const bool need_sequential = + !is_progressive || (Ah == 0 && Al == 0 && Ss == 0 && Se == 63); + if (need_sequential) { + return DoEncodeScan<0>(jpg, state); + } else if (Ah == 0) { + return DoEncodeScan<1>(jpg, state); + } else { + return DoEncodeScan<2>(jpg, state); + } +} + +SerializationStatus SerializeSection(uint8_t marker, SerializationState* state, + const JPEGData& jpg) { + const auto to_status = [](bool result) { + return result ? SerializationStatus::DONE : SerializationStatus::ERROR; + }; + // TODO(eustas): add and use marker enum + switch (marker) { + case 0xC0: + case 0xC1: + case 0xC2: + case 0xC9: + case 0xCA: + return to_status(EncodeSOF(jpg, marker, state)); + + case 0xC4: + return to_status(EncodeDHT(jpg, state)); + + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + return to_status(EncodeRestart(marker, state)); + + case 0xD9: + return to_status(EncodeEOI(jpg, state)); + + case 0xDA: + return EncodeScan(jpg, state); + + case 0xDB: + return to_status(EncodeDQT(jpg, state)); + + case 0xDD: + return to_status(EncodeDRI(jpg, state)); + + case 0xE0: + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xED: + case 0xEE: + case 0xEF: + return to_status(EncodeAPP(jpg, marker, state)); + + case 0xFE: + return to_status(EncodeCOM(jpg, state)); + + case 0xFF: + return to_status(EncodeInterMarkerData(jpg, state)); + + default: + return SerializationStatus::ERROR; + } +} + +} // namespace + +// TODO(veluca): add streaming support again. +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out) { + SerializationState ss; + + size_t written = 0; + const auto maybe_push_output = [&]() -> Status { + if (ss.stage != SerializationState::ERROR) { + while (!ss.output_queue.empty()) { + auto& chunk = ss.output_queue.front(); + size_t num_written = out(chunk.next, chunk.len); + if (num_written == 0 && chunk.len > 0) { + return StatusMessage(Status(StatusCode::kNotEnoughBytes), + "Failed to write output"); + } + chunk.len -= num_written; + written += num_written; + if (chunk.len == 0) { + ss.output_queue.pop_front(); + } + } + } + return true; + }; + + while (true) { + switch (ss.stage) { + case SerializationState::INIT: { + // Valid Brunsli requires, at least, 0xD9 marker. + // This might happen on corrupted stream, or on unconditioned JPEGData. + // TODO(eustas): check D9 in the only one and is the last one. + if (jpg.marker_order.empty()) { + ss.stage = SerializationState::ERROR; + break; + } + + ss.dc_huff_table.resize(kMaxHuffmanTables); + ss.ac_huff_table.resize(kMaxHuffmanTables); + if (jpg.has_zero_padding_bit) { + ss.pad_bits = jpg.padding_bits.data(); + ss.pad_bits_end = ss.pad_bits + jpg.padding_bits.size(); + } + + EncodeSOI(&ss); + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + ss.stage = SerializationState::SERIALIZE_SECTION; + break; + } + + case SerializationState::SERIALIZE_SECTION: { + if (ss.section_index >= jpg.marker_order.size()) { + ss.stage = SerializationState::DONE; + break; + } + uint8_t marker = jpg.marker_order[ss.section_index]; + SerializationStatus status = SerializeSection(marker, &ss, jpg); + if (status == SerializationStatus::ERROR) { + JXL_WARNING("Failed to encode marker 0x%.2x", marker); + ss.stage = SerializationState::ERROR; + break; + } + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + if (status == SerializationStatus::NEEDS_MORE_INPUT) { + return JXL_FAILURE("Incomplete serialization data"); + } else if (status != SerializationStatus::DONE) { + JXL_DASSERT(false); + ss.stage = SerializationState::ERROR; + break; + } + ++ss.section_index; + break; + } + + case SerializationState::DONE: + JXL_ASSERT(ss.output_queue.empty()); + return true; + + case SerializationState::ERROR: + return JXL_FAILURE("JPEG serialization error"); + } + } +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h new file mode 100644 index 000000000000..81634ffe9e68 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functions for writing a JPEGData object into a jpeg byte stream. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ + +#include +#include + +#include + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Function type used to write len bytes into buf. Returns the number of bytes +// written. +using JPEGOutput = std::function; + +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h new file mode 100644 index 000000000000..c165b46a6a11 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ +#define LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ + +#include +#include + +#include +#include +#include + +namespace jxl { +namespace jpeg { + +/** + * A chunk of output data. + * + * Data producer creates OutputChunks and adds them to the end output queue. + * Once control flow leaves the producer code, it is considered that chunk of + * data is final and can not be changed; to underline this fact |next| is a + * const-pointer. + * + * Data consumer removes OutputChunks from the beginning of the output queue. + * It is possible to consume OutputChunks partially, by updating |next| and + * |len|. + * + * There are 2 types of output chunks: + * - owning: actual data is stored in |buffer| field; producer fills data after + * the instance it created; it is legal to reduce |len| to show that not all + * the capacity of |buffer| is used + * - non-owning: represents the data stored (owned) somewhere else + */ +struct OutputChunk { + // Non-owning + template + explicit OutputChunk(Bytes& bytes) : len(bytes.size()) { + // Deal both with const qualifier and data type. + const void* src = bytes.data(); + next = reinterpret_cast(src); + } + + // Non-owning + OutputChunk(const uint8_t* data, size_t size) : next(data), len(size) {} + + // Owning + explicit OutputChunk(size_t size = 0) { + buffer.reset(new std::vector(size)); + next = buffer->data(); + len = size; + } + + // Owning + OutputChunk(std::initializer_list bytes) { + buffer.reset(new std::vector(bytes)); + next = buffer->data(); + len = bytes.size(); + } + + const uint8_t* next; + size_t len; + // TODO(veluca): consider removing the unique_ptr. + std::unique_ptr> buffer; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h new file mode 100644 index 000000000000..891dc8500101 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h @@ -0,0 +1,104 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ +#define LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ + +#include +#include + +#include "lib/jxl/jpeg/dec_jpeg_output_chunk.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +struct HuffmanCodeTable { + int depth[256]; + int code[256]; +}; + +// Handles the packing of bits into output bytes. +struct JpegBitWriter { + bool healthy; + std::deque* output; + OutputChunk chunk; + uint8_t* data; + size_t pos; + uint64_t put_buffer; + int put_bits; +}; + +// Holds data that is buffered between 8x8 blocks in progressive mode. +struct DCTCodingState { + // The run length of end-of-band symbols in a progressive scan. + int eob_run_; + // The huffman table to be used when flushing the state. + const HuffmanCodeTable* cur_ac_huff_; + // The sequence of currently buffered refinement bits for a successive + // approximation scan (one where Ah > 0). + std::vector refinement_bits_; +}; + +struct EncodeScanState { + enum Stage { HEAD, BODY }; + + Stage stage = HEAD; + + int mcu_y; + JpegBitWriter bw; + coeff_t last_dc_coeff[kMaxComponents] = {0}; + int restarts_to_go; + int next_restart_marker; + int block_scan_index; + DCTCodingState coding_state; + size_t extra_zero_runs_pos; + int next_extra_zero_run_index; + size_t next_reset_point_pos; + int next_reset_point; +}; + +struct SerializationState { + enum Stage { + INIT, + SERIALIZE_SECTION, + DONE, + ERROR, + }; + + Stage stage = INIT; + + std::deque output_queue; + + size_t section_index = 0; + int dht_index = 0; + int dqt_index = 0; + int app_index = 0; + int com_index = 0; + int data_index = 0; + int scan_index = 0; + std::vector dc_huff_table; + std::vector ac_huff_table; + const uint8_t* pad_bits = nullptr; + const uint8_t* pad_bits_end = nullptr; + bool seen_dri_marker = false; + bool is_progressive = false; + + EncodeScanState scan_state; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc new file mode 100644 index 000000000000..dba7db9ab8f2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc @@ -0,0 +1,378 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/enc_jpeg_data.h" + +#include +#include + +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" + +namespace jxl { +namespace jpeg { + +namespace { + +constexpr int BITS_IN_JSAMPLE = 8; +using ByteSpan = Span; + +// TODO(eustas): move to jpeg_data, to use from codec_jpg as well. +// See if there is a canonically chunked ICC profile and mark corresponding +// app-tags with AppMarkerType::kICC. +Status DetectIccProfile(JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + size_t num_icc = 0; + size_t num_icc_jpeg = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + const auto& app = jpeg_data.app_data[i]; + size_t pos = 0; + if (app[pos++] != 0xE2) continue; + // At least APPn + size; otherwise it should be intermarker-data. + JXL_DASSERT(app.size() >= 3); + size_t tag_length = (app[pos] << 8) + app[pos + 1]; + pos += 2; + JXL_DASSERT(app.size() == tag_length + 1); + // Empty payload is 2 bytes for tag length itself + signature + if (tag_length < 2 + sizeof kIccProfileTag) continue; + + if (memcmp(&app[pos], kIccProfileTag, sizeof kIccProfileTag) != 0) continue; + pos += sizeof kIccProfileTag; + uint8_t chunk_id = app[pos++]; + uint8_t num_chunks = app[pos++]; + if (chunk_id != num_icc + 1) continue; + if (num_icc_jpeg == 0) num_icc_jpeg = num_chunks; + if (num_icc_jpeg != num_chunks) continue; + num_icc++; + jpeg_data.app_marker_type[i] = AppMarkerType::kICC; + } + if (num_icc != num_icc_jpeg) { + return JXL_FAILURE("Invalid ICC chunks"); + } + return true; +} + +bool GetMarkerPayload(const uint8_t* data, size_t size, ByteSpan* payload) { + if (size < 3) { + return false; + } + size_t hi = data[1]; + size_t lo = data[2]; + size_t internal_size = (hi << 8u) | lo; + // Second byte of marker is not counted towards size. + if (internal_size != size - 1) { + return false; + } + // cut second marker byte and "length" from payload. + *payload = ByteSpan(data, size); + payload->remove_prefix(3); + return true; +} + +Status DetectBlobs(jpeg::JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + bool have_exif = false, have_xmp = false; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (!have_exif && payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kExif; + have_exif = true; + } + if (!have_xmp && payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kXMP; + have_xmp = true; + } + } + return true; +} + +Status ParseChunkedMarker(const jpeg::JPEGData& src, uint8_t marker_type, + const ByteSpan& tag, PaddedBytes* output, + bool allow_permutations = false) { + output->clear(); + + std::vector chunks; + std::vector presence; + size_t expected_number_of_parts = 0; + bool is_first_chunk = true; + size_t ordinal = 0; + for (const auto& marker : src.app_data) { + if (marker.empty() || marker[0] != marker_type) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if ((payload.size() < tag.size()) || + memcmp(payload.data(), tag.data(), tag.size()) != 0) { + continue; + } + payload.remove_prefix(tag.size()); + if (payload.size() < 2) { + return JXL_FAILURE("Chunk is too small."); + } + uint8_t index = payload[0]; + uint8_t total = payload[1]; + ordinal++; + if (!allow_permutations) { + if (index != ordinal) return JXL_FAILURE("Invalid chunk order."); + } + + payload.remove_prefix(2); + + JXL_RETURN_IF_ERROR(total != 0); + if (is_first_chunk) { + is_first_chunk = false; + expected_number_of_parts = total; + // 1-based indices; 0-th element is added for convenience. + chunks.resize(total + 1); + presence.resize(total + 1); + } else { + JXL_RETURN_IF_ERROR(expected_number_of_parts == total); + } + + if (index == 0 || index > total) { + return JXL_FAILURE("Invalid chunk index."); + } + + if (presence[index]) { + return JXL_FAILURE("Duplicate chunk."); + } + presence[index] = true; + chunks[index] = payload; + } + + for (size_t i = 0; i < expected_number_of_parts; ++i) { + // 0-th element is not used. + size_t index = i + 1; + if (!presence[index]) { + return JXL_FAILURE("Missing chunk."); + } + output->append(chunks[index]); + } + + return true; +} + +Status SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, + ColorEncoding* color_encoding) { + PaddedBytes icc_profile; + if (!ParseChunkedMarker(jpg, kApp2, ByteSpan(kIccProfileTag), &icc_profile)) { + JXL_WARNING("ReJPEG: corrupted ICC profile\n"); + icc_profile.clear(); + } + + if (icc_profile.empty()) { + bool is_gray = (jpg.components.size() == 1); + *color_encoding = ColorEncoding::SRGB(is_gray); + return true; + } + + return color_encoding->SetICC(std::move(icc_profile)); +} +Status SetBlobsFromJpegData(const jpeg::JPEGData& jpeg_data, Blobs* blobs) { + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + if (blobs->exif.empty()) { + blobs->exif.resize(payload.size() - sizeof kExifTag); + memcpy(blobs->exif.data(), payload.data() + sizeof kExifTag, + payload.size() - sizeof kExifTag); + } else { + JXL_WARNING( + "ReJPEG: multiple Exif blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + if (payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + if (blobs->xmp.empty()) { + blobs->xmp.resize(payload.size() - sizeof kXMPTag); + memcpy(blobs->xmp.data(), payload.data() + sizeof kXMPTag, + payload.size() - sizeof kXMPTag); + } else { + JXL_WARNING( + "ReJPEG: multiple XMP blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + } + return true; +} + +} // namespace + +Status EncodeJPEGData(JPEGData& jpeg_data, PaddedBytes* bytes) { + jpeg_data.app_marker_type.resize(jpeg_data.app_data.size(), + AppMarkerType::kUnknown); + JXL_RETURN_IF_ERROR(DetectIccProfile(jpeg_data)); + JXL_RETURN_IF_ERROR(DetectBlobs(jpeg_data)); + BitWriter writer; + JXL_RETURN_IF_ERROR(Bundle::Write(jpeg_data, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + *bytes = std::move(writer).TakeBytes(); + BrotliEncoderState* brotli_enc = + BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_QUALITY, 11); + size_t total_data = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + total_data += jpeg_data.app_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + total_data += jpeg_data.com_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + total_data += jpeg_data.inter_marker_data[i].size(); + } + total_data += jpeg_data.tail_data.size(); + size_t initial_size = bytes->size(); + size_t brotli_capacity = BrotliEncoderMaxCompressedSize(total_data); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_SIZE_HINT, total_data); + bytes->resize(bytes->size() + brotli_capacity); + size_t enc_size = 0; + auto br_append = [&](const std::vector& data, bool last) { + size_t available_in = data.size(); + const uint8_t* in = data.data(); + uint8_t* out = &(*bytes)[initial_size + enc_size]; + do { + JXL_CHECK(BrotliEncoderCompressStream( + brotli_enc, last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS, + &available_in, &in, &brotli_capacity, &out, &enc_size)); + } while (BrotliEncoderHasMoreOutput(brotli_enc) || available_in > 0); + }; + + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + br_append(jpeg_data.app_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + br_append(jpeg_data.com_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + br_append(jpeg_data.inter_marker_data[i], /*last=*/false); + } + br_append(jpeg_data.tail_data, /*last=*/true); + BrotliEncoderDestroyInstance(brotli_enc); + bytes->resize(initial_size + enc_size); + return true; +} + +Status DecodeImageJPG(const Span bytes, CodecInOut* io) { + io->frames.clear(); + io->frames.reserve(1); + io->frames.emplace_back(&io->metadata.m); + io->Main().jpeg_data = make_unique(); + jpeg::JPEGData* jpeg_data = io->Main().jpeg_data.get(); + if (!jpeg::ReadJpeg(bytes.data(), bytes.size(), jpeg::JpegReadMode::kReadAll, + jpeg_data)) { + return JXL_FAILURE("Error reading JPEG"); + } + JXL_RETURN_IF_ERROR( + SetColorEncodingFromJpegData(*jpeg_data, &io->metadata.m.color_encoding)); + JXL_RETURN_IF_ERROR(SetBlobsFromJpegData(*jpeg_data, &io->blobs)); + size_t nbcomp = jpeg_data->components.size(); + if (nbcomp != 1 && nbcomp != 3) { + return JXL_FAILURE("Cannot recompress JPEGs with neither 1 nor 3 channels"); + } + YCbCrChromaSubsampling cs; + if (nbcomp == 3) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < nbcomp; i++) { + hsample[i] = jpeg_data->components[i].h_samp_factor; + vsample[i] = jpeg_data->components[i].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs.Set(hsample, vsample)); + } else if (nbcomp == 1) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < 3; i++) { + hsample[i] = jpeg_data->components[0].h_samp_factor; + vsample[i] = jpeg_data->components[0].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs.Set(hsample, vsample)); + } + bool is_rgb = false; + { + const auto& markers = jpeg_data->marker_order; + // If there is a JFIF marker, this is YCbCr. Otherwise... + if (std::find(markers.begin(), markers.end(), 0xE0) == markers.end()) { + // Try to find an 'Adobe' marker. + size_t app_markers = 0; + size_t i = 0; + for (; i < markers.size(); i++) { + // This is an APP marker. + if ((markers[i] & 0xF0) == 0xE0) { + JXL_CHECK(app_markers < jpeg_data->app_data.size()); + // APP14 marker + if (markers[i] == 0xEE) { + const auto& data = jpeg_data->app_data[app_markers]; + if (data.size() == 15 && data[3] == 'A' && data[4] == 'd' && + data[5] == 'o' && data[6] == 'b' && data[7] == 'e') { + // 'Adobe' marker. + is_rgb = data[14] == 0; + break; + } + } + app_markers++; + } + } + + if (i == markers.size()) { + // No 'Adobe' marker, guess from component IDs. + is_rgb = nbcomp == 3 && jpeg_data->components[0].id == 'R' && + jpeg_data->components[1].id == 'G' && + jpeg_data->components[2].id == 'B'; + } + } + } + + io->Main().chroma_subsampling = cs; + io->Main().color_transform = + !is_rgb ? ColorTransform::kYCbCr : ColorTransform::kNone; + + io->metadata.m.SetIntensityTarget( + io->target_nits != 0 ? io->target_nits : kDefaultIntensityTarget); + io->metadata.m.SetUintSamples(BITS_IN_JSAMPLE); + io->SetFromImage(Image3F(jpeg_data->width, jpeg_data->height), + io->metadata.m.color_encoding); + SetIntensityTarget(io); + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h new file mode 100644 index 000000000000..2f1c78c8a52a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_H_ + +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { +Status EncodeJPEGData(JPEGData& jpeg_data, PaddedBytes* bytes); + +/** + * Decodes bytes containing JPEG codestream into a CodecInOut as coefficients + * only, for lossless JPEG transcoding. + */ +Status DecodeImageJPG(const Span bytes, CodecInOut* io); +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc new file mode 100644 index 000000000000..a88a729075af --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc @@ -0,0 +1,1151 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" + +#include +#include + +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +// By default only print debug messages when JXL_DEBUG_ON_ERROR is enabled. +#ifndef JXL_DEBUG_JPEG_DATA_READER +#define JXL_DEBUG_JPEG_DATA_READER JXL_DEBUG_ON_ERROR +#endif // JXL_DEBUG_JPEG_DATA_READER + +#define JXL_JPEG_DEBUG(format, ...) \ + JXL_DEBUG(JXL_DEBUG_JPEG_DATA_READER, format, ##__VA_ARGS__) + +namespace jxl { +namespace jpeg { + +namespace { +static const int kBrunsliMaxSampling = 15; +static const size_t kBrunsliMaxNumBlocks = 1ull << 24; + +// Macros for commonly used error conditions. + +#define JXL_JPEG_VERIFY_LEN(n) \ + if (*pos + (n) > len) { \ + JXL_JPEG_DEBUG("Unexpected end of input: pos=%zu need=%d len=%zu", *pos, \ + static_cast(n), len); \ + jpg->error = JPEGReadError::UNEXPECTED_EOF; \ + return false; \ + } + +#define JXL_JPEG_VERIFY_INPUT(var, low, high, code) \ + if ((var) < (low) || (var) > (high)) { \ + JXL_JPEG_DEBUG("Invalid " #var ": %d", static_cast(var)); \ + jpg->error = JPEGReadError::INVALID_##code; \ + return false; \ + } + +#define JXL_JPEG_VERIFY_MARKER_END() \ + if (start_pos + marker_len != *pos) { \ + JXL_JPEG_DEBUG("Invalid marker length: declared=%zu actual=%zu", \ + marker_len, (*pos - start_pos)); \ + jpg->error = JPEGReadError::WRONG_MARKER_SIZE; \ + return false; \ + } + +#define JXL_JPEG_EXPECT_MARKER() \ + if (pos + 2 > len || data[pos] != 0xff) { \ + JXL_JPEG_DEBUG( \ + "Marker byte (0xff) expected, found: 0x%.2x pos=%zu len=%zu", \ + (pos < len ? data[pos] : 0), pos, len); \ + jpg->error = JPEGReadError::MARKER_BYTE_NOT_FOUND; \ + return false; \ + } + +inline int ReadUint8(const uint8_t* data, size_t* pos) { + return data[(*pos)++]; +} + +inline int ReadUint16(const uint8_t* data, size_t* pos) { + int v = (data[*pos] << 8) + data[*pos + 1]; + *pos += 2; + return v; +} + +// Reads the Start of Frame (SOF) marker segment and fills in *jpg with the +// parsed data. +bool ProcessSOF(const uint8_t* data, const size_t len, JpegReadMode mode, + size_t* pos, JPEGData* jpg) { + if (jpg->width != 0) { + JXL_JPEG_DEBUG("Duplicate SOF marker."); + jpg->error = JPEGReadError::DUPLICATE_SOF; + return false; + } + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(8); + size_t marker_len = ReadUint16(data, pos); + int precision = ReadUint8(data, pos); + int height = ReadUint16(data, pos); + int width = ReadUint16(data, pos); + int num_components = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(precision, 8, 8, PRECISION); + JXL_JPEG_VERIFY_INPUT(height, 1, kMaxDimPixels, HEIGHT); + JXL_JPEG_VERIFY_INPUT(width, 1, kMaxDimPixels, WIDTH); + JXL_JPEG_VERIFY_INPUT(num_components, 1, kMaxComponents, NUMCOMP); + JXL_JPEG_VERIFY_LEN(3 * num_components); + jpg->height = height; + jpg->width = width; + jpg->components.resize(num_components); + + // Read sampling factors and quant table index for each component. + std::vector ids_seen(256, false); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + const int id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.2, syntax of Ci) + JXL_JPEG_DEBUG("Duplicate ID %d in SOF.", id); + jpg->error = JPEGReadError::DUPLICATE_COMPONENT_ID; + return false; + } + ids_seen[id] = true; + jpg->components[i].id = id; + int factor = ReadUint8(data, pos); + int h_samp_factor = factor >> 4; + int v_samp_factor = factor & 0xf; + JXL_JPEG_VERIFY_INPUT(h_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + JXL_JPEG_VERIFY_INPUT(v_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + jpg->components[i].h_samp_factor = h_samp_factor; + jpg->components[i].v_samp_factor = v_samp_factor; + jpg->components[i].quant_idx = ReadUint8(data, pos); + max_h_samp_factor = std::max(max_h_samp_factor, h_samp_factor); + max_v_samp_factor = std::max(max_v_samp_factor, v_samp_factor); + } + + // We have checked above that none of the sampling factors are 0, so the max + // sampling factors can not be 0. + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCU_cols = DivCeil(jpg->width, max_h_samp_factor * 8); + // Compute the block dimensions for each component. + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + if (max_h_samp_factor % c->h_samp_factor != 0 || + max_v_samp_factor % c->v_samp_factor != 0) { + JXL_JPEG_DEBUG("Non-integral subsampling ratios."); + jpg->error = JPEGReadError::INVALID_SAMPLING_FACTORS; + return false; + } + c->width_in_blocks = MCU_cols * c->h_samp_factor; + c->height_in_blocks = MCU_rows * c->v_samp_factor; + const uint64_t num_blocks = + static_cast(c->width_in_blocks) * c->height_in_blocks; + if (num_blocks > kBrunsliMaxNumBlocks) { + JXL_JPEG_DEBUG("Image too large."); + jpg->error = JPEGReadError::IMAGE_TOO_LARGE; + return false; + } + if (mode == JpegReadMode::kReadAll) { + c->coeffs.resize(num_blocks * kDCTBlockSize); + } + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Start of Scan (SOS) marker segment and fills in *scan_info with the +// parsed data. +bool ProcessSOS(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(3); + size_t marker_len = ReadUint16(data, pos); + size_t comps_in_scan = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(comps_in_scan, 1, jpg->components.size(), + COMPS_IN_SCAN); + + JPEGScanInfo scan_info; + scan_info.num_components = comps_in_scan; + JXL_JPEG_VERIFY_LEN(2 * comps_in_scan); + std::vector ids_seen(256, false); + for (size_t i = 0; i < comps_in_scan; ++i) { + uint32_t id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.3, regarding CSj) + JXL_JPEG_DEBUG("Duplicate ID %d in SOS.", id); + jpg->error = JPEGReadError::DUPLICATE_COMPONENT_ID; + return false; + } + ids_seen[id] = true; + bool found_index = false; + for (size_t j = 0; j < jpg->components.size(); ++j) { + if (jpg->components[j].id == id) { + scan_info.components[i].comp_idx = j; + found_index = true; + } + } + if (!found_index) { + JXL_JPEG_DEBUG("SOS marker: Could not find component with id %d", id); + jpg->error = JPEGReadError::COMPONENT_NOT_FOUND; + return false; + } + int c = ReadUint8(data, pos); + int dc_tbl_idx = c >> 4; + int ac_tbl_idx = c & 0xf; + JXL_JPEG_VERIFY_INPUT(dc_tbl_idx, 0, 3, HUFFMAN_INDEX); + JXL_JPEG_VERIFY_INPUT(ac_tbl_idx, 0, 3, HUFFMAN_INDEX); + scan_info.components[i].dc_tbl_idx = dc_tbl_idx; + scan_info.components[i].ac_tbl_idx = ac_tbl_idx; + } + JXL_JPEG_VERIFY_LEN(3); + scan_info.Ss = ReadUint8(data, pos); + scan_info.Se = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(static_cast(scan_info.Ss), 0, 63, START_OF_SCAN); + JXL_JPEG_VERIFY_INPUT(scan_info.Se, scan_info.Ss, 63, END_OF_SCAN); + int c = ReadUint8(data, pos); + scan_info.Ah = c >> 4; + scan_info.Al = c & 0xf; + if (scan_info.Ah != 0 && scan_info.Al != scan_info.Ah - 1) { + // section G.1.1.1.2 : Successive approximation control only improves + // by one bit at a time. But it's not always respected, so we just issue + // a warning. + JXL_WARNING("Invalid progressive parameters: Al=%d Ah=%d", scan_info.Al, + scan_info.Ah); + } + // Check that all the Huffman tables needed for this scan are defined. + for (size_t i = 0; i < comps_in_scan; ++i) { + bool found_dc_table = false; + bool found_ac_table = false; + for (size_t j = 0; j < jpg->huffman_code.size(); ++j) { + uint32_t slot_id = jpg->huffman_code[j].slot_id; + if (slot_id == scan_info.components[i].dc_tbl_idx) { + found_dc_table = true; + } else if (slot_id == scan_info.components[i].ac_tbl_idx + 16) { + found_ac_table = true; + } + } + if (scan_info.Ss == 0 && !found_dc_table) { + JXL_JPEG_DEBUG( + "SOS marker: Could not find DC Huffman table with index %d", + scan_info.components[i].dc_tbl_idx); + jpg->error = JPEGReadError::HUFFMAN_TABLE_NOT_FOUND; + return false; + } + if (scan_info.Se > 0 && !found_ac_table) { + JXL_JPEG_DEBUG( + "SOS marker: Could not find AC Huffman table with index %d", + scan_info.components[i].ac_tbl_idx); + jpg->error = JPEGReadError::HUFFMAN_TABLE_NOT_FOUND; + return false; + } + } + jpg->scan_info.push_back(scan_info); + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Huffman Table (DHT) marker segment and fills in *jpg with +// the parsed data. Builds the Huffman decoding table in either dc_huff_lut or +// ac_huff_lut, depending on the type and solt_id of Huffman code being read. +bool ProcessDHT(const uint8_t* data, const size_t len, JpegReadMode mode, + std::vector* dc_huff_lut, + std::vector* ac_huff_lut, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + JXL_JPEG_DEBUG("DHT marker: no Huffman table found"); + jpg->error = JPEGReadError::EMPTY_DHT; + return false; + } + while (*pos < start_pos + marker_len) { + JXL_JPEG_VERIFY_LEN(1 + kJpegHuffmanMaxBitLength); + JPEGHuffmanCode huff; + huff.slot_id = ReadUint8(data, pos); + int huffman_index = huff.slot_id; + int is_ac_table = (huff.slot_id & 0x10) != 0; + HuffmanTableEntry* huff_lut; + if (is_ac_table) { + huffman_index -= 0x10; + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*ac_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } else { + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*dc_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } + huff.counts[0] = 0; + int total_count = 0; + int space = 1 << kJpegHuffmanMaxBitLength; + int max_depth = 1; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + int count = ReadUint8(data, pos); + if (count != 0) { + max_depth = i; + } + huff.counts[i] = count; + total_count += count; + space -= count * (1 << (kJpegHuffmanMaxBitLength - i)); + } + if (is_ac_table) { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegHuffmanAlphabetSize, + HUFFMAN_CODE); + } else { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegDCAlphabetSize, HUFFMAN_CODE); + } + JXL_JPEG_VERIFY_LEN(total_count); + std::vector values_seen(256, false); + for (int i = 0; i < total_count; ++i) { + int value = ReadUint8(data, pos); + if (!is_ac_table) { + JXL_JPEG_VERIFY_INPUT(value, 0, kJpegDCAlphabetSize - 1, HUFFMAN_CODE); + } + if (values_seen[value]) { + JXL_JPEG_DEBUG("Duplicate Huffman code value %d", value); + jpg->error = JPEGReadError::INVALID_HUFFMAN_CODE; + return false; + } + values_seen[value] = true; + huff.values[i] = value; + } + // Add an invalid symbol that will have the all 1 code. + ++huff.counts[max_depth]; + huff.values[total_count] = kJpegHuffmanAlphabetSize; + space -= (1 << (kJpegHuffmanMaxBitLength - max_depth)); + if (space < 0) { + JXL_JPEG_DEBUG("Invalid Huffman code lengths."); + jpg->error = JPEGReadError::INVALID_HUFFMAN_CODE; + return false; + } else if (space > 0 && huff_lut[0].value != 0xffff) { + // Re-initialize the values to an invalid symbol so that we can recognize + // it when reading the bit stream using a Huffman code with space > 0. + for (int i = 0; i < kJpegHuffmanLutSize; ++i) { + huff_lut[i].bits = 0; + huff_lut[i].value = 0xffff; + } + } + huff.is_last = (*pos == start_pos + marker_len); + if (mode == JpegReadMode::kReadAll) { + BuildJpegHuffmanTable(&huff.counts[0], &huff.values[0], huff_lut); + } + jpg->huffman_code.push_back(huff); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Quantization Table (DQT) marker segment and fills in *jpg +// with the parsed data. +bool ProcessDQT(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + JXL_JPEG_DEBUG("DQT marker: no quantization table found"); + jpg->error = JPEGReadError::EMPTY_DQT; + return false; + } + while (*pos < start_pos + marker_len && jpg->quant.size() < kMaxQuantTables) { + JXL_JPEG_VERIFY_LEN(1); + int quant_table_index = ReadUint8(data, pos); + int quant_table_precision = quant_table_index >> 4; + JXL_JPEG_VERIFY_INPUT(quant_table_precision, 0, 1, QUANT_TBL_PRECISION); + quant_table_index &= 0xf; + JXL_JPEG_VERIFY_INPUT(quant_table_index, 0, 3, QUANT_TBL_INDEX); + JXL_JPEG_VERIFY_LEN((quant_table_precision + 1) * kDCTBlockSize); + JPEGQuantTable table; + table.index = quant_table_index; + table.precision = quant_table_precision; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int quant_val = + quant_table_precision ? ReadUint16(data, pos) : ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(quant_val, 1, 65535, QUANT_VAL); + table.values[kJPEGNaturalOrder[i]] = quant_val; + } + table.is_last = (*pos == start_pos + marker_len); + jpg->quant.push_back(table); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the DRI marker and saves the restart interval into *jpg. +bool ProcessDRI(const uint8_t* data, const size_t len, size_t* pos, + bool* found_dri, JPEGData* jpg) { + if (*found_dri) { + JXL_JPEG_DEBUG("Duplicate DRI marker."); + jpg->error = JPEGReadError::DUPLICATE_DRI; + return false; + } + *found_dri = true; + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(4); + size_t marker_len = ReadUint16(data, pos); + int restart_interval = ReadUint16(data, pos); + jpg->restart_interval = restart_interval; + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Saves the APP marker segment as a string to *jpg. +bool ProcessAPP(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + JXL_DASSERT(*pos >= 3); + // Save the marker type together with the app data. + const uint8_t* app_str_start = data + *pos - 3; + std::vector app_str(app_str_start, app_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->app_data.push_back(app_str); + return true; +} + +// Saves the COM marker segment as a string to *jpg. +bool ProcessCOM(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + const uint8_t* com_str_start = data + *pos - 3; + std::vector com_str(com_str_start, com_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->com_data.push_back(com_str); + return true; +} + +// Helper structure to read bits from the entropy coded data segment. +struct BitReaderState { + BitReaderState(const uint8_t* data, const size_t len, size_t pos) + : data_(data), len_(len) { + Reset(pos); + } + + void Reset(size_t pos) { + pos_ = pos; + val_ = 0; + bits_left_ = 0; + next_marker_pos_ = len_ - 2; + FillBitWindow(); + } + + // Returns the next byte and skips the 0xff/0x00 escape sequences. + uint8_t GetNextByte() { + if (pos_ >= next_marker_pos_) { + ++pos_; + return 0; + } + uint8_t c = data_[pos_++]; + if (c == 0xff) { + uint8_t escape = data_[pos_]; + if (escape == 0) { + ++pos_; + } else { + // 0xff was followed by a non-zero byte, which means that we found the + // start of the next marker segment. + next_marker_pos_ = pos_ - 1; + } + } + return c; + } + + void FillBitWindow() { + if (bits_left_ <= 16) { + while (bits_left_ <= 56) { + val_ <<= 8; + val_ |= (uint64_t)GetNextByte(); + bits_left_ += 8; + } + } + } + + int ReadBits(int nbits) { + FillBitWindow(); + uint64_t val = (val_ >> (bits_left_ - nbits)) & ((1ULL << nbits) - 1); + bits_left_ -= nbits; + return val; + } + + // Sets *pos to the next stream position where parsing should continue. + // Enqueue the padding bits seen (0 or 1). + // Returns false if there is inconsistent or invalid padding or the stream + // ended too early. + bool FinishStream(JPEGData* jpg, size_t* pos) { + int npadbits = bits_left_ & 7; + if (npadbits > 0) { + uint64_t padmask = (1ULL << npadbits) - 1; + uint64_t padbits = (val_ >> (bits_left_ - npadbits)) & padmask; + if (padbits != padmask) { + jpg->has_zero_padding_bit = true; + } + for (int i = npadbits - 1; i >= 0; --i) { + jpg->padding_bits.push_back((padbits >> i) & 1); + } + } + // Give back some bytes that we did not use. + int unused_bytes_left = bits_left_ >> 3; + while (unused_bytes_left-- > 0) { + --pos_; + // If we give back a 0 byte, we need to check if it was a 0xff/0x00 escape + // sequence, and if yes, we need to give back one more byte. + if (pos_ < next_marker_pos_ && data_[pos_] == 0 && + data_[pos_ - 1] == 0xff) { + --pos_; + } + } + if (pos_ > next_marker_pos_) { + // Data ran out before the scan was complete. + JXL_JPEG_DEBUG("Unexpected end of scan."); + return false; + } + *pos = pos_; + return true; + } + + const uint8_t* data_; + const size_t len_; + size_t pos_; + uint64_t val_; + int bits_left_; + size_t next_marker_pos_; +}; + +// Returns the next Huffman-coded symbol. +int ReadSymbol(const HuffmanTableEntry* table, BitReaderState* br) { + int nbits; + br->FillBitWindow(); + int val = (br->val_ >> (br->bits_left_ - 8)) & 0xff; + table += val; + nbits = table->bits - 8; + if (nbits > 0) { + br->bits_left_ -= 8; + table += table->value; + val = (br->val_ >> (br->bits_left_ - nbits)) & ((1 << nbits) - 1); + table += val; + } + br->bits_left_ -= table->bits; + return table->value; +} + +/** + * Returns the DC diff or AC value for extra bits value x and prefix code s. + * + * CCITT Rec. T.81 (1992 E) + * Table F.1 – Difference magnitude categories for DC coding + * SSSS | DIFF values + * ------+-------------------------- + * 0 | 0 + * 1 | –1, 1 + * 2 | –3, –2, 2, 3 + * 3 | –7..–4, 4..7 + * ......|.......................... + * 11 | –2047..–1024, 1024..2047 + * + * CCITT Rec. T.81 (1992 E) + * Table F.2 – Categories assigned to coefficient values + * [ Same as Table F.1, but does not include SSSS equal to 0 and 11] + * + * + * CCITT Rec. T.81 (1992 E) + * F.1.2.1.1 Structure of DC code table + * For each category,... additional bits... appended... to uniquely identify + * which difference... occurred... When DIFF is positive... SSSS... bits of DIFF + * are appended. When DIFF is negative... SSSS... bits of (DIFF – 1) are + * appended... Most significant bit... is 0 for negative differences and 1 for + * positive differences. + * + * In other words the upper half of extra bits range represents DIFF as is. + * The lower half represents the negative DIFFs with an offset. + */ +int HuffExtend(int x, int s) { + JXL_DASSERT(s >= 1); + int half = 1 << (s - 1); + if (x >= half) { + JXL_DASSERT(x < (1 << s)); + return x; + } else { + return x - (1 << s) + 1; + } +} + +// Decodes one 8x8 block of DCT coefficients from the bit stream. +bool DecodeDCTBlock(const HuffmanTableEntry* dc_huff, + const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, int* num_zero_runs, + BitReaderState* br, JPEGData* jpg, coeff_t* last_dc_coeff, + coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = ReadSymbol(dc_huff, br); + if (s >= kJpegDCAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for DC coefficient.", s); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + int diff = 0; + if (s > 0) { + int bits = br->ReadBits(s); + diff = HuffExtend(bits, s); + } + int coeff = diff + *last_dc_coeff; + const int dc_coeff = coeff * Am; + coeffs[0] = dc_coeff; + // TODO(eustas): is there a more elegant / explicit way to check this? + if (dc_coeff != coeffs[0]) { + JXL_JPEG_DEBUG("Invalid DC coefficient %d", dc_coeff); + jpg->error = JPEGReadError::NON_REPRESENTABLE_DC_COEFF; + return false; + } + *last_dc_coeff = coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + if (*eobrun > 0) { + --(*eobrun); + return true; + } + *num_zero_runs = 0; + for (int k = Ss; k <= Se; k++) { + int sr = ReadSymbol(ac_huff, br); + if (sr >= kJpegHuffmanAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", sr, k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + int r = sr >> 4; + int s = sr & 15; + if (s > 0) { + k += r; + if (k > Se) { + JXL_JPEG_DEBUG("Out-of-band coefficient %d band was %d-%d", k, Ss, Se); + jpg->error = JPEGReadError::OUT_OF_BAND_COEFF; + return false; + } + if (s + Al >= kJpegDCAlphabetSize) { + JXL_JPEG_DEBUG( + "Out of range AC coefficient value: s = %d Al = %d k = %d", s, Al, + k); + jpg->error = JPEGReadError::NON_REPRESENTABLE_AC_COEFF; + return false; + } + int bits = br->ReadBits(s); + int coeff = HuffExtend(bits, s); + coeffs[kJPEGNaturalOrder[k]] = coeff * Am; + *num_zero_runs = 0; + } else if (r == 15) { + k += 15; + ++(*num_zero_runs); + } else { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we signal + // the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + JXL_JPEG_DEBUG("End-of-block run crossing DC coeff."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + *eobrun += br->ReadBits(r); + } + break; + } + } + --(*eobrun); + return true; +} + +bool RefineDCTBlock(const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, BitReaderState* br, + JPEGData* jpg, coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = br->ReadBits(1); + coeff_t dc_coeff = coeffs[0]; + dc_coeff |= s * Am; + coeffs[0] = dc_coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + int p1 = Am; + int m1 = -Am; + int k = Ss; + int r; + int s; + bool in_zero_run = false; + if (*eobrun <= 0) { + for (; k <= Se; k++) { + s = ReadSymbol(ac_huff, br); + if (s >= kJpegHuffmanAlphabetSize) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", s, k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + r = s >> 4; + s &= 15; + if (s) { + if (s != 1) { + JXL_JPEG_DEBUG("Invalid Huffman symbol %d for AC coefficient %d", s, + k); + jpg->error = JPEGReadError::INVALID_SYMBOL; + return false; + } + s = br->ReadBits(1) ? p1 : m1; + in_zero_run = false; + } else { + if (r != 15) { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we + // signal the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + JXL_JPEG_DEBUG("End-of-block run crossing DC coeff."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + *eobrun += br->ReadBits(r); + } + break; + } + in_zero_run = true; + } + do { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } else { + if (--r < 0) { + break; + } + } + k++; + } while (k <= Se); + if (s) { + if (k > Se) { + JXL_JPEG_DEBUG("Out-of-band coefficient %d band was %d-%d", k, Ss, + Se); + jpg->error = JPEGReadError::OUT_OF_BAND_COEFF; + return false; + } + coeffs[kJPEGNaturalOrder[k]] = s; + } + } + } + if (in_zero_run) { + JXL_JPEG_DEBUG("Extra zero run before end-of-block."); + jpg->error = JPEGReadError::EXTRA_ZERO_RUN; + return false; + } + if (*eobrun > 0) { + for (; k <= Se; k++) { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } + } + } + --(*eobrun); + return true; +} + +bool ProcessRestart(const uint8_t* data, const size_t len, + int* next_restart_marker, BitReaderState* br, + JPEGData* jpg) { + size_t pos = 0; + if (!br->FinishStream(jpg, &pos)) { + jpg->error = JPEGReadError::INVALID_SCAN; + return false; + } + int expected_marker = 0xd0 + *next_restart_marker; + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + if (marker != expected_marker) { + JXL_JPEG_DEBUG("Did not find expected restart marker %d actual %d", + expected_marker, marker); + jpg->error = JPEGReadError::WRONG_RESTART_MARKER; + return false; + } + br->Reset(pos + 2); + *next_restart_marker += 1; + *next_restart_marker &= 0x7; + return true; +} + +bool ProcessScan(const uint8_t* data, const size_t len, + const std::vector& dc_huff_lut, + const std::vector& ac_huff_lut, + uint16_t scan_progression[kMaxComponents][kDCTBlockSize], + bool is_progressive, size_t* pos, JPEGData* jpg) { + if (!ProcessSOS(data, len, pos, jpg)) { + return false; + } + JPEGScanInfo* scan_info = &jpg->scan_info.back(); + bool is_interleaved = (scan_info->num_components > 1); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + max_h_samp_factor = + std::max(max_h_samp_factor, jpg->components[i].h_samp_factor); + max_v_samp_factor = + std::max(max_v_samp_factor, jpg->components[i].v_samp_factor); + } + + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCUs_per_row = DivCeil(jpg->width, max_h_samp_factor * 8); + if (!is_interleaved) { + const JPEGComponent& c = jpg->components[scan_info->components[0].comp_idx]; + MCUs_per_row = DivCeil(jpg->width * c.h_samp_factor, 8 * max_h_samp_factor); + MCU_rows = DivCeil(jpg->height * c.v_samp_factor, 8 * max_v_samp_factor); + } + coeff_t last_dc_coeff[kMaxComponents] = {0}; + BitReaderState br(data, len, *pos); + int restarts_to_go = jpg->restart_interval; + int next_restart_marker = 0; + int eobrun = -1; + int block_scan_index = 0; + const int Al = is_progressive ? scan_info->Al : 0; + const int Ah = is_progressive ? scan_info->Ah : 0; + const int Ss = is_progressive ? scan_info->Ss : 0; + const int Se = is_progressive ? scan_info->Se : 63; + const uint16_t scan_bitmask = Ah == 0 ? (0xffff << Al) : (1u << Al); + const uint16_t refinement_bitmask = (1 << Al) - 1; + for (size_t i = 0; i < scan_info->num_components; ++i) { + int comp_idx = scan_info->components[i].comp_idx; + for (int k = Ss; k <= Se; ++k) { + if (scan_progression[comp_idx][k] & scan_bitmask) { + JXL_JPEG_DEBUG( + "Overlapping scans: component=%d k=%d prev_mask: %u cur_mask %u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + jpg->error = JPEGReadError::OVERLAPPING_SCANS; + return false; + } + if (scan_progression[comp_idx][k] & refinement_bitmask) { + JXL_JPEG_DEBUG( + "Invalid scan order, a more refined scan was already done: " + "component=%d k=%d prev_mask=%u cur_mask=%u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + jpg->error = JPEGReadError::INVALID_SCAN_ORDER; + return false; + } + scan_progression[comp_idx][k] |= scan_bitmask; + } + } + if (Al > 10) { + JXL_JPEG_DEBUG("Scan parameter Al=%d is not supported.", Al); + jpg->error = JPEGReadError::NON_REPRESENTABLE_AC_COEFF; + return false; + } + for (int mcu_y = 0; mcu_y < MCU_rows; ++mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Handle the restart intervals. + if (jpg->restart_interval > 0) { + if (restarts_to_go == 0) { + if (ProcessRestart(data, len, &next_restart_marker, &br, jpg)) { + restarts_to_go = jpg->restart_interval; + memset(static_cast(last_dc_coeff), 0, sizeof(last_dc_coeff)); + if (eobrun > 0) { + JXL_JPEG_DEBUG("End-of-block run too long."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + eobrun = -1; // fresh start + } else { + return false; + } + } + --restarts_to_go; + } + // Decode one MCU. + for (size_t i = 0; i < scan_info->num_components; ++i) { + JPEGComponentScanInfo* si = &scan_info->components[i]; + JPEGComponent* c = &jpg->components[si->comp_idx]; + const HuffmanTableEntry* dc_lut = + &dc_huff_lut[si->dc_tbl_idx * kJpegHuffmanLutSize]; + const HuffmanTableEntry* ac_lut = + &ac_huff_lut[si->ac_tbl_idx * kJpegHuffmanLutSize]; + int nblocks_y = is_interleaved ? c->v_samp_factor : 1; + int nblocks_x = is_interleaved ? c->h_samp_factor : 1; + for (int iy = 0; iy < nblocks_y; ++iy) { + for (int ix = 0; ix < nblocks_x; ++ix) { + int block_y = mcu_y * nblocks_y + iy; + int block_x = mcu_x * nblocks_x + ix; + int block_idx = block_y * c->width_in_blocks + block_x; + bool reset_state = false; + int num_zero_runs = 0; + coeff_t* coeffs = &c->coeffs[block_idx * kDCTBlockSize]; + if (Ah == 0) { + if (!DecodeDCTBlock(dc_lut, ac_lut, Ss, Se, Al, &eobrun, + &reset_state, &num_zero_runs, &br, jpg, + &last_dc_coeff[si->comp_idx], coeffs)) { + return false; + } + } else { + if (!RefineDCTBlock(ac_lut, Ss, Se, Al, &eobrun, &reset_state, + &br, jpg, coeffs)) { + return false; + } + } + if (reset_state) { + scan_info->reset_points.emplace_back(block_scan_index); + } + if (num_zero_runs > 0) { + JPEGScanInfo::ExtraZeroRunInfo info; + info.block_idx = block_scan_index; + info.num_extra_zero_runs = num_zero_runs; + scan_info->extra_zero_runs.push_back(info); + } + ++block_scan_index; + } + } + } + } + } + if (eobrun > 0) { + JXL_JPEG_DEBUG("End-of-block run too long."); + jpg->error = JPEGReadError::EOB_RUN_TOO_LONG; + return false; + } + if (!br.FinishStream(jpg, pos)) { + jpg->error = JPEGReadError::INVALID_SCAN; + return false; + } + if (*pos > len) { + JXL_JPEG_DEBUG("Unexpected end of file during scan. pos=%zu len=%zu", *pos, + len); + jpg->error = JPEGReadError::UNEXPECTED_EOF; + return false; + } + return true; +} + +// Changes the quant_idx field of the components to refer to the index of the +// quant table in the jpg->quant array. +bool FixupIndexes(JPEGData* jpg) { + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + bool found_index = false; + for (size_t j = 0; j < jpg->quant.size(); ++j) { + if (jpg->quant[j].index == c->quant_idx) { + c->quant_idx = j; + found_index = true; + break; + } + } + if (!found_index) { + JXL_JPEG_DEBUG("Quantization table with index %u not found", + c->quant_idx); + jpg->error = JPEGReadError::QUANT_TABLE_NOT_FOUND; + return false; + } + } + return true; +} + +size_t FindNextMarker(const uint8_t* data, const size_t len, size_t pos) { + // kIsValidMarker[i] == 1 means (0xc0 + i) is a valid marker. + static const uint8_t kIsValidMarker[] = { + 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + }; + size_t num_skipped = 0; + while (pos + 1 < len && (data[pos] != 0xff || data[pos + 1] < 0xc0 || + !kIsValidMarker[data[pos + 1] - 0xc0])) { + ++pos; + ++num_skipped; + } + return num_skipped; +} + +} // namespace + +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg) { + size_t pos = 0; + // Check SOI marker. + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + pos += 2; + if (marker != 0xd8) { + JXL_JPEG_DEBUG("Did not find expected SOI marker, actual=%d", marker); + jpg->error = JPEGReadError::SOI_NOT_FOUND; + return false; + } + int lut_size = kMaxHuffmanTables * kJpegHuffmanLutSize; + std::vector dc_huff_lut(lut_size); + std::vector ac_huff_lut(lut_size); + bool found_sof = false; + bool found_dri = false; + uint16_t scan_progression[kMaxComponents][kDCTBlockSize] = {{0}}; + + jpg->padding_bits.resize(0); + bool is_progressive = false; // default + do { + // Read next marker. + size_t num_skipped = FindNextMarker(data, len, pos); + if (num_skipped > 0) { + // Add a fake marker to indicate arbitrary in-between-markers data. + jpg->marker_order.push_back(0xff); + jpg->inter_marker_data.emplace_back(data + pos, data + pos + num_skipped); + pos += num_skipped; + } + JXL_JPEG_EXPECT_MARKER(); + marker = data[pos + 1]; + pos += 2; + bool ok = true; + switch (marker) { + case 0xc0: + case 0xc1: + case 0xc2: + is_progressive = (marker == 0xc2); + ok = ProcessSOF(data, len, mode, &pos, jpg); + found_sof = true; + break; + case 0xc4: + ok = ProcessDHT(data, len, mode, &dc_huff_lut, &ac_huff_lut, &pos, jpg); + break; + case 0xd0: + case 0xd1: + case 0xd2: + case 0xd3: + case 0xd4: + case 0xd5: + case 0xd6: + case 0xd7: + // RST markers do not have any data. + break; + case 0xd9: + // Found end marker. + break; + case 0xda: + if (mode == JpegReadMode::kReadAll) { + ok = ProcessScan(data, len, dc_huff_lut, ac_huff_lut, + scan_progression, is_progressive, &pos, jpg); + } + break; + case 0xdb: + ok = ProcessDQT(data, len, &pos, jpg); + break; + case 0xdd: + ok = ProcessDRI(data, len, &pos, &found_dri, jpg); + break; + case 0xe0: + case 0xe1: + case 0xe2: + case 0xe3: + case 0xe4: + case 0xe5: + case 0xe6: + case 0xe7: + case 0xe8: + case 0xe9: + case 0xea: + case 0xeb: + case 0xec: + case 0xed: + case 0xee: + case 0xef: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessAPP(data, len, &pos, jpg); + } + break; + case 0xfe: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessCOM(data, len, &pos, jpg); + } + break; + default: + JXL_JPEG_DEBUG("Unsupported marker: %d pos=%zu len=%zu", marker, pos, + len); + jpg->error = JPEGReadError::UNSUPPORTED_MARKER; + ok = false; + break; + } + if (!ok) { + return false; + } + jpg->marker_order.push_back(marker); + if (mode == JpegReadMode::kReadHeader && found_sof) { + break; + } + } while (marker != 0xd9); + + if (!found_sof) { + JXL_JPEG_DEBUG("Missing SOF marker."); + jpg->error = JPEGReadError::SOF_NOT_FOUND; + return false; + } + + // Supplemental checks. + if (mode == JpegReadMode::kReadAll) { + if (pos < len) { + jpg->tail_data = std::vector(data + pos, data + len); + } + if (!FixupIndexes(jpg)) { + return false; + } + if (jpg->huffman_code.empty()) { + // Section B.2.4.2: "If a table has never been defined for a particular + // destination, then when this destination is specified in a scan header, + // the results are unpredictable." + JXL_JPEG_DEBUG("Need at least one Huffman code table."); + jpg->error = JPEGReadError::HUFFMAN_TABLE_ERROR; + return false; + } + if (jpg->huffman_code.size() >= kMaxDHTMarkers) { + JXL_JPEG_DEBUG("Too many Huffman tables."); + jpg->error = JPEGReadError::HUFFMAN_TABLE_ERROR; + return false; + } + } + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h new file mode 100644 index 000000000000..0c45f770a2ea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functions for reading a jpeg byte stream into a JPEGData object. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ + +#include +#include + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +enum class JpegReadMode { + kReadHeader, // only basic headers + kReadTables, // headers and tables (quant, Huffman, ...) + kReadAll, // everything +}; + +// Parses the JPEG stream contained in data[*pos ... len) and fills in *jpg with +// the parsed information. +// If mode is kReadHeader, it fills in only the image dimensions in *jpg. +// Returns false if the data is not valid JPEG, or if it contains an unsupported +// JPEG feature. +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc new file mode 100644 index 000000000000..e8bbf4dbe546 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc @@ -0,0 +1,112 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Returns the table width of the next 2nd level table, count is the histogram +// of bit lengths for the remaining symbols, len is the code length of the next +// processed symbol. +static inline int NextTableBitSize(const int* count, int len) { + int left = 1 << (len - kJpegHuffmanRootTableBits); + while (len < static_cast(kJpegHuffmanMaxBitLength)) { + left -= count[len]; + if (left <= 0) break; + ++len; + left <<= 1; + } + return len - kJpegHuffmanRootTableBits; +} + +void BuildJpegHuffmanTable(const uint32_t* count, const uint32_t* symbols, + HuffmanTableEntry* lut) { + HuffmanTableEntry code; // current table entry + HuffmanTableEntry* table; // next available space in table + int len; // current code length + int idx; // symbol index + int key; // prefix code + int reps; // number of replicate key values in current table + int low; // low bits for current root entry + int table_bits; // key length of current table + int table_size; // size of current table + + // Make a local copy of the input bit length histogram. + int tmp_count[kJpegHuffmanMaxBitLength + 1] = {0}; + int total_count = 0; + for (len = 1; len <= static_cast(kJpegHuffmanMaxBitLength); ++len) { + tmp_count[len] = count[len]; + total_count += tmp_count[len]; + } + + table = lut; + table_bits = kJpegHuffmanRootTableBits; + table_size = 1 << table_bits; + + // Special case code with only one value. + if (total_count == 1) { + code.bits = 0; + code.value = symbols[0]; + for (key = 0; key < table_size; ++key) { + table[key] = code; + } + return; + } + + // Fill in root table. + key = 0; + idx = 0; + for (len = 1; len <= kJpegHuffmanRootTableBits; ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + code.bits = len; + code.value = symbols[idx++]; + reps = 1 << (kJpegHuffmanRootTableBits - len); + while (reps--) { + table[key++] = code; + } + } + } + + // Fill in 2nd level tables and add pointers to root table. + table += table_size; + table_size = 0; + low = 0; + for (len = kJpegHuffmanRootTableBits + 1; + len <= static_cast(kJpegHuffmanMaxBitLength); ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + // Start a new sub-table if the previous one is full. + if (low >= table_size) { + table += table_size; + table_bits = NextTableBitSize(tmp_count, len); + table_size = 1 << table_bits; + low = 0; + lut[key].bits = table_bits + kJpegHuffmanRootTableBits; + lut[key].value = (table - lut) - key; + ++key; + } + code.bits = len - kJpegHuffmanRootTableBits; + code.value = symbols[idx++]; + reps = 1 << (table_bits - code.bits); + while (reps--) { + table[low++] = code; + } + } + } +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h new file mode 100644 index 000000000000..034c6a99324c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility function for building a Huffman lookup table for the jpeg decoder. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ +#define LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ + +#include + +namespace jxl { +namespace jpeg { + +constexpr int kJpegHuffmanRootTableBits = 8; +// Maximum huffman lookup table size. +// According to zlib/examples/enough.c, 758 entries are always enough for +// an alphabet of 257 symbols (256 + 1 special symbol for the all 1s code) and +// max bit length 16 if the root table has 8 bits. +constexpr int kJpegHuffmanLutSize = 758; + +struct HuffmanTableEntry { + // Initialize the value to an invalid symbol so that we can recognize it + // when reading the bit stream using a Huffman code with space > 0. + HuffmanTableEntry() : bits(0), value(0xffff) {} + + uint8_t bits; // number of bits used for this symbol + uint16_t value; // symbol value or table offset +}; + +// Builds jpeg-style Huffman lookup table from the given symbols. +// The symbols are in order of increasing bit lengths. The number of symbols +// with bit length n is given in counts[n] for each n >= 1. +void BuildJpegHuffmanTable(const uint32_t* counts, const uint32_t* symbols, + HuffmanTableEntry* lut); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc new file mode 100644 index 000000000000..c623f4d5cdc4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc @@ -0,0 +1,449 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/jpeg/jpeg_data.h" + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace jpeg { + +namespace { +enum JPEGComponentType : uint32_t { + kGray = 0, + kYCbCr = 1, + kRGB = 2, + kCustom = 3, +}; + +struct JPEGInfo { + size_t num_app_markers = 0; + size_t num_com_markers = 0; + size_t num_scans = 0; + size_t num_intermarker = 0; + bool has_dri = false; +}; + +Status VisitMarker(uint8_t* marker, Visitor* visitor, JPEGInfo* info) { + uint32_t marker32 = *marker - 0xc0; + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0x00, &marker32)); + *marker = marker32 + 0xc0; + if ((*marker & 0xf0) == 0xe0) { + info->num_app_markers++; + } + if (*marker == 0xfe) { + info->num_com_markers++; + } + if (*marker == 0xda) { + info->num_scans++; + } + // We use a fake 0xff marker to signal intermarker data. + if (*marker == 0xff) { + info->num_intermarker++; + } + if (*marker == 0xdd) { + info->has_dri = true; + } + return true; +} + +} // namespace + +Status JPEGData::VisitFields(Visitor* visitor) { + bool is_gray = components.size() == 1; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_gray)); + if (visitor->IsReading()) { + components.resize(is_gray ? 1 : 3); + } + JPEGInfo info; + if (visitor->IsReading()) { + uint8_t marker = 0xc0; + do { + JXL_RETURN_IF_ERROR(VisitMarker(&marker, visitor, &info)); + marker_order.push_back(marker); + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %zu\n", marker_order.size()); + } + } while (marker != 0xd9); + } else { + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %zu\n", marker_order.size()); + } + for (size_t i = 0; i < marker_order.size(); i++) { + JXL_RETURN_IF_ERROR(VisitMarker(&marker_order[i], visitor, &info)); + } + if (!marker_order.empty()) { + // Last marker should always be EOI marker. + JXL_CHECK(marker_order.back() == 0xd9); + } + } + + // Size of the APP and COM markers. + if (visitor->IsReading()) { + app_data.resize(info.num_app_markers); + app_marker_type.resize(info.num_app_markers); + com_data.resize(info.num_com_markers); + scan_info.resize(info.num_scans); + } + JXL_ASSERT(app_data.size() == info.num_app_markers); + JXL_ASSERT(app_marker_type.size() == info.num_app_markers); + JXL_ASSERT(com_data.size() == info.num_com_markers); + JXL_ASSERT(scan_info.size() == info.num_scans); + for (size_t i = 0; i < app_data.size(); i++) { + auto& app = app_data[i]; + // Encodes up to 8 different values. + JXL_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), BitsOffset(1, 2), BitsOffset(2, 4), 0, + reinterpret_cast(&app_marker_type[i]))); + if (app_marker_type[i] != AppMarkerType::kUnknown && + app_marker_type[i] != AppMarkerType::kICC && + app_marker_type[i] != AppMarkerType::kExif && + app_marker_type[i] != AppMarkerType::kXMP) { + return JXL_FAILURE("Unknown app marker type %u", + static_cast(app_marker_type[i])); + } + uint32_t len = app.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) app.resize(len + 1); + if (app.size() < 3) { + return JXL_FAILURE("Invalid marker size: %zu\n", app.size()); + } + } + for (auto& com : com_data) { + uint32_t len = com.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) com.resize(len + 1); + if (com.size() < 3) { + return JXL_FAILURE("Invalid marker size: %zu\n", com.size()); + } + } + + uint32_t num_quant_tables = quant.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 2, &num_quant_tables)); + if (num_quant_tables == 4) { + return JXL_FAILURE("Invalid number of quant tables"); + } + if (visitor->IsReading()) { + quant.resize(num_quant_tables); + } + for (size_t i = 0; i < num_quant_tables; i++) { + if (quant[i].precision > 1) { + return JXL_FAILURE( + "Quant tables with more than 16 bits are not supported"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(1, 0, &quant[i].precision)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, i, &quant[i].index)); + JXL_RETURN_IF_ERROR(visitor->Bool(true, &quant[i].is_last)); + } + + JPEGComponentType component_type = + components.size() == 1 && components[0].id == 1 + ? JPEGComponentType::kGray + : components.size() == 3 && components[0].id == 1 && + components[1].id == 2 && components[2].id == 3 + ? JPEGComponentType::kYCbCr + : components.size() == 3 && components[0].id == 'R' && + components[1].id == 'G' && components[2].id == 'B' + ? JPEGComponentType::kRGB + : JPEGComponentType::kCustom; + JXL_RETURN_IF_ERROR( + visitor->Bits(2, JPEGComponentType::kYCbCr, + reinterpret_cast(&component_type))); + uint32_t num_components; + if (component_type == JPEGComponentType::kGray) { + num_components = 1; + } else if (component_type != JPEGComponentType::kCustom) { + num_components = 3; + } else { + num_components = components.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 3, &num_components)); + if (num_components != 1 && num_components != 3) { + return JXL_FAILURE("Invalid number of components: %u", num_components); + } + } + if (visitor->IsReading()) { + components.resize(num_components); + } + if (component_type == JPEGComponentType::kCustom) { + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(8, 0, &components[i].id)); + } + } else if (component_type == JPEGComponentType::kGray) { + components[0].id = 1; + } else if (component_type == JPEGComponentType::kRGB) { + components[0].id = 'R'; + components[1].id = 'G'; + components[2].id = 'B'; + } else { + components[0].id = 1; + components[1].id = 2; + components[2].id = 3; + } + size_t used_tables = 0; + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &components[i].quant_idx)); + if (components[i].quant_idx >= quant.size()) { + return JXL_FAILURE("Invalid quant table for component %zu: %u\n", i, + components[i].quant_idx); + } + used_tables |= 1U << components[i].quant_idx; + } + if (used_tables + 1 != 1U << quant.size()) { + return JXL_FAILURE( + "Not all quant tables are used (%zu tables, %zx used table mask)", + quant.size(), used_tables); + } + + uint32_t num_huff = huffman_code.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(4), BitsOffset(3, 2), BitsOffset(4, 10), + BitsOffset(6, 26), 4, &num_huff)); + if (visitor->IsReading()) { + huffman_code.resize(num_huff); + } + for (JPEGHuffmanCode& hc : huffman_code) { + bool is_ac = hc.slot_id >> 4; + uint32_t id = hc.slot_id & 0xF; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_ac)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &id)); + hc.slot_id = (static_cast(is_ac) << 4) | id; + JXL_RETURN_IF_ERROR(visitor->Bool(true, &hc.is_last)); + size_t num_symbols = 0; + for (size_t i = 0; i <= 16; i++) { + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(3, 2), + Bits(8), 0, &hc.counts[i])); + num_symbols += hc.counts[i]; + } + if (num_symbols < 1) { + // Actually, at least 2 symbols are required, since one of them is EOI. + return JXL_FAILURE("Empty Huffman table"); + } + if (num_symbols > hc.values.size()) { + return JXL_FAILURE("Huffman code too large (%zu)", num_symbols); + } + // Presence flags for 4 * 64 + 1 values. + uint64_t value_slots[5] = {}; + for (size_t i = 0; i < num_symbols; i++) { + // Goes up to 256, included. Might have the same symbol appear twice... + JXL_RETURN_IF_ERROR(visitor->U32(Bits(2), BitsOffset(2, 4), + BitsOffset(4, 8), BitsOffset(8, 1), 0, + &hc.values[i])); + value_slots[hc.values[i] >> 6] |= (uint64_t)1 << (hc.values[i] & 0x3F); + } + if (hc.values[num_symbols - 1] != kJpegHuffmanAlphabetSize) { + return JXL_FAILURE("Missing EOI symbol"); + } + // Last element, denoting EOI, have to be 1 after the loop. + JXL_ASSERT(value_slots[4] == 1); + size_t num_values = 1; + for (size_t i = 0; i < 4; ++i) num_values += hwy::PopCount(value_slots[i]); + if (num_values != num_symbols) { + return JXL_FAILURE("Duplicate Huffman symbols"); + } + if (!is_ac) { + bool only_dc = ((value_slots[0] >> kJpegDCAlphabetSize) | value_slots[1] | + value_slots[2] | value_slots[3]) == 0; + if (!only_dc) return JXL_FAILURE("Huffman symbols out of DC range"); + } + } + + for (auto& scan : scan_info) { + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &scan.num_components)); + if (scan.num_components >= 4) { + return JXL_FAILURE("Invalid number of components in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0, &scan.Ss)); + JXL_RETURN_IF_ERROR(visitor->Bits(6, 63, &scan.Se)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Al)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Ah)); + for (size_t i = 0; i < scan.num_components; i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].comp_idx)); + if (scan.components[i].comp_idx >= components.size()) { + return JXL_FAILURE("Invalid component idx in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].ac_tbl_idx)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].dc_tbl_idx)); + } + // TODO(veluca): actually set and use this value. + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), Val(2), BitsOffset(3, 3), + kMaxNumPasses - 1, + &scan.last_needed_pass)); + } + + // From here on, this is data that is not strictly necessary to get a valid + // JPEG, but necessary for bit-exact JPEG reconstruction. + if (info.has_dri) { + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &restart_interval)); + } + + uint64_t padding_spot_limit = scan_info.size(); + + for (auto& scan : scan_info) { + uint32_t num_reset_points = scan.reset_points.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, &num_reset_points)); + if (visitor->IsReading()) { + scan.reset_points.resize(num_reset_points); + } + int last_block_idx = -1; + for (auto& block_idx : scan.reset_points) { + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (static_cast(block_idx) < last_block_idx + 1) { + return JXL_FAILURE("Invalid block ID: %u, last block was %d", block_idx, + last_block_idx); + } + // TODO(eustas): better upper boundary could be given at this point; also + // it could be applied during reset_points reading. + if (block_idx > (1u << 30)) { + // At most 8K x 8K x num_channels blocks are expected. That is, + // typically, 1.5 * 2^27. 2^30 should be sufficient for any sane + // image. + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + + uint32_t num_extra_zero_runs = scan.extra_zero_runs.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, + &num_extra_zero_runs)); + if (visitor->IsReading()) { + scan.extra_zero_runs.resize(num_extra_zero_runs); + } + last_block_idx = -1; + for (size_t i = 0; i < scan.extra_zero_runs.size(); ++i) { + uint32_t& block_idx = scan.extra_zero_runs[i].block_idx; + JXL_RETURN_IF_ERROR(visitor->U32( + Val(1), BitsOffset(2, 2), BitsOffset(4, 5), BitsOffset(8, 20), 1, + &scan.extra_zero_runs[i].num_extra_zero_runs)); + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (static_cast(block_idx) < last_block_idx + 1) { + return JXL_FAILURE("Invalid block ID: %u, last block was %d", block_idx, + last_block_idx); + } + if (block_idx > (1u << 30)) { + // At most 8K x 8K x num_channels blocks are expected. That is, + // typically, 1.5 * 2^27. 2^30 should be sufficient for any sane + // image. + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + + if (restart_interval > 0) { + int MCUs_per_row = 0; + int MCU_rows = 0; + CalculateMcuSize(scan, &MCUs_per_row, &MCU_rows); + padding_spot_limit += DivCeil(MCU_rows * MCUs_per_row, restart_interval); + } + } + std::vector inter_marker_data_sizes; + inter_marker_data_sizes.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + uint32_t len = visitor->IsReading() ? 0 : inter_marker_data[i].size(); + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) inter_marker_data_sizes.emplace_back(len); + } + uint32_t tail_data_len = tail_data.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(8, 1), + BitsOffset(16, 257), BitsOffset(22, 65793), + 0, &tail_data_len)); + + JXL_RETURN_IF_ERROR(visitor->Bool(false, &has_zero_padding_bit)); + if (has_zero_padding_bit) { + uint32_t nbit = padding_bits.size(); + JXL_RETURN_IF_ERROR(visitor->Bits(24, 0, &nbit)); + if (nbit > 7 * padding_spot_limit) { + return JXL_FAILURE("Number of padding bits does not correspond to image"); + } + // TODO(eustas): check that that much bits of input are available. + if (visitor->IsReading()) { + padding_bits.resize(nbit); + } + // TODO(eustas): read in (8-64?) bit groups to reduce overhead. + for (uint8_t& bit : padding_bits) { + bool bbit = bit; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &bbit)); + bit = bbit; + } + } + + // Apply postponed actions. + if (visitor->IsReading()) { + tail_data.resize(tail_data_len); + JXL_ASSERT(inter_marker_data_sizes.size() == info.num_intermarker); + inter_marker_data.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + inter_marker_data.emplace_back(inter_marker_data_sizes[i]); + } + } + + return true; +} + +void JPEGData::CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const { + const bool is_interleaved = (scan.num_components > 1); + const JPEGComponent& base_component = components[scan.components[0].comp_idx]; + // h_group / v_group act as numerators for converting number of blocks to + // number of MCU. In interleaved mode it is 1, so MCU is represented with + // max_*_samp_factor blocks. In non-interleaved mode we choose numerator to + // be the samping factor, consequently MCU is always represented with single + // block. + const int h_group = is_interleaved ? 1 : base_component.h_samp_factor; + const int v_group = is_interleaved ? 1 : base_component.v_samp_factor; + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (const auto& c : components) { + max_h_samp_factor = std::max(c.h_samp_factor, max_h_samp_factor); + max_v_samp_factor = std::max(c.v_samp_factor, max_v_samp_factor); + } + *MCUs_per_row = DivCeil(width * h_group, 8 * max_h_samp_factor); + *MCU_rows = DivCeil(height * v_group, 8 * max_v_samp_factor); +} + +Status SetJPEGDataFromICC(const PaddedBytes& icc, jpeg::JPEGData* jpeg_data) { + size_t icc_pos = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + if (jpeg_data->app_marker_type[i] != jpeg::AppMarkerType::kICC) { + continue; + } + size_t len = jpeg_data->app_data[i].size() - 17; + if (icc_pos + len > icc.size()) { + return JXL_FAILURE( + "ICC length is less than APP markers: requested %zu more bytes, " + "%zu available", + len, icc.size() - icc_pos); + } + memcpy(&jpeg_data->app_data[i][17], icc.data() + icc_pos, len); + icc_pos += len; + } + if (icc_pos != icc.size() && icc_pos != 0) { + return JXL_FAILURE("ICC length is more than APP markers"); + } + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h new file mode 100644 index 000000000000..b47c1e948e7d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h @@ -0,0 +1,263 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Data structures that represent the non-pixel contents of a jpeg file. + +#ifndef LIB_JXL_JPEG_JPEG_DATA_H_ +#define LIB_JXL_JPEG_JPEG_DATA_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace jpeg { + +constexpr int kMaxComponents = 4; +constexpr int kMaxQuantTables = 4; +constexpr int kMaxHuffmanTables = 4; +constexpr size_t kJpegHuffmanMaxBitLength = 16; +constexpr int kJpegHuffmanAlphabetSize = 256; +constexpr int kJpegDCAlphabetSize = 12; +constexpr int kMaxDHTMarkers = 512; +constexpr int kMaxDimPixels = 65535; +constexpr uint8_t kApp1 = 0xE1; +constexpr uint8_t kApp2 = 0xE2; +const uint8_t kIccProfileTag[12] = "ICC_PROFILE"; +const uint8_t kExifTag[6] = "Exif\0"; +const uint8_t kXMPTag[29] = "http://ns.adobe.com/xap/1.0/"; + +/* clang-format off */ +constexpr uint32_t kJPEGNaturalOrder[80] = { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // extra entries for safety in decoder + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63, 63 +}; + +constexpr uint32_t kJPEGZigZagOrder[64] = { + 0, 1, 5, 6, 14, 15, 27, 28, + 2, 4, 7, 13, 16, 26, 29, 42, + 3, 8, 12, 17, 25, 30, 41, 43, + 9, 11, 18, 24, 31, 40, 44, 53, + 10, 19, 23, 32, 39, 45, 52, 54, + 20, 22, 33, 38, 46, 51, 55, 60, + 21, 34, 37, 47, 50, 56, 59, 61, + 35, 36, 48, 49, 57, 58, 62, 63 +}; +/* clang-format on */ + +enum struct JPEGReadError { + OK = 0, + SOI_NOT_FOUND, + SOF_NOT_FOUND, + UNEXPECTED_EOF, + MARKER_BYTE_NOT_FOUND, + UNSUPPORTED_MARKER, + WRONG_MARKER_SIZE, + INVALID_PRECISION, + INVALID_WIDTH, + INVALID_HEIGHT, + INVALID_NUMCOMP, + INVALID_SAMP_FACTOR, + INVALID_START_OF_SCAN, + INVALID_END_OF_SCAN, + INVALID_SCAN_BIT_POSITION, + INVALID_COMPS_IN_SCAN, + INVALID_HUFFMAN_INDEX, + INVALID_QUANT_TBL_INDEX, + INVALID_QUANT_VAL, + INVALID_MARKER_LEN, + INVALID_SAMPLING_FACTORS, + INVALID_HUFFMAN_CODE, + INVALID_SYMBOL, + NON_REPRESENTABLE_DC_COEFF, + NON_REPRESENTABLE_AC_COEFF, + INVALID_SCAN, + OVERLAPPING_SCANS, + INVALID_SCAN_ORDER, + EXTRA_ZERO_RUN, + DUPLICATE_DRI, + DUPLICATE_SOF, + WRONG_RESTART_MARKER, + DUPLICATE_COMPONENT_ID, + COMPONENT_NOT_FOUND, + HUFFMAN_TABLE_NOT_FOUND, + HUFFMAN_TABLE_ERROR, + QUANT_TABLE_NOT_FOUND, + EMPTY_DHT, + EMPTY_DQT, + OUT_OF_BAND_COEFF, + EOB_RUN_TOO_LONG, + IMAGE_TOO_LARGE, + INVALID_QUANT_TBL_PRECISION, +}; + +// Quantization values for an 8x8 pixel block. +struct JPEGQuantTable { + std::array values; + uint32_t precision = 0; + // The index of this quantization table as it was parsed from the input JPEG. + // Each DQT marker segment contains an 'index' field, and we save this index + // here. Valid values are 0 to 3. + uint32_t index = 0; + // Set to true if this table is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman code and decoding lookup table used for DC and AC coefficients. +struct JPEGHuffmanCode { + // Bit length histogram. + std::array counts = {}; + // Symbol values sorted by increasing bit lengths. + std::array values = {}; + // The index of the Huffman code in the current set of Huffman codes. For AC + // component Huffman codes, 0x10 is added to the index. + int slot_id = 0; + // Set to true if this Huffman code is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman table indexes used for one component of one scan. +struct JPEGComponentScanInfo { + uint32_t comp_idx; + uint32_t dc_tbl_idx; + uint32_t ac_tbl_idx; +}; + +// Contains information that is used in one scan. +struct JPEGScanInfo { + // Parameters used for progressive scans (named the same way as in the spec): + // Ss : Start of spectral band in zig-zag sequence. + // Se : End of spectral band in zig-zag sequence. + // Ah : Successive approximation bit position, high. + // Al : Successive approximation bit position, low. + uint32_t Ss; + uint32_t Se; + uint32_t Ah; + uint32_t Al; + uint32_t num_components = 0; + std::array components; + // Last codestream pass that is needed to write this scan. + uint32_t last_needed_pass = 0; + + // Extra information required for bit-precise JPEG file reconstruction. + + // Set of block indexes where the JPEG encoder has to flush the end-of-block + // runs and refinement bits. + std::vector reset_points; + // The number of extra zero runs (Huffman symbol 0xf0) before the end of + // block (if nonzero), indexed by block index. + // All of these symbols can be omitted without changing the pixel values, but + // some jpeg encoders put these at the end of blocks. + typedef struct { + uint32_t block_idx; + uint32_t num_extra_zero_runs; + } ExtraZeroRunInfo; + std::vector extra_zero_runs; +}; + +typedef int16_t coeff_t; + +// Represents one component of a jpeg file. +struct JPEGComponent { + JPEGComponent() + : id(0), + h_samp_factor(1), + v_samp_factor(1), + quant_idx(0), + width_in_blocks(0), + height_in_blocks(0) {} + + // One-byte id of the component. + uint32_t id; + // Horizontal and vertical sampling factors. + // In interleaved mode, each minimal coded unit (MCU) has + // h_samp_factor x v_samp_factor DCT blocks from this component. + int h_samp_factor; + int v_samp_factor; + // The index of the quantization table used for this component. + uint32_t quant_idx; + // The dimensions of the component measured in 8x8 blocks. + uint32_t width_in_blocks; + uint32_t height_in_blocks; + // The DCT coefficients of this component, laid out block-by-block, divided + // through the quantization matrix values. + std::vector coeffs; +}; + +enum class AppMarkerType : uint32_t { + kUnknown = 0, + kICC = 1, + kExif = 2, + kXMP = 3, +}; + +// Represents a parsed jpeg file. +struct JPEGData : public Fields { + JPEGData() + : width(0), + height(0), + restart_interval(0), + error(JPEGReadError::OK), + has_zero_padding_bit(false) {} + + const char* Name() const override { return "JPEGData"; } + // Doesn't serialize everything - skips brotli-encoded data and what is + // already encoded in the codestream. + Status VisitFields(Visitor* visitor) override; + + void CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const; + + int width; + int height; + uint32_t restart_interval; + std::vector> app_data; + std::vector app_marker_type; + std::vector> com_data; + std::vector quant; + std::vector huffman_code; + std::vector components; + std::vector scan_info; + std::vector marker_order; + std::vector> inter_marker_data; + std::vector tail_data; + JPEGReadError error; + + // Extra information required for bit-precise JPEG file reconstruction. + + bool has_zero_padding_bit; + std::vector padding_bits; +}; + +// Set ICC profile in jpeg_data. +Status SetJPEGDataFromICC(const PaddedBytes& icc, jpeg::JPEGData* jpeg_data); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jxl.syms b/third_party/jpeg-xl/lib/jxl/jxl.syms new file mode 100644 index 000000000000..0f398d71519f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl.syms @@ -0,0 +1,5 @@ +{ + extern "C" { + jpegxl_*; + }; +}; diff --git a/third_party/jpeg-xl/lib/jxl/jxl.version b/third_party/jpeg-xl/lib/jxl/jxl.version new file mode 100644 index 000000000000..e0ed12be2507 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl.version @@ -0,0 +1,7 @@ +JXL_0 { + global: + Jxl*; + + local: + *; +}; diff --git a/third_party/jpeg-xl/lib/jxl/jxl_inspection.h b/third_party/jpeg-xl/lib/jxl/jxl_inspection.h new file mode 100644 index 000000000000..d3c481a12fee --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl_inspection.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_JXL_INSPECTION_H_ +#define LIB_JXL_JXL_INSPECTION_H_ + +#include + +#include "lib/jxl/image.h" + +namespace jxl { +// Type of the inspection-callback which, if enabled, will be called on various +// intermediate data during image processing, allowing inspection access. +// +// Returns false if processing can be stopped at that point, true otherwise. +// This is only advisory - it is always OK to just continue processing. +using InspectorImage3F = std::function; +} // namespace jxl + +#endif // LIB_JXL_JXL_INSPECTION_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jxl_osx.syms b/third_party/jpeg-xl/lib/jxl/jxl_osx.syms new file mode 100644 index 000000000000..96bc568025d6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl_osx.syms @@ -0,0 +1 @@ +_Jxl* diff --git a/third_party/jpeg-xl/lib/jxl/jxl_test.cc b/third_party/jpeg-xl/lib/jxl/jxl_test.cc new file mode 100644 index 000000000000..3141a4e887b7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl_test.cc @@ -0,0 +1,1401 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/extras/codec_jpg.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" +#include "tools/box/box.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +#define JXL_TEST_NL 0 // Disabled in code + +void CreateImage1x1(CodecInOut* io) { + Image3F image(1, 1); + ZeroFillImage(&image); + io->metadata.m.SetUintSamples(8); + io->metadata.m.color_encoding = ColorEncoding::SRGB(); + io->SetFromImage(std::move(image), io->metadata.m.color_encoding); +} + +TEST(JxlTest, HeaderSize) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.5; + DecompressParams dparams; + ThreadPool* pool = nullptr; + + { + CodecInOut io2; + AuxOut aux_out; + Roundtrip(&io, cparams, dparams, pool, &io2, &aux_out); + EXPECT_LE(aux_out.layers[kLayerHeader].total_bits, 34); + } + + { + CodecInOut io2; + io.metadata.m.SetAlphaBits(8); + ImageF alpha(1, 1); + alpha.Row(0)[0] = 1; + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + AuxOut aux_out; + Roundtrip(&io, cparams, dparams, pool, &io2, &aux_out); + EXPECT_LE(aux_out.layers[kLayerHeader].total_bits, 57); + } +} + +TEST(JxlTest, RoundtripSinglePixel) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + ThreadPool* pool = nullptr; + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); +} + +// Changing serialized signature causes Decode to fail. +#ifndef JXL_CRASH_ON_ERROR +TEST(JxlTest, RoundtripMarker) { + CodecInOut io; + CreateImage1x1(&io); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + AuxOut* aux_out = nullptr; + ThreadPool* pool = nullptr; + + PassesEncoderState enc_state; + for (size_t i = 0; i < 2; ++i) { + PaddedBytes compressed; + EXPECT_TRUE( + EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + compressed[i] ^= 0xFF; + CodecInOut io2; + EXPECT_FALSE(DecodeFile(dparams, compressed, &io2, pool)); + } +} +#endif + +TEST(JxlTest, RoundtripTinyFast) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(32, 32); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 4.0f; + DecompressParams dparams; + + CodecInOut io2; + const size_t enc_bytes = Roundtrip(&io, cparams, dparams, pool, &io2); + printf("32x32 image size %zu bytes\n", enc_bytes); +} + +TEST(JxlTest, RoundtripSmallD1) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + CodecInOut io_out; + size_t compressed_size; + + { + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + compressed_size = Roundtrip(&io, cparams, dparams, pool, &io_out); + EXPECT_LE(compressed_size, 1000); + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.5); + } + + { + // And then, with a lower intensity target than the default, the bitrate + // should be smaller. + CodecInOut io_dim; + io_dim.target_nits = 100; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_dim, pool)); + io_dim.ShrinkTo(io_dim.xsize() / 8, io_dim.ysize() / 8); + EXPECT_LT(Roundtrip(&io_dim, cparams, dparams, pool, &io_out), + compressed_size); + EXPECT_LE(ButteraugliDistance(io_dim, io_out, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.5); + EXPECT_EQ(io_dim.metadata.m.IntensityTarget(), + io_out.metadata.m.IntensityTarget()); + } +} + +TEST(JxlTest, RoundtripOtherTransforms) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/64px/a2d1un_nkitzmiller_srgb8.png"); + std::unique_ptr io = jxl::make_unique(); + ASSERT_TRUE(SetFromBytes(Span(orig), io.get(), pool)); + + CompressParams cparams; + // Slow modes access linear image for adaptive quant search + cparams.speed_tier = SpeedTier::kKitten; + cparams.color_transform = ColorTransform::kNone; + cparams.butteraugli_distance = 5.0f; + DecompressParams dparams; + + std::unique_ptr io2 = jxl::make_unique(); + const size_t compressed_size = + Roundtrip(io.get(), cparams, dparams, pool, io2.get()); + EXPECT_LE(compressed_size, 23000); + EXPECT_LE(ButteraugliDistance(*io, *io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 6); + + // Check the consistency when performing another roundtrip. + std::unique_ptr io3 = jxl::make_unique(); + const size_t compressed_size2 = + Roundtrip(io.get(), cparams, dparams, pool, io3.get()); + EXPECT_LE(compressed_size2, 23000); + EXPECT_LE(ButteraugliDistance(*io, *io3, cparams.ba_params, + /*distmap=*/nullptr, pool), + 6); +} + +TEST(JxlTest, RoundtripResample2) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 2; + DecompressParams dparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 15777); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 13.5); +} + +TEST(JxlTest, RoundtripResample4) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 4; + DecompressParams dparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 6000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 28); +} + +TEST(JxlTest, RoundtripResample8) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize(), io.ysize()); + CompressParams cparams; + cparams.resampling = 8; + DecompressParams dparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 2100); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 80); +} + +TEST(JxlTest, RoundtripUnalignedD2) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 12, io.ysize() / 7); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 700); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 3.2); +} + +#if JXL_TEST_NL + +TEST(JxlTest, RoundtripMultiGroupNL) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); // partial X, full Y group + + CompressParams cparams; + DecompressParams dparams; + + cparams.fast_mode = true; + cparams.butteraugli_distance = 1.0f; + CodecInOut io2; + Roundtrip(&io, cparams, dparams, &pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 0.9f); + + cparams.butteraugli_distance = 2.0f; + CodecInOut io3; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io3), 80000); + EXPECT_LE(ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 1.5f); +} + +#endif + +TEST(JxlTest, RoundtripMultiGroup) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + DecompressParams dparams; + + cparams.butteraugli_distance = 1.0f; + cparams.speed_tier = SpeedTier::kKitten; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 40000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 1.99f); + + cparams.butteraugli_distance = 2.0f; + CodecInOut io3; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io3), 22100); + EXPECT_LE(ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 3.0f); +} + +TEST(JxlTest, RoundtripLargeFast) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 265000); +} + +TEST(JxlTest, RoundtripDotsForceEpf) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.epf = 2; + cparams.dots = Override::kOn; + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 265000); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(JxlTest, RoundtripD2Consistent) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 2.0; + DecompressParams dparams; + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + io.ShrinkTo(xsize, 15); + + CodecInOut io2; + const size_t size2 = Roundtrip(&io, cparams, dparams, &pool, &io2); + + CodecInOut io3; + const size_t size3 = Roundtrip(&io, cparams, dparams, &pool, &io3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool); + const float dist3 = ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool); + EXPECT_EQ(dist2, dist3); + } +} + +// Same as above, but for full image, testing multiple groups. +TEST(JxlTest, RoundtripLargeConsistent) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 2.0; + DecompressParams dparams; + + // Try each xsize mod kBlockDim to verify right border handling. + CodecInOut io2; + const size_t size2 = Roundtrip(&io, cparams, dparams, &pool, &io2); + + CodecInOut io3; + const size_t size3 = Roundtrip(&io, cparams, dparams, &pool, &io3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool); + const float dist3 = ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool); + EXPECT_EQ(dist2, dist3); +} + +#if JXL_TEST_NL + +TEST(JxlTest, RoundtripSmallNL) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 1500); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.7); +} + +#endif + +TEST(JxlTest, RoundtripNoGaborishNoAR) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 40000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.5); +} + +TEST(JxlTest, RoundtripSmallNoGaborish) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 900); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.7); +} + +TEST(JxlTest, RoundtripSmallPatchesAlpha) { + ThreadPool* pool = nullptr; + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + Image3F black_with_small_lines(256, 256); + ImageF alpha(black_with_small_lines.xsize(), black_with_small_lines.ysize()); + ZeroFillImage(&black_with_small_lines); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < black_with_small_lines.ysize(); y++) { + float* JXL_RESTRICT row = black_with_small_lines.PlaneRow(1, y); + for (size_t x = 0; x < black_with_small_lines.xsize(); x++) { + if (x % 4 == 0 && (y / 32) % 4 == 0) row[x] = 127.0f; + } + } + io.metadata.m.SetAlphaBits(8); + io.SetFromImage(std::move(black_with_small_lines), + ColorEncoding::LinearSRGB()); + FillImage(1.0f, &alpha); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 0.1f; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 2000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 0.5f); +} + +TEST(JxlTest, RoundtripSmallPatches) { + ThreadPool* pool = nullptr; + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + Image3F black_with_small_lines(256, 256); + ZeroFillImage(&black_with_small_lines); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < black_with_small_lines.ysize(); y++) { + float* JXL_RESTRICT row = black_with_small_lines.PlaneRow(1, y); + for (size_t x = 0; x < black_with_small_lines.xsize(); x++) { + if (x % 4 == 0 && (y / 32) % 4 == 0) row[x] = 127.0f; + } + } + io.SetFromImage(std::move(black_with_small_lines), + ColorEncoding::LinearSRGB()); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.butteraugli_distance = 0.1f; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 2000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 0.5f); +} + +// Test header encoding of original bits per sample +TEST(JxlTest, RoundtripImageBundleOriginalBits) { + ThreadPool* pool = nullptr; + + // Image does not matter, only io.metadata.m and io2.metadata.m are tested. + Image3F image(1, 1); + ZeroFillImage(&image); + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(std::move(image), ColorEncoding::LinearSRGB()); + + CompressParams cparams; + DecompressParams dparams; + + // Test unsigned integers from 1 to 32 bits + for (uint32_t bit_depth = 1; bit_depth <= 32; bit_depth++) { + if (bit_depth == 32) { + // TODO(lode): allow testing 32, however the code below ends up in + // enc_modular which does not support 32. We only want to test the header + // encoding though, so try without modular. + break; + } + + io.metadata.m.SetUintSamples(bit_depth); + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0, io2.metadata.m.GetAlphaBits()); + } + + // Test various existing and non-existing floating point formats + for (uint32_t bit_depth = 8; bit_depth <= 32; bit_depth++) { + if (bit_depth != 32) { + // TODO: test other float types once they work + break; + } + + uint32_t exponent_bit_depth; + if (bit_depth < 10) { + exponent_bit_depth = 2; + } else if (bit_depth < 12) { + exponent_bit_depth = 3; + } else if (bit_depth < 16) { + exponent_bit_depth = 4; + } else if (bit_depth < 20) { + exponent_bit_depth = 5; + } else if (bit_depth < 24) { + exponent_bit_depth = 6; + } else if (bit_depth < 28) { + exponent_bit_depth = 7; + } else { + exponent_bit_depth = 8; + } + + io.metadata.m.bit_depth.bits_per_sample = bit_depth; + io.metadata.m.bit_depth.floating_point_sample = true; + io.metadata.m.bit_depth.exponent_bits_per_sample = exponent_bit_depth; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_TRUE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(exponent_bit_depth, + io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0, io2.metadata.m.GetAlphaBits()); + } +} + +TEST(JxlTest, RoundtripGrayscale) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_NE(io.xsize(), 0); + io.ShrinkTo(128, 128); + EXPECT_TRUE(io.Main().IsGray()); + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + + { + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + PaddedBytes compressed; + EXPECT_TRUE( + EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 7000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.7777777); + } + + // Test with larger butteraugli distance and other settings enabled so + // different jxl codepaths trigger. + { + CompressParams cparams; + cparams.butteraugli_distance = 8.0; + DecompressParams dparams; + + PaddedBytes compressed; + EXPECT_TRUE( + EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 1300); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 9.0); + } +} + +TEST(JxlTest, RoundtripAlpha) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(128, 128); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 5500); + + // TODO(robryk): Fix the following line in presence of different alpha_bits in + // the two contexts. + // EXPECT_TRUE(SamePixels(io.Main().alpha(), io2.Main().alpha())); + // TODO(robryk): Fix the distance estimate used in the encoder. + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 6.3); +} + +TEST(JxlTest, RoundtripAlphaNonMultipleOf8) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(12, 12); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + DecompressParams dparams; + + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 200); + + // TODO(robryk): Fix the following line in presence of different alpha_bits in + // the two contexts. + // EXPECT_TRUE(SamePixels(io.Main().alpha(), io2.Main().alpha())); + // TODO(robryk): Fix the distance estimate used in the encoder. + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 6.3); +} + +TEST(JxlTest, RoundtripAlpha16) { + ThreadPoolInternal pool(4); + + size_t xsize = 1200, ysize = 160; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // Generate 16-bit pattern that uses various colors and alpha values. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 65535 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 65535 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + // The image is wider than 512 pixels to ensure multiple groups are tested. + + ASSERT_NE(io.xsize(), 0); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + + CompressParams cparams; + cparams.butteraugli_distance = 0.5; + // Prevent the test to be too slow, does not affect alpha + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + io.metadata.m.SetUintSamples(16); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE( + EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, &pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, &pool)); + + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); +} + +namespace { +CompressParams CParamsForLossless() { + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.quality_pair = {100, 100}; + cparams.options.predictor = {Predictor::Weighted}; + return cparams; +} +} // namespace + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 3500000); + // If this test fails with a very close to 0.0 but not exactly 0.0 butteraugli + // distance, then there is likely a floating point issue, that could be + // happening either in io or io2. The values of io are generated by + // external_image.cc, and those in io2 by the jxl decoder. If they use + // slightly different floating point operations (say, one casts int to float + // while other divides the int through 255.0f and later multiplies it by + // 255 again) they will get slightly different values. To fix, ensure both + // sides do the following formula for converting integer range 0-255 to + // floating point range 0.0f-255.0f: static_cast(i) + // without any further intermediate operations. + // Note that this precision issue is not a problem in practice if the values + // are equal when rounded to 8-bit int, but currently full exact precision is + // tested. + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLosslessNoEncoderFastPath)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.options.skip_encoder_fast_path = true; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 3500000); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8Falcon)) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams = CParamsForLossless(); + cparams.speed_tier = SpeedTier::kFalcon; + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 3500000); + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripLossless8Alpha) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + EXPECT_EQ(8, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 350000); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool)); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); + EXPECT_EQ(8, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(8, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +TEST(JxlTest, RoundtripLossless16Alpha) { + ThreadPool* pool = nullptr; + + size_t xsize = 1200, ysize = 160; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // Generate 16-bit pattern that uses various colors and alpha values. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 65535 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 65535 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (x * 65535 / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + EXPECT_EQ(16, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 7100); + // If this test fails with a very close to 0.0 but not exactly 0.0 butteraugli + // distance, then there is likely a floating point issue, that could be + // happening either in io or io2. The values of io are generated by + // external_image.cc, and those in io2 by the jxl decoder. If they use + // slightly different floating point operations (say, one does "i / 257.0f" + // while the other does "i * (1.0f / 257)" they will get slightly different + // values. To fix, ensure both sides do the following formula for converting + // integer range 0-65535 to Image3F floating point range 0.0f-255.0f: + // "i * (1.0f / 257)". + // Note that this precision issue is not a problem in practice if the values + // are equal when rounded to 16-bit int, but currently full exact precision is + // tested. + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool)); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); + EXPECT_EQ(16, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(16, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +TEST(JxlTest, RoundtripLossless16AlphaNotMisdetectedAs8Bit) { + ThreadPool* pool = nullptr; + + size_t xsize = 128, ysize = 128; + Image3F color(xsize, ysize); + ImageF alpha(xsize, ysize); + // All 16-bit values, both color and alpha, of this image are below 64. + // This allows testing if a code path wrongly concludes it's an 8-bit instead + // of 16-bit image (or even 6-bit). + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + color.PlaneRow(0, y)[x] = (y * 64 / ysize) * (1.0f / 65535); + color.PlaneRow(1, y)[x] = (x * 64 / xsize) * (1.0f / 65535); + color.PlaneRow(2, y)[x] = + ((y + x) * 64 / (xsize + ysize)) * (1.0f / 65535); + alpha.Row(y)[x] = (64 * x / xsize) * (1.0f / 65535); + } + } + const bool is_gray = false; + CodecInOut io; + io.metadata.m.SetUintSamples(16); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = ColorEncoding::SRGB(is_gray); + io.SetFromImage(std::move(color), io.metadata.m.color_encoding); + io.Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + + EXPECT_EQ(16, io.metadata.m.GetAlphaBits()); + EXPECT_EQ(16, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 3100); + EXPECT_EQ(16, io2.metadata.m.GetAlphaBits()); + EXPECT_EQ(16, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io2.metadata.m.bit_depth.exponent_bits_per_sample); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool)); + EXPECT_TRUE(SamePixels(*io.Main().alpha(), *io2.Main().alpha())); +} + +TEST(JxlTest, RoundtripDots) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0); + + CompressParams cparams; + cparams.dots = Override::kOn; + cparams.butteraugli_distance = 0.04; + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 400000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.2); +} + +TEST(JxlTest, RoundtripNoise) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + ASSERT_NE(io.xsize(), 0); + + CompressParams cparams; + cparams.noise = Override::kOn; + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.tf.IsSRGB()); + PassesEncoderState enc_state; + AuxOut* aux_out = nullptr; + PaddedBytes compressed; + EXPECT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, aux_out, pool)); + CodecInOut io2; + EXPECT_TRUE(DecodeFile(dparams, compressed, &io2, pool)); + + EXPECT_LE(compressed.size(), 40000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.2); +} + +TEST(JxlTest, RoundtripLossless8Gray) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + + EXPECT_TRUE(io.Main().IsGray()); + EXPECT_EQ(8, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io.metadata.m.bit_depth.exponent_bits_per_sample); + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 130000); + // If fails, see note about floating point in RoundtripLossless8. + EXPECT_EQ(0.0, ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool)); + EXPECT_TRUE(io2.Main().IsGray()); + EXPECT_EQ(8, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0, io2.metadata.m.bit_depth.exponent_bits_per_sample); +} + +#if JPEGXL_ENABLE_GIF + +TEST(JxlTest, RoundtripAnimation) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/traffic_light.gif"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_EQ(4, io.frames.size()); + + CompressParams cparams; + DecompressParams dparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 3000); + + EXPECT_EQ(io2.frames.size(), io.frames.size()); + test::CoalesceGIFAnimationWithAlpha(&io); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), +#if JXL_HIGH_PRECISION + 1.55); +#else + 1.75); +#endif +} + +TEST(JxlTest, RoundtripLosslessAnimation) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/traffic_light.gif"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + ASSERT_EQ(4, io.frames.size()); + + CompressParams cparams = CParamsForLossless(); + DecompressParams dparams; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, pool, &io2), 1200); + + EXPECT_EQ(io2.frames.size(), io.frames.size()); + test::CoalesceGIFAnimationWithAlpha(&io); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 5e-4); +} + +#endif // JPEGXL_ENABLE_GIF + +#if JPEGXL_ENABLE_JPEG + +namespace { + +jxl::Status DecompressJxlToJPEGForTest( + const jpegxl::tools::JpegXlContainer& container, jxl::ThreadPool* pool, + jxl::PaddedBytes* output) { + output->clear(); + jxl::Span compressed(container.codestream, + container.codestream_size); + + JXL_RETURN_IF_ERROR(compressed.size() >= 2); + + // JXL case + // Decode to DCT when possible and generate a JPG file. + jxl::CodecInOut io; + jxl::DecompressParams params; + params.keep_dct = true; + if (!jpegxl::tools::DecodeJpegXlToJpeg(params, container, &io, pool)) { + return JXL_FAILURE("Failed to decode JXL to JPEG"); + } + io.jpeg_quality = 95; + if (!EncodeImageJPG(&io, jxl::JpegEncoder::kLibJpeg, io.jpeg_quality, + jxl::YCbCrChromaSubsampling(), pool, output, + jxl::DecodeTarget::kQuantizedCoeffs)) { + return JXL_FAILURE("Failed to generate JPEG"); + } + return true; +} + +} // namespace + +size_t RoundtripJpeg(const PaddedBytes& jpeg_in, ThreadPool* pool) { + CodecInOut io; + io.dec_target = jxl::DecodeTarget::kQuantizedCoeffs; + EXPECT_TRUE(SetFromBytes(Span(jpeg_in), &io, pool)); + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + PassesEncoderState passes_enc_state; + PaddedBytes compressed, codestream; + + EXPECT_TRUE(EncodeFile(cparams, &io, &passes_enc_state, &codestream, + /*aux_out=*/nullptr, pool)); + jpegxl::tools::JpegXlContainer enc_container; + enc_container.codestream = codestream.data(); + enc_container.codestream_size = codestream.size(); + jpeg::JPEGData data_in = *io.Main().jpeg_data; + jxl::PaddedBytes jpeg_data; + EXPECT_TRUE(EncodeJPEGData(data_in, &jpeg_data)); + enc_container.jpeg_reconstruction = jpeg_data.data(); + enc_container.jpeg_reconstruction_size = jpeg_data.size(); + EXPECT_TRUE(EncodeJpegXlContainerOneShot(enc_container, &compressed)); + + jpegxl::tools::JpegXlContainer container; + EXPECT_TRUE(DecodeJpegXlContainerOneShot(compressed.data(), compressed.size(), + &container)); + PaddedBytes out; + EXPECT_TRUE(DecompressJxlToJPEGForTest(container, pool, &out)); + EXPECT_EQ(out.size(), jpeg_in.size()); + size_t failures = 0; + for (size_t i = 0; i < std::min(out.size(), jpeg_in.size()); i++) { + if (out[i] != jpeg_in[i]) { + EXPECT_EQ(out[i], jpeg_in[i]) + << "byte mismatch " << i << " " << out[i] << " != " << jpeg_in[i]; + if (++failures > 4) { + return compressed.size(); + } + } + } + return compressed.size(); +} + +TEST(JxlTest, RoundtripJpegRecompression444) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_444.jpg"); + // JPEG size is 326'916 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 256000); +} + +TEST(JxlTest, RoundtripJpegRecompressionToPixels) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_444.jpg"); + CodecInOut io; + io.dec_target = jxl::DecodeTarget::kQuantizedCoeffs; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + DecompressParams dparams; + + CodecInOut io3; + Roundtrip(&io, cparams, dparams, &pool, &io3); + + // TODO(eustas): investigate, why SJPEG and JpegRecompression pixels are + // different. + EXPECT_GE(1.8, ButteraugliDistance(io2, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripJpegRecompressionToPixels420) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_420.jpg"); + CodecInOut io; + io.dec_target = jxl::DecodeTarget::kQuantizedCoeffs; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + DecompressParams dparams; + + CodecInOut io3; + Roundtrip(&io, cparams, dparams, &pool, &io3); + + EXPECT_GE(1.5, ButteraugliDistance(io2, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripJpegRecompressionToPixels_asymmetric) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "imagecompression.info/flower_foveon.png.im_q85_asymmetric.jpg"); + CodecInOut io; + io.dec_target = jxl::DecodeTarget::kQuantizedCoeffs; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CodecInOut io2; + ASSERT_TRUE(SetFromBytes(Span(orig), &io2, &pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kYCbCr; + + DecompressParams dparams; + + CodecInOut io3; + Roundtrip(&io, cparams, dparams, &pool, &io3); + + EXPECT_GE(1.5, ButteraugliDistance(io2, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool)); +} + +TEST(JxlTest, RoundtripJpegRecompressionGray) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_gray.jpg"); + // JPEG size is 167'025 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 140000); +} + +TEST(JxlTest, RoundtripJpegRecompression420) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_420.jpg"); + // JPEG size is 226'018 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 181050); +} + +TEST(JxlTest, RoundtripJpegRecompression_luma_subsample) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "imagecompression.info/flower_foveon.png.im_q85_luma_subsample.jpg"); + // JPEG size is 216'069 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 181000); +} + +TEST(JxlTest, RoundtripJpegRecompression444_12) { + // 444 JPEG that has an interesting sampling-factor (1x2, 1x2, 1x2). + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "imagecompression.info/flower_foveon.png.im_q85_444_1x2.jpg"); + // JPEG size is 329'942 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 256000); +} + +TEST(JxlTest, RoundtripJpegRecompression422) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_422.jpg"); + // JPEG size is 265'590 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 209000); +} + +TEST(JxlTest, RoundtripJpegRecompression440) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png.im_q85_440.jpg"); + // JPEG size is 262'249 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 209000); +} + +TEST(JxlTest, RoundtripJpegRecompression_asymmetric) { + // 2x vertical downsample of one chroma channel, 2x horizontal downsample of + // the other. + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "imagecompression.info/flower_foveon.png.im_q85_asymmetric.jpg"); + // JPEG size is 262'249 bytes. + EXPECT_LE(RoundtripJpeg(orig, &pool), 209000); +} + +TEST(JxlTest, RoundtripJpegRecompression420Progr) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = ReadTestData( + "imagecompression.info/flower_foveon.png.im_q85_420_progr.jpg"); + EXPECT_LE(RoundtripJpeg(orig, &pool), 181000); +} + +#endif // JPEGXL_ENABLE_JPEG + +TEST(JxlTest, RoundtripProgressive) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + DecompressParams dparams; + + cparams.butteraugli_distance = 1.0f; + cparams.progressive_dc = true; + cparams.responsive = true; + cparams.progressive_mode = true; + CodecInOut io2; + EXPECT_LE(Roundtrip(&io, cparams, dparams, &pool, &io2), 40000); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 4.0f); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/lehmer_code.h b/third_party/jpeg-xl/lib/jxl/lehmer_code.h new file mode 100644 index 000000000000..98d83d8a455b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/lehmer_code.h @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_LEHMER_CODE_H_ +#define LIB_JXL_LEHMER_CODE_H_ + +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Permutation <=> factorial base representation (Lehmer code). + +using LehmerT = uint32_t; + +template +constexpr T ValueOfLowest1Bit(T t) { + return t & -t; +} + +// Computes the Lehmer (factorial basis) code of permutation, an array of n +// unique indices in [0..n), and stores it in code[0..len). N*logN time. +// temp must have n + 1 elements but need not be initialized. +template +void ComputeLehmerCode(const PermutationT* JXL_RESTRICT permutation, + uint32_t* JXL_RESTRICT temp, const size_t n, + LehmerT* JXL_RESTRICT code) { + for (size_t idx = 0; idx < n + 1; ++idx) temp[idx] = 0; + + for (size_t idx = 0; idx < n; ++idx) { + const PermutationT s = permutation[idx]; + + // Compute sum in Fenwick tree + uint32_t penalty = 0; + uint32_t i = s + 1; + while (i != 0) { + penalty += temp[i]; + i &= i - 1; // clear lowest bit + } + JXL_DASSERT(s >= penalty); + code[idx] = s - penalty; + i = s + 1; + // Add operation in Fenwick tree + while (i < n + 1) { + temp[i] += 1; + i += ValueOfLowest1Bit(i); + } + } +} + +// Decodes the Lehmer code in code[0..n) into permutation[0..n). +// temp must have 1 << CeilLog2(n) elements but need not be initialized. +template +void DecodeLehmerCode(const LehmerT* JXL_RESTRICT code, + uint32_t* JXL_RESTRICT temp, size_t n, + PermutationT* JXL_RESTRICT permutation) { + JXL_DASSERT(n != 0); + const size_t log2n = CeilLog2Nonzero(n); + const size_t padded_n = 1ull << log2n; + + for (size_t i = 0; i < padded_n; i++) { + const int32_t i1 = static_cast(i + 1); + temp[i] = static_cast(ValueOfLowest1Bit(i1)); + } + + for (size_t i = 0; i < n; i++) { + JXL_DASSERT(code[i] + i < n); + uint32_t rank = code[i] + 1; + + // Extract i-th unused element via implicit order-statistics tree. + size_t bit = padded_n; + size_t next = 0; + for (size_t i = 0; i <= log2n; i++) { + const size_t cand = next + bit; + JXL_DASSERT(cand >= 1); + bit >>= 1; + if (temp[cand - 1] < rank) { + next = cand; + rank -= temp[cand - 1]; + } + } + + permutation[i] = next; + + // Mark as used + next += 1; + while (next <= padded_n) { + temp[next - 1] -= 1; + next += ValueOfLowest1Bit(next); + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_LEHMER_CODE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc b/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc new file mode 100644 index 000000000000..b0dea2d46762 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/lehmer_code.h" + +#include +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" + +namespace jxl { +namespace { + +template +struct WorkingSet { + explicit WorkingSet(size_t max_n) + : padded_n(1ull << CeilLog2Nonzero(max_n + 1)), + permutation(max_n), + temp(padded_n), + lehmer(max_n), + decoded(max_n) {} + + size_t padded_n; + std::vector permutation; + std::vector temp; + std::vector lehmer; + std::vector decoded; +}; + +template +void Roundtrip(size_t n, WorkingSet* ws) { + JXL_ASSERT(n != 0); + const size_t padded_n = 1ull << CeilLog2Nonzero(n); + + std::mt19937 rng(n * 65537 + 13); + + // Ensure indices fit into PermutationT + EXPECT_LE(n, 1ULL << (sizeof(PermutationT) * 8)); + + std::iota(ws->permutation.begin(), ws->permutation.begin() + n, 0); + + // For various random permutations: + for (size_t rep = 0; rep < 100; ++rep) { + std::shuffle(ws->permutation.begin(), ws->permutation.begin() + n, rng); + + // Must decode to the same permutation + ComputeLehmerCode(ws->permutation.data(), ws->temp.data(), n, + ws->lehmer.data()); + memset(ws->temp.data(), 0, padded_n * 4); + DecodeLehmerCode(ws->lehmer.data(), ws->temp.data(), n, ws->decoded.data()); + + for (size_t i = 0; i < n; ++i) { + EXPECT_EQ(ws->permutation[i], ws->decoded[i]); + } + } +} + +// Preallocates arrays and tests n = [begin, end). +template +void RoundtripSizeRange(ThreadPool* pool, uint32_t begin, uint32_t end) { + ASSERT_NE(0, begin); // n = 0 not allowed. + std::vector> working_sets; + + RunOnPool( + pool, begin, end, + [&working_sets, end](size_t num_threads) { + for (size_t i = 0; i < num_threads; i++) { + working_sets.emplace_back(end - 1); + } + return true; + }, + [&working_sets](int n, int thread) { + Roundtrip(n, &working_sets[thread]); + }, + "lehmer test"); +} + +TEST(LehmerCodeTest, TestRoundtrips) { + ThreadPoolInternal pool(8); + + RoundtripSizeRange(&pool, 1, 1026); + + // Ensures PermutationT can fit > 16 bit values. + RoundtripSizeRange(&pool, 65536, 65540); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/libjxl.pc.in b/third_party/jpeg-xl/lib/jxl/libjxl.pc.in new file mode 100644 index 000000000000..5dca2ac16885 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/libjxl.pc.in @@ -0,0 +1,12 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libjxl +Description: Loads and saves JPEG XL files +Version: @JPEGXL_LIBRARY_VERSION@ +Requires.private: @JPEGXL_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl +Libs.private: -lm +Cflags: -I${includedir} diff --git a/third_party/jpeg-xl/lib/jxl/linalg.cc b/third_party/jpeg-xl/lib/jxl/linalg.cc new file mode 100644 index 000000000000..6ff1a3838c17 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/linalg.cc @@ -0,0 +1,244 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/linalg.h" + +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +void AssertSymmetric(const ImageD& A) { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(A.xsize() == A.ysize()); + for (size_t i = 0; i < A.xsize(); ++i) { + for (size_t j = i + 1; j < A.xsize(); ++j) { + JXL_ASSERT(std::abs(A.Row(i)[j] - A.Row(j)[i]) < 1e-15); + } + } +#endif +} + +void Diagonalize2x2(const double a0, const double a1, const double b, double* c, + double* s) { + if (std::abs(b) < 1e-15) { + *c = 1.0; + *s = 0.0; + return; + } + double phi = std::atan2(2 * b, a1 - a0); + double theta = b > 0.0 ? 0.5 * phi : 0.5 * phi + Pi(1.0); + *c = std::cos(theta); + *s = std::sin(theta); +} + +void GivensRotation(const double x, const double y, double* c, double* s) { + if (y == 0.0) { + *c = x < 0.0 ? -1.0 : 1.0; + *s = 0.0; + } else { + const double h = hypot(x, y); + const double d = 1.0 / h; + *c = x * d; + *s = -y * d; + } +} + +void RotateMatrixCols(ImageD* const JXL_RESTRICT U, int i, int j, double c, + double s) { + JXL_ASSERT(U->xsize() == U->ysize()); + const size_t N = U->xsize(); + double* const JXL_RESTRICT u_i = U->Row(i); + double* const JXL_RESTRICT u_j = U->Row(j); + std::vector rot_i, rot_j; + rot_i.reserve(N); + rot_j.reserve(N); + for (size_t k = 0; k < N; ++k) { + rot_i.push_back(u_i[k] * c - u_j[k] * s); + rot_j.push_back(u_i[k] * s + u_j[k] * c); + } + for (size_t k = 0; k < N; ++k) { + u_i[k] = rot_i[k]; + u_j[k] = rot_j[k]; + } +} +void HouseholderReflector(const size_t N, const double* x, double* u) { + const double sigma = x[0] <= 0.0 ? 1.0 : -1.0; + u[0] = x[0] - sigma * std::sqrt(DotProduct(N, x, x)); + for (size_t k = 1; k < N; ++k) { + u[k] = x[k]; + } + double u_norm = 1.0 / std::sqrt(DotProduct(N, u, u)); + for (size_t k = 0; k < N; ++k) { + u[k] *= u_norm; + } +} + +void ConvertToTridiagonal(const ImageD& A, ImageD* const JXL_RESTRICT T, + ImageD* const JXL_RESTRICT U) { + AssertSymmetric(A); + const size_t N = A.xsize(); + *U = Identity(A.xsize()); + *T = CopyImage(A); + std::vector u_stack; + for (size_t k = 0; k + 2 < N; ++k) { + if (DotProduct(N - k - 2, &T->Row(k)[k + 2], &T->Row(k)[k + 2]) > 1e-15) { + ImageD u(N, 1); + ZeroFillImage(&u); + HouseholderReflector(N - k - 1, &T->Row(k)[k + 1], &u.Row(0)[k + 1]); + ImageD v = MatMul(*T, u); + double scale = DotProduct(u, v); + v = LinComb(2.0, v, -2.0 * scale, u); + SubtractFrom(MatMul(u, Transpose(v)), T); + SubtractFrom(MatMul(v, Transpose(u)), T); + u_stack.emplace_back(std::move(u)); + } + } + while (!u_stack.empty()) { + const ImageD& u = u_stack.back(); + ImageD v = MatMul(Transpose(*U), u); + SubtractFrom(ScaleImage(2.0, MatMul(u, Transpose(v))), U); + u_stack.pop_back(); + } +} + +double WilkinsonShift(const double a0, const double a1, const double b) { + const double d = 0.5 * (a0 - a1); + if (d == 0.0) { + return a1 - std::abs(b); + } + const double sign_d = d > 0.0 ? 1.0 : -1.0; + return a1 - b * b / (d + sign_d * hypotf(d, b)); +} + +void ImplicitQRStep(ImageD* const JXL_RESTRICT U, double* const JXL_RESTRICT a, + double* const JXL_RESTRICT b, int m0, int m1) { + JXL_ASSERT(m1 - m0 > 2); + double x = a[m0] - WilkinsonShift(a[m1 - 2], a[m1 - 1], b[m1 - 1]); + double y = b[m0 + 1]; + for (int k = m0; k < m1 - 1; ++k) { + double c, s; + GivensRotation(x, y, &c, &s); + const double w = c * x - s * y; + const double d = a[k] - a[k + 1]; + const double z = (2 * c * b[k + 1] + d * s) * s; + a[k] -= z; + a[k + 1] += z; + b[k + 1] = d * c * s + (c * c - s * s) * b[k + 1]; + x = b[k + 1]; + if (k > m0) { + b[k] = w; + } + if (k < m1 - 2) { + y = -s * b[k + 2]; + b[k + 2] *= c; + } + RotateMatrixCols(U, k, k + 1, c, s); + } +} + +void ScanInterval(const double* const JXL_RESTRICT a, + const double* const JXL_RESTRICT b, int istart, + const int iend, const double eps, + std::deque >* intervals) { + for (int k = istart; k < iend; ++k) { + if ((k + 1 == iend) || + std::abs(b[k + 1]) < eps * (std::abs(a[k]) + std::abs(a[k + 1]))) { + if (k > istart) { + intervals->push_back(std::make_pair(istart, k + 1)); + } + istart = k + 1; + } + } +} + +void ConvertToDiagonal(const ImageD& A, ImageD* const JXL_RESTRICT diag, + ImageD* const JXL_RESTRICT U) { + AssertSymmetric(A); + const size_t N = A.xsize(); + ImageD T; + ConvertToTridiagonal(A, &T, U); + // From now on, the algorithm keeps the transformed matrix tri-diagonal, + // so we only need to keep track of the diagonal and the off-diagonal entries. + std::vector a(N); + std::vector b(N); + for (size_t k = 0; k < N; ++k) { + a[k] = T.Row(k)[k]; + if (k > 0) b[k] = T.Row(k)[k - 1]; + } + // Run the symmetric tri-diagonal QR algorithm with implicit Wilkinson shift. + const double kEpsilon = 1e-14; + std::deque > intervals; + ScanInterval(&a[0], &b[0], 0, N, kEpsilon, &intervals); + while (!intervals.empty()) { + const int istart = intervals[0].first; + const int iend = intervals[0].second; + intervals.pop_front(); + if (iend == istart + 2) { + double& a0 = a[istart]; + double& a1 = a[istart + 1]; + double& b1 = b[istart + 1]; + double c, s; + Diagonalize2x2(a0, a1, b1, &c, &s); + const double d = a0 - a1; + const double z = (2 * c * b1 + d * s) * s; + a0 -= z; + a1 += z; + b1 = 0.0; + RotateMatrixCols(U, istart, istart + 1, c, s); + } else { + ImplicitQRStep(U, &a[0], &b[0], istart, iend); + ScanInterval(&a[0], &b[0], istart, iend, kEpsilon, &intervals); + } + } + *diag = ImageD(N, 1); + double* const JXL_RESTRICT diag_row = diag->Row(0); + for (size_t k = 0; k < N; ++k) { + diag_row[k] = a[k]; + } +} + +void ComputeQRFactorization(const ImageD& A, ImageD* const JXL_RESTRICT Q, + ImageD* const JXL_RESTRICT R) { + JXL_ASSERT(A.xsize() == A.ysize()); + const size_t N = A.xsize(); + *Q = Identity(N); + *R = CopyImage(A); + std::vector u_stack; + for (size_t k = 0; k + 1 < N; ++k) { + if (DotProduct(N - k - 1, &R->Row(k)[k + 1], &R->Row(k)[k + 1]) > 1e-15) { + ImageD u(N, 1); + FillImage(0.0, &u); + HouseholderReflector(N - k, &R->Row(k)[k], &u.Row(0)[k]); + ImageD v = MatMul(Transpose(u), *R); + SubtractFrom(ScaleImage(2.0, MatMul(u, v)), R); + u_stack.emplace_back(std::move(u)); + } + } + while (!u_stack.empty()) { + const ImageD& u = u_stack.back(); + ImageD v = MatMul(Transpose(u), *Q); + SubtractFrom(ScaleImage(2.0, MatMul(u, v)), Q); + u_stack.pop_back(); + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/linalg.h b/third_party/jpeg-xl/lib/jxl/linalg.h new file mode 100644 index 000000000000..a424eeb2cf50 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/linalg.h @@ -0,0 +1,299 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_LINALG_H_ +#define LIB_JXL_LINALG_H_ + +// Linear algebra. + +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +using ImageD = Plane; + +template +inline T DotProduct(const size_t N, const T* const JXL_RESTRICT a, + const T* const JXL_RESTRICT b) { + T sum = 0.0; + for (size_t k = 0; k < N; ++k) { + sum += a[k] * b[k]; + } + return sum; +} + +template +inline T L2NormSquared(const size_t N, const T* const JXL_RESTRICT a) { + return DotProduct(N, a, a); +} + +template +inline T L1Norm(const size_t N, const T* const JXL_RESTRICT a) { + T sum = 0; + for (size_t k = 0; k < N; ++k) { + sum += a[k] >= 0 ? a[k] : -a[k]; + } + return sum; +} + +inline double DotProduct(const ImageD& a, const ImageD& b) { + JXL_ASSERT(a.ysize() == 1); + JXL_ASSERT(b.ysize() == 1); + JXL_ASSERT(a.xsize() == b.xsize()); + const double* const JXL_RESTRICT row_a = a.Row(0); + const double* const JXL_RESTRICT row_b = b.Row(0); + return DotProduct(a.xsize(), row_a, row_b); +} + +inline ImageD Transpose(const ImageD& A) { + ImageD out(A.ysize(), A.xsize()); + for (size_t x = 0; x < A.xsize(); ++x) { + double* const JXL_RESTRICT row_out = out.Row(x); + for (size_t y = 0; y < A.ysize(); ++y) { + row_out[y] = A.Row(y)[x]; + } + } + return out; +} + +template +Plane MatMul(const Plane& A, const Plane& B) { + JXL_ASSERT(A.ysize() == B.xsize()); + Plane out(A.xsize(), B.ysize()); + for (size_t y = 0; y < B.ysize(); ++y) { + const Tin2* const JXL_RESTRICT row_b = B.Row(y); + Tout* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < A.xsize(); ++x) { + row_out[x] = 0.0; + for (size_t k = 0; k < B.xsize(); ++k) { + row_out[x] += A.Row(k)[x] * row_b[k]; + } + } + } + return out; +} + +template +ImageD MatMul(const Plane& A, const Plane& B) { + return MatMul(A, B); +} + +template +ImageI MatMulI(const Plane& A, const Plane& B) { + return MatMul(A, B); +} + +// Computes A = B * C, with sizes rows*cols: A=ha*wa, B=wa*wb, C=ha*wb +template +void MatMul(const T* a, const T* b, int ha, int wa, int wb, T* c) { + std::vector temp(wa); // Make better use of cache lines + for (int x = 0; x < wb; x++) { + for (int z = 0; z < wa; z++) { + temp[z] = b[z * wb + x]; + } + for (int y = 0; y < ha; y++) { + double e = 0; + for (int z = 0; z < wa; z++) { + e += a[y * wa + z] * temp[z]; + } + c[y * wb + x] = e; + } + } +} + +// Computes C = A + factor * B +template +void MatAdd(const T* a, const T* b, F factor, int h, int w, T* c) { + for (int i = 0; i < w * h; i++) { + c[i] = a[i] + b[i] * factor; + } +} + +template +inline Plane Identity(const size_t N) { + Plane out(N, N); + for (size_t i = 0; i < N; ++i) { + T* JXL_RESTRICT row = out.Row(i); + std::fill(row, row + N, 0); + row[i] = static_cast(1.0); + } + return out; +} + +inline ImageD Diagonal(const ImageD& d) { + JXL_ASSERT(d.ysize() == 1); + ImageD out(d.xsize(), d.xsize()); + const double* JXL_RESTRICT row_diag = d.Row(0); + for (size_t k = 0; k < d.xsize(); ++k) { + double* JXL_RESTRICT row_out = out.Row(k); + std::fill(row_out, row_out + d.xsize(), 0.0); + row_out[k] = row_diag[k]; + } + return out; +} + +// Computes c, s such that c^2 + s^2 = 1 and +// [c -s] [x] = [ * ] +// [s c] [y] [ 0 ] +void GivensRotation(double x, double y, double* c, double* s); + +// U = U * Givens(i, j, c, s) +void RotateMatrixCols(ImageD* JXL_RESTRICT U, int i, int j, double c, double s); + +// A is symmetric, U is orthogonal, T is tri-diagonal and +// A = U * T * Transpose(U). +void ConvertToTridiagonal(const ImageD& A, ImageD* JXL_RESTRICT T, + ImageD* JXL_RESTRICT U); + +// A is symmetric, U is orthogonal, and A = U * Diagonal(diag) * Transpose(U). +void ConvertToDiagonal(const ImageD& A, ImageD* JXL_RESTRICT diag, + ImageD* JXL_RESTRICT U); + +// A is square matrix, Q is orthogonal, R is upper triangular and A = Q * R; +void ComputeQRFactorization(const ImageD& A, ImageD* JXL_RESTRICT Q, + ImageD* JXL_RESTRICT R); + +// Inverts a 3x3 matrix in place +template +void Inv3x3Matrix(T* matrix) { + // Intermediate computation is done in double precision. + double temp[9]; + temp[0] = static_cast(matrix[4]) * matrix[8] - + static_cast(matrix[5]) * matrix[7]; + temp[1] = static_cast(matrix[2]) * matrix[7] - + static_cast(matrix[1]) * matrix[8]; + temp[2] = static_cast(matrix[1]) * matrix[5] - + static_cast(matrix[2]) * matrix[4]; + temp[3] = static_cast(matrix[5]) * matrix[6] - + static_cast(matrix[3]) * matrix[8]; + temp[4] = static_cast(matrix[0]) * matrix[8] - + static_cast(matrix[2]) * matrix[6]; + temp[5] = static_cast(matrix[2]) * matrix[3] - + static_cast(matrix[0]) * matrix[5]; + temp[6] = static_cast(matrix[3]) * matrix[7] - + static_cast(matrix[4]) * matrix[6]; + temp[7] = static_cast(matrix[1]) * matrix[6] - + static_cast(matrix[0]) * matrix[7]; + temp[8] = static_cast(matrix[0]) * matrix[4] - + static_cast(matrix[1]) * matrix[3]; + double idet = + 1.0 / (matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6]); + for (int i = 0; i < 9; i++) { + matrix[i] = temp[i] * idet; + } +} + +// Solves system of linear equations A * X = B using the conjugate gradient +// method. Matrix a must be a n*n, symmetric and positive definite. +// Vectors b and x must have n elements +template +void ConjugateGradient(const T* a, int n, const T* b, T* x) { + std::vector r(n); + MatMul(a, x, n, n, 1, r.data()); + MatAdd(b, r.data(), -1, n, 1, r.data()); + std::vector p = r; + T rr; + MatMul(r.data(), r.data(), 1, n, 1, &rr); // inner product + + if (rr == 0) return; // The initial values were already optimal + + for (int i = 0; i < n; i++) { + std::vector ap(n); + MatMul(a, p.data(), n, n, 1, ap.data()); + T alpha; + MatMul(r.data(), ap.data(), 1, n, 1, &alpha); + // Normally alpha couldn't be zero here but if numerical issues caused it, + // return assuming the solution is close. + if (alpha == 0) return; + alpha = rr / alpha; + MatAdd(x, p.data(), alpha, n, 1, x); + MatAdd(r.data(), ap.data(), -alpha, n, 1, r.data()); + + T rr2; + MatMul(r.data(), r.data(), 1, n, 1, &rr2); // inner product + if (rr2 < 1e-20) break; + + T beta = rr2 / rr; + MatAdd(r.data(), p.data(), beta, 1, n, p.data()); + rr = rr2; + } +} + +// Computes optimal coefficients r to approximate points p with linear +// combination of functions f. The matrix f has h rows and w columns, r has h +// values, p has w values. h is the amount of functions, w the amount of points. +// Uses the finite element method and minimizes mean square error. +template +void FEM(const T* f, int h, int w, const T* p, T* r) { + // Compute "Gramian" matrix G = F * F^T + // Speed up multiplication by using non-zero intervals in sparse F. + std::vector start(h); + std::vector end(h); + for (int y = 0; y < h; y++) { + start[y] = end[y] = 0; + for (int x = 0; x < w; x++) { + if (f[y * w + x] != 0) { + start[y] = x; + break; + } + } + for (int x = w - 1; x >= 0; x--) { + if (f[y * w + x] != 0) { + end[y] = x + 1; + break; + } + } + } + + std::vector g(h * h); + for (int y = 0; y < h; y++) { + for (int x = 0; x <= y; x++) { + T v = 0; + // Intersection of the two sparse intervals. + int s = std::max(start[x], start[y]); + int e = std::min(end[x], end[y]); + for (int z = s; z < e; z++) { + v += f[x * w + z] * f[y * w + z]; + } + // Symmetric, so two values output at once + g[y * h + x] = v; + g[x * h + y] = v; + } + } + + // B vector: sum of each column of F multiplied by corresponding p + std::vector b(h, 0); + for (int y = 0; y < h; y++) { + T v = 0; + for (int x = 0; x < w; x++) { + v += f[y * w + x] * p[x]; + } + b[y] = v; + } + + ConjugateGradient(g.data(), h, b.data(), r); +} + +} // namespace jxl + +#endif // LIB_JXL_LINALG_H_ diff --git a/third_party/jpeg-xl/lib/jxl/linalg_test.cc b/third_party/jpeg-xl/lib/jxl/linalg_test.cc new file mode 100644 index 000000000000..0df1cf10f92c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/linalg_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/linalg.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +template +Plane RandomMatrix(const size_t xsize, const size_t ysize, Random& rng, + const T vmin, const T vmax) { + Plane A(xsize, ysize); + GeneratorRandom gen(&rng, vmin, vmax); + GenerateImage(gen, &A); + return A; +} + +template +Plane RandomSymmetricMatrix(const size_t N, Random& rng, const T vmin, + const T vmax) { + Plane A = RandomMatrix(N, N, rng, vmin, vmax); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < i; ++j) { + A.Row(j)[i] = A.Row(i)[j]; + } + } + return A; +} +void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) { + ASSERT_EQ(A.xsize(), B.xsize()); + ASSERT_EQ(A.ysize(), B.ysize()); + for (size_t y = 0; y < A.ysize(); ++y) { + for (size_t x = 0; x < A.xsize(); ++x) { + ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps); + } + } +} + +void VerifyOrthogonal(const ImageD& A, const double eps) { + VerifyMatrixEqual(Identity(A.xsize()), MatMul(Transpose(A), A), eps); +} + +void VerifyTridiagonal(const ImageD& T, const double eps) { + ASSERT_EQ(T.xsize(), T.ysize()); + for (size_t i = 0; i < T.xsize(); ++i) { + for (size_t j = i + 2; j < T.xsize(); ++j) { + ASSERT_NEAR(T.Row(i)[j], 0.0, eps); + ASSERT_NEAR(T.Row(j)[i], 0.0, eps); + } + } +} + +void VerifyUpperTriangular(const ImageD& R, const double eps) { + ASSERT_EQ(R.xsize(), R.ysize()); + for (size_t i = 0; i < R.xsize(); ++i) { + for (size_t j = i + 1; j < R.xsize(); ++j) { + ASSERT_NEAR(R.Row(i)[j], 0.0, eps); + } + } +} + +TEST(LinAlgTest, ConvertToTridiagonal) { + { + ImageD I = Identity(5); + ImageD T, U; + ConvertToTridiagonal(I, &T, &U); + VerifyMatrixEqual(I, T, 1e-15); + VerifyMatrixEqual(I, U, 1e-15); + } + { + ImageD A = Identity(5); + A.Row(0)[1] = A.Row(1)[0] = 2.0; + A.Row(0)[4] = A.Row(4)[0] = 3.0; + A.Row(2)[3] = A.Row(3)[2] = 2.0; + A.Row(3)[4] = A.Row(4)[3] = 2.0; + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } + std::mt19937_64 rng; + for (int N = 2; N < 100; ++N) { + ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0); + ImageD T, U; + ConvertToTridiagonal(A, &T, &U); + VerifyOrthogonal(U, 1e-12); + VerifyTridiagonal(T, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(T, Transpose(U))), 1e-12); + } +} + +TEST(LinAlgTest, ConvertToDiagonal) { + { + ImageD I = Identity(5); + ImageD U, d; + ConvertToDiagonal(I, &d, &U); + VerifyMatrixEqual(I, U, 1e-15); + for (int k = 0; k < 5; ++k) { + ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15); + } + } + { + ImageD A = Identity(5); + A.Row(0)[1] = A.Row(1)[0] = 2.0; + A.Row(2)[3] = A.Row(3)[2] = 2.0; + A.Row(3)[4] = A.Row(4)[3] = 2.0; + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } + std::mt19937_64 rng; + for (int N = 2; N < 100; ++N) { + ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0); + ImageD U, d; + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } +} + +TEST(LinAlgTest, ComputeQRFactorization) { + { + ImageD I = Identity(5); + ImageD Q, R; + ComputeQRFactorization(I, &Q, &R); + VerifyMatrixEqual(I, Q, 1e-15); + VerifyMatrixEqual(I, R, 1e-15); + } + std::mt19937_64 rng; + for (int N = 2; N < 100; ++N) { + ImageD A = RandomMatrix(N, N, rng, -1.0, 1.0); + ImageD Q, R; + ComputeQRFactorization(A, &Q, &R); + VerifyOrthogonal(Q, 1e-12); + VerifyUpperTriangular(R, 1e-12); + VerifyMatrixEqual(A, MatMul(Q, R), 1e-12); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/loop_filter.cc b/third_party/jpeg-xl/lib/jxl/loop_filter.cc new file mode 100644 index 000000000000..088676d30d26 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/loop_filter.cc @@ -0,0 +1,96 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/loop_filter.h" + +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +LoopFilter::LoopFilter() { Bundle::Init(this); } +Status LoopFilter::VisitFields(Visitor* JXL_RESTRICT visitor) { + // Must come before AllDefault. + + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &gab)); + if (visitor->Conditional(gab)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &gab_custom)); + if (visitor->Conditional(gab_custom)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_x_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_x_weight2)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_y_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_y_weight2)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_b_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_b_weight2)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 2, &epf_iters)); + if (visitor->Conditional(epf_iters > 0)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sharp_custom)); + if (visitor->Conditional(epf_sharp_custom)) { + for (size_t i = 0; i < kEpfSharpEntries; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + float(i) / float(kEpfSharpEntries - 1), &epf_sharp_lut[i])); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_weight_custom)); + if (visitor->Conditional(epf_weight_custom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(40.0f, &epf_channel_scale[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(5.0f, &epf_channel_scale[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(3.5f, &epf_channel_scale[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.45f, &epf_pass1_zeroflush)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.6f, &epf_pass2_zeroflush)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sigma_custom)); + if (visitor->Conditional(epf_sigma_custom)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.46f, &epf_quant_mul)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.9f, &epf_pass0_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(6.5f, &epf_pass2_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(0.6666666666666666f, &epf_border_sad_mul)); + } + if (visitor->Conditional(nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.0f, &epf_sigma_for_modular)); + if (epf_sigma_for_modular < 1e-8) { + return JXL_FAILURE("EPF: sigma for modular is too small"); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/loop_filter.h b/third_party/jpeg-xl/lib/jxl/loop_filter.h new file mode 100644 index 000000000000..6ecec1bb2d7d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/loop_filter.h @@ -0,0 +1,87 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_LOOP_FILTER_H_ +#define LIB_JXL_LOOP_FILTER_H_ + +// Parameters for loop filter(s), stored in each frame. + +#include +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +struct LoopFilter : public Fields { + LoopFilter(); + const char* Name() const override { return "LoopFilter"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + size_t Padding() const { + static const size_t padding_per_epf_iter[4] = {0, 2, 3, 6}; + return padding_per_epf_iter[epf_iters] + (gab ? 1 : 0); + } + + mutable bool all_default; + + // --- Gaborish convolution + bool gab; + + bool gab_custom; + float gab_x_weight1; + float gab_x_weight2; + float gab_y_weight1; + float gab_y_weight2; + float gab_b_weight1; + float gab_b_weight2; + + // --- Edge-preserving filter + + // Number of EPF stages to apply. 0 means EPF disabled. 1 applies only the + // first stage, 2 applies both stages and 3 applies the first stage twice and + // the second stage once. + uint32_t epf_iters; + + bool epf_sharp_custom; + enum { kEpfSharpEntries = 8 }; + float epf_sharp_lut[kEpfSharpEntries]; + + bool epf_weight_custom; // Custom weight params + float epf_channel_scale[3]; // Relative weight of each channel + float epf_pass1_zeroflush; // Minimum weight for first pass + float epf_pass2_zeroflush; // Minimum weight for second pass + + bool epf_sigma_custom; // Custom sigma parameters + float epf_quant_mul; // Sigma is ~ this * quant + float epf_pass0_sigma_scale; // Multiplier for sigma in pass 0 + float epf_pass2_sigma_scale; // Multiplier for sigma in the second pass + float epf_border_sad_mul; // (inverse) multiplier for sigma on borders + + float epf_sigma_for_modular; + + uint64_t extensions; + + bool nonserialized_is_modular = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_LOOP_FILTER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/luminance.cc b/third_party/jpeg-xl/lib/jxl/luminance.cc new file mode 100644 index 000000000000..adbef91a51b2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/luminance.cc @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/luminance.h" + +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { + +void SetIntensityTarget(CodecInOut* io) { + if (io->target_nits != 0) { + io->metadata.m.SetIntensityTarget(io->target_nits); + return; + } + if (io->metadata.m.color_encoding.tf.IsPQ()) { + // Peak luminance of PQ as defined by SMPTE ST 2084:2014. + io->metadata.m.SetIntensityTarget(10000); + } else if (io->metadata.m.color_encoding.tf.IsHLG()) { + // Nominal display peak luminance used as a reference by + // Rec. ITU-R BT.2100-2. + io->metadata.m.SetIntensityTarget(1000); + } else { + // SDR + io->metadata.m.SetIntensityTarget(kDefaultIntensityTarget); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/luminance.h b/third_party/jpeg-xl/lib/jxl/luminance.h new file mode 100644 index 000000000000..44a248bc11c7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/luminance.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_LUMINANCE_H_ +#define LIB_JXL_LUMINANCE_H_ + +namespace jxl { + +// Chooses a default intensity target based on the transfer function of the +// image, if known. For SDR images or images not known to be HDR, returns +// kDefaultIntensityTarget, for images known to have PQ or HLG transfer function +// returns a higher value. If the image metadata already has a non-zero +// intensity target, does nothing. +class CodecInOut; +void SetIntensityTarget(CodecInOut* io); + +} // namespace jxl + +#endif // LIB_JXL_LUMINANCE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc new file mode 100644 index 000000000000..c4411d6d509a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/memory_manager_internal.h" + +#include + +namespace jxl { + +void* MemoryManagerDefaultAlloc(void* opaque, size_t size) { + return malloc(size); +} + +void MemoryManagerDefaultFree(void* opaque, void* address) { free(address); } + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h new file mode 100644 index 000000000000..21dd8df02abb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h @@ -0,0 +1,110 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ +#define LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ + +// Memory allocator with support for alignment + misalignment. + +#include +#include +#include +#include // memcpy + +#include +#include + +#include "jxl/memory_manager.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Default alloc and free functions. +void* MemoryManagerDefaultAlloc(void* opaque, size_t size); +void MemoryManagerDefaultFree(void* opaque, void* address); + +// Initializes the memory manager instance with the passed one. The +// MemoryManager passed in |memory_manager| may be NULL or contain NULL +// functions which will be initialized with the default ones. If either alloc +// or free are NULL, then both must be NULL, otherwise this function returns an +// error. +static JXL_INLINE Status MemoryManagerInit( + JxlMemoryManager* self, const JxlMemoryManager* memory_manager) { + if (memory_manager) { + *self = *memory_manager; + } else { + memset(self, 0, sizeof(*self)); + } + if (!self->alloc != !self->free) { + return false; + } + if (!self->alloc) self->alloc = jxl::MemoryManagerDefaultAlloc; + if (!self->free) self->free = jxl::MemoryManagerDefaultFree; + + return true; +} + +static JXL_INLINE void* MemoryManagerAlloc( + const JxlMemoryManager* memory_manager, size_t size) { + return memory_manager->alloc(memory_manager->opaque, size); +} + +static JXL_INLINE void MemoryManagerFree(const JxlMemoryManager* memory_manager, + void* address) { + return memory_manager->free(memory_manager->opaque, address); +} + +// Helper class to be used as a deleter in a unique_ptr call. +class MemoryManagerDeleteHelper { + public: + explicit MemoryManagerDeleteHelper(const JxlMemoryManager* memory_manager) + : memory_manager_(memory_manager) {} + + // Delete and free the passed pointer using the memory_manager. + template + void operator()(T* address) const { + if (!address) { + return; + } + address->~T(); + return memory_manager_->free(memory_manager_->opaque, address); + } + + private: + const JxlMemoryManager* memory_manager_; +}; + +template +using MemoryManagerUniquePtr = std::unique_ptr; + +// Creates a new object T allocating it with the memory allocator into a +// unique_ptr. +template +JXL_INLINE MemoryManagerUniquePtr MemoryManagerMakeUnique( + const JxlMemoryManager* memory_manager, Args&&... args) { + T* mem = + static_cast(memory_manager->alloc(memory_manager->opaque, sizeof(T))); + if (!mem) { + // Allocation error case. + return MemoryManagerUniquePtr(nullptr, + MemoryManagerDeleteHelper(memory_manager)); + } + return MemoryManagerUniquePtr(new (mem) T(std::forward(args)...), + MemoryManagerDeleteHelper(memory_manager)); +} + +} // namespace jxl + +#endif // LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h new file mode 100644 index 000000000000..0197491e1d28 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h @@ -0,0 +1,634 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ +#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ + +#include +#include + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace weighted { +constexpr static size_t kNumPredictors = 4; +constexpr static int64_t kPredExtraBits = 3; +constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; +constexpr static size_t kNumProperties = 1; + +struct Header : public Fields { + const char *Name() const override { return "WeightedPredictorHeader"; } + // TODO(janwas): move to cc file, avoid including fields.h. + Header() { Bundle::Init(this); } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + auto visit_p = [visitor](pixel_type val, pixel_type *p) { + uint32_t up = *p; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); + *p = up; + return Status(true); + }; + JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); + return true; + } + + bool all_default; + pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; + uint32_t w[kNumPredictors] = {}; +}; + +struct State { + pixel_type_w prediction[kNumPredictors] = {}; + pixel_type_w pred = 0; // *before* removing the added bits. + std::vector pred_errors[kNumPredictors]; + std::vector error; + Header header; + + // Allows to approximate division by a number from 1 to 64. + uint32_t divlookup[64]; + + constexpr static pixel_type_w AddBits(pixel_type_w x) { + return uint64_t(x) << kPredExtraBits; + } + + State(Header header, size_t xsize, size_t ysize) : header(header) { + // Extra margin to avoid out-of-bounds writes. + // All have space for two rows of data. + for (size_t i = 0; i < 4; i++) { + pred_errors[i].resize((xsize + 2) * 2); + } + error.resize((xsize + 2) * 2); + // Initialize division lookup table. + for (int i = 0; i < 64; i++) { + divlookup[i] = (1 << 24) / (i + 1); + } + } + + // Approximates 4+(maxweight<<24)/(x+1), avoiding division + JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { + int shift = FloorLog2Nonzero(x + 1) - 5; + if (shift < 0) shift = 0; + return 4 + ((maxweight * divlookup[x >> shift]) >> shift); + } + + // Approximates the weighted average of the input values with the given + // weights, avoiding division. Weights must sum to at least 16. + JXL_INLINE pixel_type_w + WeightedAverage(const pixel_type_w *JXL_RESTRICT p, + std::array w) const { + uint32_t weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + weight_sum += w[i]; + } + JXL_DASSERT(weight_sum > 15); + uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. + weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + w[i] >>= log_weight - 4; + weight_sum += w[i]; + } + // for rounding. + pixel_type_w sum = (weight_sum >> 1) - 1; + for (size_t i = 0; i < kNumPredictors; i++) { + sum += p[i] * w[i]; + } + return (sum * divlookup[weight_sum - 1]) >> 24; + } + + template + JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, + pixel_type_w N, pixel_type_w W, + pixel_type_w NE, pixel_type_w NW, + pixel_type_w NN, Properties *properties, + size_t offset) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + size_t pos_N = prev_row + x; + size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; + size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; + std::array weights; + for (size_t i = 0; i < kNumPredictors; i++) { + // pred_errors[pos_N] also contains the error of pixel W. + // pred_errors[pos_NW] also contains the error of pixel WW. + weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + + pred_errors[i][pos_NW]; + weights[i] = ErrorWeight(weights[i], header.w[i]); + } + + N = AddBits(N); + W = AddBits(W); + NE = AddBits(NE); + NW = AddBits(NW); + NN = AddBits(NN); + + pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; + pixel_type_w teN = error[pos_N]; + pixel_type_w teNW = error[pos_NW]; + pixel_type_w sumWN = teN + teW; + pixel_type_w teNE = error[pos_NE]; + + if (compute_properties) { + pixel_type_w p = teW; + if (std::abs(teN) > std::abs(p)) p = teN; + if (std::abs(teNW) > std::abs(p)) p = teNW; + if (std::abs(teNE) > std::abs(p)) p = teNE; + (*properties)[offset++] = p; + } + + prediction[0] = W + NE - N; + prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); + prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); + prediction[3] = + N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> + 5); + + pred = WeightedAverage(prediction, weights); + + // If all three have the same sign, skip clamping. + if (((teN ^ teW) | (teN ^ teNW)) > 0) { + return (pred + kPredictionRound) >> kPredExtraBits; + } + + // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). + pixel_type_w mx = std::max(W, std::max(NE, N)); + pixel_type_w mn = std::min(W, std::min(NE, N)); + pred = std::max(mn, std::min(mx, pred)); + return (pred + kPredictionRound) >> kPredExtraBits; + } + + JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, + size_t xsize) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + val = AddBits(val); + error[cur_row + x] = ClampToRange(pred - val); + for (size_t i = 0; i < kNumPredictors; i++) { + pixel_type_w err = + (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; + // For predicting in the next row. + pred_errors[i][cur_row + x] = err; + // Add the error on this pixel to the error on the NE pixel. This has the + // effect of adding the error on this pixel to the E and EE pixels. + pred_errors[i][prev_row + x + 1] += err; + } + } +}; + +// Encoder helper function to set the parameters to some presets. +inline void PredictorMode(int i, Header *header) { + switch (i) { + case 0: + // ~ lossless16 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 10; + header->p3Ca = 7; + header->p3Cb = 7; + header->p3Cc = 7; + header->p3Cd = 0; + header->p3Ce = 0; + break; + case 1: + // ~ default lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xb; + header->p1C = 8; + header->p2C = 8; + header->p3Ca = 4; + header->p3Cb = 0; + header->p3Cc = 3; + header->p3Cd = 23; + header->p3Ce = 2; + break; + case 2: + // ~ west lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xd; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 9; + header->p3Ca = 7; + header->p3Cb = 0; + header->p3Cc = 0; + header->p3Cd = 16; + header->p3Ce = 9; + break; + case 3: + // ~ north lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xd; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 8; + header->p3Ca = 0; + header->p3Cb = 16; + header->p3Cc = 0; + header->p3Cd = 23; + header->p3Ce = 0; + break; + case 4: + default: + // something else, because why not + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 10; + header->p3Ca = 5; + header->p3Cb = 5; + header->p3Cc = 5; + header->p3Cd = 12; + header->p3Ce = 4; + break; + } +} +} // namespace weighted + +// Stores a node and its two children at the same time. This significantly +// reduces the number of branches needed during decoding. +struct FlatDecisionNode { + // Property + splitval of the top node. + int32_t property0; // -1 if leaf. + union { + PropertyVal splitval0; + Predictor predictor; + }; + uint32_t childID; // childID is ctx id if leaf. + // Property+splitval of the two child nodes. + union { + PropertyVal splitvals[2]; + int32_t multiplier; + }; + union { + int32_t properties[2]; + int64_t predictor_offset; + }; +}; +using FlatTree = std::vector; + +class MATreeLookup { + public: + explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} + struct LookupResult { + uint32_t context; + Predictor predictor; + int64_t offset; + int32_t multiplier; + }; + LookupResult Lookup(const Properties &properties) const { + uint32_t pos = 0; + while (true) { + const FlatDecisionNode &node = nodes_[pos]; + if (node.property0 < 0) { + return {node.childID, node.predictor, node.predictor_offset, + node.multiplier}; + } + bool p0 = properties[node.property0] <= node.splitval0; + uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; + uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); + pos = node.childID + (p0 ? off1 : off0); + } + } + + private: + const FlatTree &nodes_; +}; + +static constexpr size_t kExtraPropsPerChannel = 4; +static constexpr size_t kNumNonrefProperties = + kNumStaticProperties + 13 + weighted::kNumProperties; + +constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; +constexpr size_t kGradientProp = 9; + +// Clamps gradient to the min/max of n, w (and l, implicitly). +static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, + const int32_t l) { + const int32_t m = std::min(n, w); + const int32_t M = std::max(n, w); + // The end result of this operation doesn't overflow or underflow if the + // result is between m and M, but the intermediate value may overflow, so we + // do the intermediate operations in uint32_t and check later if we had an + // overflow or underflow condition comparing m, M and l directly. + // grad = M + m - l = n + w - l + const int32_t grad = + static_cast(static_cast(n) + static_cast(w) - + static_cast(l)); + // We use two sets of ternary operators to force the evaluation of them in + // any case, allowing the compiler to avoid branches and use cmovl/cmovg in + // x86. + const int32_t grad_clamp_M = (l < m) ? M : grad; + return (l > M) ? m : grad_clamp_M; +} + +inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { + pixel_type_w p = a + b - c; + pixel_type_w pa = std::abs(p - a); + pixel_type_w pb = std::abs(p - b); + return pa < pb ? a : b; +} + +inline void PrecomputeReferences(const Channel &ch, size_t y, + const Image &image, uint32_t i, + Channel *references) { + ZeroFillImage(&references->plane); + uint32_t offset = 0; + size_t num_extra_props = references->w; + intptr_t onerow = references->plane.PixelsPerRow(); + for (int32_t j = static_cast(i) - 1; + j >= 0 && offset < num_extra_props; j--) { + if (image.channel[j].w != image.channel[i].w || + image.channel[j].h != image.channel[i].h) { + continue; + } + if (image.channel[j].hshift != image.channel[i].hshift) continue; + if (image.channel[j].vshift != image.channel[i].vshift) continue; + pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; + const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); + const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); + for (size_t x = 0; x < ch.w; x++, rp += onerow) { + pixel_type_w v = rpp[x]; + rp[0] = std::abs(v); + rp[1] = v; + pixel_type_w vleft = (x ? rpp[x - 1] : 0); + pixel_type_w vtop = (y ? rpprev[x] : vleft); + pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); + pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); + rp[2] = std::abs(v - vpredicted); + rp[3] = v - vpredicted; + } + + offset += kExtraPropsPerChannel; + } +} + +struct PredictionResult { + int context = 0; + pixel_type_w guess = 0; + Predictor predictor; + int32_t multiplier; +}; + +inline std::string PropertyName(size_t i) { + static_assert(kNumNonrefProperties == 16, "Update this function"); + switch (i) { + case 0: + return "c"; + case 1: + return "g"; + case 2: + return "y"; + case 3: + return "x"; + case 4: + return "|N|"; + case 5: + return "|W|"; + case 6: + return "N"; + case 7: + return "W"; + case 8: + return "W-WW-NW+NWW"; + case 9: + return "W+N-NW"; + case 10: + return "W-NW"; + case 11: + return "NW-N"; + case 12: + return "N-NE"; + case 13: + return "N-NN"; + case 14: + return "W-WW"; + case 15: + return "WGH"; + default: + return "ch[" + ToString(15 - (int)i) + "]"; + } +} + +inline void InitPropsRow( + Properties *p, + const std::array &static_props, + const int y) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + (*p)[i] = static_props[i]; + } + (*p)[2] = y; + (*p)[9] = 0; // local gradient. +} + +namespace detail { +enum PredictorMode { + kUseTree = 1, + kUseWP = 2, + kForceComputeProperties = 4, + kAllPredictions = 8, +}; + +template +inline PredictionResult Predict( + Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, + const MATreeLookup *lookup, const Channel *references, + weighted::State *wp_state, pixel_type_w *predictions) { + // We start in position 3 because of 2 static properties + y. + size_t offset = 3; + constexpr bool compute_properties = + mode & kUseTree || mode & kForceComputeProperties; + pixel_type_w left = (x ? pp[-1] : (y ? pp[-onerow] : 0)); + pixel_type_w top = (y ? pp[-onerow] : left); + pixel_type_w topleft = (x && y ? pp[-1 - onerow] : left); + pixel_type_w topright = (x + 1 < w && y ? pp[1 - onerow] : top); + pixel_type_w leftleft = (x > 1 ? pp[-2] : left); + pixel_type_w toptop = (y > 1 ? pp[-onerow - onerow] : top); + pixel_type_w toprightright = (x + 2 < w && y ? pp[2 - onerow] : topright); + + if (compute_properties) { + // location + (*p)[offset++] = x; + // neighbors + (*p)[offset++] = std::abs(top); + (*p)[offset++] = std::abs(left); + (*p)[offset++] = top; + (*p)[offset++] = left; + + // local gradient + (*p)[offset] = left - (*p)[offset + 1]; + offset++; + // local gradient + (*p)[offset++] = left + top - topleft; + + // FFV1 context properties + (*p)[offset++] = left - topleft; + (*p)[offset++] = topleft - top; + (*p)[offset++] = top - topright; + (*p)[offset++] = top - toptop; + (*p)[offset++] = left - leftleft; + } + + pixel_type_w wp_pred = 0; + if (mode & kUseWP) { + wp_pred = wp_state->Predict( + x, y, w, top, left, topright, topleft, toptop, p, offset); + } + if (compute_properties) { + offset += weighted::kNumProperties; + // Extra properties. + const pixel_type *JXL_RESTRICT rp = references->Row(x); + for (size_t i = 0; i < references->w; i++) { + (*p)[offset++] = rp[i]; + } + } + PredictionResult result; + if (mode & kUseTree) { + MATreeLookup::LookupResult lr = lookup->Lookup(*p); + result.context = lr.context; + result.guess = lr.offset; + result.multiplier = lr.multiplier; + predictor = lr.predictor; + } + pixel_type_w pred_storage[kNumModularPredictors]; + if (!(mode & kAllPredictions)) { + predictions = pred_storage; + } + predictions[(int)Predictor::Zero] = 0; + predictions[(int)Predictor::Left] = left; + predictions[(int)Predictor::Top] = top; + predictions[(int)Predictor::Select] = Select(left, top, topleft); + predictions[(int)Predictor::Weighted] = wp_pred; + predictions[(int)Predictor::Gradient] = ClampedGradient(left, top, topleft); + predictions[(int)Predictor::TopLeft] = topleft; + predictions[(int)Predictor::TopRight] = topright; + predictions[(int)Predictor::LeftLeft] = leftleft; + predictions[(int)Predictor::Average0] = (left + top) / 2; + predictions[(int)Predictor::Average1] = (left + topleft) / 2; + predictions[(int)Predictor::Average2] = (topleft + top) / 2; + predictions[(int)Predictor::Average3] = (top + topright) / 2; + predictions[(int)Predictor::Average4] = + (6 * top - 2 * toptop + 7 * left + 1 * leftleft + 1 * toprightright + + 3 * topright + 8) / + 16; + result.guess += predictions[(int)predictor]; + result.predictor = predictor; + + return result; +} +} // namespace detail + +inline PredictionResult PredictNoTreeNoWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor) { + return detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictNoTreeWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + weighted::State *wp_state) { + return detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references) { + return detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictLearn(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAll(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} + +inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + pixel_type_w *predictions) { + detail::Predict( + /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, predictions); +} +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc new file mode 100644 index 000000000000..356cff35d8a7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/modular/encoding/dec_ma.h" + +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +namespace { + +Status ValidateTree( + const Tree &tree, + const std::vector> &prop_bounds, + size_t root) { + if (tree[root].property == -1) return true; + size_t p = tree[root].property; + int val = tree[root].splitval; + if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); + // Splitting at max value makes no sense: left range will be exactly same + // as parent, right range will be invalid (min > max). + if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); + auto new_bounds = prop_bounds; + new_bounds[p].first = val + 1; + JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); + new_bounds[p] = prop_bounds[p]; + new_bounds[p].second = val; + return ValidateTree(tree, new_bounds, tree[root].rchild); +} + +Status DecodeTree(BitReader *br, ANSSymbolReader *reader, + const std::vector &context_map, Tree *tree, + size_t tree_size_limit) { + size_t leaf_id = 0; + size_t to_decode = 1; + tree->clear(); + while (to_decode > 0) { + JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds()); + if (tree->size() > tree_size_limit) { + return JXL_FAILURE("Tree is too large"); + } + to_decode--; + int property = + reader->ReadHybridUint(kPropertyContext, br, context_map) - 1; + if (property < -1 || property >= 256) { + return JXL_FAILURE("Invalid tree property value"); + } + if (property == -1) { + size_t predictor = + reader->ReadHybridUint(kPredictorContext, br, context_map); + if (predictor >= kNumModularPredictors) { + return JXL_FAILURE("Invalid predictor"); + } + int64_t predictor_offset = + UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map)); + uint32_t mul_log = + reader->ReadHybridUint(kMultiplierLogContext, br, context_map); + if (mul_log >= 31) { + return JXL_FAILURE("Invalid multiplier logarithm"); + } + uint32_t mul_bits = + reader->ReadHybridUint(kMultiplierBitsContext, br, context_map); + if (mul_bits + 1 >= 1u << (31u - mul_log)) { + return JXL_FAILURE("Invalid multiplier"); + } + uint32_t multiplier = (mul_bits + 1U) << mul_log; + tree->emplace_back(-1, 0, leaf_id++, 0, static_cast(predictor), + predictor_offset, multiplier); + continue; + } + int splitval = + UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map)); + tree->emplace_back(property, splitval, tree->size() + to_decode + 1, + tree->size() + to_decode + 2, Predictor::Zero, 0, 1); + to_decode += 2; + } + std::vector> prop_bounds; + prop_bounds.resize(256, {std::numeric_limits::min(), + std::numeric_limits::max()}); + return ValidateTree(*tree, prop_bounds, 0); +} +} // namespace + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { + std::vector tree_context_map; + ANSCode tree_code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + // TODO(eustas): investigate more infinite tree cases. + if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { + return JXL_FAILURE("Infinite tree"); + } + ANSSymbolReader reader(&tree_code, br); + JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, + std::min(tree_size_limit, kMaxTreeSize))); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h new file mode 100644 index 000000000000..2207c34ef90e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ + +#include +#include + +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// inner nodes +struct PropertyDecisionNode { + PropertyVal splitval; + int16_t property; // -1: leaf node, lchild points to leaf node + uint32_t lchild; + uint32_t rchild; + Predictor predictor; + int64_t predictor_offset; + uint32_t multiplier; + + PropertyDecisionNode(int p, int split_val, int lchild, int rchild, + Predictor predictor, int64_t predictor_offset, + uint32_t multiplier) + : splitval(split_val), + property(p), + lchild(lchild), + rchild(rchild), + predictor(predictor), + predictor_offset(predictor_offset), + multiplier(multiplier) {} + PropertyDecisionNode() + : splitval(0), + property(-1), + lchild(0), + rchild(0), + predictor(Predictor::Zero), + predictor_offset(0), + multiplier(1) {} + static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0, + uint32_t multiplier = 1) { + return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier); + } + static PropertyDecisionNode Split(int p, int split_val, int lchild, + int rchild = -1) { + if (rchild == -1) rchild = lchild + 1; + return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero, + 0, 1); + } +}; + +using Tree = std::vector; + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc new file mode 100644 index 000000000000..8f09f86cd22b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc @@ -0,0 +1,509 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Plot tree (if enabled) and predictor usage map. +constexpr bool kWantDebug = false; +} // namespace + +void GatherTreeData(const Image &image, pixel_type chan, size_t group_id, + const weighted::Header &wp_header, + const ModularOptions &options, TreeSamples &tree_samples, + size_t *total_pixels) { + const Channel &channel = image.channel[chan]; + + JXL_DEBUG_V(7, "Learning %zux%zu channel %d", channel.w, channel.h, chan); + + std::array static_props = {chan, + (int)group_id}; + Properties properties(kNumNonrefProperties + + kExtraPropsPerChannel * options.max_properties); + double pixel_fraction = std::min(1.0f, options.nb_repeats); + // a fraction of 0 is used to disable learning entirely. + if (pixel_fraction > 0) { + pixel_fraction = std::max(pixel_fraction, + std::min(1.0, 1024.0 / (channel.w * channel.h))); + } + uint64_t threshold = + (std::numeric_limits::max() >> 32) * pixel_fraction; + uint64_t s[2] = {0x94D049BB133111EBull, 0xBF58476D1CE4E5B9ull}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + InitPropsRow(&properties, static_props, y); + // TODO(veluca): avoid computing WP if we don't use its property or + // predictions. + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w pred[kNumModularPredictors]; + if (tree_samples.NumPredictors() != 1) { + PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, + &wp_state, pred); + } else { + pred[static_cast(tree_samples.PredictorFromIndex(0))] = + PredictLearn(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } +} + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector &multiplier_info = {}, + StaticPropRange static_prop_range = {}) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (static_prop_range[i][1] == 0) { + static_prop_range[i][1] = std::numeric_limits::max(); + } + } + if (!tree_samples.HasSamples()) { + Tree tree; + tree.emplace_back(); + tree.back().predictor = tree_samples.PredictorFromIndex(0); + tree.back().property = -1; + tree.back().predictor_offset = 0; + tree.back().multiplier = 1; + return tree; + } + float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; + float required_cost = pixel_fraction * 0.9 + 0.1; + tree_samples.AllSamplesDone(); + Tree tree; + ComputeBestTree(tree_samples, + options.splitting_heuristics_node_threshold * required_cost, + multiplier_info, static_prop_range, + options.fast_decode_multiplier, &tree); + return tree; +} + +constexpr bool kPrintTree = false; + +void PrintTree(const Tree &tree, const std::string &path) { + if (!kPrintTree) return; + FILE *f = fopen((path + ".dot").c_str(), "w"); + fprintf(f, "graph{\n"); + for (size_t cur = 0; cur < tree.size(); cur++) { + if (tree[cur].property < 0) { + fprintf(f, "n%05zu [label=\"%s%+" PRId64 " (x%u)\"];\n", cur, + PredictorName(tree[cur].predictor), tree[cur].predictor_offset, + tree[cur].multiplier); + } else { + fprintf(f, "n%05zu [label=\"%s>%d\"];\n", cur, + PropertyName(tree[cur].property).c_str(), tree[cur].splitval); + fprintf(f, "n%05zu -- n%05d;\n", cur, tree[cur].lchild); + fprintf(f, "n%05zu -- n%05d;\n", cur, tree[cur].rchild); + } + } + fprintf(f, "}\n"); + fclose(f); + JXL_ASSERT( + system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0); +} + +Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, + const weighted::Header &wp_header, + const Tree &global_tree, + std::vector *tokens, AuxOut *aux_out, + size_t group_id, bool skip_encoder_fast_path) { + const Channel &channel = image.channel[chan]; + + JXL_ASSERT(channel.w != 0 && channel.h != 0); + + Image3F predictor_img; + if (kWantDebug) predictor_img = Image3F(channel.w, channel.h); + + JXL_DEBUG_V(6, + "Encoding %zux%zu channel %d, " + "(shift=%i,%i, cshift=%i,%i)", + channel.w, channel.h, chan, channel.hshift, channel.vshift, + channel.hcshift, channel.vcshift); + + std::array static_props = {chan, + (int)group_id}; + bool use_wp, is_wp_only; + bool is_gradient_only; + size_t num_props; + FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, + &is_wp_only, &is_gradient_only); + Properties properties(num_props); + MATreeLookup tree_lookup(tree); + JXL_DEBUG_V(3, "Encoding using a MA tree with %zu nodes", tree.size()); + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + uint16_t context_lookup[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets); + } + + tokens->reserve(tokens->size() + channel.w * channel.h); + if (is_wp_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Weighted)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + int32_t residual = r[x] - guess - offsets[pos]; + tokens->emplace_back(ctx_id, PackSigned(residual)); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(Predictor::Zero)[c]), + &predictor_img.Plane(c)); + } + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + tokens->emplace_back(tree[0].childID, PackSigned(p[x])); + } + } + } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && + (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && + tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { + // multiplier is a power of 2. + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast(PredictorColor(tree[0].predictor)[c]), + &predictor_img.Plane(c)); + } + uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, + y, tree[0].predictor); + pixel_type_w residual = r[x] - pred.guess; + JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); + tokens->emplace_back(tree[0].childID, + PackSigned(residual >> mul_shift)); + } + } + + } else if (!use_wp && !skip_encoder_fast_path) { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + tokens->emplace_back(res.context, + PackSigned(residual / res.multiplier)); + } + } + } else { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_ASSERT(residual % res.multiplier == 0); + tokens->emplace_back(res.context, + PackSigned(residual / res.multiplier)); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + if (kWantDebug && WantDebugOutput(aux_out)) { + aux_out->DumpImage( + ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), + predictor_img); + } + return true; +} + +Status ModularEncode(const Image &image, const ModularOptions &options, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector *tokens, + size_t *width) { + if (image.error) return JXL_FAILURE("Invalid image"); + size_t nb_channels = image.channel.size(); + int bit_depth = 1, maxval = 1; + while (maxval < image.maxval) { + bit_depth++; + maxval = maxval * 2 + 1; + } + JXL_DEBUG_V(2, "Encoding %zu-channel, %i-bit, %zux%zu image.", nb_channels, + bit_depth, image.w, image.h); + + if (nb_channels < 1) { + return true; // is there any use for a zero-channel image? + } + + // encode transforms + GroupHeader header_storage; + if (header == nullptr) header = &header_storage; + Bundle::Init(header); + if (options.predictor == Predictor::Weighted) { + weighted::PredictorMode(options.wp_mode, &header->wp_header); + } + header->transforms = image.transform; + // This doesn't actually work + if (tree != nullptr) { + header->use_global_tree = true; + } + if (tree_samples == nullptr && tree == nullptr) { + JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); + } + + TreeSamples tree_samples_storage; + size_t total_pixels_storage = 0; + if (!total_pixels) total_pixels = &total_pixels_storage; + // If there's no tree, compute one (or gather data to). + if (tree == nullptr) { + bool gather_data = tree_samples != nullptr; + if (tree_samples == nullptr) { + JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( + options.predictor, options.wp_tree_mode)); + JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( + options.splitting_heuristics_properties, options.wp_tree_mode)); + std::vector pixel_samples; + std::vector diff_samples; + std::vector group_pixel_count; + std::vector channel_pixel_count; + CollectPixelSamples(image, options, 0, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples); + std::vector dummy_multiplier_info; + StaticPropRange range; + tree_samples_storage.PreQuantizeProperties( + range, dummy_multiplier_info, group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples, options.max_property_values); + } + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + GatherTreeData(image, i, group_id, header->wp_header, options, + gather_data ? *tree_samples : tree_samples_storage, + total_pixels); + } + if (gather_data) return true; + } + + JXL_ASSERT((tree == nullptr) == (tokens == nullptr)); + + Tree tree_storage; + std::vector> tokens_storage(1); + // Compute tree. + if (tree == nullptr) { + EntropyEncodingData code; + std::vector context_map; + + std::vector> tree_tokens(1); + tree_storage = + LearnTree(std::move(tree_samples_storage), *total_pixels, options); + tree = &tree_storage; + tokens = &tokens_storage[0]; + + Tree decoded_tree; + TokenizeTree(*tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree->size() == decoded_tree.size()); + tree_storage = std::move(decoded_tree); + + if (kWantDebug && WantDebugOutput(aux_out)) { + PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); + } + // Write tree + BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens, + &code, &context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree, + aux_out); + } + + size_t image_width = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + if (image.channel[i].w > image_width) image_width = image.channel[i].w; + if (options.zero_tokens) { + tokens->resize(tokens->size() + image.channel[i].w * image.channel[i].h, + {0, 0}); + } else { + JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( + image, i, header->wp_header, *tree, tokens, aux_out, group_id, + options.skip_encoder_fast_path)); + } + } + + // Write data if not using a global tree/ANS stream. + if (!header->use_global_tree) { + EntropyEncodingData code; + std::vector context_map; + HistogramParams histo_params; + histo_params.image_widths.push_back(image_width); + BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, + tokens_storage, &code, &context_map, writer, layer, + aux_out); + WriteTokens(tokens_storage[0], code, context_map, writer, layer, aux_out); + } else { + *width = image_width; + } + return true; +} + +Status ModularGenericCompress(Image &image, const ModularOptions &opts, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector *tokens, + size_t *width) { + if (image.w == 0 || image.h == 0) return true; + ModularOptions options = opts; // Make a copy to modify it. + + if (options.predictor == static_cast(-1)) { + options.predictor = Predictor::Gradient; + } + + size_t bits = writer ? writer->BitsWritten() : 0; + JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, + group_id, tree_samples, total_pixels, tree, + header, tokens, width)); + bits = writer ? writer->BitsWritten() - bits : 0; + if (writer) { + JXL_DEBUG_V( + 4, "Modular-encoded a %zux%zu maxval=%i nbchans=%zu image in %zu bytes", + image.w, image.h, image.maxval, image.real_nb_channels, bits / 8); + } + (void)bits; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h new file mode 100644 index 000000000000..c409c0a1958c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h @@ -0,0 +1,58 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void PrintTree(const Tree &tree, const std::string &path); +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector &multiplier_info = {}, + StaticPropRange static_prop_range = {}); + +// TODO(veluca): make cleaner interfaces. + +Status ModularGenericCompress( + Image &image, const ModularOptions &opts, BitWriter *writer, + AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0, + // For gathering data for producing a global tree. + TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr, + // For encoding with global tree. + const Tree *tree = nullptr, GroupHeader *header = nullptr, + std::vector *tokens = nullptr, size_t *widths = nullptr); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc new file mode 100644 index 000000000000..a9872a88b557 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc @@ -0,0 +1,1048 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/modular/encoding/enc_ma.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/modular/encoding/ma_common.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" +#include +#include + +#ifndef LIB_JXL_ENC_MODULAR_ENCODING_MA_ +#define LIB_JXL_ENC_MODULAR_ENCODING_MA_ +namespace { +struct Rng { + uint64_t s[2]; + explicit Rng(size_t seed) + : s{0x94D049BB133111EBull, 0xBF58476D1CE4E5B9ull + seed} {} + // Xorshift128+ adapted from xorshift128+-inl.h + uint64_t operator()() { + uint64_t s1 = s[0]; + const uint64_t s0 = s[1]; + const uint64_t bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return bits; + } + static constexpr uint64_t max() { return ~0ULL; } + static constexpr uint64_t min() { return 0; } +}; +} // namespace +#endif + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +const HWY_FULL(float) df; +const HWY_FULL(int32_t) di; +size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } + +float EstimateBits(const int32_t *counts, int32_t *rounded_counts, + size_t num_symbols) { + // Try to approximate the effect of rounding up nonzero probabilities. + int32_t total = std::accumulate(counts, counts + num_symbols, 0); + const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE); + const auto zero_i = Zero(di); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + auto counts_v = LoadU(di, &counts[i]); + counts_v = IfThenElse(counts_v == zero_i, zero_i, + IfThenElse(counts_v < min, min, counts_v)); + StoreU(counts_v, di, &rounded_counts[i]); + } + // Compute entropy of the "rounded" probabilities. + const auto zero = Zero(df); + const size_t total_scalar = + std::accumulate(rounded_counts, rounded_counts + num_symbols, 0); + const auto inv_total = Set(df, 1.0f / total_scalar); + auto bits_lanes = Zero(df); + auto total_v = Set(di, total_scalar); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + const auto counts_v = ConvertTo(df, LoadU(di, &counts[i])); + const auto round_counts_v = LoadU(di, &rounded_counts[i]); + const auto probs = ConvertTo(df, round_counts_v) * inv_total; + const auto nbps = IfThenElse(round_counts_v == total_v, BitCast(di, zero), + BitCast(di, FastLog2f(df, probs))); + bits_lanes -= + IfThenElse(counts_v == zero, zero, counts_v * BitCast(df, nbps)); + } + return GetLane(SumOfLanes(bits_lanes)); +} + +void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, + int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { + // Note that the tree splits on *strictly greater*. + (*tree)[pos].lchild = tree->size(); + (*tree)[pos].rchild = tree->size() + 1; + (*tree)[pos].splitval = splitval; + (*tree)[pos].property = property; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = rpred; + tree->back().predictor_offset = roff; + tree->back().multiplier = 1; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = lpred; + tree->back().predictor_offset = loff; + tree->back().multiplier = 1; +} + +enum class IntersectionType { kNone, kPartial, kInside }; +IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, + uint32_t &partial_axis, uint32_t &partial_val) { + bool partial = false; + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (haystack[i][0] >= needle[i][1]) { + return IntersectionType::kNone; + } + if (haystack[i][1] <= needle[i][0]) { + return IntersectionType::kNone; + } + if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { + continue; + } + partial = true; + partial_axis = i; + if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { + partial_val = haystack[i][0] - 1; + } else { + JXL_DASSERT(haystack[i][1] > needle[i][0] && + haystack[i][1] < needle[i][1]); + partial_val = haystack[i][1] - 1; + } + } + return partial ? IntersectionType::kPartial : IntersectionType::kInside; +} + +void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, + size_t end, size_t prop) { + auto cmp = [&](size_t a, size_t b) { + return int32_t(tree_samples.Property(prop, a)) - + int32_t(tree_samples.Property(prop, b)); + }; + Rng rng(0); + while (end > begin + 1) { + { + JXL_ASSERT(end > begin); // silence clang-tidy. + size_t pivot = rng() % (end - begin) + begin; + tree_samples.Swap(begin, pivot); + } + size_t pivot_begin = begin; + size_t pivot_end = pivot_begin + 1; + for (size_t i = begin + 1; i < end; i++) { + JXL_DASSERT(i >= pivot_end); + JXL_DASSERT(pivot_end > pivot_begin); + int32_t cmp_result = cmp(i, pivot_begin); + if (cmp_result < 0) { // i < pivot, move pivot forward and put i before + // the pivot. + tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); + pivot_begin++; + pivot_end++; + } else if (cmp_result == 0) { + tree_samples.Swap(pivot_end, i); + pivot_end++; + } + } + JXL_DASSERT(pivot_begin >= begin); + JXL_DASSERT(pivot_end > pivot_begin); + JXL_DASSERT(pivot_end <= end); + for (size_t i = begin; i < pivot_begin; i++) { + JXL_DASSERT(cmp(i, pivot_begin) < 0); + } + for (size_t i = pivot_end; i < end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) > 0); + } + for (size_t i = pivot_begin; i < pivot_end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) == 0); + } + // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, + // pivot_end) is = pivot, and [pivot_end, end) is > pivot. + // If pos falls in the first or the last interval, we continue in that + // interval; otherwise, we are done. + if (pivot_begin > pos) { + end = pivot_begin; + } else if (pivot_end < pos) { + begin = pivot_end; + } else { + break; + } + } +} + +void FindBestSplit(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange initial_static_prop_range, + float fast_decode_multiplier, Tree *tree) { + struct NodeInfo { + size_t pos; + size_t begin; + size_t end; + uint64_t used_properties; + StaticPropRange static_prop_range; + }; + std::vector nodes; + nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, + initial_static_prop_range}); + + size_t num_predictors = tree_samples.NumPredictors(); + size_t num_properties = tree_samples.NumProperties(); + + // TODO(veluca): consider parallelizing the search (processing multiple nodes + // at a time). + while (!nodes.empty()) { + size_t pos = nodes.back().pos; + size_t begin = nodes.back().begin; + size_t end = nodes.back().end; + uint64_t used_properties = nodes.back().used_properties; + StaticPropRange static_prop_range = nodes.back().static_prop_range; + nodes.pop_back(); + if (begin == end) continue; + + struct SplitInfo { + size_t prop = 0; + uint32_t val = 0; + size_t pos = 0; + float lcost = std::numeric_limits::max(); + float rcost = std::numeric_limits::max(); + Predictor lpred = Predictor::Zero; + Predictor rpred = Predictor::Zero; + float Cost() { return lcost + rcost; } + }; + + SplitInfo best_split_static_constant; + SplitInfo best_split_static; + SplitInfo best_split_nonstatic; + SplitInfo best_split_nowp; + + JXL_DASSERT(begin <= end); + JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); + + // Compute the maximum token in the range. + size_t max_symbols = 0; + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + uint32_t tok = tree_samples.Token(pred, i); + max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; + } + } + max_symbols = Padded(max_symbols); + std::vector rounded_counts(max_symbols); + std::vector counts(max_symbols * num_predictors); + std::vector tot_extra_bits(num_predictors); + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + counts[pred * max_symbols + tree_samples.Token(pred, i)] += + tree_samples.Count(i); + tot_extra_bits[pred] += + tree_samples.NBits(pred, i) * tree_samples.Count(i); + } + } + + float base_bits; + { + size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); + base_bits = EstimateBits(counts.data() + pred * max_symbols, + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred]; + } + + SplitInfo *best = &best_split_nonstatic; + + SplitInfo forced_split; + // The multiplier ranges cut halfway through the current ranges of static + // properties. We do this even if the current node is not a leaf, to + // minimize the number of nodes in the resulting tree. + for (size_t i = 0; i < mul_info.size(); i++) { + uint32_t axis, val; + IntersectionType t = + BoxIntersects(static_prop_range, mul_info[i].range, axis, val); + if (t == IntersectionType::kNone) continue; + if (t == IntersectionType::kInside) { + (*tree)[pos].multiplier = mul_info[i].multiplier; + break; + } + if (t == IntersectionType::kPartial) { + forced_split.val = tree_samples.QuantizeProperty(axis, val); + forced_split.prop = axis; + forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; + forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; + best = &forced_split; + best->pos = begin; + JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); + for (size_t x = begin; x < end; x++) { + if (tree_samples.Property(best->prop, x) <= best->val) { + best->pos++; + } + } + break; + } + } + + if (best != &forced_split) { + std::vector prop_value_used_count; + std::vector count_increase; + std::vector extra_bits_increase; + // For each property, compute which of its values are used, and what + // tokens correspond to those usages. Then, iterate through the values, + // and compute the entropy of each side of the split (of the form `prop > + // threshold`). Finally, find the split that minimizes the cost. + struct CostInfo { + float cost = std::numeric_limits::max(); + float extra_cost = 0; + float Cost() const { return cost + extra_cost; } + Predictor pred; // will be uninitialized in some cases, but never used. + }; + std::vector costs_l; + std::vector costs_r; + + std::vector counts_above(max_symbols); + std::vector counts_below(max_symbols); + + // The lower the threshold, the higher the expected noisiness of the + // estimate. Thus, discourage changing predictors. + float change_pred_penalty = 800.0f / (100.0f + threshold); + for (size_t prop = 0; prop < num_properties && base_bits > threshold; + prop++) { + costs_l.clear(); + costs_r.clear(); + size_t prop_size = tree_samples.NumPropertyValues(prop); + if (extra_bits_increase.size() < prop_size) { + count_increase.resize(prop_size * max_symbols); + extra_bits_increase.resize(prop_size); + } + // Clear prop_value_used_count (which cannot be cleared "on the go") + prop_value_used_count.clear(); + prop_value_used_count.resize(prop_size); + + size_t first_used = prop_size; + size_t last_used = 0; + + // TODO(veluca): consider finding multiple splits along a single + // property at the same time, possibly with a bottom-up approach. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + prop_value_used_count[p]++; + last_used = std::max(last_used, p); + first_used = std::min(first_used, p); + } + costs_l.resize(last_used - first_used); + costs_r.resize(last_used - first_used); + // For all predictors, compute the right and left costs of each split. + for (size_t pred = 0; pred < num_predictors; pred++) { + // Compute cost and histogram increments for each property value. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + size_t cnt = tree_samples.Count(i); + size_t sym = tree_samples.Token(pred, i); + count_increase[p * max_symbols + sym] += cnt; + extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; + } + memcpy(counts_above.data(), counts.data() + pred * max_symbols, + max_symbols * sizeof counts_above[0]); + memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); + size_t extra_bits_below = 0; + // Exclude last used: this ensures neither counts_above nor + // counts_below is empty. + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + extra_bits_below += extra_bits_increase[i]; + // The increase for this property value has been used, and will not + // be used again: clear it. Also below. + extra_bits_increase[i] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + counts_above[sym] -= count_increase[i * max_symbols + sym]; + counts_below[sym] += count_increase[i * max_symbols + sym]; + count_increase[i * max_symbols + sym] = 0; + } + float rcost = EstimateBits(counts_above.data(), + rounded_counts.data(), max_symbols) + + tot_extra_bits[pred] - extra_bits_below; + float lcost = EstimateBits(counts_below.data(), + rounded_counts.data(), max_symbols) + + extra_bits_below; + JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); + float penalty = 0; + // Never discourage moving away from the Weighted predictor. + if (tree_samples.PredictorFromIndex(pred) != + (*tree)[pos].predictor && + (*tree)[pos].predictor != Predictor::Weighted) { + penalty = change_pred_penalty; + } + // If everything else is equal, disfavour Weighted (slower) and + // favour Zero (faster if it's the only predictor used in a + // group+channel combination) + if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { + penalty += 1e-8; + } + if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { + penalty -= 1e-8; + } + if (rcost + penalty < costs_r[i - first_used].Cost()) { + costs_r[i - first_used].cost = rcost; + costs_r[i - first_used].extra_cost = penalty; + costs_r[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + if (lcost + penalty < costs_l[i - first_used].Cost()) { + costs_l[i - first_used].cost = lcost; + costs_l[i - first_used].extra_cost = penalty; + costs_l[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + } + } + // Iterate through the possible splits and find the one with minimum sum + // of costs of the two sides. + size_t split = begin; + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + split += prop_value_used_count[i]; + float rcost = costs_r[i - first_used].cost; + float lcost = costs_l[i - first_used].cost; + // WP was not used + we would use the WP property or predictor + bool adds_wp = + (tree_samples.PropertyFromIndex(prop) == kWPProp && + (used_properties & (1LU << prop)) == 0) || + ((costs_l[i - first_used].pred == Predictor::Weighted || + costs_r[i - first_used].pred == Predictor::Weighted) && + (*tree)[pos].predictor != Predictor::Weighted); + bool zero_entropy_side = rcost == 0 || lcost == 0; + + SplitInfo &best = + prop < kNumStaticProperties + ? (zero_entropy_side ? best_split_static_constant + : best_split_static) + : (adds_wp ? best_split_nonstatic : best_split_nowp); + if (lcost + rcost < best.Cost()) { + best.prop = prop; + best.val = i; + best.pos = split; + best.lcost = lcost; + best.lpred = costs_l[i - first_used].pred; + best.rcost = rcost; + best.rpred = costs_r[i - first_used].pred; + } + } + // Clear extra_bits_increase and cost_increase for last_used. + extra_bits_increase[last_used] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + count_increase[last_used * max_symbols + sym] = 0; + } + } + + // Try to avoid introducing WP. + if (best_split_nowp.Cost() + threshold < base_bits && + best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_nowp; + } + // Split along static props if possible and not significantly more + // expensive. + if (best_split_static.Cost() + threshold < base_bits && + best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_static; + } + // Split along static props to create constant nodes if possible. + if (best_split_static_constant.Cost() + threshold < base_bits) { + best = &best_split_static_constant; + } + } + + if (best->Cost() + threshold < base_bits) { + uint32_t p = tree_samples.PropertyFromIndex(best->prop); + pixel_type dequant = + tree_samples.UnquantizeProperty(best->prop, best->val); + // Split node and try to split children. + MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); + // "Sort" according to winning property + SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); + if (p >= kNumStaticProperties) { + used_properties |= 1 << best->prop; + } + auto new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(static_cast(dequant + 1) <= new_sp_range[p][1]); + new_sp_range[p][1] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, + used_properties, new_sp_range}); + new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(new_sp_range[p][0] <= static_cast(dequant + 1)); + new_sp_range[p][0] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, + used_properties, new_sp_range}); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FindBestSplit); // Local function. + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree) { + // TODO(veluca): take into account that different contexts can have different + // uint configs. + // + // Initialize tree. + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = tree_samples.PredictorFromIndex(0); + tree->back().predictor_offset = 0; + tree->back().multiplier = 1; + JXL_ASSERT(tree_samples.NumProperties() < 64); + + JXL_ASSERT(tree_samples.NumDistinctSamples() <= + std::numeric_limits::max()); + HWY_DYNAMIC_DISPATCH(FindBestSplit) + (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, + tree); +} + +constexpr int TreeSamples::kPropertyRange; +constexpr uint32_t TreeSamples::kDedupEntryUnused; + +Status TreeSamples::SetPredictor(Predictor predictor, + ModularOptions::WPTreeMode wp_tree_mode) { + if (wp_tree_mode == ModularOptions::WPTreeMode::kWPOnly) { + predictors = {Predictor::Weighted}; + residuals.resize(1); + return true; + } + if (wp_tree_mode == ModularOptions::WPTreeMode::kNoWP && + predictor == Predictor::Weighted) { + return JXL_FAILURE("Invalid predictor settings"); + } + if (predictor == Predictor::Variable) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictors.push_back(static_cast(i)); + } + std::swap(predictors[0], predictors[static_cast(Predictor::Weighted)]); + std::swap(predictors[1], predictors[static_cast(Predictor::Gradient)]); + } else if (predictor == Predictor::Best) { + predictors = {Predictor::Weighted, Predictor::Gradient}; + } else { + predictors = {predictor}; + } + if (wp_tree_mode == ModularOptions::WPTreeMode::kNoWP) { + auto wp_it = + std::find(predictors.begin(), predictors.end(), Predictor::Weighted); + if (wp_it != predictors.end()) { + predictors.erase(wp_it); + } + } + residuals.resize(predictors.size()); + return true; +} + +Status TreeSamples::SetProperties(const std::vector &properties, + ModularOptions::WPTreeMode wp_tree_mode) { + props_to_use = properties; + if (wp_tree_mode == ModularOptions::WPTreeMode::kWPOnly) { + props_to_use = {kWPProp}; + } + if (wp_tree_mode == ModularOptions::WPTreeMode::kNoWP) { + auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); + if (it != props_to_use.end()) { + props_to_use.erase(it); + } + } + if (props_to_use.empty()) { + return JXL_FAILURE("Invalid property set configuration"); + } + props.resize(props_to_use.size()); + return true; +} + +void TreeSamples::InitTable(size_t size) { + JXL_DASSERT((size & (size - 1)) == 0); + if (dedup_table_.size() == size) return; + dedup_table_.resize(size, kDedupEntryUnused); + for (size_t i = 0; i < NumDistinctSamples(); i++) { + if (sample_counts[i] != std::numeric_limits::max()) { + AddToTable(i); + } + } +} + +bool TreeSamples::AddToTableAndMerge(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos1])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos1]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos1]] == + std::numeric_limits::max()) { + dedup_table_[pos1] = kDedupEntryUnused; + } + return true; + } + if (dedup_table_[pos2] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos2])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos2]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos2]] == + std::numeric_limits::max()) { + dedup_table_[pos2] = kDedupEntryUnused; + } + return true; + } + AddToTable(a); + return false; +} + +void TreeSamples::AddToTable(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] == kDedupEntryUnused) { + dedup_table_[pos1] = a; + } else if (dedup_table_[pos2] == kDedupEntryUnused) { + dedup_table_[pos2] = a; + } +} + +void TreeSamples::PrepareForSamples(size_t num_samples) { + for (auto &res : residuals) { + res.reserve(res.size() + num_samples); + } + for (auto &p : props) { + p.reserve(p.size() + num_samples); + } + size_t total_num_samples = num_samples + sample_counts.size(); + size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); + InitTable(next_pow2); +} + +size_t TreeSamples::Hash1(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd; + uint64_t h = constant; + for (const auto &r : residuals) { + h = h * constant + r[a].tok; + h = h * constant + r[a].nbits; + } + for (const auto &p : props) { + h = h * constant + p[a]; + } + return (h >> 16) & (dedup_table_.size() - 1); +} +size_t TreeSamples::Hash2(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; + uint64_t h = constant; + for (const auto &p : props) { + h = h * constant ^ p[a]; + } + for (const auto &r : residuals) { + h = h * constant ^ r[a].tok; + h = h * constant ^ r[a].nbits; + } + return (h >> 16) & (dedup_table_.size() - 1); +} + +bool TreeSamples::IsSameSample(size_t a, size_t b) const { + bool ret = true; + for (const auto &r : residuals) { + if (r[a].tok != r[b].tok) { + ret = false; + } + if (r[a].nbits != r[b].nbits) { + ret = false; + } + } + for (const auto &p : props) { + if (p[a] != p[b]) { + ret = false; + } + } + return ret; +} + +void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions) { + for (size_t i = 0; i < predictors.size(); i++) { + pixel_type v = pixel - predictions[static_cast(predictors[i])]; + uint32_t tok, nbits, bits; + HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); + JXL_DASSERT(tok < 256); + JXL_DASSERT(nbits < 256); + residuals[i].emplace_back( + ResidualToken{static_cast(tok), static_cast(nbits)}); + } + for (size_t i = 0; i < props_to_use.size(); i++) { + props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); + } + sample_counts.push_back(1); + num_samples++; + if (AddToTableAndMerge(sample_counts.size() - 1)) { + for (auto &r : residuals) r.pop_back(); + for (auto &p : props) p.pop_back(); + sample_counts.pop_back(); + } +} + +void TreeSamples::Swap(size_t a, size_t b) { + if (a == b) return; + for (auto &r : residuals) { + std::swap(r[a], r[b]); + } + for (auto &p : props) { + std::swap(p[a], p[b]); + } + std::swap(sample_counts[a], sample_counts[b]); +} + +void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { + if (b == c) return Swap(a, b); + for (auto &r : residuals) { + auto tmp = r[a]; + r[a] = r[c]; + r[c] = r[b]; + r[b] = tmp; + } + for (auto &p : props) { + auto tmp = p[a]; + p[a] = p[c]; + p[c] = p[b]; + p[b] = tmp; + } + auto tmp = sample_counts[a]; + sample_counts[a] = sample_counts[c]; + sample_counts[c] = sample_counts[b]; + sample_counts[b] = tmp; +} + +namespace { +std::vector QuantizeHistogram(const std::vector &histogram, + size_t num_chunks) { + if (histogram.empty()) return {}; + // TODO(veluca): selecting distinct quantiles is likely not the best + // way to go about this. + std::vector thresholds; + size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); + size_t cumsum = 0; + size_t threshold = 0; + for (size_t i = 0; i + 1 < histogram.size(); i++) { + cumsum += histogram[i]; + if (cumsum > (threshold + 1) * sum / num_chunks) { + thresholds.push_back(i); + while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++; + } + } + return thresholds; +} + +std::vector QuantizeSamples(const std::vector &samples, + size_t num_chunks) { + if (samples.empty()) return {}; + int min = *std::min_element(samples.begin(), samples.end()); + constexpr int kRange = 512; + min = std::min(std::max(min, -kRange), kRange); + std::vector counts(2 * kRange + 1); + for (int s : samples) { + uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; + counts[sample_offset]++; + } + std::vector thresholds = QuantizeHistogram(counts, num_chunks); + for (auto &v : thresholds) v += min; + return thresholds; +} +} // namespace + +void TreeSamples::PreQuantizeProperties( + const StaticPropRange &range, + const std::vector &multiplier_info, + const std::vector &group_pixel_count, + const std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples, size_t max_property_values) { + // If we have forced splits because of multipliers, choose channel and group + // thresholds accordingly. + std::vector group_multiplier_thresholds; + std::vector channel_multiplier_thresholds; + for (const auto &v : multiplier_info) { + if (v.range[0][0] != range[0][0]) { + channel_multiplier_thresholds.push_back(v.range[0][0] - 1); + } + if (v.range[0][1] != range[0][1]) { + channel_multiplier_thresholds.push_back(v.range[0][1] - 1); + } + if (v.range[1][0] != range[1][0]) { + group_multiplier_thresholds.push_back(v.range[1][0] - 1); + } + if (v.range[1][1] != range[1][1]) { + group_multiplier_thresholds.push_back(v.range[1][1] - 1); + } + } + std::sort(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()); + channel_multiplier_thresholds.resize( + std::unique(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()) - + channel_multiplier_thresholds.begin()); + std::sort(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()); + group_multiplier_thresholds.resize( + std::unique(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()) - + group_multiplier_thresholds.begin()); + + compact_properties.resize(props_to_use.size()); + auto quantize_channel = [&]() { + if (!channel_multiplier_thresholds.empty()) { + return channel_multiplier_thresholds; + } + return QuantizeHistogram(channel_pixel_count, max_property_values); + }; + auto quantize_group_id = [&]() { + if (!group_multiplier_thresholds.empty()) { + return group_multiplier_thresholds; + } + return QuantizeHistogram(group_pixel_count, max_property_values); + }; + auto quantize_coordinate = [&]() { + std::vector quantized; + quantized.reserve(max_property_values - 1); + for (size_t i = 0; i + 1 < max_property_values; i++) { + quantized.push_back((i + 1) * 256 / max_property_values - 1); + } + return quantized; + }; + std::vector abs_pixel_thr; + std::vector pixel_thr; + auto quantize_pixel_property = [&]() { + if (pixel_thr.empty()) { + pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return pixel_thr; + }; + auto quantize_abs_pixel_property = [&]() { + if (abs_pixel_thr.empty()) { + quantize_pixel_property(); // Compute the non-abs thresholds. + for (auto &v : pixel_samples) v = std::abs(v); + abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return abs_pixel_thr; + }; + std::vector abs_diff_thr; + std::vector diff_thr; + auto quantize_diff_property = [&]() { + if (diff_thr.empty()) { + diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return diff_thr; + }; + auto quantize_abs_diff_property = [&]() { + if (abs_diff_thr.empty()) { + quantize_diff_property(); // Compute the non-abs thresholds. + for (auto &v : diff_samples) v = std::abs(v); + abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return abs_diff_thr; + }; + auto quantize_wp = [&]() { + if (max_property_values < 32) { + return std::vector{-127, -63, -31, -15, -7, -3, -1, 0, + 1, 3, 7, 15, 31, 63, 127}; + } + if (max_property_values < 64) { + return std::vector{-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + } + return std::vector{ + -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, + -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, + -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, + 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, + 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; + }; + + property_mapping.resize(props_to_use.size()); + for (size_t i = 0; i < props_to_use.size(); i++) { + if (props_to_use[i] == 0) { + compact_properties[i] = quantize_channel(); + } else if (props_to_use[i] == 1) { + compact_properties[i] = quantize_group_id(); + } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { + compact_properties[i] = quantize_coordinate(); + } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || + props_to_use[i] == 8 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { + compact_properties[i] = quantize_pixel_property(); + } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { + compact_properties[i] = quantize_abs_pixel_property(); + } else if (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { + compact_properties[i] = quantize_abs_diff_property(); + } else if (props_to_use[i] == kWPProp) { + compact_properties[i] = quantize_wp(); + } else { + compact_properties[i] = quantize_diff_property(); + } + property_mapping[i].resize(kPropertyRange * 2 + 1); + size_t mapped = 0; + for (size_t j = 0; j < property_mapping[i].size(); j++) { + while (mapped < compact_properties[i].size() && + static_cast(j) - kPropertyRange > + compact_properties[i][mapped]) { + mapped++; + } + // property_mapping[i] of a value V is `mapped` if + // compact_properties[i][mapped] <= j and + // compact_properties[i][mapped-1] > j + // This is because the decision node in the tree splits on (property) > j, + // hence everything that is not > of a threshold should be clustered + // together. + property_mapping[i][j] = mapped; + } + } +} + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector &group_pixel_count, + std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples) { + if (group_pixel_count.size() <= group_id) { + group_pixel_count.resize(group_id + 1); + } + if (channel_pixel_count.size() < image.channel.size()) { + channel_pixel_count.resize(image.channel.size()); + } + Rng rng(group_id); + // Sample 10% of the final number of samples for property quantization. + float fraction = options.nb_repeats * 0.1; + std::geometric_distribution dist(fraction); + size_t total_pixels = 0; + std::vector channel_ids; + for (size_t i = 0; i < image.channel.size(); i++) { + if (image.channel[i].w <= 1 || image.channel[i].h == 0) { + continue; // skip empty or width-1 channels. + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + channel_ids.push_back(i); + group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; + channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; + total_pixels += image.channel[i].w * image.channel[i].h; + } + if (channel_ids.empty()) return; + pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); + diff_samples.reserve(diff_samples.size() + fraction * total_pixels); + size_t i = 0; + size_t y = 0; + size_t x = 0; + auto advance = [&](size_t amount) { + x += amount; + // Detect row overflow (rare). + while (x >= image.channel[channel_ids[i]].w) { + x -= image.channel[channel_ids[i]].w; + y++; + // Detect end-of-channel (even rarer). + if (y == image.channel[channel_ids[i]].h) { + i++; + y = 0; + if (i >= channel_ids.size()) { + return; + } + } + } + }; + advance(dist(rng)); + for (; i < channel_ids.size(); advance(dist(rng) + 1)) { + const pixel_type *row = image.channel[channel_ids[i]].Row(y); + pixel_samples.push_back(row[x]); + size_t xp = x == 0 ? 1 : x - 1; + diff_samples.push_back(row[x] - row[xp]); + } +} + +// TODO(veluca): very simple encoding scheme. This should be improved. +void TokenizeTree(const Tree &tree, std::vector *tokens, + Tree *decoder_tree) { + JXL_ASSERT(tree.size() <= kMaxTreeSize); + std::queue q; + q.push(0); + size_t leaf_id = 0; + decoder_tree->clear(); + while (!q.empty()) { + int cur = q.front(); + q.pop(); + JXL_ASSERT(tree[cur].property >= -1); + tokens->emplace_back(kPropertyContext, tree[cur].property + 1); + if (tree[cur].property == -1) { + tokens->emplace_back(kPredictorContext, + static_cast(tree[cur].predictor)); + tokens->emplace_back(kOffsetContext, + PackSigned(tree[cur].predictor_offset)); + uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); + uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; + tokens->emplace_back(kMultiplierLogContext, mul_log); + tokens->emplace_back(kMultiplierBitsContext, mul_bits); + JXL_ASSERT(tree[cur].predictor < Predictor::Best); + decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, + tree[cur].predictor_offset, + tree[cur].multiplier); + continue; + } + decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, + decoder_tree->size() + q.size() + 1, + decoder_tree->size() + q.size() + 2, + Predictor::Zero, 0, 1); + q.push(tree[cur].lchild); + q.push(tree[cur].rchild); + tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h new file mode 100644 index 000000000000..f0bd5d0fb7a9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h @@ -0,0 +1,166 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ + +#include + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Struct to collect all the data needed to build a tree. +struct TreeSamples { + bool HasSamples() const { + return !residuals.empty() && !residuals[0].empty(); + } + size_t NumDistinctSamples() const { return sample_counts.size(); } + size_t NumSamples() const { return num_samples; } + // Set the predictor to use. Must be called before adding any samples. + Status SetPredictor(Predictor predictor, + ModularOptions::WPTreeMode wp_tree_mode); + // Set the properties to use. Must be called before adding any samples. + Status SetProperties(const std::vector &properties, + ModularOptions::WPTreeMode wp_tree_mode); + + size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; } + size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; } + size_t Count(size_t i) const { return sample_counts[i]; } + size_t PredictorIndex(Predictor predictor) const { + const auto predictor_elem = + std::find(predictors.begin(), predictors.end(), predictor); + JXL_DASSERT(predictor_elem != predictors.end()); + return predictor_elem - predictors.begin(); + } + size_t PropertyIndex(size_t property) const { + const auto property_elem = + std::find(props_to_use.begin(), props_to_use.end(), property); + JXL_DASSERT(property_elem != props_to_use.end()); + return property_elem - props_to_use.begin(); + } + size_t NumPropertyValues(size_t property_index) const { + return compact_properties[property_index].size() + 1; + } + // Returns the *quantized* property value. + size_t Property(size_t property_index, size_t i) const { + return props[property_index][i]; + } + int UnquantizeProperty(size_t property_index, uint32_t quant) const { + JXL_ASSERT(quant < compact_properties[property_index].size()); + return compact_properties[property_index][quant]; + } + + Predictor PredictorFromIndex(size_t index) const { + JXL_DASSERT(index < predictors.size()); + return predictors[index]; + } + size_t PropertyFromIndex(size_t index) const { + JXL_DASSERT(index < props_to_use.size()); + return props_to_use[index]; + } + size_t NumPredictors() const { return predictors.size(); } + size_t NumProperties() const { return props_to_use.size(); } + + // Preallocate data for a given number of samples. MUST be called before + // adding any sample. + void PrepareForSamples(size_t num_samples); + // Add a sample. + void AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions); + // Pre-cluster property values. + void PreQuantizeProperties( + const StaticPropRange &range, + const std::vector &multiplier_info, + const std::vector &group_pixel_count, + const std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples, size_t max_property_values); + + void AllSamplesDone() { dedup_table_ = std::vector(); } + + uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const { + v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange; + return property_mapping[prop][v]; + } + + // Swaps samples in position a and b. Does nothing if a == b. + void Swap(size_t a, size_t b); + + // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can + // just call Swap(a, b) if b==c. + void ThreeShuffle(size_t a, size_t b, size_t c); + + private: + // TODO(veluca): as the total number of properties and predictors are known + // before adding any samples, it might be better to interleave predictors, + // properties and counts in a single vector to improve locality. + // A first attempt at doing this actually results in much slower encoding, + // possibly because of the more complex addressing. + struct ResidualToken { + uint8_t tok; + uint8_t nbits; + }; + // Residual information: token and number of extra bits, per predictor. + std::vector> residuals; + // Number of occurrences of each sample. + std::vector sample_counts; + // Property values, quantized to at most 256 distinct values. + std::vector> props; + // Decompactification info for `props`. + std::vector> compact_properties; + // List of properties to use. + std::vector props_to_use; + // List of predictors to use. + std::vector predictors; + // Mapping property value -> quantized property value. + static constexpr int kPropertyRange = 511; + std::vector> property_mapping; + // Number of samples seen. + size_t num_samples = 0; + // Table for deduplication. + static constexpr uint32_t kDedupEntryUnused{static_cast(-1)}; + std::vector dedup_table_; + + // Functions for sample deduplication. + bool IsSameSample(size_t a, size_t b) const; + size_t Hash1(size_t a) const; + size_t Hash2(size_t a) const; + void InitTable(size_t size); + // Returns true if `a` was already present in the table. + bool AddToTableAndMerge(size_t a); + void AddToTable(size_t a); +}; + +void TokenizeTree(const Tree &tree, std::vector *tokens, + Tree *decoder_tree); + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector &group_pixel_count, + std::vector &channel_pixel_count, + std::vector &pixel_samples, + std::vector &diff_samples); + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree); + +} // namespace jxl +#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc new file mode 100644 index 000000000000..d8f92b0fc4e0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc @@ -0,0 +1,509 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/modular/encoding/encoding.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +// Removes all nodes that use a static property (i.e. channel or group ID) from +// the tree and collapses each node on even levels with its two children to +// produce a flatter tree. Also computes whether the resulting tree requires +// using the weighted predictor. +FlatTree FilterTree(const Tree &global_tree, + std::array &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only) { + *num_props = 0; + bool has_wp = false; + bool has_non_wp = false; + *gradient_only = true; + const auto mark_property = [&](int32_t p) { + if (p == kWPProp) { + has_wp = true; + } else if (p >= kNumStaticProperties) { + has_non_wp = true; + } + if (p >= kNumStaticProperties && p != kGradientProp) { + *gradient_only = false; + } + }; + FlatTree output; + std::queue nodes; + nodes.push(0); + // Produces a trimmed and flattened tree by doing a BFS visit of the original + // tree, ignoring branches that are known to be false and proceeding two + // levels at a time to collapse nodes in a flatter tree; if an inner parent + // node has a leaf as a child, the leaf is duplicated and an implicit fake + // node is added. This allows to reduce the number of branches when traversing + // the resulting flat tree. + while (!nodes.empty()) { + size_t cur = nodes.front(); + nodes.pop(); + // Skip nodes that we can decide now, by jumping directly to their children. + while (global_tree[cur].property < kNumStaticProperties && + global_tree[cur].property != -1) { + if (static_props[global_tree[cur].property] > global_tree[cur].splitval) { + cur = global_tree[cur].lchild; + } else { + cur = global_tree[cur].rchild; + } + } + FlatDecisionNode flat; + if (global_tree[cur].property == -1) { + flat.property0 = -1; + flat.childID = global_tree[cur].lchild; + flat.predictor = global_tree[cur].predictor; + flat.predictor_offset = global_tree[cur].predictor_offset; + flat.multiplier = global_tree[cur].multiplier; + *gradient_only &= flat.predictor == Predictor::Gradient; + has_wp |= flat.predictor == Predictor::Weighted; + has_non_wp |= flat.predictor != Predictor::Weighted; + output.push_back(flat); + continue; + } + flat.childID = output.size() + nodes.size() + 1; + + flat.property0 = global_tree[cur].property; + *num_props = std::max(flat.property0 + 1, *num_props); + flat.splitval0 = global_tree[cur].splitval; + + for (size_t i = 0; i < 2; i++) { + size_t cur_child = + i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild; + // Skip nodes that we can decide now. + while (global_tree[cur_child].property < kNumStaticProperties && + global_tree[cur_child].property != -1) { + if (static_props[global_tree[cur_child].property] > + global_tree[cur_child].splitval) { + cur_child = global_tree[cur_child].lchild; + } else { + cur_child = global_tree[cur_child].rchild; + } + } + // We ended up in a leaf, add a dummy decision and two copies of the leaf. + if (global_tree[cur_child].property == -1) { + flat.properties[i] = 0; + flat.splitvals[i] = 0; + nodes.push(cur_child); + nodes.push(cur_child); + } else { + flat.properties[i] = global_tree[cur_child].property; + flat.splitvals[i] = global_tree[cur_child].splitval; + nodes.push(global_tree[cur_child].lchild); + nodes.push(global_tree[cur_child].rchild); + *num_props = std::max(flat.properties[i] + 1, *num_props); + } + } + + for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]); + mark_property(flat.property0); + output.push_back(flat); + } + if (*num_props > kNumNonrefProperties) { + *num_props = + DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) * + kExtraPropsPerChannel + + kNumNonrefProperties; + } else { + *num_props = kNumNonrefProperties; + } + *use_wp = has_wp; + *wp_only = has_wp && !has_non_wp; + + return output; +} + +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + Image *image) { + Channel &channel = image->channel[chan]; + + std::array static_props = {chan, + (int)group_id}; + // TODO(veluca): filter the tree according to static_props. + + // zero pixel channel? could happen + if (channel.w == 0 || channel.h == 0) return true; + + channel.resize(channel.w, channel.h); + bool tree_has_wp_prop_or_pred = false; + bool is_wp_only = false; + bool is_gradient_only = false; + size_t num_props; + FlatTree tree = + FilterTree(global_tree, static_props, &num_props, + &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only); + + // From here on, tree lookup returns a *clustered* context ID. + // This avoids an extra memory lookup after tree traversal. + for (size_t i = 0; i < tree.size(); i++) { + if (tree[i].property0 == -1) { + tree[i].childID = context_map[tree[i].childID]; + } + } + + JXL_DEBUG_V(3, "Decoded MA tree with %zu nodes", tree.size()); + + // MAANS decode + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Those contexts are *clustered* context ids. This reduces stack usages and + // avoids an extra memory lookup. + // Initialized to avoid clang-tidy complaining. + uint8_t context_lookup[2 * kPropRangeFast] = {}; + int8_t multipliers[2 * kPropRangeFast] = {}; + int8_t offsets[2 * kPropRangeFast] = {}; + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + if (is_gradient_only) { + is_gradient_only = + TreeToLookupTable(tree, context_lookup, offsets, multipliers); + } + + const auto make_pixel = [](uint64_t v, pixel_type multiplier, + pixel_type_w offset) -> pixel_type { + JXL_DASSERT((v & 0xFFFFFFFF) == v); + pixel_type_w val = UnpackSigned(v); + return SaturatingAdd(val * multiplier, offset); + }; + + if (is_gradient_only) { + JXL_DEBUG_V(8, "Gradient fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min( + std::max(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast(offsets[pos]) + guess); + } + } + } else if (is_wp_only) { + JXL_DEBUG_V(8, "WP fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = context_lookup[pos]; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multipliers[pos], + static_cast(offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1) { + // special optimized case: no meta-adaptation, so no need + // to compute properties. + Predictor predictor = tree[0].predictor; + int64_t offset = tree[0].predictor_offset; + int32_t multiplier = tree[0].multiplier; + size_t ctx_id = tree[0].childID; + if (predictor == Predictor::Zero) { + uint32_t value; + if (reader->IsSingleValueAndAdvance(ctx_id, &value, + channel.w * channel.h)) { + // Special-case: histogram has a single symbol, with no extra bits, and + // we use ANS mode. + JXL_DEBUG_V(8, "Fastest track."); + pixel_type v = make_pixel(value, multiplier, offset); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + std::fill(r, r + channel.w, v); + } + + } else { + JXL_DEBUG_V(8, "Fast track."); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, offset); + } + } + } + } else if (predictor != Predictor::Weighted) { + // special optimized case: no meta-adaptation, no wp, so no need to + // compute properties + JXL_DEBUG_V(8, "Quite fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = + PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor); + pixel_type_w g = pred.guess + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + // NOTE: pred.multiplier is unset. + r[x] = make_pixel(v, multiplier, g); + } + } + } else { + // special optimized case: no meta-adaptation, so no need to + // compute properties + JXL_DEBUG_V(8, "Somewhat fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y, + predictor, &wp_state) + .guess + + offset; + uint64_t v = reader->ReadHybridUintClustered(ctx_id, br); + r[x] = make_pixel(v, multiplier, g); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } + } else if (!tree_has_wp_prop_or_pred) { + // special optimized case: the weighted predictor and its properties are not + // used, so no need to compute weights and properties. + JXL_DEBUG_V(8, "Slow track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, *image, chan, &references); + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } + } else { + JXL_DEBUG_V(8, "Slowest track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + InitPropsRow(&properties, static_props, y); + PrecomputeReferences(channel, y, *image, chan, &references); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = reader->ReadHybridUintClustered(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + return true; +} + +GroupHeader::GroupHeader() { Bundle::Init(this); } + +Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, + size_t group_id, ModularOptions *options, + const Tree *global_tree, const ANSCode *global_code, + const std::vector *global_ctx_map, + bool allow_truncated_group) { + if (image.nb_channels < 1) return true; + + // decode transforms + JXL_RETURN_IF_ERROR(Bundle::Read(br, &header)); + JXL_DEBUG_V(4, "Global option: up to %i back-referencing MA properties.", + options->max_properties); + JXL_DEBUG_V(3, "Image data underwent %zu transformations: ", + header.transforms.size()); + image.transform = header.transforms; + for (Transform &transform : image.transform) { + JXL_RETURN_IF_ERROR(transform.MetaApply(image)); + } + if (image.error) { + return JXL_FAILURE("Corrupt file. Aborting."); + } + + size_t nb_channels = image.channel.size(); + + size_t num_chans = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options->max_chan_size || + image.channel[i].h > options->max_chan_size)) { + break; + } + num_chans++; + } + if (num_chans == 0) return true; + + // Read tree. + Tree tree_storage; + std::vector context_map_storage; + ANSCode code_storage; + const Tree *tree = &tree_storage; + const ANSCode *code = &code_storage; + const std::vector *context_map = &context_map_storage; + if (!header.use_global_tree) { + size_t tree_size_limit = 1024 + image.w * image.h * nb_channels; + JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, tree_size_limit)); + JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + &code_storage, &context_map_storage)); + } else { + if (!global_tree || !global_code || !global_ctx_map || + global_tree->empty()) { + return JXL_FAILURE("No global tree available but one was requested"); + } + tree = global_tree; + code = global_code; + context_map = global_ctx_map; + } + + size_t distance_multiplier = 0; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + if (channel.w > distance_multiplier) { + distance_multiplier = channel.w; + } + } + // Read channels + ANSSymbolReader reader(code, br, distance_multiplier); + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(br, &reader, *context_map, + *tree, header.wp_header, i, + group_id, &image)); + // Truncated group. + if (!br->AllReadsWithinBounds()) { + if (!allow_truncated_group) return JXL_FAILURE("Truncated input"); + ZeroFillImage(&channel.plane); + return Status(StatusCode::kNotEnoughBytes); + } + } + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, int undo_transforms, + const Tree *tree, const ANSCode *code, + const std::vector *ctx_map, + bool allow_truncated_group) { +#ifdef JXL_ENABLE_ASSERT + std::vector> req_sizes(image.channel.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + req_sizes[c] = {image.channel[c].w, image.channel[c].h}; + } +#endif + GroupHeader local_header; + if (header == nullptr) header = &local_header; + auto dec_status = ModularDecode(br, image, *header, group_id, options, tree, + code, ctx_map, allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) return dec_status; + image.undo_transforms(header->wp_header, undo_transforms); + if (image.error) return JXL_FAILURE("Corrupt file. Aborting."); + size_t bit_pos = br->TotalBitsConsumed(); + JXL_DEBUG_V(4, "Modular-decoded a %zux%zu nbchans=%zu image from %zu bytes", + image.w, image.h, image.real_nb_channels, + (br->TotalBitsConsumed() - bit_pos) / 8); + (void)bit_pos; +#ifdef JXL_ENABLE_ASSERT + // Check that after applying all transforms we are back to the requested image + // sizes, otherwise there's a programming error with the transformations. + if (undo_transforms == -1 || undo_transforms == 0) { + JXL_ASSERT(image.channel.size() == req_sizes.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + JXL_ASSERT(req_sizes[c].first == image.channel[c].w); + JXL_ASSERT(req_sizes[c].second == image.channel[c].h); + } + } +#endif + return dec_status; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h new file mode 100644 index 000000000000..727efc15cef7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h @@ -0,0 +1,150 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_ + +#include +#include + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +// Valid range of properties for using lookup tables instead of trees. +constexpr int32_t kPropRangeFast = 512; + +struct GroupHeader : public Fields { + GroupHeader(); + + const char *Name() const override { return "GroupHeader"; } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header)); + uint32_t num_transforms = static_cast(transforms.size()); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(8, 18), 0, + &num_transforms)); + if (visitor->IsReading()) transforms.resize(num_transforms); + for (size_t i = 0; i < num_transforms; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i])); + } + return true; + } + + bool use_global_tree; + weighted::Header wp_header; + + std::vector transforms; +}; + +FlatTree FilterTree(const Tree &global_tree, + std::array &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only); + +template +bool TreeToLookupTable(const FlatTree &tree, + T context_lookup[2 * kPropRangeFast], + int8_t offsets[2 * kPropRangeFast], + int8_t multipliers[2 * kPropRangeFast] = nullptr) { + struct TreeRange { + // Begin *excluded*, end *included*. This works best with > vs <= decision + // nodes. + int begin, end; + size_t pos; + }; + std::vector ranges; + ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0}); + while (!ranges.empty()) { + TreeRange cur = ranges.back(); + ranges.pop_back(); + if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 || + cur.end > kPropRangeFast - 1) { + // Tree is outside the allowed range, exit. + return false; + } + auto &node = tree[cur.pos]; + // Leaf. + if (node.property0 == -1) { + if (node.predictor_offset < std::numeric_limits::min() || + node.predictor_offset > std::numeric_limits::max()) { + return false; + } + if (node.multiplier < std::numeric_limits::min() || + node.multiplier > std::numeric_limits::max()) { + return false; + } + if (multipliers == nullptr && node.multiplier != 1) { + return false; + } + for (int i = cur.begin + 1; i < cur.end + 1; i++) { + context_lookup[i + kPropRangeFast] = node.childID; + if (multipliers) multipliers[i + kPropRangeFast] = node.multiplier; + offsets[i + kPropRangeFast] = node.predictor_offset; + } + continue; + } + // > side of top node. + if (node.properties[0] >= kNumStaticProperties) { + ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID})); + ranges.push_back( + TreeRange({node.splitval0, node.splitvals[0], node.childID + 1})); + } else { + ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID})); + } + // <= side + if (node.properties[1] >= kNumStaticProperties) { + ranges.push_back( + TreeRange({node.splitvals[1], node.splitval0, node.childID + 2})); + ranges.push_back( + TreeRange({cur.begin, node.splitvals[1], node.childID + 3})); + } else { + ranges.push_back( + TreeRange({cur.begin, node.splitval0, node.childID + 2})); + } + } + return true; +} +// TODO(veluca): make cleaner interfaces. + +// undo_transforms == N > 0: undo all transforms except the first N +// (e.g. to represent YCbCr420 losslessly) +// undo_transforms == 0: undo all transforms +// undo_transforms == -1: undo all transforms but don't clamp to range +// undo_transforms == -2: don't undo any transform +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, + int undo_transforms = -1, + const Tree *tree = nullptr, + const ANSCode *code = nullptr, + const std::vector *ctx_map = nullptr, + bool allow_truncated_group = false); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h new file mode 100644 index 000000000000..c4e6342cff86 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ +#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ + +#include + +namespace jxl { + +enum MATreeContext : size_t { + kSplitValContext = 0, + kPropertyContext = 1, + kPredictorContext = 2, + kOffsetContext = 3, + kMultiplierLogContext = 4, + kMultiplierBitsContext = 5, + + kNumTreeContexts = 6, +}; + +static constexpr size_t kMaxTreeSize = 1 << 26; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc new file mode 100644 index 000000000000..9b880f62e83e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/modular/modular_image.h" + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void Channel::compute_minmax(pixel_type *min, pixel_type *max) const { + pixel_type realmin = std::numeric_limits::max(); + pixel_type realmax = std::numeric_limits::min(); + for (size_t y = 0; y < h; y++) { + const pixel_type *JXL_RESTRICT p = plane.Row(y); + for (size_t x = 0; x < w; x++) { + if (p[x] < realmin) realmin = p[x]; + if (p[x] > realmax) realmax = p[x]; + } + } + + if (min) *min = realmin; + if (max) *max = realmax; +} + +void Image::undo_transforms(const weighted::Header &wp_header, int keep, + jxl::ThreadPool *pool) { + if (keep == -2) return; + while ((int)transform.size() > keep && transform.size() > 0) { + Transform t = transform.back(); + JXL_DEBUG_V(4, "Undoing transform %s", t.Name()); + Status result = t.Inverse(*this, wp_header, pool); + if (result == false) { + JXL_NOTIFY_ERROR("Error while undoing transform %s.", t.Name()); + error = true; + return; + } + JXL_DEBUG_V(8, "Undoing transform %s: done", t.Name()); + transform.pop_back(); + } + if (!keep) { // clamp the values to the valid range (lossy + // compression can produce values outside the range) + for (size_t i = 0; i < channel.size(); i++) { + for (size_t y = 0; y < channel[i].h; y++) { + pixel_type *JXL_RESTRICT p = channel[i].plane.Row(y); + for (size_t x = 0; x < channel[i].w; x++, p++) { + *p = Clamp1(*p, minval, maxval); + } + } + } + } +} + +bool Image::do_transform(const Transform &tr, + const weighted::Header &wp_header) { + Transform t = tr; + bool did_it = t.Forward(*this, wp_header); + if (did_it) transform.push_back(t); + return did_it; +} + +Image::Image(size_t iw, size_t ih, int maxval, int nb_chans) + : w(iw), + h(ih), + minval(0), + maxval(maxval), + nb_channels(nb_chans), + real_nb_channels(nb_chans), + nb_meta_channels(0), + error(false) { + for (int i = 0; i < nb_chans; i++) channel.emplace_back(Channel(iw, ih)); +} +Image::Image() + : w(0), + h(0), + minval(0), + maxval(255), + nb_channels(0), + real_nb_channels(0), + nb_meta_channels(0), + error(true) {} + +Image::~Image() = default; + +Image &Image::operator=(Image &&other) noexcept { + w = other.w; + h = other.h; + minval = other.minval; + maxval = other.maxval; + nb_channels = other.nb_channels; + real_nb_channels = other.real_nb_channels; + nb_meta_channels = other.nb_meta_channels; + error = other.error; + channel = std::move(other.channel); + transform = std::move(other.transform); + return *this; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.h b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h new file mode 100644 index 000000000000..13b280792935 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h @@ -0,0 +1,165 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_MODULAR_IMAGE_H_ +#define LIB_JXL_MODULAR_MODULAR_IMAGE_H_ + +#include +#include +#include +#include + +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +typedef int32_t pixel_type; // can use int16_t if it's only for 8-bit images. + // Need some wiggle room for YCoCg / Squeeze etc + +typedef int64_t pixel_type_w; + +namespace weighted { +struct Header; +} + +class Channel { + public: + jxl::Plane plane; + size_t w, h; + int hshift, vshift; // w ~= image.w >> hshift; h ~= image.h >> vshift + int hcshift, + vcshift; // cumulative, i.e. when decoding up to this point, we have data + // available with these shifts (for this component) + Channel(size_t iw, size_t ih, int hsh = 0, int vsh = 0, int hcsh = 0, + int vcsh = 0) + : plane(iw, ih), + w(iw), + h(ih), + hshift(hsh), + vshift(vsh), + hcshift(hcsh), + vcshift(vcsh) {} + Channel() + : plane(0, 0), w(0), h(0), hshift(0), vshift(0), hcshift(0), vcshift(0) {} + + Channel(const Channel& other) = delete; + Channel& operator=(const Channel& other) = delete; + + // Move assignment + Channel& operator=(Channel&& other) noexcept { + w = other.w; + h = other.h; + hshift = other.hshift; + vshift = other.vshift; + hcshift = other.hcshift; + vcshift = other.vcshift; + plane = std::move(other.plane); + return *this; + } + + // Move constructor + Channel(Channel&& other) noexcept = default; + + void resize(pixel_type value = 0) { + if (plane.xsize() == w && plane.ysize() == h) return; + jxl::Plane resizedplane(w, h); + if (plane.xsize() || plane.ysize()) { + // copy pixels over from old plane to new plane + size_t y = 0; + for (; y < plane.ysize() && y < h; y++) { + const pixel_type* JXL_RESTRICT p = plane.Row(y); + pixel_type* JXL_RESTRICT rp = resizedplane.Row(y); + size_t x = 0; + for (; x < plane.xsize() && x < w; x++) rp[x] = p[x]; + for (; x < w; x++) rp[x] = value; + } + for (; y < h; y++) { + pixel_type* JXL_RESTRICT p = resizedplane.Row(y); + for (size_t x = 0; x < w; x++) p[x] = value; + } + } else if (w && h && value == 0) { + size_t ppr = resizedplane.bytes_per_row(); + memset(resizedplane.bytes(), 0, ppr * h); + } else if (w && h) { + FillImage(value, &resizedplane); + } + plane = std::move(resizedplane); + } + void resize(int nw, int nh) { + w = nw; + h = nh; + resize(); + } + bool is_empty() const { return (plane.ysize() == 0); } + + JXL_INLINE pixel_type* Row(const size_t y) { return plane.Row(y); } + JXL_INLINE const pixel_type* Row(const size_t y) const { + return plane.Row(y); + } + void compute_minmax(pixel_type* min, pixel_type* max) const; +}; + +class Transform; + +class Image { + public: + std::vector + channel; // image data, transforms can dramatically change the number of + // channels and their semantics + std::vector + transform; // keeps track of the transforms that have been applied (and + // that have to be undone when rendering the image) + + size_t w, h; // actual dimensions of the image (channels may have different + // dimensions due to transforms like chroma subsampling and DCT) + int minval, maxval; // actual (largest) range of the channels (actual ranges + // might be different due to transforms; after undoing + // transforms, might still be different due to lossy) + size_t nb_channels; // actual number of distinct channels (after undoing all + // transforms except Palette; can be different from + // channel.size()) + size_t real_nb_channels; // real number of channels (after undoing all + // transforms) + size_t nb_meta_channels; // first few channels might contain things like + // palettes or compaction data that are not yet real + // image data + bool error; // true if a fatal error occurred, false otherwise + + Image(size_t iw, size_t ih, int maxval, int nb_chans); + + Image(); + ~Image(); + + Image(const Image& other) = delete; + Image& operator=(const Image& other) = delete; + + Image& operator=(Image&& other) noexcept; + Image(Image&& other) noexcept = default; + + bool do_transform(const Transform& t, const weighted::Header& wp_header); + // undo all except the first 'keep' transforms + void undo_transforms(const weighted::Header& wp_header, int keep = 0, + jxl::ThreadPool* pool = nullptr); +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_MODULAR_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/options.h b/third_party/jpeg-xl/lib/jxl/modular/options.h new file mode 100644 index 000000000000..f9be61e80c11 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/options.h @@ -0,0 +1,178 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_OPTIONS_H_ +#define LIB_JXL_MODULAR_OPTIONS_H_ + +#include + +#include +#include + +namespace jxl { + +using PropertyVal = int32_t; +using Properties = std::vector; + +enum class Predictor : uint32_t { + Zero = 0, + Left = 1, + Top = 2, + Average0 = 3, + Select = 4, + Gradient = 5, + Weighted = 6, + TopRight = 7, + TopLeft = 8, + LeftLeft = 9, + Average1 = 10, + Average2 = 11, + Average3 = 12, + Average4 = 13, + // The following predictors are encoder-only. + Best = 14, // Best of Gradient and Weighted + Variable = + 15, // Find the best decision tree for predictors/predictor per row +}; + +inline const char* PredictorName(Predictor p) { + switch (p) { + case Predictor::Zero: + return "Zero"; + case Predictor::Left: + return "Left"; + case Predictor::Top: + return "Top"; + case Predictor::Average0: + return "Avg0"; + case Predictor::Average1: + return "Avg1"; + case Predictor::Average2: + return "Avg2"; + case Predictor::Average3: + return "Avg3"; + case Predictor::Average4: + return "Avg4"; + case Predictor::Select: + return "Sel"; + case Predictor::Gradient: + return "Grd"; + case Predictor::Weighted: + return "Wgh"; + case Predictor::TopLeft: + return "TopL"; + case Predictor::TopRight: + return "TopR"; + case Predictor::LeftLeft: + return "LL"; + default: + return "INVALID"; + }; +} + +inline std::array PredictorColor(Predictor p) { + switch (p) { + case Predictor::Zero: + return {0, 0, 0}; + case Predictor::Left: + return {255, 0, 0}; + case Predictor::Top: + return {0, 255, 0}; + case Predictor::Average0: + return {0, 0, 255}; + case Predictor::Average4: + return {192, 128, 128}; + case Predictor::Select: + return {255, 255, 0}; + case Predictor::Gradient: + return {255, 0, 255}; + case Predictor::Weighted: + return {0, 255, 255}; + // TODO + default: + return {255, 255, 255}; + }; +} + +constexpr size_t kNumModularPredictors = static_cast(Predictor::Best); + +static constexpr ssize_t kNumStaticProperties = 2; // channel, group_id. + +using StaticPropRange = + std::array, kNumStaticProperties>; + +struct ModularMultiplierInfo { + StaticPropRange range; + uint32_t multiplier; +}; + +struct ModularOptions { + /// Used in both encode and decode: + + // Stop encoding/decoding when reaching a (non-meta) channel that has a + // dimension bigger than max_chan_size. + size_t max_chan_size = 0xFFFFFF; + + /// Encode options: + // Fraction of pixels to look at to learn a MA tree + // Number of iterations to do to learn a MA tree + // (if zero there is no MA context model) + float nb_repeats = .5f; + + // Maximum number of (previous channel) properties to use in the MA trees + int max_properties = 0; // no previous channels + + // Alternative heuristic tweaks. + // Properties default to channel, group, weighted, gradient residual, W-NW, + // NW-N, N-NE, N-NN + std::vector splitting_heuristics_properties = {0, 1, 15, 9, + 10, 11, 12, 13}; + float splitting_heuristics_node_threshold = 96; + size_t max_property_values = 32; + + // Predictor to use for each channel. + Predictor predictor = static_cast(-1); + + int wp_mode = 0; + + float fast_decode_multiplier = 1.01f; + + // Forces the encoder to produce a tree that is compatible with the WP-only + // decode path (or with the no-wp path). + enum class WPTreeMode { kWPOnly, kNoWP, kDefault }; + WPTreeMode wp_tree_mode = WPTreeMode::kDefault; + + // Skip fast paths in the encoder. + bool skip_encoder_fast_path = false; + + // Kind of tree to use. + // TODO(veluca): add tree kinds for JPEG recompression with CfL enabled, + // general AC metadata, different DC qualities, and others. + enum class TreeKind { + kLearn, + kJpegTranscodeACMeta, + kFalconACMeta, + kACMeta, + kWPFixedDC, + kGradientFixedDC, + }; + TreeKind tree_kind = TreeKind::kLearn; + + // Ignore the image and just pretend all tokens are zeroes + bool zero_tokens = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_OPTIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/near-lossless.h b/third_party/jpeg-xl/lib/jxl/modular/transform/near-lossless.h new file mode 100644 index 000000000000..632ca07890fc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/near-lossless.h @@ -0,0 +1,85 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_NEAR_LOSSLESS_H_ +#define LIB_JXL_MODULAR_TRANSFORM_NEAR_LOSSLESS_H_ + +// Very simple lossy preprocessing step. +// Quantizes the prediction residual (so the entropy coder has an easier job) +// Obviously there's room for encoder improvement here +// The decoder doesn't need to know about this step + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +pixel_type DeltaQuantize(int max_error, pixel_type in, pixel_type prediction) { + int resolution = 2 << max_error; + pixel_type d = in - prediction; + int32_t val = std::abs(d); + if (val < std::max(1, resolution / 4 * 3 / 4)) return 0; + if (val < resolution / 2 * 3 / 4) { + return d > 0 ? resolution / 4 : -resolution / 4; + } + if (val < resolution * 3 / 4) return d > 0 ? resolution / 2 : -resolution / 2; + val = (val + resolution / 2) / resolution * resolution; + return d > 0 ? val : -val; +} + +static Status FwdNearLossless(Image& input, size_t begin_c, size_t end_c, + int max_delta_error, Predictor predictor) { + if (begin_c < input.nb_meta_channels || begin_c > input.channel.size() || + end_c < input.nb_meta_channels || end_c >= input.channel.size() || + end_c < begin_c) { + return JXL_FAILURE("Invalid channel range %zu-%zu", begin_c, end_c); + } + + JXL_DEBUG_V(8, "Applying loss on channels %zu-%zu with max delta=%i.", + begin_c, end_c, max_delta_error); + uint64_t total_error = 0; + for (size_t c = begin_c; c <= end_c; c++) { + size_t w = input.channel[c].w; + size_t h = input.channel[c].h; + + Channel out(w, h); + weighted::Header header; + weighted::State wp_state(header, w, h); + for (size_t y = 0; y < h; y++) { + pixel_type* JXL_RESTRICT p_in = input.channel[c].Row(y); + pixel_type* JXL_RESTRICT p_out = out.Row(y); + for (size_t x = 0; x < w; x++) { + PredictionResult pred = PredictNoTreeWP( + w, p_out + x, out.plane.PixelsPerRow(), x, y, predictor, &wp_state); + pixel_type delta = DeltaQuantize(max_delta_error, p_in[x], pred.guess); + pixel_type reconstructed = pred.guess + delta; + int e = p_in[x] - reconstructed; + total_error += abs(e); + p_out[x] = reconstructed; + wp_state.UpdateErrors(p_out[x], x, y, w); + } + } + input.channel[c] = std::move(out); + // fprintf(stderr, "Avg error: %f\n", total_error * 1.0 / (w * h)); + JXL_DEBUG_V(9, " Avg error: %f", total_error * 1.0 / (w * h)); + } + return false; // don't signal this 'transform' in the bitstream, there is no + // inverse transform to be done +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_NEAR_LOSSLESS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h new file mode 100644 index 000000000000..e9ee55126415 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h @@ -0,0 +1,710 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +namespace palette_internal { + +static constexpr int kMaxPaletteLookupTableSize = 1 << 16; + +static constexpr bool kEncodeToHighQualityImplicitPalette = true; + +static constexpr int kCubePow = 3; + +// 5x5x5 color cube for the larger cube. +static constexpr int kLargeCube = 5; + +// Smaller interleaved color cube to fill the holes of the larger cube. +static constexpr int kSmallCube = kLargeCube - 1; +// kSmallCube ** kCubePow +static constexpr int kLargeCubeOffset = kSmallCube * kSmallCube * kSmallCube; + +// Inclusive. +static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1); + +// The purpose of this function is solely to extend the interpretation of +// palette indices to implicit values. If index < nb_deltas, indicating that the +// result is a delta palette entry, it is the responsibility of the caller to +// treat it as such. +static pixel_type GetPaletteValue(const pixel_type *const palette, int index, + const size_t c, const int palette_size, + const int onerow, const int bit_depth) { + if (index < 0) { + static constexpr std::array, 72> kDeltaPalette = { + { + {0, 0, 0}, {4, 4, 4}, {11, 0, 0}, {0, 0, -13}, + {0, -12, 0}, {-10, -10, -10}, {-18, -18, -18}, {-27, -27, -27}, + {-18, -18, 0}, {0, 0, -32}, {-32, 0, 0}, {-37, -37, -37}, + {0, -32, -32}, {24, 24, 45}, {50, 50, 50}, {-45, -24, -24}, + {-24, -45, -45}, {0, -24, -24}, {-34, -34, 0}, {-24, 0, -24}, + {-45, -45, -24}, {64, 64, 64}, {-32, 0, -32}, {0, -32, 0}, + {-32, 0, 32}, {-24, -45, -24}, {45, 24, 45}, {24, -24, -45}, + {-45, -24, 24}, {80, 80, 80}, {64, 0, 0}, {0, 0, -64}, + {0, -64, -64}, {-24, -24, 45}, {96, 96, 96}, {64, 64, 0}, + {45, -24, -24}, {34, -34, 0}, {112, 112, 112}, {24, -45, -45}, + {45, 45, -24}, {0, -32, 32}, {24, -24, 45}, {0, 96, 96}, + {45, -24, 24}, {24, -45, -24}, {-24, -45, 24}, {0, -64, 0}, + {96, 0, 0}, {128, 128, 128}, {64, 0, 64}, {144, 144, 144}, + {96, 96, 0}, {-36, -36, 36}, {45, -24, -45}, {45, -45, -24}, + {0, 0, -96}, {0, 128, 128}, {0, 96, 0}, {45, 24, -45}, + {-128, 0, 0}, {24, -45, 24}, {-45, 24, -45}, {64, 0, -64}, + {64, -64, -64}, {96, 0, 96}, {45, -45, 24}, {24, 45, -45}, + {64, 64, -64}, {128, 128, 0}, {0, 0, -128}, {-24, 45, -45}, + }}; + if (c >= kDeltaPalette[0].size()) { + return 0; + } + // Do not open the brackets, otherwise INT32_MIN negation could overflow. + index = -(index + 1); + index %= 1 + 2 * (kDeltaPalette.size() - 1); + static constexpr int kMultiplier[] = {-1, 1}; + pixel_type result = + kDeltaPalette[((index + 1) >> 1)][c] * kMultiplier[index & 1]; + if (bit_depth > 8) { + result *= static_cast(1) << (bit_depth - 8); + } + return result; + } else if (palette_size <= index && index < palette_size + kLargeCubeOffset) { + if (c >= kCubePow) return 0; + index -= palette_size; + if (c > 0) { + int divisor = kSmallCube; + for (size_t i = 1; i < c; ++i) { + divisor *= kSmallCube; + } + index /= divisor; + } + index %= kSmallCube; + return (index * ((1 << bit_depth) - 1)) / kSmallCube + + (1 << (std::max(0, bit_depth - 3))); + } else if (palette_size + kLargeCubeOffset <= index) { + if (c >= kCubePow) return 0; + index -= palette_size + kLargeCubeOffset; + // TODO(eustas): should we take care of ambiguity created by + // index >= kLargeCube ** 3 ? + if (c > 0) { + int divisor = kLargeCube; + for (size_t i = 1; i < c; ++i) { + divisor *= kLargeCube; + } + index /= divisor; + } + index %= kLargeCube; + return (index * ((1 << bit_depth) - 1)) / (kLargeCube - 1); + } + + return palette[c * onerow + static_cast(index)]; +} + +// Template so that it can take vectors of pixel_type or pixel_type_w +// indifferently. +template +float ColorDistance(const T &JXL_RESTRICT a, const U &JXL_RESTRICT b) { + JXL_ASSERT(a.size() == b.size()); + float distance = 0; + float ave3 = 0; + if (a.size() >= 3) { + ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f); + } + float sum_a = 0, sum_b = 0; + for (size_t c = 0; c < a.size(); ++c) { + const float difference = + static_cast(a[c]) - static_cast(b[c]); + float weight = c == 0 ? 3 : c == 1 ? 5 : 2; + if (c < 3 && (a[c] + b[c] >= ave3)) { + const float add_w[3] = { + 1.15, + 1.15, + 1.12, + }; + weight += add_w[c]; + if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) { + weight -= 0.5; + } + } + distance += difference * difference * weight * weight; + const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1; + sum_a += a[c] * sum_weight; + sum_b += b[c] * sum_weight; + } + distance *= 4; + float sum_difference = sum_a - sum_b; + distance += sum_difference * sum_difference; + return distance; +} + +static int QuantizeColorToImplicitPaletteIndex( + const std::vector &color, const int palette_size, + const int bit_depth, bool high_quality) { + int index = 0; + if (high_quality) { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int quantized = ((kLargeCube - 1) * color[c] + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + index += quantized * multiplier; + multiplier *= kLargeCube; + } + return index + palette_size + kLargeCubeOffset; + } else { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int value = color[c]; + value -= 1 << (std::max(0, bit_depth - 3)); + value = std::max(0, value); + int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + if (quantized > kSmallCube - 1) { + quantized = kSmallCube - 1; + } + index += quantized * multiplier; + multiplier *= kSmallCube; + } + return index + palette_size; + } +} + +} // namespace palette_internal + +namespace { +// Returns the sum of a+b. If ever over- / underflow occurs it is reflected +// in "flags". +pixel_type CautiousAdd(pixel_type a, pixel_type b, pixel_type *flags) { + // Avoid signed integer overflow. + pixel_type sum = static_cast(static_cast(a) + + static_cast(b)); + // We care only about the highest bit. If sign is different, addition is safe. + // If sign is the same, result sign should be the same as of the addends. + *flags &= (a ^ b) | (a ^ ~sum); + return sum; +} + +bool IsHealthy(pixel_type flags) { return (flags >> 31); } +} // namespace + +static Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool) { + if (input.nb_meta_channels < 1) { + return JXL_FAILURE("Error: Palette transform without palette."); + } + std::atomic num_errors{0}; + int nb = input.channel[0].h; + uint32_t c0 = begin_c + 1; + if (c0 >= input.channel.size()) { + return JXL_FAILURE("Channel is out of range."); + } + size_t w = input.channel[c0].w; + size_t h = input.channel[c0].h; + // might be false in case of lossy + // JXL_DASSERT(input.channel[c0].minval == 0); + // JXL_DASSERT(input.channel[c0].maxval == palette.w-1); + if (nb < 1) return JXL_FAILURE("Corrupted transforms"); + for (int i = 1; i < nb; i++) { + input.channel.insert(input.channel.begin() + c0 + 1, Channel(w, h)); + } + const Channel &palette = input.channel[0]; + const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0); + intptr_t onerow = input.channel[0].plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow(); + const int bit_depth = + CeilLog2Nonzero(static_cast(input.maxval - input.minval + 1)); + + if (w == 0) { + // Nothing to do. + // Avoid touching "empty" channels with non-zero height. + } else if (nb_deltas == 0 && predictor == Predictor::Zero) { + if (nb == 1) { + RunOnPool( + pool, 0, h, ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + pixel_type *p = input.channel[c0].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = Clamp1(p[x], 0, (pixel_type)palette.w - 1); + p[x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/0, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + }, + "UndoChannelPalette"); + } else { + RunOnPool( + pool, 0, h, ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + std::vector p_out(nb); + const pixel_type *p_index = input.channel[c0].Row(y); + for (int c = 0; c < nb; c++) + p_out[c] = input.channel[c0 + c].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = p_index[x]; + for (int c = 0; c < nb; c++) { + p_out[c][x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + } + }, + "UndoPalette"); + } + } else { + // Parallelized per channel. + ImageI indices = CopyImage(input.channel[c0].plane); + if (predictor == Predictor::Weighted) { + RunOnPool( + pool, 0, nb, ThreadPool::SkipInit(), + [&](size_t c, size_t _) { + Channel &channel = input.channel[c0 + c]; + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, /*onerow=*/onerow, + /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + PredictionResult pred = + PredictNoTreeWP(channel.w, p + x, onerow_image, x, y, + predictor, &wp_state); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + }, + "UndoDeltaPaletteWP"); + } else if (predictor == Predictor::Gradient) { + // Gradient is the most common predictor for now. This special case gives + // about 20% extra speed. + RunOnPool( + pool, 0, nb, ThreadPool::SkipInit(), + [&](size_t c, size_t _) { + pixel_type flags = -1; + Channel &channel = input.channel[c0 + c]; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + pixel_type left = + x ? p[x - 1] : (y ? *(p + x - onerow_image) : 0); + pixel_type top = y ? *(p + x - onerow_image) : left; + pixel_type topleft = + x && y ? *(p + x - 1 - onerow_image) : left; + val = CautiousAdd(ClampedGradient(left, top, topleft), + palette_entry, &flags); + } else { + val = palette_entry; + } + p[x] = val; + } + } + if (!IsHealthy(flags)) { + num_errors.fetch_add(1, std::memory_order_relaxed); + } + }, + "UndoDeltaPaletteGradient"); + } else { + RunOnPool( + pool, 0, nb, ThreadPool::SkipInit(), + [&](size_t c, size_t _) { + Channel &channel = input.channel[c0 + c]; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast(nb_deltas)) { + PredictionResult pred = PredictNoTreeNoWP( + channel.w, p + x, onerow_image, x, y, predictor); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + } + } + }, + "UndoDeltaPaletteNoWP"); + } + } + input.nb_channels += nb - 1; + input.nb_meta_channels--; + input.channel.erase(input.channel.begin(), input.channel.begin() + 1); + return num_errors.load(std::memory_order_relaxed) == 0; +} + +static Status CheckPaletteParams(const Image &image, uint32_t begin_c, + uint32_t end_c) { + uint32_t c1 = begin_c; + uint32_t c2 = end_c; + // The range is including c1 and c2, so c2 may not be num_channels. + if (c1 > image.channel.size() || c2 >= image.channel.size() || c2 < c1) { + return JXL_FAILURE("Invalid channel range"); + } + for (size_t c = begin_c + 1; c <= end_c; c++) { + if (image.channel[c].w != image.channel[begin_c].w || + image.channel[c].h != image.channel[begin_c].h) { + return false; + } + } + + return true; +} + +static Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy) { + JXL_RETURN_IF_ERROR(CheckPaletteParams(input, begin_c, end_c)); + + size_t nb = end_c - begin_c + 1; + input.nb_meta_channels++; + if (nb < 1 || input.nb_channels < nb) { + return JXL_FAILURE("Corrupted transforms"); + } + input.nb_channels -= nb - 1; + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + Channel pch(nb_colors + nb_deltas, nb); + pch.hshift = -1; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; +} + +static Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, bool ordered, bool lossy, + Predictor &predictor, + const weighted::Header &wp_header) { + JXL_QUIET_RETURN_IF_ERROR(CheckPaletteParams(input, begin_c, end_c)); + uint32_t nb = end_c - begin_c + 1; + + size_t w = input.channel[begin_c].w; + size_t h = input.channel[begin_c].h; + + Image quantized_input; + if (lossy) { + quantized_input = Image(w, h, input.maxval, nb); + for (size_t c = 0; c < nb; c++) { + CopyImageTo(input.channel[begin_c + c].plane, + &quantized_input.channel[c].plane); + } + } + + JXL_DEBUG_V( + 7, "Trying to represent channels %i-%i using at most a %i-color palette.", + begin_c, end_c, nb_colors); + int nb_deltas = 0; + bool delta_used = false; + std::set> + candidate_palette; // ordered lexicographically + std::vector> candidate_palette_imageorder; + std::vector color(nb); + std::vector color_with_error(nb); + std::vector p_in(nb); + for (size_t y = 0; y < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 0; x < w; x++) { + if (lossy && candidate_palette.size() >= nb_colors) break; + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + const bool new_color = candidate_palette.insert(color).second; + if (new_color) { + candidate_palette_imageorder.push_back(color); + } + if (candidate_palette.size() > nb_colors) { + return false; // too many colors + } + } + } + nb_colors = candidate_palette.size(); + JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.", + begin_c, end_c, nb_colors); + + Channel pch(nb_colors, nb); + pch.hshift = -1; + int x = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + intptr_t onerow = pch.plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow(); + const int bit_depth = + CeilLog2Nonzero(static_cast(input.maxval - input.minval + 1)); + std::vector lookup; + int minval, maxval; + input.channel[begin_c].compute_minmax(&minval, &maxval); + int lookup_table_size = maxval - minval + 1; + if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) { + return false; // too large lookup table + } + if (nb == 1) { + lookup.resize(lookup_table_size); + } + if (ordered) { + JXL_DEBUG_V(7, "Palette of %i colors, using lexicographic order", + nb_colors); + for (auto pcol : candidate_palette) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) { + p_palette[i * onerow + x] = pcol[i]; + } + if (nb == 1) lookup[pcol[0] - minval] = x; + for (size_t i = 0; i < nb; i++) { + JXL_DEBUG_V(9, "%i ", pcol[i]); + } + x++; + } + } else { + JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors); + for (auto pcol : candidate_palette_imageorder) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) p_palette[i * onerow + x] = pcol[i]; + if (nb == 1) lookup[pcol[0] - minval] = x; + for (size_t i = 0; i < nb; i++) JXL_DEBUG_V(9, "%i ", pcol[i]); + x++; + } + } + std::vector wp_states; + for (size_t c = 0; c < nb; c++) { + wp_states.emplace_back(wp_header, w, h); + } + std::vector p_quant(nb); + // Three rows of error for dithering: y to y + 2. + // Each row has two pixels of padding in the ends, which is + // beneficial for both precision and encoding speed. + std::vector> error_row[3]; + if (lossy) { + for (int i = 0; i < 3; ++i) { + error_row[i].resize(nb); + for (size_t c = 0; c < nb; ++c) { + error_row[i][c].resize(w + 4); + } + } + } + for (size_t y = 0; y < h; y++) { + for (size_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + if (lossy) p_quant[c] = quantized_input.channel[c].Row(y); + } + pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y); + if (nb == 1 && !lossy) { + for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval]; + } else { + for (size_t x = 0; x < w; x++) { + int index; + if (!lossy) { + for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x]; + // Exact search. + for (index = 0; static_cast(index) < nb_colors; index++) { + bool found = true; + for (size_t c = 0; c < nb; c++) { + if (color[c] != p_palette[c * onerow + index]) { + found = false; + break; + } + } + if (found) break; + } + if (index < nb_deltas) { + delta_used = true; + } + } else { + for (size_t c = 0; c < nb; c++) { + color_with_error[c] = p_in[c][x] + error_row[0][c][x + 2]; + color[c] = std::min( + input.maxval, std::max( + input.minval, lroundf(color_with_error[c]))); + } + float best_distance = std::numeric_limits::infinity(); + int best_index = 0; + bool best_is_delta = false; + std::vector best_val(nb, 0); + std::vector quantized_val(nb); + std::vector predictions(nb); + for (size_t c = 0; c < nb; ++c) { + predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x, + y, predictor, &wp_states[c]) + .guess; + } + const auto TryIndex = [&](const int index) { + for (size_t c = 0; c < nb; c++) { + quantized_val[c] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/nb_colors, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < nb_deltas) { + quantized_val[c] += predictions[c]; + } + } + const float color_distance = + 32 * palette_internal::ColorDistance(color_with_error, + quantized_val); + float index_penalty = 0; + if (index == -1) { + index_penalty = -124; + } else if (index < static_cast(nb_colors)) { + index_penalty = 2 * std::abs(index); + } else if (index < static_cast(nb_colors) + + palette_internal::kLargeCubeOffset) { + index_penalty = 70; + } else { + index_penalty = 256; + } + index_penalty *= 1LL << std::max(2 * (bit_depth - 8), 0); + const float distance = color_distance + index_penalty; + if (distance < best_distance) { + best_distance = distance; + best_index = index; + best_is_delta = index < nb_deltas; + best_val.swap(quantized_val); + } + }; + for (index = palette_internal::kMinImplicitPaletteIndex; + index < static_cast(nb_colors); index++) { + TryIndex(index); + } + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/false)); + if (palette_internal::kEncodeToHighQualityImplicitPalette) { + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/true)); + } + index = best_index; + delta_used |= best_is_delta; + for (size_t c = 0; c < nb; ++c) { + wp_states[c].UpdateErrors(best_val[c], x, y, w); + p_quant[c][x] = best_val[c]; + } + float len_error = 0; + for (size_t c = 0; c < nb; ++c) { + float local_error = color_with_error[c] - best_val[c]; + len_error += local_error * local_error; + } + len_error = sqrt(len_error); + float modulate = 1.0; + int len_limit = 38 << std::max(0, bit_depth - 8); + if (len_error > len_limit) { + modulate *= len_limit / len_error; + } + for (size_t c = 0; c < nb; ++c) { + float local_error = (color_with_error[c] - best_val[c]); + float total_error = 0.65 * local_error; + + // If the neighboring pixels have some error in the opposite + // direction of total_error, cancel some or all of it out before + // spreading among them. + constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1}, + {1, 3}, {2, 2}, {1, 0}, {1, 4}, + {2, 1}, {2, 3}, {2, 0}, {2, 4}}; + float total_available = 0; + int n = 0; + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_available += error_row[row][c][x + col]; + n++; + } + } + float weight = + std::abs(total_error) / (std::abs(total_available) + 1e-3); + weight = std::min(weight, 1.0f); + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_error += weight * error_row[row][c][x + col]; + error_row[row][c][x + col] *= (1 - weight); + } + } + total_error *= modulate; + const float remaining_error = (1.0f / 14.) * total_error; + error_row[0][c][x + 3] += 2 * remaining_error; + error_row[0][c][x + 4] += remaining_error; + error_row[1][c][x + 0] += remaining_error; + for (int i = 0; i < 5; ++i) { + error_row[1][c][x + i] += remaining_error; + error_row[2][c][x + i] += remaining_error; + } + } + } + p[x] = index; + } + if (lossy) { + for (size_t c = 0; c < nb; ++c) { + error_row[0][c].swap(error_row[1][c]); + error_row[1][c].swap(error_row[2][c]); + std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f); + } + } + } + } + if (!delta_used) { + predictor = Predictor::Zero; + } + input.nb_meta_channels++; + if (nb < 1 || input.nb_channels < nb) { + return JXL_FAILURE("Corrupted transforms"); + } + input.nb_channels -= nb - 1; + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h new file mode 100644 index 000000000000..81d5b5ab8141 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h @@ -0,0 +1,510 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ + +// Haar-like transform: halves the resolution in one direction +// A B -> (A+B)>>1 in one channel (average) -> same range as +// original channel +// A-B - tendency in a new channel ('residual' needed to make +// the transform reversible) +// -> theoretically range could be 2.5 +// times larger (2 times without the +// 'tendency'), but there should be lots +// of zeroes +// Repeated application (alternating horizontal and vertical squeezes) results +// in downscaling +// +// The default coefficient ordering is low-frequency to high-frequency, as in +// M. Antonini, M. Barlaud, P. Mathieu and I. Daubechies, "Image coding using +// wavelet transform", IEEE Transactions on Image Processing, vol. 1, no. 2, pp. +// 205-220, April 1992, doi: 10.1109/83.136597. + +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +#define JXL_MAX_FIRST_PREVIEW_SIZE 8 + +namespace jxl { + +/* + int avg=(A+B)>>1; + int diff=(A-B); + int rA=(diff+(avg<<1)+(diff&1))>>1; + int rB=rA-diff; + +*/ +// |A B|C D|E F| +// p a n p=avg(A,B), a=avg(C,D), n=avg(E,F) +// +// Goal: estimate C-D (avoiding ringing artifacts) +// (ensuring that in smooth areas, a zero residual corresponds to a smooth +// gradient) + +// best estimate for C: (B + 2*a)/3 +// best estimate for D: (n + 3*a)/4 +// best estimate for C-D: 4*B - 3*n - a /12 + +// avoid ringing by 1) only doing this if B <= a <= n or B >= a >= n +// (otherwise, this is not a smooth area and we cannot really estimate C-D) +// 2) making sure that B <= C <= D <= n or B >= C >= D >= n + +inline pixel_type_w SmoothTendency(pixel_type_w B, pixel_type_w a, + pixel_type_w n) { + pixel_type_w diff = 0; + if (B >= a && a >= n) { + diff = (4 * B - 3 * n - a + 6) / 12; + // 2C = a<<1 + diff - diff&1 <= 2B so diff - diff&1 <= 2B - 2a + // 2D = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n + if (diff - (diff & 1) > 2 * (B - a)) diff = 2 * (B - a) + 1; + if (diff + (diff & 1) > 2 * (a - n)) diff = 2 * (a - n); + } else if (B <= a && a <= n) { + diff = (4 * B - 3 * n - a - 6) / 12; + // 2C = a<<1 + diff + diff&1 >= 2B so diff + diff&1 >= 2B - 2a + // 2D = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n + if (diff + (diff & 1) < 2 * (B - a)) diff = 2 * (B - a) - 1; + if (diff - (diff & 1) < 2 * (a - n)) diff = 2 * (a - n); + } + return diff; +} + +void InvHSqueeze(Image &input, int c, int rc, ThreadPool *pool) { + const Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2)); + JXL_ASSERT(chin.h == chin_residual.h); + + if (chin_residual.w == 0 || chin_residual.h == 0) { + input.channel[c].resize(chin.w + chin_residual.w, chin.h); + input.channel[c].hshift--; + input.channel[c].hcshift--; + return; + } + + Channel chout(chin.w + chin_residual.w, chin.h, chin.hshift - 1, chin.vshift, + chin.hcshift - 1, chin.vcshift); + JXL_DEBUG_V(4, + "Undoing horizontal squeeze of channel %i using residuals in " + "channel %i (going from width %zu to %zu)", + c, rc, chin.w, chout.w); + RunOnPool( + pool, 0, chin.h, ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t y = task; + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + + // special case for x=0 so we don't have to check x>0 + pixel_type_w avg = p_avg[0]; + pixel_type_w next_avg = (1 < chin.w ? p_avg[1] : avg); + pixel_type_w tendency = SmoothTendency(avg, avg, next_avg); + pixel_type_w diff = p_residual[0] + tendency; + pixel_type_w A = + ((avg * 2) + diff + (diff > 0 ? -(diff & 1) : (diff & 1))) >> 1; + pixel_type_w B = A - diff; + p_out[0] = ClampToRange(A); + p_out[1] = ClampToRange(B); + + for (size_t x = 1; x < chin_residual.w; x++) { + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); + pixel_type_w left = p_out[(x << 1) - 1]; + pixel_type_w tendency = SmoothTendency(left, avg, next_avg); + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w A = + ((avg * 2) + diff + (diff > 0 ? -(diff & 1) : (diff & 1))) >> 1; + p_out[x << 1] = ClampToRange(A); + pixel_type_w B = A - diff; + p_out[(x << 1) + 1] = ClampToRange(B); + } + if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; + }, + "InvHorizontalSqueeze"); + input.channel[c] = std::move(chout); +} + +void FwdHSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing horizontal squeeze of channel %i to new channel %i", c, + rc); + + Channel chout((chin.w + 1) / 2, chin.h, chin.hshift + 1, chin.vshift, + chin.hcshift + 1, chin.vcshift); + Channel chout_residual(chin.w - chout.w, chout.h, chin.hshift + 1, + chin.vshift, chin.hcshift, chin.vcshift); + + for (size_t y = 0; y < chout.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout_residual.w; x++) { + pixel_type A = p_in[x * 2]; + pixel_type B = p_in[x * 2 + 1]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (x + 1 < chout_residual.w) { + next_avg = (p_in[x * 2 + 2] + p_in[x * 2 + 3] + + (p_in[x * 2 + 2] > p_in[x * 2 + 3])) >> + 1; // which will be chout.value(y,x+1) + } else if (chin.w & 1) + next_avg = p_in[x * 2 + 2]; + pixel_type left = (x > 0 ? p_in[x * 2 - 1] : avg); + pixel_type tendency = SmoothTendency(left, avg, next_avg); + + p_res[x] = diff - tendency; + } + if (chin.w & 1) { + int x = chout.w - 1; + p_out[x] = p_in[x * 2]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +void InvVSqueeze(Image &input, int c, int rc, ThreadPool *pool) { + const Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2)); + JXL_ASSERT(chin.w == chin_residual.w); + + if (chin_residual.w == 0 || chin_residual.h == 0) { + input.channel[c].resize(chin.w, chin.h + chin_residual.h); + input.channel[c].vshift--; + input.channel[c].vcshift--; + return; + } + + // Note: chin.h >= chin_residual.h and at most 1 different. + Channel chout(chin.w, chin.h + chin_residual.h, chin.hshift, chin.vshift - 1, + chin.hcshift, chin.vcshift - 1); + JXL_DEBUG_V( + 4, + "Undoing vertical squeeze of channel %i using residuals in channel " + "%i (going from height %zu to %zu)", + c, rc, chin.h, chout.h); + + intptr_t onerow_in = chin.plane.PixelsPerRow(); + intptr_t onerow_out = chout.plane.PixelsPerRow(); + constexpr int kColsPerThread = 64; + RunOnPool( + pool, 0, DivCeil(chin.w, kColsPerThread), ThreadPool::SkipInit(), + [&](const int task, const int thread) { + const size_t x0 = task * kColsPerThread; + const size_t x1 = std::min((size_t)(task + 1) * kColsPerThread, chin.w); + // We only iterate up to std::min(chin_residual.h, chin.h) which is + // always chin_residual.h. + for (size_t y = 0; y < chin_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1); + for (size_t x = x0; x < x1; x++) { + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w avg = p_avg[x]; + + pixel_type_w next_avg = avg; + if (y + 1 < chin.h) next_avg = p_avg[x + onerow_in]; + pixel_type_w top = + (y > 0 ? p_out[static_cast(x) - onerow_out] : avg); + pixel_type_w tendency = SmoothTendency(top, avg, next_avg); + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w out = + ((avg * 2) + diff + (diff > 0 ? -(diff & 1) : (diff & 1))) >> 1; + + p_out[x] = ClampToRange(out); + // If the chin_residual.h == chin.h, the output has an even number + // of rows so the next line is fine. Otherwise, this loop won't + // write to the last output row which is handled separately. + p_out[x + onerow_out] = ClampToRange(p_out[x] - diff); + } + } + }, + "InvVertSqueeze"); + + if (chout.h & 1) { + size_t y = chin.h - 1; + const pixel_type *p_avg = chin.Row(y); + pixel_type *p_out = chout.Row(y << 1); + for (size_t x = 0; x < chin.w; x++) { + p_out[x] = p_avg[x]; + } + } + input.channel[c] = std::move(chout); +} + +void FwdVSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing vertical squeeze of channel %i to new channel %i", c, + rc); + + Channel chout(chin.w, (chin.h + 1) / 2, chin.hshift, chin.vshift + 1, + chin.hcshift, chin.vcshift + 1); + Channel chout_residual(chin.w, chin.h - chout.h, chin.hshift, chin.vshift + 1, + chin.hcshift, chin.vcshift); + intptr_t onerow_in = chin.plane.PixelsPerRow(); + for (size_t y = 0; y < chout_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y * 2); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout.w; x++) { + pixel_type A = p_in[x]; + pixel_type B = p_in[x + onerow_in]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (y + 1 < chout_residual.h) { + next_avg = (p_in[x + 2 * onerow_in] + p_in[x + 3 * onerow_in] + + (p_in[x + 2 * onerow_in] > p_in[x + 3 * onerow_in])) >> + 1; // which will be chout.value(y+1,x) + } else if (chin.h & 1) { + next_avg = p_in[x + 2 * onerow_in]; + } + pixel_type top = + (y > 0 ? p_in[static_cast(x) - onerow_in] : avg); + pixel_type tendency = SmoothTendency(top, avg, next_avg); + + p_res[x] = diff - tendency; + } + } + if (chin.h & 1) { + size_t y = chout.h - 1; + const pixel_type *p_in = chin.Row(y * 2); + pixel_type *p_out = chout.Row(y); + for (size_t x = 0; x < chout.w; x++) { + p_out[x] = p_in[x]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +void DefaultSqueezeParameters(std::vector *parameters, + const Image &image) { + int nb_channels = image.nb_channels; + // maybe other transforms have been applied before, but let's assume the first + // nb_channels channels still contain the 'main' data + + parameters->clear(); + size_t w = image.channel[image.nb_meta_channels].w; + size_t h = image.channel[image.nb_meta_channels].h; + JXL_DEBUG_V(7, "Default squeeze parameters for %zux%zu image: ", w, h); + + bool wide = + (w > + h); // do horizontal first on wide images; vertical first on tall images + + if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && + image.channel[image.nb_meta_channels + 1].h == h) { + // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 + // previews + JXL_DEBUG_V(7, "(4:2:0 chroma), %zux%zu image", w, h); + // if (!wide) { + // parameters.push_back(0+2); // vertical chroma squeeze + // parameters.push_back(image.nb_meta_channels+1); + // parameters.push_back(image.nb_meta_channels+2); + // } + SqueezeParams params; + // horizontal chroma squeeze + params.horizontal = true; + params.in_place = false; + params.begin_c = image.nb_meta_channels + 1; + params.num_c = 2; + parameters->push_back(params); + params.horizontal = false; + // vertical chroma squeeze + parameters->push_back(params); + } + SqueezeParams params; + params.begin_c = image.nb_meta_channels; + params.num_c = nb_channels; + params.in_place = true; + + if (!wide) { + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%zux%zu), ", w, h); + } + } + while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) { + if (w > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = true; + parameters->push_back(params); + w = (w + 1) / 2; + JXL_DEBUG_V(7, "Horizontal (%zux%zu), ", w, h); + } + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%zux%zu), ", w, h); + } + } + JXL_DEBUG_V(7, "that's it"); +} + +Status CheckMetaSqueezeParams(const std::vector ¶meters, + int num_channels) { + for (size_t i = 0; i < parameters.size(); i++) { + int c1 = parameters[i].begin_c; + int c2 = parameters[i].begin_c + parameters[i].num_c - 1; + if (c1 < 0 || c1 > num_channels || c2 < 0 || c2 >= num_channels || + c2 < c1) { + return JXL_FAILURE("Invalid channel range"); + } + } + return true; +} + +Status MetaSqueeze(Image &image, std::vector *parameters) { + if (parameters->empty()) { + DefaultSqueezeParameters(parameters, image); + } + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(*parameters, image.channel.size())); + + for (size_t i = 0; i < parameters->size(); i++) { + bool horizontal = (*parameters)[i].horizontal; + bool in_place = (*parameters)[i].in_place; + uint32_t beginc = (*parameters)[i].begin_c; + uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1; + + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = image.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + Channel dummy; + dummy.hcshift = image.channel[c].hcshift; + dummy.vcshift = image.channel[c].vcshift; + if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { + return JXL_FAILURE("Too many squeezes: shift > 30"); + } + if (horizontal) { + size_t w = image.channel[c].w; + image.channel[c].w = (w + 1) / 2; + image.channel[c].hshift++; + image.channel[c].hcshift++; + dummy.w = w - (w + 1) / 2; + dummy.h = image.channel[c].h; + } else { + size_t h = image.channel[c].h; + image.channel[c].h = (h + 1) / 2; + image.channel[c].vshift++; + image.channel[c].vcshift++; + dummy.h = h - (h + 1) / 2; + dummy.w = image.channel[c].w; + } + dummy.hshift = image.channel[c].hshift; + dummy.vshift = image.channel[c].vshift; + + image.channel.insert(image.channel.begin() + offset + (c - beginc), + std::move(dummy)); + } + } + return true; +} + +Status InvSqueeze(Image &input, std::vector parameters, + ThreadPool *pool) { + if (parameters.empty()) { + DefaultSqueezeParameters(¶meters, input); + } + JXL_RETURN_IF_ERROR(CheckMetaSqueezeParams(parameters, input.channel.size())); + + for (int i = parameters.size() - 1; i >= 0; i--) { + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size() + beginc - endc - 1; + } + for (uint32_t c = beginc; c <= endc; c++) { + uint32_t rc = offset + c - beginc; + if ((input.channel[c].w < input.channel[rc].w) || + (input.channel[c].h < input.channel[rc].h)) { + return JXL_FAILURE("Corrupted squeeze transform"); + } + if (input.channel[rc].is_empty()) { + input.channel[rc].resize(); // assume all zeroes + } + if (horizontal) { + InvHSqueeze(input, c, rc, pool); + } else { + InvVSqueeze(input, c, rc, pool); + } + } + input.channel.erase(input.channel.begin() + offset, + input.channel.begin() + offset + (endc - beginc + 1)); + } + return true; +} + +Status FwdSqueeze(Image &input, std::vector parameters, + ThreadPool *pool) { + if (parameters.empty()) { + DefaultSqueezeParameters(¶meters, input); + } + JXL_RETURN_IF_ERROR(CheckMetaSqueezeParams(parameters, input.channel.size())); + + for (size_t i = 0; i < parameters.size(); i++) { + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (horizontal) { + FwdHSqueeze(input, c, offset + c - beginc); + } else { + FwdVSqueeze(input, c, offset + c - beginc); + } + } + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/subtractgreen.h b/third_party/jpeg-xl/lib/jxl/modular/transform/subtractgreen.h new file mode 100644 index 000000000000..d65d5d6a2caf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/subtractgreen.h @@ -0,0 +1,186 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_SUBTRACTGREEN_H_ +#define LIB_JXL_MODULAR_TRANSFORM_SUBTRACTGREEN_H_ + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +template +void InvSubtractGreenRow(const pixel_type* in0, const pixel_type* in1, + const pixel_type* in2, pixel_type* out0, + pixel_type* out1, pixel_type* out2, size_t w) { + static_assert(transform_type >= 0 && transform_type < 7, + "Invalid transform type"); + int second = transform_type >> 1; + int third = transform_type & 1; + for (size_t x = 0; x < w; x++) { + if (transform_type == 6) { + pixel_type_w Y = in0[x]; + pixel_type_w Co = in1[x]; + pixel_type_w Cg = in2[x]; + pixel_type_w tmp = Y - (Cg >> 1); + pixel_type_w G = Cg + tmp; + pixel_type_w B = tmp - (Co >> 1); + pixel_type_w R = B + Co; + out0[x] = ClampToRange(R); + out1[x] = ClampToRange(G); + out2[x] = ClampToRange(B); + } else { + pixel_type_w First = in0[x]; + pixel_type_w Second = in1[x]; + pixel_type_w Third = in2[x]; + if (third) Third = Third + First; + if (second == 1) { + Second = Second + First; + } else if (second == 2) { + Second = Second + ((First + Third) >> 1); + } + out0[x] = ClampToRange(First); + out1[x] = ClampToRange(Second); + out2[x] = ClampToRange(Third); + } + } +} + +Status InvSubtractGreen(Image& input, size_t begin_c, size_t rct_type) { + size_t m = begin_c; + if (input.nb_channels + input.nb_meta_channels < begin_c + 3) { + return JXL_FAILURE( + "Invalid number of channels to apply inverse subtract_green."); + } + Channel& c0 = input.channel[m + 0]; + Channel& c1 = input.channel[m + 1]; + Channel& c2 = input.channel[m + 2]; + size_t w = c0.w; + size_t h = c0.h; + if (c0.plane.xsize() < w || c0.plane.ysize() < h || c1.plane.xsize() < w || + c1.plane.ysize() < h || c2.plane.xsize() < w || c2.plane.ysize() < h || + c1.w != w || c1.h != h || c2.w != w || c2.h != h) { + return JXL_FAILURE( + "Invalid channel dimensions to apply inverse subtract_green (maybe " + "chroma is subsampled?)."); + } + if (rct_type == 0) { // noop + return true; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + JXL_CHECK(permutation < 6); + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + // Special case: permute-only. Swap channels around. + if (custom == 0) { + Channel ch0 = std::move(input.channel[m]); + Channel ch1 = std::move(input.channel[m + 1]); + Channel ch2 = std::move(input.channel[m + 2]); + input.channel[m + (permutation % 3)] = std::move(ch0); + input.channel[m + ((permutation + 1 + permutation / 3) % 3)] = + std::move(ch1); + input.channel[m + ((permutation + 2 - permutation / 3) % 3)] = + std::move(ch2); + return true; + } + constexpr decltype(&InvSubtractGreenRow<0>) inv_subtract_green_row[] = { + InvSubtractGreenRow<0>, InvSubtractGreenRow<1>, InvSubtractGreenRow<2>, + InvSubtractGreenRow<3>, InvSubtractGreenRow<4>, InvSubtractGreenRow<5>, + InvSubtractGreenRow<6>}; + for (size_t y = 0; y < h; y++) { + const pixel_type* in0 = input.channel[m].Row(y); + const pixel_type* in1 = input.channel[m + 1].Row(y); + const pixel_type* in2 = input.channel[m + 2].Row(y); + pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y); + pixel_type* out1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + pixel_type* out2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + inv_subtract_green_row[custom](in0, in1, in2, out0, out1, out2, w); + } + return true; +} + +Status FwdSubtractGreen(Image& input, size_t begin_c, size_t rct_type) { + if (input.nb_channels + input.nb_meta_channels < begin_c + 3) { + return false; + } + if (rct_type == 0) { // noop + return false; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + size_t m = begin_c; + size_t w = input.channel[m + 0].w; + size_t h = input.channel[m + 0].h; + if (input.channel[m + 1].w < w || input.channel[m + 1].h < h || + input.channel[m + 2].w < w || input.channel[m + 2].h < h) { + return JXL_FAILURE("Invalid channel dimensions to apply subtract_green."); + } + int second = (custom % 7) >> 1; + int third = (custom % 7) & 1; + for (size_t y = 0; y < h; y++) { + const pixel_type* in0 = input.channel[m + (permutation % 3)].Row(y); + const pixel_type* in1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + const pixel_type* in2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + pixel_type* out0 = input.channel[m].Row(y); + pixel_type* out1 = input.channel[m + 1].Row(y); + pixel_type* out2 = input.channel[m + 2].Row(y); + for (size_t x = 0; x < w; x++) { + if (custom == 6) { + pixel_type R = in0[x]; + pixel_type G = in1[x]; + pixel_type B = in2[x]; + out1[x] = R - B; + pixel_type tmp = B + (out1[x] >> 1); + out2[x] = G - tmp; + out0[x] = tmp + (out2[x] >> 1); + } else { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (second == 1) { + Second = Second - First; + } else if (second == 2) { + Second = Second - ((First + Third) >> 1); + } + if (third) Third = Third - First; + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_SUBTRACTGREEN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc new file mode 100644 index 000000000000..f0e7555cd9ea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc @@ -0,0 +1,106 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/modular/transform/transform.h" + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/near-lossless.h" +#include "lib/jxl/modular/transform/palette.h" +#include "lib/jxl/modular/transform/squeeze.h" +#include "lib/jxl/modular/transform/subtractgreen.h" + +namespace jxl { + +namespace { +const char *transform_name[static_cast(TransformId::kNumTransforms)] = + {"RCT", "Palette", "Squeeze", "Invalid", "Near-Lossless"}; +} // namespace + +SqueezeParams::SqueezeParams() { Bundle::Init(this); } +Transform::Transform(TransformId id) { + Bundle::Init(this); + this->id = id; +} + +const char *Transform::TransformName() const { + return transform_name[static_cast(id)]; +} + +Status Transform::Forward(Image &input, const weighted::Header &wp_header, + ThreadPool *pool) { + switch (id) { + case TransformId::kRCT: + return FwdSubtractGreen(input, begin_c, rct_type); + case TransformId::kSqueeze: + return FwdSqueeze(input, squeezes, pool); + case TransformId::kPalette: + return FwdPalette(input, begin_c, begin_c + num_c - 1, nb_colors, + ordered_palette, lossy_palette, predictor, wp_header); + case TransformId::kNearLossless: + return FwdNearLossless(input, begin_c, begin_c + num_c - 1, + max_delta_error, predictor); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(id)); + } +} + +Status Transform::Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool) { + switch (id) { + case TransformId::kRCT: + return InvSubtractGreen(input, begin_c, rct_type); + case TransformId::kSqueeze: + return InvSqueeze(input, squeezes, pool); + case TransformId::kPalette: + return InvPalette(input, begin_c, nb_colors, nb_deltas, predictor, + wp_header, pool); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(id)); + } +} + +Status Transform::MetaApply(Image &input) { + switch (id) { + case TransformId::kRCT: + JXL_DEBUG_V(2, "Transform: kRCT, rct_type=%" PRIu32, rct_type); + return true; + case TransformId::kSqueeze: + JXL_DEBUG_V(2, "Transform: kSqueeze:"); +#if JXL_DEBUG_V_LEVEL >= 2 + for (const auto ¶ms : squeezes) { + JXL_DEBUG_V( + 2, + " squeeze params: horizontal=%d, in_place=%d, begin_c=%" PRIu32 + ", num_c=%" PRIu32, + params.horizontal, params.in_place, params.begin_c, params.num_c); + } +#endif + return MetaSqueeze(input, &squeezes); + case TransformId::kPalette: + JXL_DEBUG_V(2, + "Transform: kPalette, begin_c=%" PRIu32 ", num_c=%" PRIu32 + ", nb_colors=%" PRIu32 ", nb_deltas=%" PRIu32, + begin_c, num_c, nb_colors, nb_deltas); + return MetaPalette(input, begin_c, begin_c + num_c - 1, nb_colors, + nb_deltas, lossy_palette); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast(id)); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h new file mode 100644 index 000000000000..18ed038bfc8d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h @@ -0,0 +1,162 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ + +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +enum class TransformId : uint32_t { + // G, R-G, B-G and variants (including YCoCg). + kRCT = 0, + + // Color palette. Parameters are: [begin_c] [end_c] [nb_colors] + kPalette = 1, + + // Squeezing (Haar-style) + kSqueeze = 2, + + // Invalid for now. + kInvalid = 3, + + // this is lossy preprocessing, doesn't have an inverse transform and doesn't + // exist from the decoder point of view + kNearLossless = 4, + + // The total number of transforms. Update this if adding more transformations. + kNumTransforms = 5, +}; + +struct SqueezeParams : public Fields { + const char *Name() const override { return "SqueezeParams"; } + bool horizontal; + bool in_place; + uint32_t begin_c; + uint32_t num_c; + SqueezeParams(); + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &horizontal)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &in_place)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(3), BitsOffset(6, 8), + BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(4, 4), 2, &num_c)); + return true; + } +}; + +class Transform : public Fields { + public: + TransformId id; + // for Palette and RCT. + uint32_t begin_c; + // for RCT. 42 possible values starting from 0. + uint32_t rct_type; + // Only for Palette and NearLossless. + uint32_t num_c; + // Only for Palette. + uint32_t nb_colors; + uint32_t nb_deltas; + // for Squeeze. Default squeeze if empty. + std::vector squeezes; + // for NearLossless, not serialized. + int max_delta_error; + // Serialized for Palette. + Predictor predictor; + // for Palette, not serialized. + bool ordered_palette = true; + bool lossy_palette = false; + + explicit Transform(TransformId id); + // default constructor for bundles. + Transform() : Transform(TransformId::kNumTransforms) {} + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val((uint32_t)TransformId::kRCT), Val((uint32_t)TransformId::kPalette), + Val((uint32_t)TransformId::kSqueeze), + Val((uint32_t)TransformId::kInvalid), (uint32_t)TransformId::kRCT, + reinterpret_cast(&id))); + if (id == TransformId::kInvalid) { + return JXL_FAILURE("Invalid transform ID"); + } + if (visitor->Conditional(id == TransformId::kRCT || + id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(3), BitsOffset(6, 8), BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + } + if (visitor->Conditional(id == TransformId::kRCT)) { + // 0-41, default YCoCg. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(6), Bits(2), BitsOffset(4, 2), + BitsOffset(6, 10), 6, &rct_type)); + if (rct_type >= 42) { + return JXL_FAILURE("Invalid transform RCT type"); + } + } + if (visitor->Conditional(id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(3), Val(4), BitsOffset(13, 1), 3, &num_c)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(8, 0), BitsOffset(10, 256), BitsOffset(12, 1280), + BitsOffset(16, 5376), 256, &nb_colors)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(8, 1), BitsOffset(10, 257), + BitsOffset(16, 1281), 0, &nb_deltas)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, (uint32_t)Predictor::Zero, + reinterpret_cast(&predictor))); + if (predictor >= Predictor::Best) { + return JXL_FAILURE("Invalid predictor"); + } + } + + if (visitor->Conditional(id == TransformId::kSqueeze)) { + uint32_t num_squeezes = static_cast(squeezes.size()); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(4, 1), BitsOffset(6, 9), + BitsOffset(8, 41), 0, &num_squeezes)); + if (visitor->IsReading()) squeezes.resize(num_squeezes); + for (size_t i = 0; i < num_squeezes; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&squeezes[i])); + } + } + return true; + } + + const char *Name() const override { return "Transform"; } + + // Returns the name of the transform. + const char *TransformName() const; + + Status Forward(Image &input, const weighted::Header &wp_header, + ThreadPool *pool = nullptr); + Status Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool = nullptr); + Status MetaApply(Image &input); +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular_test.cc b/third_party/jpeg-xl/lib/jxl/modular_test.cc new file mode 100644 index 000000000000..aed47e7a0bbc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular_test.cc @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +void TestLosslessGroups(size_t group_size_shift) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.modular_group_size_shift = group_size_shift; + cparams.color_transform = jxl::ColorTransform::kNone; + DecompressParams dparams; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 4, io.ysize() / 4); + + compressed_size = Roundtrip(&io, cparams, dparams, pool, &io_out); + EXPECT_LE(compressed_size, 280000); + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, + /*distmap=*/nullptr, pool), + 0.0); +} + +TEST(ModularTest, RoundtripLosslessGroups128) { TestLosslessGroups(0); } + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups512)) { + TestLosslessGroups(2); +} + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups1024)) { + TestLosslessGroups(3); +} + +TEST(ModularTest, RoundtripLossy) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.quality_pair = {80.0f, 80.0f}; + DecompressParams dparams; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + compressed_size = Roundtrip(&io, cparams, dparams, pool, &io_out); + EXPECT_LE(compressed_size, 40000); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, + /*distmap=*/nullptr, pool), + 3.0); +} + +TEST(ModularTest, RoundtripLossy16) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("raw.pixls/DJI-FC6310-16bit_709_v4_krita.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.quality_pair = {80.0f, 80.0f}; + DecompressParams dparams; + + CodecInOut io_out; + size_t compressed_size; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + JXL_CHECK(io.TransformTo(ColorEncoding::SRGB(), pool)); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + + compressed_size = Roundtrip(&io, cparams, dparams, pool, &io_out); + EXPECT_LE(compressed_size, 400); + cparams.ba_params.intensity_target = 80.0f; + EXPECT_LE(ButteraugliDistance(io, io_out, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.5); +} + +TEST(ModularTest, RoundtripExtraProperties) { + constexpr size_t kSize = 250; + Image image(kSize, kSize, /*maxval=*/255, 3); + ModularOptions options; + options.max_properties = 4; + options.predictor = Predictor::Zero; + std::mt19937 rng(0); + std::uniform_int_distribution<> dist(0, 8); + for (size_t y = 0; y < kSize; y++) { + for (size_t x = 0; x < kSize; x++) { + image.channel[0].plane.Row(y)[x] = image.channel[2].plane.Row(y)[x] = + dist(rng); + } + } + ZeroFillImage(&image.channel[1].plane); + BitWriter writer; + ASSERT_TRUE(ModularGenericCompress(image, options, &writer)); + writer.ZeroPadToByte(); + Image decoded(kSize, kSize, /*maxval=*/255, image.channel.size()); + for (size_t i = 0; i < image.channel.size(); i++) { + const Channel& ch = image.channel[i]; + decoded.channel[i] = Channel(ch.w, ch.h, ch.hshift, ch.vshift); + } + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(ModularGenericDecompress(&reader, decoded, /*header=*/nullptr, + /*group_id=*/0, &options)); + } + ASSERT_TRUE(status); + ASSERT_EQ(image.channel.size(), decoded.channel.size()); + for (size_t c = 0; c < image.channel.size(); c++) { + for (size_t y = 0; y < image.channel[c].plane.ysize(); y++) { + for (size_t x = 0; x < image.channel[c].plane.xsize(); x++) { + EXPECT_EQ(image.channel[c].plane.Row(y)[x], + decoded.channel[c].plane.Row(y)[x]) + << "c = " << c << ", x = " << x << ", y = " << y; + } + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/noise.h b/third_party/jpeg-xl/lib/jxl/noise.h new file mode 100644 index 000000000000..7066eb13915c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/noise.h @@ -0,0 +1,68 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_NOISE_H_ +#define LIB_JXL_NOISE_H_ + +// Noise parameters shared by encoder/decoder. + +#include + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +const float kNoisePrecision = 1 << 10; + +struct NoiseParams { + // LUT index is an intensity of pixel / mean intensity of patch + static constexpr size_t kNumNoisePoints = 8; + float lut[kNumNoisePoints]; + + void Clear() { + for (float& i : lut) i = 0; + } + bool HasAny() const { + for (float i : lut) { + if (std::abs(i) > 1e-3f) return true; + } + return false; + } +}; + +static inline std::pair IndexAndFrac(float x) { + constexpr size_t kScaleNumerator = NoiseParams::kNumNoisePoints - 2; + // TODO: instead of 1, this should be a proper Y range. + constexpr float kScale = kScaleNumerator / 1; + float scaled_x = std::max(0.f, x * kScale); + float floor_x; + float frac_x = std::modf(scaled_x, &floor_x); + if (JXL_UNLIKELY(scaled_x > kScaleNumerator)) { + floor_x = kScaleNumerator; + } + return std::make_pair(static_cast(static_cast(floor_x)), frac_x); +} + +struct NoiseLevel { + float noise_level; + float intensity; +}; + +} // namespace jxl + +#endif // LIB_JXL_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/noise_distributions.h b/third_party/jpeg-xl/lib/jxl/noise_distributions.h new file mode 100644 index 000000000000..f7f1d9fef4e7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/noise_distributions.h @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_NOISE_DISTRIBUTIONS_H_ +#define LIB_JXL_NOISE_DISTRIBUTIONS_H_ + +// Noise distributions for testing partial_derivatives and robust_statistics. + +#include +#include + +#include // distributions +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Unmodified input +struct NoiseNone { + std::string Name() const { return "None"; } + + template + float operator()(const float in, Random* rng) const { + return in; + } +}; + +// Salt+pepper +class NoiseImpulse { + public: + explicit NoiseImpulse(const uint32_t threshold) : threshold_(threshold) {} + std::string Name() const { return "Impulse" + ToString(threshold_); } + + // Sets pixels to 0 if rand < threshold or 1 if rand > ~threshold. + template + float operator()(const float in, Random* rng) const { + const uint32_t rand = (*rng)(); + float out = 0.0f; + if (rand > ~threshold_) { + out = 1.0f; + } + if (rand > threshold_) { + out = in; + } + return out; + } + + private: + const uint32_t threshold_; +}; + +class NoiseUniform { + public: + NoiseUniform(const float min, const float max_exclusive) + : dist_(min, max_exclusive) {} + std::string Name() const { return "Uniform" + ToString(dist_.b()); } + + template + float operator()(const float in, Random* rng) const { + return in + dist_(*rng); + } + + private: + mutable std::uniform_real_distribution dist_; +}; + +// Additive, zero-mean Gaussian. +class NoiseGaussian { + public: + explicit NoiseGaussian(const float stddev) : dist_(0.0f, stddev) {} + std::string Name() const { return "Gaussian" + ToString(dist_.stddev()); } + + template + float operator()(const float in, Random* rng) const { + return in + dist_(*rng); + } + + private: + mutable std::normal_distribution dist_; +}; + +// Integer noise is scaled by 1E-3. +class NoisePoisson { + public: + explicit NoisePoisson(const double mean) : dist_(mean) {} + std::string Name() const { return "Poisson" + ToString(dist_.mean()); } + + template + float operator()(const float in, Random* rng) const { + return in + dist_(*rng) * 1E-3f; + } + + private: + mutable std::poisson_distribution dist_; +}; + +// Returns the result of applying the randomized "noise" function to each pixel. +template +ImageF AddNoise(const ImageF& in, const NoiseType& noise, Random* rng) { + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + ImageF out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const float* JXL_RESTRICT in_row = in.ConstRow(y); + float* JXL_RESTRICT out_row = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out_row[x] = noise(in_row[x], rng); + } + } + return out; +} + +template +Image3F AddNoise(const Image3F& in, const NoiseType& noise, Random* rng) { + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + Image3F out(xsize, ysize); + // noise_estimator_test requires this loop order. + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ysize; ++y) { + const float* JXL_RESTRICT in_row = in.ConstPlaneRow(c, y); + float* JXL_RESTRICT out_row = out.PlaneRow(c, y); + + for (size_t x = 0; x < xsize; ++x) { + out_row[x] = noise(in_row[x], rng); + } + } + } + return out; +} + +} // namespace jxl + +#endif // LIB_JXL_NOISE_DISTRIBUTIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc b/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc new file mode 100644 index 000000000000..0d95c71bacfd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc @@ -0,0 +1,136 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/opsin_params.h" + +namespace jxl { +namespace { + +class OpsinImageTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(OpsinImageTargetTest); + +TEST_P(OpsinImageTargetTest, MaxCubeRootError) { TestCubeRoot(); } + +// Convert a single linear sRGB color to xyb, using the exact image conversion +// procedure that jpeg xl uses. +void LinearSrgbToOpsin(float rgb_r, float rgb_g, float rgb_b, + float* JXL_RESTRICT xyb_x, float* JXL_RESTRICT xyb_y, + float* JXL_RESTRICT xyb_b) { + Image3F linear(1, 1); + linear.PlaneRow(0, 0)[0] = rgb_r; + linear.PlaneRow(1, 0)[0] = rgb_g; + linear.PlaneRow(2, 0)[0] = rgb_b; + + ImageMetadata metadata; + metadata.SetFloat32Samples(); + metadata.color_encoding = ColorEncoding::LinearSRGB(); + ImageBundle ib(&metadata); + ib.SetFromImage(std::move(linear), metadata.color_encoding); + Image3F opsin(1, 1); + (void)ToXYB(ib, /*pool=*/nullptr, &opsin); + + *xyb_x = opsin.PlaneRow(0, 0)[0]; + *xyb_y = opsin.PlaneRow(1, 0)[0]; + *xyb_b = opsin.PlaneRow(2, 0)[0]; +} + +// Convert a single XYB color to linear sRGB, using the exact image conversion +// procedure that jpeg xl uses. +void OpsinToLinearSrgb(float xyb_x, float xyb_y, float xyb_b, + float* JXL_RESTRICT rgb_r, float* JXL_RESTRICT rgb_g, + float* JXL_RESTRICT rgb_b) { + Image3F opsin(1, 1); + opsin.PlaneRow(0, 0)[0] = xyb_x; + opsin.PlaneRow(1, 0)[0] = xyb_y; + opsin.PlaneRow(2, 0)[0] = xyb_b; + Image3F linear(1, 1); + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinear(opsin, Rect(opsin), nullptr, &linear, opsin_params); + *rgb_r = linear.PlaneRow(0, 0)[0]; + *rgb_g = linear.PlaneRow(1, 0)[0]; + *rgb_b = linear.PlaneRow(2, 0)[0]; +} + +void OpsinRoundtripTestRGB(float r, float g, float b) { + float xyb_x, xyb_y, xyb_b; + LinearSrgbToOpsin(r, g, b, &xyb_x, &xyb_y, &xyb_b); + float r2, g2, b2; + OpsinToLinearSrgb(xyb_x, xyb_y, xyb_b, &r2, &g2, &b2); + EXPECT_NEAR(r, r2, 1e-3); + EXPECT_NEAR(g, g2, 1e-3); + EXPECT_NEAR(b, b2, 1e-3); +} + +TEST(OpsinImageTest, VerifyOpsinAbsorbanceInverseMatrix) { + float matrix[9]; // writable copy + for (int i = 0; i < 9; i++) { + matrix[i] = GetOpsinAbsorbanceInverseMatrix()[i]; + } + Inv3x3Matrix(matrix); + for (int i = 0; i < 9; i++) { + EXPECT_NEAR(matrix[i], kOpsinAbsorbanceMatrix[i], 1e-6); + } +} + +TEST(OpsinImageTest, OpsinRoundtrip) { + OpsinRoundtripTestRGB(0, 0, 0); + OpsinRoundtripTestRGB(1. / 255, 1. / 255, 1. / 255); + OpsinRoundtripTestRGB(128. / 255, 128. / 255, 128. / 255); + OpsinRoundtripTestRGB(1, 1, 1); + + OpsinRoundtripTestRGB(0, 0, 1. / 255); + OpsinRoundtripTestRGB(0, 0, 128. / 255); + OpsinRoundtripTestRGB(0, 0, 1); + + OpsinRoundtripTestRGB(0, 1. / 255, 0); + OpsinRoundtripTestRGB(0, 128. / 255, 0); + OpsinRoundtripTestRGB(0, 1, 0); + + OpsinRoundtripTestRGB(1. / 255, 0, 0); + OpsinRoundtripTestRGB(128. / 255, 0, 0); + OpsinRoundtripTestRGB(1, 0, 0); +} + +TEST(OpsinImageTest, VerifyZero) { + // Test that black color (zero energy) is 0,0,0 in xyb. + float x, y, b; + LinearSrgbToOpsin(0, 0, 0, &x, &y, &b); + EXPECT_NEAR(0, x, 1e-9); + EXPECT_NEAR(0, y, 1e-7); + EXPECT_NEAR(0, b, 1e-7); +} + +TEST(OpsinImageTest, VerifyGray) { + // Test that grayscale colors have a fixed y/b ratio and x==0. + for (size_t i = 1; i < 255; i++) { + float x, y, b; + LinearSrgbToOpsin(i / 255., i / 255., i / 255., &x, &y, &b); + EXPECT_NEAR(0, x, 1e-6); + EXPECT_NEAR(kYToBRatio, b / y, 3e-5); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc b/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc new file mode 100644 index 000000000000..3afd5686aa28 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +TEST(OpsinInverseTest, LinearInverseInverts) { + Image3F linear(128, 128); + RandomFillImage(&linear, 1.0f); + + CodecInOut io; + io.metadata.m.SetFloat32Samples(); + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(CopyImage(linear), io.metadata.m.color_encoding); + ThreadPool* null_pool = nullptr; + Image3F opsin(io.xsize(), io.ysize()); + (void)ToXYB(io.Main(), null_pool, &opsin); + + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinearInplace(&opsin, /*pool=*/nullptr, opsin_params); + + VerifyRelativeError(linear, opsin, 3E-3, 2E-4); +} + +TEST(OpsinInverseTest, YcbCrInverts) { + Image3F rgb(128, 128); + RandomFillImage(&rgb, 1.0f); + + ThreadPool* null_pool = nullptr; + Image3F ycbcr(rgb.xsize(), rgb.ysize()); + RgbToYcbcr(rgb.Plane(0), rgb.Plane(1), rgb.Plane(2), &ycbcr.Plane(1), + &ycbcr.Plane(0), &ycbcr.Plane(2), null_pool); + + Image3F rgb2(rgb.xsize(), rgb.ysize()); + YcbcrToRgb(ycbcr, &rgb2, Rect(rgb)); + + VerifyRelativeError(rgb, rgb2, 4E-5, 4E-7); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_params.cc b/third_party/jpeg-xl/lib/jxl/opsin_params.cc new file mode 100644 index 000000000000..2fe2a41c9126 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_params.cc @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/opsin_params.h" + +#include + +#include "lib/jxl/linalg.h" + +namespace jxl { + +#define INVERSE_OPSIN_FROM_SPEC 1 + +const float* GetOpsinAbsorbanceInverseMatrix() { +#if INVERSE_OPSIN_FROM_SPEC + return DefaultInverseOpsinAbsorbanceMatrix(); +#else // INVERSE_OPSIN_FROM_SPEC + // Compute the inverse opsin matrix from the forward matrix. Less precise + // than taking the values from the specification, but must be used if the + // forward transform is changed and the spec will require updating. + static const float* const kInverse = [] { + static float inverse[9]; + for (int i = 0; i < 9; i++) { + inverse[i] = kOpsinAbsorbanceMatrix[i]; + } + Inv3x3Matrix(inverse); + return inverse; + }(); + return kInverse; +#endif // INVERSE_OPSIN_FROM_SPEC +} + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target) { + for (size_t i = 0; i < 9; ++i) { + simd_inverse[4 * i] = simd_inverse[4 * i + 1] = simd_inverse[4 * i + 2] = + simd_inverse[4 * i + 3] = inverse[i] * (255.0f / intensity_target); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_params.h b/third_party/jpeg-xl/lib/jxl/opsin_params.h new file mode 100644 index 000000000000..80b44af09d58 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_params.h @@ -0,0 +1,83 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_OPSIN_PARAMS_H_ +#define LIB_JXL_OPSIN_PARAMS_H_ + +// Constants that define the XYB color space. + +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Parameters for opsin absorbance. +static const float kM02 = 0.078f; +static const float kM00 = 0.30f; +static const float kM01 = 1.0f - kM02 - kM00; + +static const float kM12 = 0.078f; +static const float kM10 = 0.23f; +static const float kM11 = 1.0f - kM12 - kM10; + +static const float kM20 = 0.24342268924547819f; +static const float kM21 = 0.20476744424496821f; +static const float kM22 = 1.0f - kM20 - kM21; + +static const float kBScale = 1.0f; +static const float kYToBRatio = 1.0f; // works better with 0.50017729543783418 +static const float kBToYRatio = 1.0f / kYToBRatio; + +static const float kB0 = 0.0037930732552754493f; +static const float kB1 = kB0; +static const float kB2 = kB0; + +// Opsin absorbance matrix is now frozen. +static const float kOpsinAbsorbanceMatrix[9] = { + kM00, kM01, kM02, kM10, kM11, kM12, kM20, kM21, kM22, +}; + +// Must be the inverse matrix of kOpsinAbsorbanceMatrix and match the spec. +static inline const float* DefaultInverseOpsinAbsorbanceMatrix() { + static float kDefaultInverseOpsinAbsorbanceMatrix[9] = { + 11.031566901960783f, -9.866943921568629f, -0.16462299647058826f, + -3.254147380392157f, 4.418770392156863f, -0.16462299647058826f, + -3.6588512862745097f, 2.7129230470588235f, 1.9459282392156863f}; + return kDefaultInverseOpsinAbsorbanceMatrix; +} + +// Returns 3x3 row-major matrix inverse of kOpsinAbsorbanceMatrix. +// opsin_image_test verifies this is actually the inverse. +const float* GetOpsinAbsorbanceInverseMatrix(); + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target); + +static const float kOpsinAbsorbanceBias[3] = { + kB0, + kB1, + kB2, +}; + +static const float kNegOpsinAbsorbanceBiasRGB[4] = { + -kOpsinAbsorbanceBias[0], -kOpsinAbsorbanceBias[1], + -kOpsinAbsorbanceBias[2], 1.0f}; + +} // namespace jxl + +#endif // LIB_JXL_OPSIN_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/optimize.cc b/third_party/jpeg-xl/lib/jxl/optimize.cc new file mode 100644 index 000000000000..71cee1ed2990 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/optimize.cc @@ -0,0 +1,172 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/optimize.h" + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +namespace optimize { + +namespace { + +// simplex vector must be sorted by first element of its elements +std::vector Midpoint(const std::vector>& simplex) { + JXL_CHECK(!simplex.empty()); + JXL_CHECK(simplex.size() == simplex[0].size()); + int dim = simplex.size() - 1; + std::vector result(dim + 1, 0); + for (int i = 0; i < dim; i++) { + for (int k = 0; k < dim; k++) { + result[i + 1] += simplex[k][i + 1]; + } + result[i + 1] /= dim; + } + return result; +} + +// first element ignored +std::vector Subtract(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] - b[i]; + } + return result; +} + +// first element ignored +std::vector Add(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] + b[i]; + } + return result; +} + +// first element ignored +std::vector Average(const std::vector& a, + const std::vector& b) { + JXL_CHECK(a.size() == b.size()); + std::vector result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = 0.5 * (a[i] + b[i]); + } + return result; +} + +// vec: [0] will contain the objective function, [1:] will +// contain the vector position for the objective function. +// fun: the function evaluates the value. +void Eval(std::vector* vec, + const std::function&)>& fun) { + std::vector args(vec->begin() + 1, vec->end()); + (*vec)[0] = fun(args); +} + +void Sort(std::vector>* simplex) { + std::sort(simplex->begin(), simplex->end()); +} + +// Main iteration step of Nelder-Mead like optimization. +void Reflect(std::vector>* simplex, + const std::function&)>& fun) { + Sort(simplex); + const std::vector& last = simplex->back(); + std::vector mid = Midpoint(*simplex); + std::vector diff = Subtract(mid, last); + std::vector mirrored = Add(mid, diff); + Eval(&mirrored, fun); + if (mirrored[0] > (*simplex)[simplex->size() - 2][0]) { + // Still the worst, shrink towards the best. + std::vector shrinking = Average(simplex->back(), (*simplex)[0]); + Eval(&shrinking, fun); + simplex->back() = shrinking; + } else if (mirrored[0] < (*simplex)[0][0]) { + // new best + std::vector even_further = Add(mirrored, diff); + Eval(&even_further, fun); + if (even_further[0] < mirrored[0]) { + mirrored = even_further; + } + simplex->back() = mirrored; + } else { + // not a best, not a worst point + simplex->back() = mirrored; + } +} + +// Initialize the simplex at origin. +std::vector> InitialSimplex( + int dim, double amount, const std::vector& init, + const std::function&)>& fun) { + std::vector best(1 + dim, 0); + std::copy(init.begin(), init.end(), best.begin() + 1); + Eval(&best, fun); + std::vector> result{best}; + for (int i = 0; i < dim; i++) { + best = result[0]; + best[i + 1] += amount; + Eval(&best, fun); + result.push_back(best); + Sort(&result); + } + return result; +} + +// For comparing the same with the python tool +/*void RunSimplexExternal( + int dim, double amount, int max_iterations, + const std::function&))>& fun) { + vector vars; + for (int i = 0; i < dim; i++) { + vars.push_back(atof(getenv(StrCat("VAR", i).c_str()))); + } + double result = fun(vars); + std::cout << "Result=" << result; +}*/ + +} // namespace + +std::vector RunSimplex( + int dim, double amount, int max_iterations, const std::vector& init, + const std::function&)>& fun) { + std::vector> simplex = + InitialSimplex(dim, amount, init, fun); + for (int i = 0; i < max_iterations; i++) { + Sort(&simplex); + Reflect(&simplex, fun); + } + return simplex[0]; +} + +std::vector RunSimplex( + int dim, double amount, int max_iterations, + const std::function&)>& fun) { + std::vector init(dim, 0.0); + return RunSimplex(dim, amount, max_iterations, init, fun); +} + +} // namespace optimize + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/optimize.h b/third_party/jpeg-xl/lib/jxl/optimize.h new file mode 100644 index 000000000000..dbd891997ede --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/optimize.h @@ -0,0 +1,227 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility functions for optimizing multi-dimensional nonlinear functions. + +#ifndef LIB_JXL_OPTIMIZE_H_ +#define LIB_JXL_OPTIMIZE_H_ + +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace optimize { + +// An array type of numeric values that supports math operations with operator-, +// operator+, etc. +template +class Array { + public: + Array() = default; + explicit Array(T v) { + for (size_t i = 0; i < N; i++) v_[i] = v; + } + + size_t size() const { return N; } + + T& operator[](size_t index) { + JXL_DASSERT(index < N); + return v_[index]; + } + T operator[](size_t index) const { + JXL_DASSERT(index < N); + return v_[index]; + } + + private: + // The values used by this Array. + T v_[N]; +}; + +template +Array operator+(const Array& x, const Array& y) { + Array z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] + y[i]; + } + return z; +} + +template +Array operator-(const Array& x, const Array& y) { + Array z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] - y[i]; + } + return z; +} + +template +Array operator*(T v, const Array& x) { + Array y; + for (size_t i = 0; i < N; ++i) { + y[i] = v * x[i]; + } + return y; +} + +template +T operator*(const Array& x, const Array& y) { + T r = 0.0; + for (size_t i = 0; i < N; ++i) { + r += x[i] * y[i]; + } + return r; +} + +// Runs Nelder-Mead like optimization. Runs for max_iterations times, +// fun gets called with a vector of size dim as argument, and returns the score +// based on those parameters (lower is better). Returns a vector of dim+1 +// dimensions, where the first value is the optimal value of the function and +// the rest is the argmin value. Use init to pass an initial guess or where +// the optimal value is. +// +// Usage example: +// +// RunSimplex(2, 0.1, 100, [](const vector& v) { +// return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7); +// }); +// +// Returns (0.0, 5, 7) +std::vector RunSimplex( + int dim, double amount, int max_iterations, + const std::function&)>& fun); +std::vector RunSimplex( + int dim, double amount, int max_iterations, const std::vector& init, + const std::function&)>& fun); + +// Implementation of the Scaled Conjugate Gradient method described in the +// following paper: +// Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised +// Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 +// http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf +// +// The Function template parameter is a class that has the following method: +// +// // Returns the value of the function at point w and sets *df to be the +// // negative gradient vector of the function at point w. +// double Compute(const optimize::Array& w, +// optimize::Array* df) const; +// +// Returns a vector w, such that |df(w)| < grad_norm_threshold. +template +Array OptimizeWithScaledConjugateGradientMethod( + const Function& f, const Array& w0, const T grad_norm_threshold, + size_t max_iters) { + const size_t n = w0.size(); + const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; + const T sigma0 = static_cast(0.0001); + const T l_min = static_cast(1.0e-15); + const T l_max = static_cast(1.0e15); + + Array w = w0; + Array wp; + Array r; + Array rt; + Array e; + Array p; + T psq; + T fp; + T D; + T d; + T m; + T a; + T b; + T s; + T t; + + T fw = f.Compute(w, &r); + T rsq = r * r; + e = r; + p = r; + T l = static_cast(1.0); + bool success = true; + size_t n_success = 0; + size_t k = 0; + + while (k++ < max_iters) { + if (success) { + m = -(p * r); + if (m >= 0) { + p = r; + m = -(p * r); + } + psq = p * p; + s = sigma0 / std::sqrt(psq); + f.Compute(w + (s * p), &rt); + t = (p * (r - rt)) / s; + } + + d = t + l * psq; + if (d <= 0) { + d = l * psq; + l = l - t / psq; + } + + a = -m / d; + wp = w + a * p; + fp = f.Compute(wp, &rt); + + D = 2.0 * (fp - fw) / (a * m); + if (D >= 0.0) { + success = true; + n_success++; + w = wp; + } else { + success = false; + } + + if (success) { + e = r; + r = rt; + rsq = r * r; + fw = fp; + if (rsq <= rsq_threshold) { + break; + } + } + + if (D < 0.25) { + l = std::min(4.0 * l, l_max); + } else if (D > 0.75) { + l = std::max(0.25 * l, l_min); + } + + if ((n_success % n) == 0) { + p = r; + l = 1.0; + } else if (success) { + b = ((e - r) * r) / m; + p = b * p + r; + } + } + + return w; +} + +} // namespace optimize +} // namespace jxl + +#endif // LIB_JXL_OPTIMIZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/optimize_test.cc b/third_party/jpeg-xl/lib/jxl/optimize_test.cc new file mode 100644 index 000000000000..9ea695119bed --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/optimize_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/optimize.h" + +#include + +#include "gtest/gtest.h" + +namespace jxl { +namespace optimize { +namespace { + +// The maximum number of iterations for the test. +static const size_t kMaxTestIter = 100000; + +// F(w) = (w - w_min)^2. +struct SimpleQuadraticFunction { + typedef Array ArrayType; + explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {} + + double Compute(const ArrayType& w, ArrayType* df) const { + ArrayType dw = w - w_min; + *df = -2.0 * dw; + return dw * dw; + } + + ArrayType w_min; +}; + +// F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2. +struct PowerFunction { + explicit PowerFunction(const std::vector& x0, + const std::vector& y0) + : x(x0), y(y0) {} + + typedef Array ArrayType; + double Compute(const ArrayType& w, ArrayType* df) const { + double loss_function = 0; + (*df)[0] = 0; + (*df)[1] = 0; + (*df)[2] = 0; + for (size_t ind = 0; ind < y.size(); ++ind) { + if (x[ind] != 0) { + double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]); + (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]); + (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]); + (*df)[2] += 2.0 * l_f * 1; + loss_function += l_f * l_f; + } + } + return loss_function; + } + + std::vector x; + std::vector y; +}; + +TEST(OptimizeTest, SimpleQuadraticFunction) { + SimpleQuadraticFunction::ArrayType w_min; + w_min[0] = 1.0; + w_min[1] = 2.0; + SimpleQuadraticFunction f(w_min); + SimpleQuadraticFunction::ArrayType w(0.); + static const double kPrecision = 1e-8; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 1.0, kPrecision); + EXPECT_NEAR(w[1], 2.0, kPrecision); +} + +TEST(OptimizeTest, PowerFunction) { + std::vector x(10); + std::vector y(10); + for (int ind = 0; ind < 10; ++ind) { + x[ind] = 1. * ind; + y[ind] = 2. * pow(x[ind], 3) + 5.; + } + PowerFunction f(x, y); + PowerFunction::ArrayType w(0.); + + static const double kPrecision = 0.01; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 2.0, kPrecision); + EXPECT_NEAR(w[1], 3.0, kPrecision); + EXPECT_NEAR(w[2], 5.0, kPrecision); +} + +TEST(OptimizeTest, SimplexOptTest) { + auto f = [](const std::vector& x) -> double { + double t1 = x[0] - 1.0; + double t2 = x[1] + 1.5; + return 2.0 + t1 * t1 + t2 * t2; + }; + auto opt = RunSimplex(2, 0.01, 100, f); + EXPECT_EQ(opt.size(), 3); + + static const double kPrecision = 0.01; + EXPECT_NEAR(opt[0], 2.0, kPrecision); + EXPECT_NEAR(opt[1], 1.0, kPrecision); + EXPECT_NEAR(opt[2], -1.5, kPrecision); +} + +} // namespace +} // namespace optimize +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc b/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc new file mode 100644 index 000000000000..4c01b9af2aca --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc @@ -0,0 +1,135 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/padded_bytes.h" + +#include // iota +#include + +#include "gtest/gtest.h" + +namespace jxl { +namespace { + +TEST(PaddedBytesTest, TestNonEmptyFirstByteZero) { + PaddedBytes pb(1); + EXPECT_EQ(0, pb[0]); + // Even after resizing.. + pb.resize(20); + EXPECT_EQ(0, pb[0]); + // And reserving. + pb.reserve(200); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestEmptyFirstByteZero) { + PaddedBytes pb(0); + // After resizing - new zero is written despite there being nothing to copy. + pb.resize(20); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestFillWithoutReserve) { + PaddedBytes pb; + for (size_t i = 0; i < 170; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170, pb.size()); + EXPECT_GE(pb.capacity(), 170); +} + +TEST(PaddedBytesTest, TestFillWithExactReserve) { + PaddedBytes pb; + pb.reserve(170); + for (size_t i = 0; i < 170; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170, pb.size()); + EXPECT_EQ(pb.capacity(), 170); +} + +TEST(PaddedBytesTest, TestFillWithMoreReserve) { + PaddedBytes pb; + pb.reserve(171); + for (size_t i = 0; i < 170; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170, pb.size()); + EXPECT_GT(pb.capacity(), 170); +} + +// Can assign() a subset of the valid data. +TEST(PaddedBytesTest, TestAssignFromWithin) { + PaddedBytes pb; + pb.reserve(256); + for (size_t i = 0; i < 256; ++i) { + pb.push_back(i); + } + pb.assign(pb.data() + 64, pb.data() + 192); + EXPECT_EQ(128, pb.size()); + for (size_t i = 0; i < 128; ++i) { + EXPECT_EQ(i + 64, pb[i]); + } +} + +// Can assign() a range with both valid and previously-allocated data. +TEST(PaddedBytesTest, TestAssignReclaim) { + PaddedBytes pb; + pb.reserve(256); + for (size_t i = 0; i < 256; ++i) { + pb.push_back(i); + } + + const uint8_t* mem = pb.data(); + pb.resize(200); + // Just shrank without reallocating + EXPECT_EQ(mem, pb.data()); + EXPECT_EQ(256, pb.capacity()); + + // Reclaim part of initial allocation + pb.assign(pb.data() + 100, pb.data() + 240); + EXPECT_EQ(140, pb.size()); + + for (size_t i = 0; i < 140; ++i) { + EXPECT_EQ(i + 100, pb[i]); + } +} + +// Can assign() smaller and larger ranges outside the current allocation. +TEST(PaddedBytesTest, TestAssignOutside) { + PaddedBytes pb; + pb.resize(400); + std::iota(pb.begin(), pb.end(), 1); + + std::vector small(64); + std::iota(small.begin(), small.end(), 500); + + pb.assign(small.data(), small.data() + small.size()); + EXPECT_EQ(64, pb.size()); + for (size_t i = 0; i < 64; ++i) { + EXPECT_EQ((i + 500) & 0xFF, pb[i]); + } + + std::vector large(1000); + std::iota(large.begin(), large.end(), 600); + + pb.assign(large.data(), large.data() + large.size()); + EXPECT_EQ(1000, pb.size()); + for (size_t i = 0; i < 1000; ++i) { + EXPECT_EQ((i + 600) & 0xFF, pb[i]); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/passes_state.cc b/third_party/jpeg-xl/lib/jxl/passes_state.cc new file mode 100644 index 000000000000..ce998e0411ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_state.cc @@ -0,0 +1,77 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/passes_state.h" + +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" + +namespace jxl { + +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder) { + JXL_ASSERT(frame_header.nonserialized_metadata != nullptr); + shared->frame_header = frame_header; + shared->metadata = frame_header.nonserialized_metadata; + shared->frame_dim = frame_header.ToFrameDimensions(); + shared->image_features.patches.SetPassesSharedState(shared); + + const FrameDimensions& frame_dim = shared->frame_dim; + + shared->ac_strategy = + AcStrategyImage(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->raw_quant_field = + ImageI(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->epf_sharpness = + ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->cmap = ColorCorrelationMap(frame_dim.xsize, frame_dim.ysize); + + // In the decoder, we allocate coeff orders afterwards, when we know how many + // we will actually need. + shared->coeff_order_size = kCoeffOrderMaxSize; + if (encoder && + shared->coeff_orders.size() < + frame_header.passes.num_passes * kCoeffOrderMaxSize && + frame_header.encoding == FrameEncoding::kVarDCT) { + shared->coeff_orders.resize(frame_header.passes.num_passes * + kCoeffOrderMaxSize); + } + + shared->quant_dc = ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + if (!(frame_header.flags & FrameHeader::kUseDcFrame) || encoder) { + shared->dc_storage = + Image3F(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + } else { + if (frame_header.dc_level == 4) { + return JXL_FAILURE("Invalid DC level for kUseDcFrame: %u", + frame_header.dc_level); + } + shared->dc = &shared->dc_frames[frame_header.dc_level]; + if (shared->dc->xsize() == 0) { + return JXL_FAILURE( + "kUseDcFrame specified for dc_level %u, but no frame was decoded " + "with level %u", + frame_header.dc_level, frame_header.dc_level + 1); + } + ZeroFillImage(&shared->quant_dc); + } + + shared->dc_storage = Image3F(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/passes_state.h b/third_party/jpeg-xl/lib/jxl/passes_state.h new file mode 100644 index 000000000000..3b21bd8d42a7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_state.h @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_PASSES_STATE_H_ +#define LIB_JXL_PASSES_STATE_H_ + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/noise.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +// Structures that hold the (en/de)coder state for a JPEG XL kVarDCT +// (en/de)coder. + +namespace jxl { + +struct ImageFeatures { + NoiseParams noise_params; + PatchDictionary patches; + Splines splines; +}; + +// State common to both encoder and decoder. +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct PassesSharedState { + PassesSharedState() : frame_header(nullptr) {} + + // Headers and metadata. + const CodecMetadata* metadata; + FrameHeader frame_header; + + FrameDimensions frame_dim; + + // Control fields and parameters. + AcStrategyImage ac_strategy; + + // Dequant matrices + quantizer. + DequantMatrices matrices; + Quantizer quantizer{&matrices}; + ImageI raw_quant_field; + + // Per-block side information for EPF detail preservation. + ImageB epf_sharpness; + + ColorCorrelationMap cmap; + + ImageFeatures image_features; + + // Memory area for storing coefficient orders. + // `coeff_order_size` is the size used by *one* set of coefficient orders (at + // most kMaxCoeffOrderSize). A set of coefficient orders is present for each + // pass. + size_t coeff_order_size = 0; + std::vector coeff_orders; + + // Decoder-side DC and quantized DC. + ImageB quant_dc; + Image3F dc_storage; + const Image3F* JXL_RESTRICT dc = &dc_storage; + + BlockCtxMap block_ctx_map; + + Image3F dc_frames[4]; + + struct { + ImageBundle storage; + // Can either point to `storage`, if this is a frame that is not stored in + // the CodecInOut, or can point to an existing ImageBundle. + // TODO(veluca): pointing to ImageBundles in CodecInOut is not possible for + // now, as they are stored in a vector and thus may be moved. Fix this. + ImageBundle* JXL_RESTRICT frame = &storage; + // ImageBundle doesn't yet have a simple way to state it is in XYB. + bool ib_is_in_xyb = true; + } reference_frames[4] = {}; + + // Number of pre-clustered set of histograms (with the same ctx map), per + // pass. Encoded as num_histograms_ - 1. + size_t num_histograms = 0; + + bool IsGrayscale() const { return metadata->m.color_encoding.IsGray(); } + + Rect GroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, frame_dim.xsize, + frame_dim.ysize); + return rect; + } + + Rect PaddedGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, + frame_dim.xsize_padded, frame_dim.ysize_padded); + return rect; + } + + Rect BlockGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * (frame_dim.group_dim >> 3), + gy * (frame_dim.group_dim >> 3), frame_dim.group_dim >> 3, + frame_dim.group_dim >> 3, frame_dim.xsize_blocks, + frame_dim.ysize_blocks); + return rect; + } + + Rect DCGroupRect(size_t group_index) const { + const size_t gx = group_index % frame_dim.xsize_dc_groups; + const size_t gy = group_index / frame_dim.xsize_dc_groups; + const Rect rect(gx * frame_dim.group_dim, gy * frame_dim.group_dim, + frame_dim.group_dim, frame_dim.group_dim, + frame_dim.xsize_blocks, frame_dim.ysize_blocks); + return rect; + } +}; + +// Initialized the state information that is shared between encoder and decoder. +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder = false); + +} // namespace jxl + +#endif // LIB_JXL_PASSES_STATE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/passes_test.cc b/third_party/jpeg-xl/lib/jxl/passes_test.cc new file mode 100644 index 000000000000..44da23c8e7ad --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_test.cc @@ -0,0 +1,398 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +TEST(PassesTest, RoundtripSmallPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = true; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.5); +} + +TEST(PassesTest, RoundtripUnalignedPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 12, io.ysize() / 7); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.progressive_mode = true; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 3.2); +} + +TEST(PassesTest, RoundtripMultiGroupPasses) { + ThreadPoolInternal pool(4); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + io.ShrinkTo(600, 1024); // partial X, full Y group + + CompressParams cparams; + DecompressParams dparams; + + cparams.butteraugli_distance = 1.0f; + cparams.progressive_mode = true; + CodecInOut io2; + Roundtrip(&io, cparams, dparams, &pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 1.99f); + + cparams.butteraugli_distance = 2.0f; + CodecInOut io3; + Roundtrip(&io, cparams, dparams, &pool, &io3); + EXPECT_LE(ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool), + 3.0f); +} + +TEST(PassesTest, RoundtripLargeFastPasses) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, &pool, &io2); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(PassesTest, RoundtripProgressiveConsistent) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + cparams.butteraugli_distance = 2.0; + DecompressParams dparams; + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + io.ShrinkTo(xsize, 15); + + CodecInOut io2; + const size_t size2 = Roundtrip(&io, cparams, dparams, &pool, &io2); + + CodecInOut io3; + const size_t size3 = Roundtrip(&io, cparams, dparams, &pool, &io3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, &pool); + const float dist3 = ButteraugliDistance(io, io3, cparams.ba_params, + /*distmap=*/nullptr, &pool); + EXPECT_EQ(dist2, dist3); + } +} + +TEST(PassesTest, AllDownsampleFeasible) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, &aux, &pool)); + + EXPECT_LE(compressed.size(), 240000); + float target_butteraugli[9] = {}; + target_butteraugli[1] = 2.5f; + target_butteraugli[2] = 14.5f; + target_butteraugli[4] = 20.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + // TODO(veluca): re-enable downsampling 16. + std::vector downsamplings = {1, 2, 4, 8}; //, 16}; + + auto check = [&](uint32_t task, uint32_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + DecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output, nullptr)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io, output, cparams.ba_params, + /*distmap=*/nullptr, nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + pool.Run(0, downsamplings.size(), ThreadPool::SkipInit(), check); +} + +TEST(PassesTest, AllDownsampleFeasibleQProgressive) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, &aux, &pool)); + + EXPECT_LE(compressed.size(), 220000); + + float target_butteraugli[9] = {}; + target_butteraugli[1] = 3.0f; + target_butteraugli[2] = 6.0f; + target_butteraugli[4] = 10.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + std::vector downsamplings = {1, 2, 4, 8}; + + auto check = [&](uint32_t task, uint32_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + DecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output, nullptr)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io, output, cparams.ba_params, + /*distmap=*/nullptr, nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + pool.Run(0, downsamplings.size(), ThreadPool::SkipInit(), check); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectlyGrayscale) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.metadata = io_orig.metadata; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, &aux, &pool)); + + EXPECT_LE(compressed.size(), 10000); + + DecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output, nullptr)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output_d2, nullptr)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = + ButteraugliDistance(output, output_d2, cparams.ba_params, + /*distmap=*/nullptr, nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.0f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectly) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Span(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = true; + cparams.butteraugli_distance = 1.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, &aux, &pool)); + + EXPECT_LE(compressed.size(), 220000); + + DecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output, nullptr)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output_d2, nullptr)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = + ButteraugliDistance(output, output_d2, cparams.ba_params, + /*distmap=*/nullptr, nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.0f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, NonProgressiveDCImage) { + ThreadPoolInternal pool(8); + const PaddedBytes orig = + ReadTestData("imagecompression.info/flower_foveon.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + PaddedBytes compressed; + AuxOut aux; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = false; + cparams.butteraugli_distance = 2.0; + PassesEncoderState enc_state; + ASSERT_TRUE(EncodeFile(cparams, &io, &enc_state, &compressed, &aux, &pool)); + + // Even in non-progressive mode, it should be possible to return a DC-only + // image. + DecompressParams dparams; + dparams.max_downsampling = 100; + CodecInOut output; + ASSERT_TRUE(DecodeFile(dparams, compressed, &output, &pool)); + EXPECT_EQ(output.xsize(), io.xsize()); + EXPECT_EQ(output.ysize(), io.ysize()); +} + +TEST(PassesTest, RoundtripSmallNoGaborishPasses) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = true; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 1.7); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h b/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h new file mode 100644 index 000000000000..ee8b7da5ad48 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ +#define LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ + +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/passes_state.h" // for PassesSharedState + +namespace jxl { + +// Context numbers as specified in Section C.4.5, Listing C.2: +enum Contexts { + kNumRefPatchContext = 0, + kReferenceFrameContext = 1, + kPatchSizeContext = 2, + kPatchReferencePositionContext = 3, + kPatchPositionContext = 4, + kPatchBlendModeContext = 5, + kPatchOffsetContext = 6, + kPatchCountContext = 7, + kPatchAlphaChannelContext = 8, + kPatchClampContext = 9, + kNumPatchDictionaryContexts +}; + +template +void PatchDictionary::Apply(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + JXL_CHECK(SameSize(opsin_rect, image_rect)); + size_t num = 0; + for (size_t y = image_rect.y0(); y < image_rect.y0() + image_rect.ysize(); + y++) { + if (y + 1 >= patch_starts_.size()) continue; + float* JXL_RESTRICT rows[3] = { + opsin_rect.PlaneRow(opsin, 0, y - image_rect.y0()), + opsin_rect.PlaneRow(opsin, 1, y - image_rect.y0()), + opsin_rect.PlaneRow(opsin, 2, y - image_rect.y0()), + }; + for (size_t id = patch_starts_[y]; id < patch_starts_[y + 1]; id++) { + num++; + const PatchPosition& pos = positions_[sorted_patches_[id]]; + size_t by = pos.y; + size_t bx = pos.x; + size_t xsize = pos.ref_pos.xsize; + JXL_DASSERT(y >= by); + JXL_DASSERT(y < by + pos.ref_pos.ysize); + size_t iy = y - by; + size_t ref = pos.ref_pos.ref; + if (bx >= image_rect.x0() + image_rect.xsize()) continue; + if (bx + xsize < image_rect.x0()) continue; + // TODO(veluca): check that the reference frame is in XYB. + const float* JXL_RESTRICT ref_rows[3] = { + shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 0, pos.ref_pos.y0 + iy) + + pos.ref_pos.x0, + shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 1, pos.ref_pos.y0 + iy) + + pos.ref_pos.x0, + shared_->reference_frames[ref].frame->color()->ConstPlaneRow( + 2, pos.ref_pos.y0 + iy) + + pos.ref_pos.x0, + }; + // TODO(veluca): use the same code as in dec_reconstruct.cc. + for (size_t ix = 0; ix < xsize; ix++) { + // TODO(veluca): hoist branches and checks. + // TODO(veluca): implement for extra channels. + if (bx + ix < image_rect.x0()) continue; + if (bx + ix >= image_rect.x0() + image_rect.xsize()) continue; + for (size_t c = 0; c < 3; c++) { + if (add) { + if (pos.blending[0].mode == PatchBlendMode::kAdd) { + rows[c][bx + ix - image_rect.x0()] += ref_rows[c][ix]; + } else if (pos.blending[0].mode == PatchBlendMode::kReplace) { + rows[c][bx + ix - image_rect.x0()] = ref_rows[c][ix]; + } else if (pos.blending[0].mode == PatchBlendMode::kNone) { + // Nothing to do. + } else { + // Checked in decoding code. + JXL_ABORT("Blending mode %u not yet implemented", + (uint32_t)pos.blending[0].mode); + } + } else { + if (pos.blending[0].mode == PatchBlendMode::kAdd) { + rows[c][bx + ix - image_rect.x0()] -= ref_rows[c][ix]; + } else if (pos.blending[0].mode == PatchBlendMode::kReplace) { + rows[c][bx + ix - image_rect.x0()] = 0; + } else if (pos.blending[0].mode == PatchBlendMode::kNone) { + // Nothing to do. + } else { + // Checked in decoding code. + JXL_ABORT("Blending mode %u not yet implemented", + (uint32_t)pos.blending[0].mode); + } + } + } + } + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc b/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc new file mode 100644 index 000000000000..410763d7f378 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +using ::jxl::test::Roundtrip; + +TEST(PatchDictionaryTest, GrayscaleModular) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.modular_mode = true; + cparams.patches = jxl::Override::kOn; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + VerifyEqual(*io.Main().color(), *io2.Main().color()); +} + +TEST(PatchDictionaryTest, GrayscaleVarDCT) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + + CompressParams cparams; + cparams.patches = jxl::Override::kOn; + DecompressParams dparams; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 2); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/preview_test.cc b/third_party/jpeg-xl/lib/jxl/preview_test.cc new file mode 100644 index 000000000000..b396ce227578 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/preview_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { +using test::Roundtrip; + +TEST(PreviewTest, RoundtripGivenPreview) { + ThreadPool* pool = nullptr; + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Span(orig), &io, pool)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + // Same as main image + io.preview_frame = io.Main().Copy(); + const size_t preview_xsize = 15; + const size_t preview_ysize = 27; + io.preview_frame.ShrinkTo(preview_xsize, preview_ysize); + io.metadata.m.have_preview = true; + ASSERT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.speed_tier = SpeedTier::kSquirrel; + DecompressParams dparams; + + dparams.preview = Override::kOff; + + CodecInOut io2; + Roundtrip(&io, cparams, dparams, pool, &io2); + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.5); + EXPECT_EQ(0, io2.preview_frame.xsize()); + + dparams.preview = Override::kOn; + + CodecInOut io3; + Roundtrip(&io, cparams, dparams, pool, &io3); + EXPECT_EQ(preview_xsize, io3.metadata.m.preview_size.xsize()); + EXPECT_EQ(preview_ysize, io3.metadata.m.preview_size.ysize()); + EXPECT_EQ(preview_xsize, io3.preview_frame.xsize()); + EXPECT_EQ(preview_ysize, io3.preview_frame.ysize()); + + EXPECT_LE(ButteraugliDistance(io.preview_frame, io3.preview_frame, + cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.5); + EXPECT_LE(ButteraugliDistance(io.Main(), io3.Main(), cparams.ba_params, + /*distmap=*/nullptr, pool), + 2.5); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/progressive_split.cc b/third_party/jpeg-xl/lib/jxl/progressive_split.cc new file mode 100644 index 000000000000..e580c5de4896 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/progressive_split.cc @@ -0,0 +1,137 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/progressive_split.h" + +#include + +#include +#include + +#include "lib/jxl/common.h" +#include "lib/jxl/image.h" + +namespace jxl { + +bool ProgressiveSplitter::SuperblockIsSalient(size_t row_start, + size_t col_start, size_t num_rows, + size_t num_cols) const { + if (saliency_map_ == nullptr || saliency_map_->xsize() == 0 || + saliency_threshold_ == 0.0) { + // If we do not have a saliency-map, or the threshold says to include + // every block, we straightaway classify the superblock as 'salient'. + return true; + } + const size_t row_end = std::min(saliency_map_->ysize(), row_start + num_rows); + const size_t col_end = std::min(saliency_map_->xsize(), col_start + num_cols); + for (size_t num_row = row_start; num_row < row_end; num_row++) { + const float* JXL_RESTRICT map_row = saliency_map_->ConstRow(num_row); + for (size_t num_col = col_start; num_col < col_end; num_col++) { + if (map_row[num_col] >= saliency_threshold_) { + // One of the blocks covered by this superblock is above the saliency + // threshold. + return true; + } + } + } + // We did not see any block above the saliency threshold. + return false; +} + +template +void ProgressiveSplitter::SplitACCoefficients( + const T* JXL_RESTRICT block, size_t size, const AcStrategy& acs, size_t bx, + size_t by, size_t offset, T* JXL_RESTRICT output[kMaxNumPasses][3]) { + auto shift_right_round0 = [&](T v, int shift) { + T one_if_negative = static_cast(v) >> 31; + T add = (one_if_negative << shift) - one_if_negative; + return (v + add) >> shift; + }; + // Early quit for the simple case of only one pass. + if (mode_.num_passes == 1) { + for (size_t c = 0; c < 3; c++) { + memcpy(output[0][c] + offset, block + c * size, sizeof(T) * size); + } + return; + } + size_t ncoeffs_all_done_from_earlier_passes = 1; + size_t previous_pass_salient_only = false; + + int previous_pass_shift = 0; + for (size_t num_pass = 0; num_pass < mode_.num_passes; num_pass++) { // pass + // Zero out output block. + for (size_t c = 0; c < 3; c++) { + memset(output[num_pass][c] + offset, 0, size * sizeof(T)); + } + const bool current_pass_salient_only = mode_.passes[num_pass].salient_only; + const int pass_shift = mode_.passes[num_pass].shift; + size_t frame_ncoeffs = mode_.passes[num_pass].num_coefficients; + for (size_t c = 0; c < 3; c++) { // color-channel + size_t xsize = acs.covered_blocks_x(); + size_t ysize = acs.covered_blocks_y(); + CoefficientLayout(&ysize, &xsize); + if (current_pass_salient_only || previous_pass_salient_only) { + // Current or previous pass is salient-only. + const bool superblock_is_salient = + SuperblockIsSalient(by, bx, ysize, xsize); + if (current_pass_salient_only != superblock_is_salient) { + // Current pass is salient-only, but block is not salient, + // OR last pass was salient-only, and block is salient + // (hence was already included in last pass). + continue; + } + } + for (size_t y = 0; y < ysize * frame_ncoeffs; y++) { // superblk-y + for (size_t x = 0; x < xsize * frame_ncoeffs; x++) { // superblk-x + size_t pos = y * xsize * kBlockDim + x; + if (x < xsize * ncoeffs_all_done_from_earlier_passes && + y < ysize * ncoeffs_all_done_from_earlier_passes) { + // This coefficient was already included in an earlier pass, + // which included a genuinely smaller set of coefficients + // (= is not about saliency-splitting). + continue; + } + T v = block[c * size + pos]; + // Previous pass discarded some bits: do not encode them again. + if (previous_pass_shift != 0) { + T previous_v = shift_right_round0(v, previous_pass_shift) * + (1 << previous_pass_shift); + v -= previous_v; + } + output[num_pass][c][offset + pos] = shift_right_round0(v, pass_shift); + } // superblk-x + } // superblk-y + } // color-channel + if (!current_pass_salient_only) { + // We just finished a non-salient pass. + // Hence, we are now guaranteed to have included all coeffs up to + // frame_ncoeffs in every block, unless the current pass is shifted. + if (mode_.passes[num_pass].shift == 0) { + ncoeffs_all_done_from_earlier_passes = frame_ncoeffs; + } + } + previous_pass_salient_only = current_pass_salient_only; + previous_pass_shift = mode_.passes[num_pass].shift; + } // num_pass +} + +template void ProgressiveSplitter::SplitACCoefficients( + const int32_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int32_t* JXL_RESTRICT[kMaxNumPasses][3]); + +template void ProgressiveSplitter::SplitACCoefficients( + const int16_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int16_t* JXL_RESTRICT[kMaxNumPasses][3]); + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/progressive_split.h b/third_party/jpeg-xl/lib/jxl/progressive_split.h new file mode 100644 index 000000000000..4a3513f3c0dc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/progressive_split.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_PROGRESSIVE_SPLIT_H_ +#define LIB_JXL_PROGRESSIVE_SPLIT_H_ + +#include +#include + +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/splines.h" + +// Functions to split DCT coefficients in multiple passes. All the passes of a +// single frame are added together. + +namespace jxl { + +constexpr size_t kNoDownsamplingFactor = std::numeric_limits::max(); + +struct PassDefinition { + // Side of the square of the coefficients that should be kept in each 8x8 + // block. Must be greater than 1, and at most 8. Should be in non-decreasing + // order. + size_t num_coefficients; + + // How much to shift the encoded values by, with rounding. + size_t shift; + + // Whether or not we should include only salient blocks. + // TODO(veluca): ignored for now. + bool salient_only; + + // If specified, this indicates that if the requested downsampling factor is + // sufficiently high, then it is fine to stop decoding after this pass. + // By default, passes are not marked as being suitable for any downsampling. + size_t suitable_for_downsampling_of_at_least; +}; + +struct ProgressiveMode { + size_t num_passes = 1; + PassDefinition passes[kMaxNumPasses] = {PassDefinition{ + /*num_coefficients=*/8, /*shift=*/0, /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/1}}; + + ProgressiveMode() = default; + + template + explicit ProgressiveMode(const PassDefinition (&p)[nump]) { + JXL_ASSERT(nump <= kMaxNumPasses); + num_passes = nump; + PassDefinition previous_pass{ + /*num_coefficients=*/1, /*shift=*/0, + /*salient_only=*/false, + /*suitable_for_downsampling_of_at_least=*/kNoDownsamplingFactor}; + size_t last_downsampling_factor = kNoDownsamplingFactor; + for (size_t i = 0; i < nump; i++) { + JXL_ASSERT(p[i].num_coefficients > previous_pass.num_coefficients || + (p[i].num_coefficients == previous_pass.num_coefficients && + !p[i].salient_only && previous_pass.salient_only) || + (p[i].num_coefficients == previous_pass.num_coefficients && + p[i].shift < previous_pass.shift)); + JXL_ASSERT(p[i].suitable_for_downsampling_of_at_least == + kNoDownsamplingFactor || + p[i].suitable_for_downsampling_of_at_least <= + last_downsampling_factor); + if (p[i].suitable_for_downsampling_of_at_least != kNoDownsamplingFactor) { + last_downsampling_factor = p[i].suitable_for_downsampling_of_at_least; + } + previous_pass = passes[i] = p[i]; + } + } +}; + +class ProgressiveSplitter { + public: + void SetProgressiveMode(ProgressiveMode mode) { mode_ = mode; } + + void SetSaliencyMap(const ImageF* saliency_map) { + saliency_map_ = saliency_map; + } + + void SetSaliencyThreshold(float threshold) { + saliency_threshold_ = threshold; + } + + size_t GetNumPasses() const { return mode_.num_passes; } + + void InitPasses(Passes* JXL_RESTRICT passes) const { + passes->num_passes = static_cast(GetNumPasses()); + passes->num_downsample = 0; + JXL_ASSERT(passes->num_passes != 0); + if (passes->num_passes == 1) return; // Done, arrays are empty + + for (uint32_t i = 0; i < mode_.num_passes - 1; ++i) { + const size_t min_downsampling_factor = + mode_.passes[i].suitable_for_downsampling_of_at_least; + passes->shift[i] = mode_.passes[i].shift; + if (1 < min_downsampling_factor && + min_downsampling_factor != kNoDownsamplingFactor) { + passes->downsample[passes->num_downsample] = min_downsampling_factor; + passes->last_pass[passes->num_downsample] = i; + passes->num_downsample += 1; + } + } + } + + template + void SplitACCoefficients(const T* JXL_RESTRICT block, size_t size, + const AcStrategy& acs, size_t bx, size_t by, + size_t offset, + T* JXL_RESTRICT output[kMaxNumPasses][3]); + + private: + bool SuperblockIsSalient(size_t row_start, size_t col_start, size_t num_rows, + size_t num_cols) const; + ProgressiveMode mode_; + + // Not owned, must remain valid. + const ImageF* saliency_map_ = nullptr; + float saliency_threshold_ = 0.0; +}; + +extern template void ProgressiveSplitter::SplitACCoefficients( + const int32_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int32_t* JXL_RESTRICT[kMaxNumPasses][3]); + +extern template void ProgressiveSplitter::SplitACCoefficients( + const int16_t* JXL_RESTRICT, size_t, const AcStrategy&, size_t, size_t, + size_t, int16_t* JXL_RESTRICT[kMaxNumPasses][3]); + +} // namespace jxl + +#endif // LIB_JXL_PROGRESSIVE_SPLIT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.cc b/third_party/jpeg-xl/lib/jxl/quant_weights.cc new file mode 100644 index 000000000000..e13becfddaca --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.cc @@ -0,0 +1,1191 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lib/jxl/quant_weights.h" + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) +// coefficient in component c. Higher weights correspond to finer quantization +// intervals and more bits spent in encoding. + +namespace { + +static constexpr const float kAlmostZero = 1e-8f; + +void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + size_t start = c * 64; + weights[start] = 0xBAD; + weights[start + 1] = weights[start + 8] = dct2weights[c][0]; + weights[start + 9] = dct2weights[c][1]; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + y * 8 + x + 2] = dct2weights[c][2]; + weights[start + (y + 2) * 8 + x] = dct2weights[c][2]; + } + } + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + y * 8 + x + 4] = dct2weights[c][4]; + weights[start + (y + 4) * 8 + x] = dct2weights[c][4]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5]; + } + } + } +} + +void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 64; i++) { + weights[64 * c + i] = idweights[c][0]; + } + weights[64 * c + 1] = idweights[c][1]; + weights[64 * c + 8] = idweights[c][1]; + weights[64 * c + 9] = idweights[c][2]; + } +} + +float Mult(float v) { + if (v > 0) return 1 + v; + return 1 / (1 - v); +} + +float Interpolate(float pos, float max, const float* array, size_t len) { + float scaled_pos = pos * (len - 1) / max; + size_t idx = scaled_pos; + JXL_ASSERT(idx + 1 < len); + float a = array[idx]; + float b = array[idx + 1]; + return a * pow(b / a, scaled_pos - idx); +} + +// Computes quant weights for a COLS*ROWS-sized transform, using num_bands +// eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, +// prints the resulting matrix; if print_mode is 2, prints the matrix in a +// format suitable for a 3d plot with gnuplot. +template +Status GetQuantWeights( + size_t ROWS, size_t COLS, + const DctQuantWeightParams::DistanceBandsArray& distance_bands, + size_t num_bands, float* out) { + for (size_t c = 0; c < 3; c++) { + if (print_mode) { + fprintf(stderr, "Channel %zu\n", c); + } + float bands[DctQuantWeightParams::kMaxDistanceBands] = { + distance_bands[c][0]}; + for (size_t i = 1; i < num_bands; i++) { + bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); + } + for (size_t y = 0; y < ROWS; y++) { + for (size_t x = 0; x < COLS; x++) { + float dx = 1.0f * x / (COLS - 1); + float dy = 1.0f * y / (ROWS - 1); + float distance = std::sqrt(dx * dx + dy * dy); + float weight = + num_bands == 1 + ? bands[0] + : Interpolate(distance, std::sqrt(2) + 1e-6f, bands, num_bands); + + if (print_mode == 1) { + fprintf(stderr, "%15.12f, ", weight); + } + if (print_mode == 2) { + fprintf(stderr, "%zu %zu %15.12f\n", x, y, weight); + } + out[c * COLS * ROWS + y * COLS + x] = weight; + } + if (print_mode) fprintf(stderr, "\n"); + if (print_mode == 1) fprintf(stderr, "\n"); + } + if (print_mode) fprintf(stderr, "\n"); + } + return true; +} + +Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { + params->num_distance_bands = + br->ReadFixedBits() + 1; + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params->num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); + } + if (params->distance_bands[c][0] < kAlmostZero) { + return JXL_FAILURE("Distance band seed is too small"); + } + params->distance_bands[c][0] *= 64.0f; + } + return true; +} + +Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, + size_t required_size_y, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + size_t required_size = required_size_x * required_size_y; + required_size_x *= kBlockDim; + required_size_y *= kBlockDim; + int mode = br->ReadFixedBits(); + switch (mode) { + case QuantEncoding::kQuantModeLibrary: { + encoding->predefined = br->ReadFixedBits(); + if (encoding->predefined >= kNumPredefinedTables) { + return JXL_FAILURE("Invalid predefined table"); + } + break; + } + case QuantEncoding::kQuantModeID: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); + if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { + return JXL_FAILURE("ID Quantizer is too small"); + } + encoding->idweights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); + if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { + return JXL_FAILURE("Quantizer is too small"); + } + encoding->dct2weights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4x8multipliers[c])); + if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { + return JXL_FAILURE("DCT4X8 multiplier is too small"); + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4multipliers[c][i])); + if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { + return JXL_FAILURE("DCT4 multiplier is too small"); + } + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeAFV: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); + } + for (size_t i = 0; i < 6; i++) { + encoding->afv_weights[c][i] *= 64; + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeRAW: { + // Set mode early, to avoid mem-leak. + encoding->mode = QuantEncoding::kQuantModeRAW; + JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( + required_size_x, required_size_y, br, encoding, idx, + modular_frame_decoder)); + break; + } + default: + return JXL_FAILURE("Invalid quantization table encoding"); + } + encoding->mode = QuantEncoding::Mode(mode); + return true; +} + +// TODO(veluca): SIMD-fy. With 256x256, this is actually slow. +Status ComputeQuantTable(const QuantEncoding& encoding, + float* JXL_RESTRICT table, + float* JXL_RESTRICT inv_table, size_t table_num, + DequantMatrices::QuantTable kind, size_t* pos) { + std::vector weights(3 * kMaxQuantTableSize); + + constexpr size_t N = kBlockDim; + size_t wrows = 8 * DequantMatrices::required_size_x[kind], + wcols = 8 * DequantMatrices::required_size_y[kind]; + size_t num = wrows * wcols; + + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + // Library and copy quant encoding should get replaced by the actual + // parameters by the caller. + JXL_ASSERT(false); + break; + } + case QuantEncoding::kQuantModeID: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsIdentity(encoding.idweights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT2: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT4: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x4[3 * 4 * 4]; + // Always use 4x4 GetQuantWeights for DCT4 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 4, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x4)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; + } + } + weights[c * num + 1] /= encoding.dct4multipliers[c][0]; + weights[c * num + N] /= encoding.dct4multipliers[c][0]; + weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x8[3 * 4 * 8]; + // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x8[c * 32 + (y / 2) * 8 + x]; + } + } + weights[c * num + N] /= encoding.dct4x8multipliers[c]; + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(GetQuantWeights( + wrows, wcols, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights.data())); + break; + } + case QuantEncoding::kQuantModeRAW: { + if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { + return JXL_FAILURE("Invalid table encoding"); + } + for (size_t i = 0; i < 3 * num; i++) { + weights[i] = + 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); + } + break; + } + case QuantEncoding::kQuantModeAFV: { + constexpr float kFreqs[] = { + 0xBAD, + 0xBAD, + 0.8517778890324296, + 5.37778436506804, + 0xBAD, + 0xBAD, + 4.734747904497923, + 5.449245381693219, + 1.6598270267479331, + 4, + 7.275749096817861, + 10.423227632456525, + 2.662932286148962, + 7.630657783650829, + 8.962388608184032, + 12.97166202570235, + }; + + float weights4x8[3 * 4 * 8]; + JXL_RETURN_IF_ERROR(( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8))); + float weights4x4[3 * 4 * 4]; + JXL_RETURN_IF_ERROR((GetQuantWeights( + 4, 4, encoding.dct_params_afv_4x4.distance_bands, + encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); + + constexpr float lo = 0.8517778890324296; + constexpr float hi = 12.97166202570235 - lo + 1e-6; + for (size_t c = 0; c < 3; c++) { + float bands[4]; + bands[0] = encoding.afv_weights[c][5]; + if (bands[0] < 0) return JXL_FAILURE("Invalid AFV bands"); + for (size_t i = 1; i < 4; i++) { + bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); + if (bands[i] < 0) return JXL_FAILURE("Invalid AFV bands"); + } + size_t start = c * 64; + auto set_weight = [&start, &weights](size_t x, size_t y, float val) { + weights[start + y * 8 + x] = val; + }; + weights[start] = 1; // Not used, but causes MSAN error otherwise. + // Weights for (0, 1) and (1, 0). + set_weight(0, 1, encoding.afv_weights[c][0]); + set_weight(1, 0, encoding.afv_weights[c][1]); + // AFV special weights for 3-pixel corner. + set_weight(0, 2, encoding.afv_weights[c][2]); + set_weight(2, 0, encoding.afv_weights[c][3]); + set_weight(2, 2, encoding.afv_weights[c][4]); + + // All other AFV weights. + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + if (x < 2 && y < 2) continue; + float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); + set_weight(2 * x, 2 * y, val); + } + } + + // Put 4x8 weights in odd rows, except (1, 0). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y + 1) * kBlockDim + x] = + weights4x8[c * 32 + y * 8 + x]; + } + } + // Put 4x4 weights in even rows / odd columns, except (0, 1). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim / 2; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = + weights4x4[c * 16 + y * 4 + x]; + } + } + } + break; + } + } + size_t prev_pos = *pos; + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < num; i++) { + float val = 1.0f / weights[c * num + i]; + if (val > std::numeric_limits::max() || val < 0) { + return JXL_FAILURE("Invalid quantization table"); + } + table[*pos] = val; + inv_table[*pos] = 1.0f / val; + (*pos)++; + } + } + // Ensure that the lowest frequencies have a 0 inverse table. + // This does not affect en/decoding, but allows AC strategy selection to be + // slightly simpler. + size_t xs = DequantMatrices::required_size_x[kind]; + size_t ys = DequantMatrices::required_size_y[kind]; + CoefficientLayout(&ys, &xs); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ys; y++) { + for (size_t x = 0; x < xs; x++) { + inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + + x] = 0; + } + } + } + return true; +} + +} // namespace + +// These definitions are needed before C++17. +constexpr size_t DequantMatrices::required_size_[]; +constexpr size_t DequantMatrices::required_size_x[]; +constexpr size_t DequantMatrices::required_size_y[]; +constexpr DequantMatrices::QuantTable DequantMatrices::kQuantTable[]; + +Status DequantMatrices::Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder) { + size_t all_default = br->ReadBits(1); + size_t num_tables = all_default ? 0 : static_cast(kNum); + encodings_.clear(); + encodings_.resize(kNum, QuantEncoding::Library(0)); + for (size_t i = 0; i < num_tables; i++) { + JXL_RETURN_IF_ERROR( + jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], + required_size_y[i % kNum], i, modular_frame_decoder)); + } + return DequantMatrices::Compute(); +} + +Status DequantMatrices::DecodeDC(BitReader* br) { + bool all_default = br->ReadBits(1); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c])); + dc_quant_[c] *= 1.0f / 128.0f; + // Negative values and nearly zero are invalid values. + if (dc_quant_[c] < kAlmostZero) { + return JXL_FAILURE("Invalid dc_quant: coefficient is too small."); + } + inv_dc_quant_[c] = 1.0f / dc_quant_[c]; + } + } + return true; +} + +constexpr float V(float v) { return static_cast(v); } + +namespace { +struct DequantMatricesLibraryDef { + // DCT8 + static constexpr const QuantEncodingInternal DCT() { + return QuantEncodingInternal::DCT(DctQuantWeightParams({{{ + V(3150.0), + V(0.0), + V(-0.4), + V(-0.4), + V(-0.4), + V(-2.0), + }, + { + V(560.0), + V(0.0), + V(-0.3), + V(-0.3), + V(-0.3), + V(-0.3), + }, + { + V(512.0), + V(-2.0), + V(-1.0), + V(0.0), + V(-1.0), + V(-2.0), + }}}, + 6)); + } + + // Identity + static constexpr const QuantEncodingInternal IDENTITY() { + return QuantEncodingInternal::Identity({{{ + V(280.0), + V(3160.0), + V(3160.0), + }, + { + V(60.0), + V(864.0), + V(864.0), + }, + { + V(18.0), + V(200.0), + V(200.0), + }}}); + } + + // DCT2 + static constexpr const QuantEncodingInternal DCT2X2() { + return QuantEncodingInternal::DCT2({{{ + V(3840.0), + V(2560.0), + V(1280.0), + V(640.0), + V(480.0), + V(300.0), + }, + { + V(960.0), + V(640.0), + V(320.0), + V(180.0), + V(140.0), + V(120.0), + }, + { + V(640.0), + V(320.0), + V(128.0), + V(64.0), + V(32.0), + V(16.0), + }}}); + } + + // DCT4 (quant_kind 3) + static constexpr const QuantEncodingInternal DCT4X4() { + return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{ + V(2200.0), + V(0.0), + V(0.0), + V(0.0), + }, + { + V(392.0), + V(0.0), + V(0.0), + V(0.0), + }, + { + V(112.0), + V(-0.25), + V(-0.25), + V(-0.5), + }}}, + 4), + /* kMul */ + {{{ + V(1.0), + V(1.0), + }, + { + V(1.0), + V(1.0), + }, + { + V(1.0), + V(1.0), + }}}); + } + + // DCT16 + static constexpr const QuantEncodingInternal DCT16X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(8996.8725711814115328), + V(-1.3000777393353804), + V(-0.49424529824571225), + V(-0.439093774457103443), + V(-0.6350101832695744), + V(-0.90177264050827612), + V(-1.6162099239887414), + }, + { + V(3191.48366296844234752), + V(-0.67424582104194355), + V(-0.80745813428471001), + V(-0.44925837484843441), + V(-0.35865440981033403), + V(-0.31322389111877305), + V(-0.37615025315725483), + }, + { + V(1157.50408145487200256), + V(-2.0531423165804414), + V(-1.4), + V(-0.50687130033378396), + V(-0.42708730624733904), + V(-1.4856834539296244), + V(-4.9209142884401604), + }}}, + 7)); + } + + // DCT32 + static constexpr const QuantEncodingInternal DCT32X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(15718.40830982518931456), + V(-1.025), + V(-0.98), + V(-0.9012), + V(-0.4), + V(-0.48819395464), + V(-0.421064), + V(-0.27), + }, + { + V(7305.7636810695983104), + V(-0.8041958212306401), + V(-0.7633036457487539), + V(-0.55660379990111464), + V(-0.49785304658857626), + V(-0.43699592683512467), + V(-0.40180866526242109), + V(-0.27321683125358037), + }, + { + V(3803.53173721215041536), + V(-3.060733579805728), + V(-2.0413270132490346), + V(-2.0235650159727417), + V(-0.5495389509954993), + V(-0.4), + V(-0.4), + V(-0.3), + }}}, + 8)); + } + + // DCT16X8 + static constexpr const QuantEncodingInternal DCT8X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(7240.7734393502), + V(-0.7), + V(-0.7), + V(-0.2), + V(-0.2), + V(-0.2), + V(-0.5), + }, + { + V(1448.15468787004), + V(-0.5), + V(-0.5), + V(-0.5), + V(-0.2), + V(-0.2), + V(-0.2), + }, + { + V(506.854140754517), + V(-1.4), + V(-0.2), + V(-0.5), + V(-0.5), + V(-1.5), + V(-3.6), + }}}, + 7)); + } + + // DCT32X8 + static constexpr const QuantEncodingInternal DCT8X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(16283.2494710648897), + V(-1.7812845336559429), + V(-1.6309059012653515), + V(-1.0382179034313539), + V(-0.85), + V(-0.7), + V(-0.9), + V(-1.2360638576849587), + }, + { + V(5089.15750884921511936), + V(-0.320049391452786891), + V(-0.35362849922161446), + V(-0.30340000000000003), + V(-0.61), + V(-0.5), + V(-0.5), + V(-0.6), + }, + { + V(3397.77603275308720128), + V(-0.321327362693153371), + V(-0.34507619223117997), + V(-0.70340000000000003), + V(-0.9), + V(-1.0), + V(-1.0), + V(-1.1754605576265209), + }}}, + 8)); + } + + // DCT32X16 + static constexpr const QuantEncodingInternal DCT16X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(13844.97076442300573), + V(-0.97113799999999995), + V(-0.658), + V(-0.42026), + V(-0.22712), + V(-0.2206), + V(-0.226), + V(-0.6), + }, + { + V(4798.964084220744293), + V(-0.61125308982767057), + V(-0.83770786552491361), + V(-0.79014862079498627), + V(-0.2692727459704829), + V(-0.38272769465388551), + V(-0.22924222653091453), + V(-0.20719098826199578), + }, + { + V(1807.236946760964614), + V(-1.2), + V(-1.2), + V(-0.7), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + + // DCT4X8 and 8x4 + static constexpr const QuantEncodingInternal DCT4X8() { + return QuantEncodingInternal::DCT4X8( + DctQuantWeightParams({{ + { + V(2198.050556016380522), + V(-0.96269623020744692), + V(-0.76194253026666783), + V(-0.6551140670773547), + }, + { + V(764.3655248643528689), + V(-0.92630200888366945), + V(-0.9675229603596517), + V(-0.27845290869168118), + }, + { + V(527.107573587542228), + V(-1.4594385811273854), + V(-1.450082094097871593), + V(-1.5843722511996204), + }, + }}, + 4), + /* kMuls */ + {{ + V(1.0), + V(1.0), + V(1.0), + }}); + } + // AFV + static const QuantEncodingInternal AFV0() { + return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params, + {{{ + // 4x4/4x8 DC tendency. + V(3072.0), + V(3072.0), + // AFV corner. + V(256.0), + V(256.0), + V(256.0), + // AFV high freqs. + V(414.0), + V(0.0), + V(0.0), + V(0.0), + }, + { + // 4x4/4x8 DC tendency. + V(1024.0), + V(1024.0), + // AFV corner. + V(50), + V(50), + V(50), + // AFV high freqs. + V(58.0), + V(0.0), + V(0.0), + V(0.0), + }, + { + // 4x4/4x8 DC tendency. + V(384.0), + V(384.0), + // AFV corner. + V(12.0), + V(12.0), + V(12.0), + // AFV high freqs. + V(22.0), + V(-0.25), + V(-0.25), + V(-0.25), + }}}); + } + + // DCT64 + static const QuantEncodingInternal DCT64X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(0.9 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(0.9 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(0.9 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + + // DCT64X32 + static const QuantEncodingInternal DCT32X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(0.65 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(0.65 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(0.65 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + // DCT128X128 + static const QuantEncodingInternal DCT128X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(1.8 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(1.8 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(1.8 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + + // DCT128X64 + static const QuantEncodingInternal DCT64X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(1.3 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(1.3 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(1.3 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + // DCT256X256 + static const QuantEncodingInternal DCT256X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(3.6 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(3.6 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(3.6 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } + + // DCT256X128 + static const QuantEncodingInternal DCT128X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{ + V(2.6 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }, + { + V(2.6 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }, + { + V(2.6 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}, + 8)); + } +}; +} // namespace + +const DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() { + static_assert(kNum == 17, + "Update this function when adding new quantization kinds."); + static_assert(kNumPredefinedTables == 1, + "Update this function when adding new quantization matrices to " + "the library."); + + // The library and the indices need to be kept in sync manually. + static_assert(0 == DCT, "Update the DequantLibrary array below."); + static_assert(1 == IDENTITY, "Update the DequantLibrary array below."); + static_assert(2 == DCT2X2, "Update the DequantLibrary array below."); + static_assert(3 == DCT4X4, "Update the DequantLibrary array below."); + static_assert(4 == DCT16X16, "Update the DequantLibrary array below."); + static_assert(5 == DCT32X32, "Update the DequantLibrary array below."); + static_assert(6 == DCT8X16, "Update the DequantLibrary array below."); + static_assert(7 == DCT8X32, "Update the DequantLibrary array below."); + static_assert(8 == DCT16X32, "Update the DequantLibrary array below."); + static_assert(9 == DCT4X8, "Update the DequantLibrary array below."); + static_assert(10 == AFV0, "Update the DequantLibrary array below."); + static_assert(11 == DCT64X64, "Update the DequantLibrary array below."); + static_assert(12 == DCT32X64, "Update the DequantLibrary array below."); + static_assert(13 == DCT128X128, "Update the DequantLibrary array below."); + static_assert(14 == DCT64X128, "Update the DequantLibrary array below."); + static_assert(15 == DCT256X256, "Update the DequantLibrary array below."); + static_assert(16 == DCT128X256, "Update the DequantLibrary array below."); + return DequantMatrices::DequantLibraryInternal{ + DequantMatricesLibraryDef::DCT(), + DequantMatricesLibraryDef::IDENTITY(), + DequantMatricesLibraryDef::DCT2X2(), + DequantMatricesLibraryDef::DCT4X4(), + DequantMatricesLibraryDef::DCT16X16(), + DequantMatricesLibraryDef::DCT32X32(), + DequantMatricesLibraryDef::DCT8X16(), + DequantMatricesLibraryDef::DCT8X32(), + DequantMatricesLibraryDef::DCT16X32(), + DequantMatricesLibraryDef::DCT4X8(), + DequantMatricesLibraryDef::AFV0(), + DequantMatricesLibraryDef::DCT64X64(), + DequantMatricesLibraryDef::DCT32X64(), + // Same default for large transforms (128+) as for 64x* transforms. + DequantMatricesLibraryDef::DCT128X128(), + DequantMatricesLibraryDef::DCT64X128(), + DequantMatricesLibraryDef::DCT256X256(), + DequantMatricesLibraryDef::DCT128X256(), + }; +} + +const QuantEncoding* DequantMatrices::Library() { + static const DequantMatrices::DequantLibraryInternal kDequantLibrary = + DequantMatrices::LibraryInit(); + // Downcast the result to a const QuantEncoding* from QuantEncodingInternal* + // since the subclass (QuantEncoding) doesn't add any new members and users + // will need to upcast to QuantEncodingInternal to access the members of that + // class. This allows to have kDequantLibrary as a constexpr value while still + // allowing to create QuantEncoding::RAW() instances that use std::vector in + // C++11. + return reinterpret_cast(kDequantLibrary.data()); +} + +Status DequantMatrices::Compute() { + size_t pos = 0; + + struct DefaultMatrices { + DefaultMatrices() { + const QuantEncoding* library = Library(); + size_t pos = 0; + for (size_t i = 0; i < kNum; i++) { + JXL_CHECK(ComputeQuantTable(library[i], table, inv_table, i, + QuantTable(i), &pos)); + } + JXL_CHECK(pos == kTotalTableSize); + } + HWY_ALIGN_MAX float table[kTotalTableSize]; + HWY_ALIGN_MAX float inv_table[kTotalTableSize]; + }; + + static const DefaultMatrices& default_matrices = + *hwy::MakeUniqueAligned().release(); + + JXL_ASSERT(encodings_.size() == kNum); + + bool has_nondefault_matrix = false; + for (const auto& enc : encodings_) { + if (enc.mode != QuantEncoding::kQuantModeLibrary) { + has_nondefault_matrix = true; + } + } + if (has_nondefault_matrix) { + table_storage_ = hwy::AllocateAligned(2 * kTotalTableSize); + table_ = table_storage_.get(); + inv_table_ = table_storage_.get() + kTotalTableSize; + for (size_t table = 0; table < kNum; table++) { + size_t prev_pos = pos; + if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { + size_t num = required_size_[table] * kDCTBlockSize; + memcpy(table_storage_.get() + prev_pos, + default_matrices.table + prev_pos, num * sizeof(float) * 3); + memcpy(table_storage_.get() + kTotalTableSize + prev_pos, + default_matrices.inv_table + prev_pos, num * sizeof(float) * 3); + pos += num * 3; + } else { + JXL_RETURN_IF_ERROR( + ComputeQuantTable(encodings_[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, + QuantTable(table), &pos)); + } + } + JXL_ASSERT(pos == kTotalTableSize); + } else { + table_ = default_matrices.table; + inv_table_ = default_matrices.inv_table; + } + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.h b/third_party/jpeg-xl/lib/jxl/quant_weights.h new file mode 100644 index 000000000000..1f621b0ce867 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.h @@ -0,0 +1,478 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_QUANT_WEIGHTS_H_ +#define LIB_JXL_QUANT_WEIGHTS_H_ + +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template +constexpr T ArraySum(T (&a)[N], size_t i = N - 1) { + static_assert(N > 0, "Trying to compute the sum of an empty array"); + return i == 0 ? a[0] : a[i] + ArraySum(a, i - 1); +} + +static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; +static constexpr size_t kNumPredefinedTables = 1; +static constexpr size_t kCeilLog2NumPredefinedTables = 0; +static constexpr size_t kLog2NumQuantModes = 3; + +struct DctQuantWeightParams { + static constexpr size_t kLog2MaxDistanceBands = 4; + static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); + typedef std::array, 3> + DistanceBandsArray; + + size_t num_distance_bands = 0; + DistanceBandsArray distance_bands = {}; + + constexpr DctQuantWeightParams() : num_distance_bands(0) {} + + constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, + size_t num_dist_bands) + : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} + + template + explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { + num_distance_bands = num_dist_bands; + for (size_t c = 0; c < 3; c++) { + memcpy(distance_bands[c].data(), dist_bands[c], + sizeof(float) * num_dist_bands); + } + } +}; + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct QuantEncodingInternal { + enum Mode { + kQuantModeLibrary, + kQuantModeID, + kQuantModeDCT2, + kQuantModeDCT4, + kQuantModeDCT4X8, + kQuantModeAFV, + kQuantModeDCT, + kQuantModeRAW, + }; + + template + struct Tag {}; + + typedef std::array, 3> IdWeights; + typedef std::array, 3> DCT2Weights; + typedef std::array, 3> DCT4Multipliers; + typedef std::array, 3> AFVWeights; + typedef std::array DCT4x8Multipliers; + + static constexpr QuantEncodingInternal Library(uint8_t predefined) { + return ((predefined < kNumPredefinedTables) || + JXL_ABORT("Assert predefined < kNumPredefinedTables")), + QuantEncodingInternal(Tag(), predefined); + } + constexpr QuantEncodingInternal(Tag /* tag */, + uint8_t predefined) + : mode(kQuantModeLibrary), predefined(predefined) {} + + // Identity + // xybweights is an array of {xweights, yweights, bweights}. + static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { + return QuantEncodingInternal(Tag(), xybweights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const IdWeights& xybweights) + : mode(kQuantModeID), idweights(xybweights) {} + + // DCT2 + static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { + return QuantEncodingInternal(Tag(), xybweights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DCT2Weights& xybweights) + : mode(kQuantModeDCT2), dct2weights(xybweights) {} + + // DCT4 + static constexpr QuantEncodingInternal DCT4( + const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { + return QuantEncodingInternal(Tag(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) + : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} + + // DCT4x8 + static constexpr QuantEncodingInternal DCT4X8( + const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { + return QuantEncodingInternal(Tag(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) + : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} + + // DCT + static constexpr QuantEncodingInternal DCT( + const DctQuantWeightParams& params) { + return QuantEncodingInternal(Tag(), params); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params) + : mode(kQuantModeDCT), dct_params(params) {} + + // AFV + static constexpr QuantEncodingInternal AFV( + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, const AFVWeights& weights) { + return QuantEncodingInternal(Tag(), params4x8, params4x4, + weights); + } + constexpr QuantEncodingInternal(Tag /* tag */, + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) + : mode(kQuantModeAFV), + dct_params(params4x8), + afv_weights(weights), + dct_params_afv_4x4(params4x4) {} + + // This constructor is not constexpr so it can't be used in any of the + // constexpr cases above. + explicit QuantEncodingInternal(Mode mode) : mode(mode) {} + + Mode mode; + + // Weights for DCT4+ tables. + DctQuantWeightParams dct_params; + + union { + // Weights for identity. + IdWeights idweights; + + // Weights for DCT2. + DCT2Weights dct2weights; + + // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. + DCT4Multipliers dct4multipliers; + + // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, + // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + + // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated + // as in GetQuantWeights for DC and are used for other coefficients. + AFVWeights afv_weights = {}; + + // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. + DCT4x8Multipliers dct4x8multipliers; + + // Only used in kQuantModeRAW mode. + struct { + // explicit quantization table (like in JPEG) + std::vector* qtable = nullptr; + float qtable_den = 1.f / (8 * 255); + } qraw; + }; + + // Weights for 4x4 sub-block in AFV. + DctQuantWeightParams dct_params_afv_4x4; + + union { + // Which predefined table to use. Only used if mode is kQuantModeLibrary. + uint8_t predefined = 0; + + // Which other quant table to copy; must copy from a table that comes before + // the current one. Only used if mode is kQuantModeCopy. + uint8_t source; + }; +}; + +class QuantEncoding final : public QuantEncodingInternal { + public: + QuantEncoding(const QuantEncoding& other) + : QuantEncodingInternal( + static_cast(other)) { + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector(*other.qraw.qtable); + } + } + QuantEncoding(QuantEncoding&& other) noexcept + : QuantEncodingInternal( + static_cast(other)) { + // Steal the qtable from the other object if any. + if (mode == kQuantModeRAW) { + other.qraw.qtable = nullptr; + } + } + QuantEncoding& operator=(const QuantEncoding& other) { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + *static_cast(this) = + QuantEncodingInternal(static_cast(other)); + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector(*other.qraw.qtable); + } + return *this; + } + + ~QuantEncoding() { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + } + + // Wrappers of the QuantEncodingInternal:: static functions that return a + // QuantEncoding instead. This is using the explicit and private cast from + // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. + // In general, you should use this wrappers. The only reason to directly + // create a QuantEncodingInternal instance is if you need a constexpr version + // of this class. Note that RAW() is not supported in that case since it uses + // a std::vector. + static QuantEncoding Library(uint8_t predefined) { + return QuantEncoding(QuantEncodingInternal::Library(predefined)); + } + static QuantEncoding Identity(const IdWeights& xybweights) { + return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); + } + static QuantEncoding DCT2(const DCT2Weights& xybweights) { + return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); + } + static QuantEncoding DCT4(const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); + } + static QuantEncoding DCT4X8(const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); + } + static QuantEncoding DCT(const DctQuantWeightParams& params) { + return QuantEncoding(QuantEncodingInternal::DCT(params)); + } + static QuantEncoding AFV(const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) { + return QuantEncoding( + QuantEncodingInternal::AFV(params4x8, params4x4, weights)); + } + + // RAW, note that this one is not a constexpr one. + static QuantEncoding RAW(const std::vector& qtable, int shift = 0) { + QuantEncoding encoding(kQuantModeRAW); + encoding.qraw.qtable = new std::vector(); + *encoding.qraw.qtable = qtable; + encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); + return encoding; + } + + private: + explicit QuantEncoding(const QuantEncodingInternal& other) + : QuantEncodingInternal(other) {} + + explicit QuantEncoding(QuantEncodingInternal::Mode mode) + : QuantEncodingInternal(mode) {} +}; + +// A constexpr QuantEncodingInternal instance is often downcasted to the +// QuantEncoding subclass even if the instance wasn't an instance of the +// subclass. This is safe because user will upcast to QuantEncodingInternal to +// access any of its members. +static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), + "Don't add any members to QuantEncoding"); + +// Let's try to keep these 2**N for possible future simplicity. +const float kInvDCQuant[3] = { + 4096.0f, + 512.0f, + 256.0f, +}; + +const float kDCQuant[3] = { + 1.0f / kInvDCQuant[0], + 1.0f / kInvDCQuant[1], + 1.0f / kInvDCQuant[2], +}; + +class ModularFrameEncoder; +class ModularFrameDecoder; + +class DequantMatrices { + public: + enum QuantTable : size_t { + DCT = 0, + IDENTITY, + DCT2X2, + DCT4X4, + DCT16X16, + DCT32X32, + // DCT16X8 + DCT8X16, + // DCT32X8 + DCT8X32, + // DCT32X16 + DCT16X32, + DCT4X8, + // DCT8X4 + AFV0, + // AFV1 + // AFV2 + // AFV3 + DCT64X64, + // DCT64X32, + DCT32X64, + DCT128X128, + // DCT128X64, + DCT64X128, + DCT256X256, + // DCT256X128, + DCT128X256, + kNum + }; + + static constexpr QuantTable kQuantTable[] = { + QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, + QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, + QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, + QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, + QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, + QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, + QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, + QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, + QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, + }; + static_assert(AcStrategy::kNumValidStrategies == + sizeof(kQuantTable) / sizeof *kQuantTable, + "Update this array when adding or removing AC strategies."); + + DequantMatrices() { + encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); + size_t pos = 0; + size_t offsets[kNum * 3]; + for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { + encodings_[i] = QuantEncoding::Library(0); + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + for (size_t c = 0; c < 3; c++) { + table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; + } + } + // Default quantization tables need to be valid. + JXL_CHECK(Compute()); + } + + static const QuantEncoding* Library(); + + typedef std::array + DequantLibraryInternal; + // Return the array of library kNumPredefinedTables QuantEncoding entries as + // a constexpr array. Use Library() to obtain a pointer to the copy in the + // .cc file. + static const DequantLibraryInternal LibraryInit(); + + JXL_INLINE size_t MatrixOffset(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + return table_offsets_[quant_kind * 3 + c]; + } + + // Returns aligned memory. + JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + return &table_[MatrixOffset(quant_kind, c)]; + } + + JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + return &inv_table_[MatrixOffset(quant_kind, c)]; + } + + // DC quants are used in modular mode for XYB multipliers. + JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } + JXL_INLINE const float* DCQuants() const { return dc_quant_; } + + JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } + + // For encoder. + void SetEncodings(const std::vector& encodings) { + encodings_ = encodings; + } + + // For encoder. + void SetDCQuant(const float dc[3]) { + for (size_t c = 0; c < 3; c++) { + dc_quant_[c] = 1.0f / dc[c]; + inv_dc_quant_[c] = dc[c]; + } + } + + Status Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder = nullptr); + Status DecodeDC(BitReader* br); + + const std::vector& encodings() const { return encodings_; } + + static constexpr size_t required_size_x[] = {1, 1, 1, 1, 2, 4, 1, 1, 2, + 1, 1, 8, 4, 16, 8, 32, 16}; + static_assert(kNum == sizeof(required_size_x) / sizeof(*required_size_x), + "Update this array when adding or removing quant tables."); + + static constexpr size_t required_size_y[] = {1, 1, 1, 1, 2, 4, 2, 4, 4, + 1, 1, 8, 8, 16, 16, 32, 32}; + static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), + "Update this array when adding or removing quant tables."); + + private: + Status Compute(); + + static constexpr size_t required_size_[] = { + 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; + static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), + "Update this array when adding or removing quant tables."); + static constexpr size_t kTotalTableSize = + ArraySum(required_size_) * kDCTBlockSize * 3; + + // kTotalTableSize entries followed by kTotalTableSize for inv_table + hwy::AlignedFreeUniquePtr table_storage_; + const float* table_; + const float* inv_table_; + float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; + float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; + size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; + std::vector encodings_; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANT_WEIGHTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc new file mode 100644 index 000000000000..4ab149770b07 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc @@ -0,0 +1,249 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lib/jxl/quant_weights.h" + +#include + +#include +#include +#include // HWY_ALIGN_MAX +#include +#include +#include + +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_transforms.h" + +namespace jxl { +namespace { + +template +void CheckSimilar(T a, T b) { + EXPECT_EQ(a, b); +} +// minimum exponent = -15. +template <> +void CheckSimilar(float a, float b) { + float m = std::max(std::abs(a), std::abs(b)); + // 10 bits of precision are used in the format. Relative error should be + // below 2^-10. + EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b; +} + +TEST(QuantWeightsTest, DC) { + DequantMatrices mat; + float dc_quant[3] = {1e+5, 1e+3, 1e+1}; + DequantMatricesSetCustomDC(&mat, dc_quant); + for (size_t c = 0; c < 3; c++) { + CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); + } +} + +void RoundtripMatrices(const std::vector& encodings) { + ASSERT_TRUE(encodings.size() == DequantMatrices::kNum); + DequantMatrices mat; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&mat, encodings, &encoder); + const std::vector& encodings_dec = mat.encodings(); + for (size_t i = 0; i < encodings.size(); i++) { + const QuantEncoding& e = encodings[i]; + const QuantEncoding& d = encodings_dec[i]; + // Check values roundtripped correctly. + EXPECT_EQ(e.mode, d.mode); + EXPECT_EQ(e.predefined, d.predefined); + EXPECT_EQ(e.source, d.source); + + EXPECT_EQ(static_cast(e.dct_params.num_distance_bands), + static_cast(d.dct_params.num_distance_bands)); + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params.distance_bands[c][j], + d.dct_params.distance_bands[c][j]); + } + } + + if (e.mode == QuantEncoding::kQuantModeRAW) { + EXPECT_FALSE(!e.qraw.qtable); + EXPECT_FALSE(!d.qraw.qtable); + EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size()); + for (size_t j = 0; j < e.qraw.qtable->size(); j++) { + EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]); + } + EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f); + } else { + // modes different than kQuantModeRAW use one of the other fields used + // here, which all happen to be arrays of floats. + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < 3; j++) { + CheckSimilar(e.idweights[c][j], d.idweights[c][j]); + } + for (size_t j = 0; j < 6; j++) { + CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]); + } + for (size_t j = 0; j < 2; j++) { + CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]); + } + CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]); + for (size_t j = 0; j < 9; j++) { + CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]); + } + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j], + d.dct_params_afv_4x4.distance_bands[c][j]); + } + } + } + } +} + +TEST(QuantWeightsTest, AllDefault) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + RoundtripMatrices(encodings); +} + +void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + encodings[kind] = DequantMatrices::Library()[kind]; + RoundtripMatrices(encodings); +} + +// Ensure we can reasonably represent default quant tables. +TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); } +TEST(QuantWeightsTest, IDENTITY) { + TestSingleQuantMatrix(DequantMatrices::IDENTITY); +} +TEST(QuantWeightsTest, DCT2X2) { + TestSingleQuantMatrix(DequantMatrices::DCT2X2); +} +TEST(QuantWeightsTest, DCT4X4) { + TestSingleQuantMatrix(DequantMatrices::DCT4X4); +} +TEST(QuantWeightsTest, DCT16X16) { + TestSingleQuantMatrix(DequantMatrices::DCT16X16); +} +TEST(QuantWeightsTest, DCT32X32) { + TestSingleQuantMatrix(DequantMatrices::DCT32X32); +} +TEST(QuantWeightsTest, DCT8X16) { + TestSingleQuantMatrix(DequantMatrices::DCT8X16); +} +TEST(QuantWeightsTest, DCT8X32) { + TestSingleQuantMatrix(DequantMatrices::DCT8X32); +} +TEST(QuantWeightsTest, DCT16X32) { + TestSingleQuantMatrix(DequantMatrices::DCT16X32); +} +TEST(QuantWeightsTest, DCT4X8) { + TestSingleQuantMatrix(DequantMatrices::DCT4X8); +} +TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); } +TEST(QuantWeightsTest, RAW) { + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + std::vector matrix(3 * 32 * 32); + std::mt19937 rng; + std::uniform_int_distribution dist(1, 255); + for (size_t i = 0; i < matrix.size(); i++) matrix[i] = dist(rng); + encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] = + QuantEncoding::RAW(matrix, 2); + RoundtripMatrices(encodings); +} + +class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); + +TEST_P(QuantWeightsTargetTest, DCTUniform) { + constexpr float kUniformQuant = 4; + float weights[3][2] = {{1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}}; + DctQuantWeightParams dct_params(weights); + std::vector encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatrices dequant_matrices; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder); + + const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, + 1.0f / kUniformQuant}; + DequantMatricesSetCustomDC(&dequant_matrices, dc_quant); + + HWY_ALIGN_MAX float scratch_space[16 * 16 * 2]; + + // DCT8 + { + HWY_ALIGN_MAX float pixels[64]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64]; + const AcStrategy::Type dct = AcStrategy::DCT; + TransformFromPixels(dct, pixels, 8, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64]; + for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<8>(slow_coeffs); + + for (size_t i = 0; i < 64; i++) { + // DCTSlow doesn't multiply/divide by 1/N, so we do it manually. + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + IDCTSlow<8>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 8, scratch_space); + for (size_t i = 0; i < 64; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // DCT16 + { + HWY_ALIGN_MAX float pixels[64 * 4]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64 * 4]; + const AcStrategy::Type dct = AcStrategy::DCT16X16; + TransformFromPixels(dct, pixels, 16, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64 * 4]; + for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<16>(slow_coeffs); + + for (size_t i = 0; i < 64 * 4; i++) { + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + + IDCTSlow<16>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 16, scratch_space); + for (size_t i = 0; i < 64 * 4; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // Check that all matrices have the same DC quantization, i.e. that they all + // have the same scaling. + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quantizer-inl.h b/third_party/jpeg-xl/lib/jxl/quantizer-inl.h new file mode 100644 index 000000000000..47b8e7218f4c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer-inl.h @@ -0,0 +1,82 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(LIB_JXL_QUANTIZER_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_QUANTIZER_INL_H_ +#undef LIB_JXL_QUANTIZER_INL_H_ +#else +#define LIB_JXL_QUANTIZER_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Vec; + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec> AdjustQuantBias( + DI di, const size_t c, const Vec quant_i, + const float* HWY_RESTRICT biases) { + const Rebind df; + +#if JXL_HIGH_PRECISION + const auto quant = ConvertTo(df, quant_i); + + // Compare |quant|, keep sign bit for negating result. + const auto kSign = BitCast(df, Set(di, INT32_MIN)); + const auto sign = And(quant, kSign); // TODO(janwas): = abs ^ orig + const auto abs_quant = AndNot(kSign, quant); + + // If |x| is 1, kZeroBias creates a different bias for each channel. + // We're implementing the following: + // if (quant == 0) return 0; + // if (quant == 1) return biases[c]; + // if (quant == -1) return -biases[c]; + // return quant - biases[3] / quant; + + // Integer comparison is not helpful because Clang incurs bypass penalties + // from unnecessarily mixing integer and float. + const auto is_01 = abs_quant < Set(df, 1.125f); + const auto not_0 = abs_quant > Zero(df); + + // Bitwise logic is faster than quant * biases[c]. + const auto one_bias = IfThenElseZero(not_0, Xor(Set(df, biases[c]), sign)); + + // About 2E-5 worse than ReciprocalNR or division. + const auto bias = + NegMulAdd(Set(df, biases[3]), ApproximateReciprocal(quant), quant); + + return IfThenElse(is_01, one_bias, bias); +#else + auto sign = IfThenElseZero(quant_i < Zero(di), Set(di, INT32_MIN)); + return BitCast(df, IfThenElse(Abs(quant_i) == Set(di, 1), + sign | BitCast(di, Set(df, biases[c])), + BitCast(di, ConvertTo(df, quant_i)))); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_QUANTIZER_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quantizer.cc b/third_party/jpeg-xl/lib/jxl/quantizer.cc new file mode 100644 index 000000000000..36d12f351b01 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer.cc @@ -0,0 +1,159 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/quantizer.h" + +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/robust_statistics.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +static const int kDefaultQuant = 64; + +constexpr int Quantizer::kQuantMax; + +Quantizer::Quantizer(const DequantMatrices* dequant) + : Quantizer(dequant, kDefaultQuant, kGlobalScaleDenom / kDefaultQuant) {} + +Quantizer::Quantizer(const DequantMatrices* dequant, int quant_dc, + int global_scale) + : global_scale_(global_scale), quant_dc_(quant_dc), dequant_(dequant) { + JXL_ASSERT(dequant_ != nullptr); + RecomputeFromGlobalScale(); + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + + memcpy(zero_bias_, kZeroBiasDefault, sizeof(kZeroBiasDefault)); +} + +void Quantizer::ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd) { + // Target value for the median value in the quant field. + const float kQuantFieldTarget = 3.80987740592518214386f; + // We reduce the median of the quant field by the median absolute deviation: + // higher resolution on highly varying quant fields. + float scale = kGlobalScaleDenom * (quant_median - quant_median_absd) / + kQuantFieldTarget; + // Ensure that new_global_scale is positive and no more than 1<<15. + if (scale < 1) scale = 1; + if (scale > (1 << 15)) scale = 1 << 15; + int new_global_scale = static_cast(scale); + // Ensure that quant_dc_ will always be at least + // kGlobalScaleDenom/kGlobalScaleNumerator. + const int scaled_quant_dc = + static_cast(quant_dc * kGlobalScaleNumerator); + if (new_global_scale > scaled_quant_dc) { + new_global_scale = scaled_quant_dc; + if (new_global_scale <= 0) new_global_scale = 1; + } + global_scale_ = new_global_scale; + // Code below uses inv_global_scale_. + RecomputeFromGlobalScale(); + + float fval = quant_dc * inv_global_scale_ + 0.5f; + fval = std::min(1 << 16, fval); + const int new_quant_dc = static_cast(fval); + quant_dc_ = new_quant_dc; + + // quant_dc_ was updated, recompute values. + RecomputeFromGlobalScale(); +} + +void Quantizer::SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field) { + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = rect.ConstRow(qf, y); + int32_t* JXL_RESTRICT row_qi = rect.Row(raw_quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + int val = ClampVal(row_qf[x] * inv_global_scale_ + 0.5f); + row_qi[x] = val; + } + } +} + +void Quantizer::SetQuantField(const float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field) { + JXL_CHECK(SameSize(*raw_quant_field, qf)); + std::vector data(qf.xsize() * qf.ysize()); + for (size_t y = 0; y < qf.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = qf.Row(y); + for (size_t x = 0; x < qf.xsize(); ++x) { + float quant = row_qf[x]; + data[qf.xsize() * y + x] = quant; + } + } + const float quant_median = Median(&data); + const float quant_median_absd = MedianAbsoluteDeviation(data, quant_median); + ComputeGlobalScaleAndQuant(quant_dc, quant_median, quant_median_absd); + SetQuantFieldRect(qf, Rect(qf), raw_quant_field); +} + +void Quantizer::SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field) { + ComputeGlobalScaleAndQuant(quant_dc, quant_ac, 0); + int val = ClampVal(quant_ac * inv_global_scale_ + 0.5f); + FillImage(val, raw_quant_field); +} + +Status QuantizerParams::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(11, 1), BitsOffset(11, 2049), BitsOffset(12, 4097), + BitsOffset(16, 8193), 1, &global_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), BitsOffset(5, 1), + BitsOffset(8, 1), BitsOffset(16, 1), 1, + &quant_dc)); + return true; +} + +Status Quantizer::Encode(BitWriter* writer, size_t layer, + AuxOut* aux_out) const { + QuantizerParams params; + params.global_scale = global_scale_; + params.quant_dc = quant_dc_; + return Bundle::Write(params, writer, layer, aux_out); +} + +Status Quantizer::Decode(BitReader* reader) { + QuantizerParams params; + JXL_RETURN_IF_ERROR(Bundle::Read(reader, ¶ms)); + global_scale_ = static_cast(params.global_scale); + quant_dc_ = static_cast(params.quant_dc); + RecomputeFromGlobalScale(); + return true; +} + +void Quantizer::DumpQuantizationMap(const ImageI& raw_quant_field) const { + printf("Global scale: %d (%.7f)\nDC quant: %d\n", global_scale_, + global_scale_ * 1.0 / kGlobalScaleDenom, quant_dc_); + printf("AC quantization Map:\n"); + for (size_t y = 0; y < raw_quant_field.ysize(); ++y) { + for (size_t x = 0; x < raw_quant_field.xsize(); ++x) { + printf(" %3d", raw_quant_field.Row(y)[x]); + } + printf("\n"); + } +} + +static constexpr JXL_INLINE int QuantizeValue(float value, float inv_step) { + return static_cast(value * inv_step + (value >= 0 ? .5f : -.5f)); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quantizer.h b/third_party/jpeg-xl/lib/jxl/quantizer.h new file mode 100644 index 000000000000..c01b56ccffdb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer.h @@ -0,0 +1,187 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_QUANTIZER_H_ +#define LIB_JXL_QUANTIZER_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/linalg.h" +#include "lib/jxl/quant_weights.h" + +// Quantizes DC and AC coefficients, with separate quantization tables according +// to the quant_kind (which is currently computed from the AC strategy and the +// block index inside that strategy). + +namespace jxl { + +static constexpr int kGlobalScaleDenom = 1 << 16; +static constexpr int kGlobalScaleNumerator = 4096; + +// zero-biases for quantizing channels X, Y, B +static constexpr float kZeroBiasDefault[3] = {0.5f, 0.5f, 0.5f}; + +// Returns adjusted version of a quantized integer, such that its value is +// closer to the expected value of the original. +// The residuals of AC coefficients that we quantize are not uniformly +// distributed. Numerical experiments show that they have a distribution with +// the "shape" of 1/(1+x^2) [up to some coefficients]. This means that the +// expected value of a coefficient that gets quantized to x will not be x +// itself, but (at least with reasonable approximation): +// - 0 if x is 0 +// - x * biases[c] if x is 1 or -1 +// - x - biases[3]/x otherwise +// This follows from computing the distribution of the quantization bias, which +// can be approximated fairly well by /x when |x| is at least two. +static constexpr float kBiasNumerator = 0.145f; + +static constexpr float kDefaultQuantBias[4] = { + 1.0f - 0.05465007330715401f, + 1.0f - 0.07005449891748593f, + 1.0f - 0.049935103337343655f, + 0.145f, +}; + +class Quantizer { + public: + explicit Quantizer(const DequantMatrices* dequant); + Quantizer(const DequantMatrices* dequant, int quant_dc, int global_scale); + + static constexpr int kQuantMax = 256; + + static JXL_INLINE int ClampVal(float val) { + return static_cast(std::max(1.0f, std::min(val, kQuantMax))); + } + + // Recomputes other derived fields after global_scale_ has changed. + void RecomputeFromGlobalScale() { + global_scale_float_ = global_scale_ * (1.0 / kGlobalScaleDenom); + inv_global_scale_ = 1.0 * kGlobalScaleDenom / global_scale_; + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + for (size_t c = 0; c < 3; c++) { + mul_dc_[c] = GetDcStep(c); + inv_mul_dc_[c] = GetInvDcStep(c); + } + } + + // Returns scaling factor such that Scale() * (RawDC() or RawQuantField()) + // pixels yields the same float values returned by GetQuantField. + JXL_INLINE float Scale() const { return global_scale_float_; } + + // Reciprocal of Scale(). + JXL_INLINE float InvGlobalScale() const { return inv_global_scale_; } + + void SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field); + + void SetQuantField(float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field); + + void SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field); + + // Returns the DC quantization base value, which is currently global (not + // adaptive). The actual scale factor used to dequantize pixels in channel c + // is: inv_quant_dc() * dequant_->DCQuant(c). + float inv_quant_dc() const { return inv_quant_dc_; } + + // Dequantize by multiplying with this times dequant_matrix. + float inv_quant_ac(int32_t quant) const { return inv_global_scale_ / quant; } + + Status Encode(BitWriter* writer, size_t layer, AuxOut* aux_out) const; + + Status Decode(BitReader* reader); + + void DumpQuantizationMap(const ImageI& raw_quant_field) const; + + JXL_INLINE const float* DequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->Matrix(quant_kind, c); + } + + JXL_INLINE const float* InvDequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->InvMatrix(quant_kind, c); + } + + JXL_INLINE size_t DequantMatrixOffset(size_t quant_kind, size_t c) const { + return dequant_->MatrixOffset(quant_kind, c); + } + + // Calculates DC quantization step. + JXL_INLINE float GetDcStep(size_t c) const { + return inv_quant_dc_ * dequant_->DCQuant(c); + } + JXL_INLINE float GetInvDcStep(size_t c) const { + return dequant_->InvDCQuant(c) * (global_scale_float_ * quant_dc_); + } + + JXL_INLINE const float* MulDC() const { return mul_dc_; } + JXL_INLINE const float* InvMulDC() const { return inv_mul_dc_; } + + JXL_INLINE void ClearDCMul() { + std::fill(mul_dc_, mul_dc_ + 4, 1); + std::fill(inv_mul_dc_, inv_mul_dc_ + 4, 1); + } + + void ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd); + + private: + float mul_dc_[4]; + float inv_mul_dc_[4]; + + // These are serialized: + int global_scale_; + int quant_dc_; + + // These are derived from global_scale_: + float inv_global_scale_; + float global_scale_float_; // reciprocal of inv_global_scale_ + float inv_quant_dc_; + + float zero_bias_[3]; + const DequantMatrices* dequant_; +}; + +struct QuantizerParams : public Fields { + QuantizerParams() { Bundle::Init(this); } + const char* Name() const override { return "QuantizerParams"; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + uint32_t global_scale; + uint32_t quant_dc; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANTIZER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quantizer_test.cc b/third_party/jpeg-xl/lib/jxl/quantizer_test.cc new file mode 100644 index 000000000000..34bd48d21a27 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer_test.cc @@ -0,0 +1,91 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/quantizer.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" + +namespace jxl { +namespace { + +void TestEquivalence(int qxsize, int qysize, const Quantizer& quantizer1, + const Quantizer& quantizer2) { + ASSERT_NEAR(quantizer1.inv_quant_dc(), quantizer2.inv_quant_dc(), 1e-7); +} + +TEST(QuantizerTest, QuantizerParams) { + for (uint32_t i = 1; i < 10000; ++i) { + QuantizerParams p; + p.global_scale = i; + size_t extension_bits = 0, total_bits = 0; + EXPECT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0, extension_bits); + EXPECT_GE(total_bits, 4); + } +} + +TEST(QuantizerTest, BitStreamRoundtripSameQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + BitWriter writer; + EXPECT_TRUE(quantizer1.Encode(&writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant, qxsize, qysize); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} + +TEST(QuantizerTest, BitStreamRoundtripRandomQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + std::mt19937_64 rng; + std::uniform_int_distribution<> uniform(1, 256); + float quant_dc = 0.17f; + ImageF qf(qxsize, qysize); + RandomFillImage(&qf, 1.0f); + quantizer1.SetQuantField(quant_dc, qf, &raw_quant_field); + BitWriter writer; + EXPECT_TRUE(quantizer1.Encode(&writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant, qxsize, qysize); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/rational_polynomial-inl.h b/third_party/jpeg-xl/lib/jxl/rational_polynomial-inl.h new file mode 100644 index 000000000000..c2842acd788c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/rational_polynomial-inl.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Fast SIMD evaluation of rational polynomials for approximating functions. + +#if defined(LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#undef LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#else +#define LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// Primary template: default to actual division. +template +struct FastDivision { + HWY_INLINE V operator()(const V n, const V d) const { return n / d; } +}; +// Partial specialization for float vectors. +template +struct FastDivision { + // One Newton-Raphson iteration. + static HWY_INLINE V ReciprocalNR(const V x) { + const auto rcp = ApproximateReciprocal(x); + const auto sum = rcp + rcp; + const auto x_rcp = x * rcp; + return NegMulAdd(x_rcp, rcp, sum); + } + + V operator()(const V n, const V d) const { +#if 1 // Faster on SKX + return n / d; +#else + return n * ReciprocalNR(d); +#endif + } +}; + +// Approximates smooth functions via rational polynomials (i.e. dividing two +// polynomials). Evaluates polynomials via Horner's scheme, which is faster than +// Clenshaw recurrence for Chebyshev polynomials. LoadDup128 allows us to +// specify constants (replicated 4x) independently of the lane count. +template +HWY_INLINE HWY_MAYBE_UNUSED V EvalRationalPolynomial(const D d, const V x, + const T (&p)[NP], + const T (&q)[NQ]) { + constexpr size_t kDegP = NP / 4 - 1; + constexpr size_t kDegQ = NQ / 4 - 1; + auto yp = LoadDup128(d, &p[kDegP * 4]); + auto yq = LoadDup128(d, &q[kDegQ * 4]); + // We use pointer arithmetic to refer to &p[(kDegP - n) * 4] to avoid a + // compiler warning that the index is out of bounds since we are already + // checking that it is not out of bounds with (kDegP >= n) and the access + // will be optimized away. Similarly with q and kDegQ. + HWY_FENCE; + if (kDegP >= 1) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 1) * 4))); + if (kDegQ >= 1) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 1) * 4))); + HWY_FENCE; + if (kDegP >= 2) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 2) * 4))); + if (kDegQ >= 2) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 2) * 4))); + HWY_FENCE; + if (kDegP >= 3) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 3) * 4))); + if (kDegQ >= 3) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 3) * 4))); + HWY_FENCE; + if (kDegP >= 4) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 4) * 4))); + if (kDegQ >= 4) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 4) * 4))); + HWY_FENCE; + if (kDegP >= 5) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 5) * 4))); + if (kDegQ >= 5) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 5) * 4))); + HWY_FENCE; + if (kDegP >= 6) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 6) * 4))); + if (kDegQ >= 6) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 6) * 4))); + HWY_FENCE; + if (kDegP >= 7) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 7) * 4))); + if (kDegQ >= 7) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 7) * 4))); + + return FastDivision()(yp, yq); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_RATIONAL_POLYNOMIAL_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc b/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc new file mode 100644 index 000000000000..a2eb6b009089 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc @@ -0,0 +1,248 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/rational_polynomial_test.cc" +#include +#include +#include + +#include "lib/jxl/base/descriptive_statistics.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/rational_polynomial-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using T = float; // required by EvalLog2 +using D = HWY_FULL(T); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; + +// Generic: only computes polynomial +struct EvalPoly { + template + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + const auto vx = Set(d, x); + const auto approx = EvalRationalPolynomial(d, vx, p, q); + return GetLane(approx); + } +}; + +// Range reduction for log2 +struct EvalLog2 { + template + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + auto vx = Set(d, x); + + const HWY_FULL(int32_t) di; + const auto x_bits = BitCast(di, vx); + // Cannot handle negative numbers / NaN. + JXL_DASSERT(AllTrue(Abs(x_bits) == x_bits)); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = x_bits - Set(di, 0x3f2aaaab); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(d, x_bits - ShiftLeft<23>(exp_shifted)); + const auto exp_val = ConvertTo(d, exp_shifted); + vx = mantissa - Set(d, 1.0f); + + const auto approx = EvalRationalPolynomial(d, vx, p, q) + exp_val; + return GetLane(approx); + } +}; + +// Functions to approximate: + +T LinearToSrgb8Direct(T val) { + if (val < 0.0) return 0.0; + if (val >= 255.0) return 255.0; + if (val <= 10.0 / 12.92) return val * 12.92; + return 255.0 * (std::pow(val / 255.0, 1.0 / 2.4) * 1.055 - 0.055); +} + +T SimpleGamma(T v) { + static const T kGamma = 0.387494322593; + static const T limit = 43.01745241042018; + T bright = v - limit; + if (bright >= 0) { + static const T mul = 0.0383723643799; + v -= bright * mul; + } + static const T limit2 = 94.68634353321337; + T bright2 = v - limit2; + if (bright2 >= 0) { + static const T mul = 0.22885405968; + v -= bright2 * mul; + } + static const T offset = 0.156775786057; + static const T scale = 8.898059160493739; + T retval = scale * (offset + pow(v, kGamma)); + return retval; +} + +// Runs CaratheodoryFejer and verifies the polynomial using a lot of samples to +// return the biggest error. +template +T RunApproximation(T x0, T x1, const T (&p)[NP], const T (&q)[NQ], + const Eval& eval, T func_to_approx(T)) { + Stats err; + + T lastPrint = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (T x = x0; x <= x1; x += (x1 - x0) / 10000.0) { + const T f = func_to_approx(x); + const T g = eval(x, p, q); + err.Notify(fabs(g - f)); + if (x == x0 || x - lastPrint > (x1 - x0) / 20.0) { + printf("x: %11.6f, f: %11.6f, g: %11.6f, e: %11.6f\n", x, f, g, + fabs(g - f)); + lastPrint = x; + } + } + printf("%s\n", err.ToString().c_str()); + + return err.Max(); +} + +void TestSimpleGamma() { + const T p[4 * (6 + 1)] = { + HWY_REP4(-5.0646949363741811E-05), HWY_REP4(6.7369380528439771E-05), + HWY_REP4(8.9376652530412794E-05), HWY_REP4(2.1153513301520462E-06), + HWY_REP4(-6.9130322970386449E-08), HWY_REP4(3.9424752749293728E-10), + HWY_REP4(1.2360288207619576E-13)}; + + const T q[4 * (6 + 1)] = { + HWY_REP4(-6.6389733798591366E-06), HWY_REP4(1.3299859726565908E-05), + HWY_REP4(3.8538748358398873E-06), HWY_REP4(-2.8707687262928236E-08), + HWY_REP4(-6.6897385800005434E-10), HWY_REP4(6.1428748869186003E-12), + HWY_REP4(-2.5475738169252870E-15)}; + + const T err = RunApproximation(0.77, 274.579999999999984, p, q, EvalPoly(), + SimpleGamma); + EXPECT_LT(err, 0.05); +} + +void TestLinearToSrgb8Direct() { + const T p[4 * (5 + 1)] = { + HWY_REP4(-9.5357499040105154E-05), HWY_REP4(4.6761186249798248E-04), + HWY_REP4(2.5708174333943594E-04), HWY_REP4(1.5250087770436082E-05), + HWY_REP4(1.1946768008931187E-07), HWY_REP4(5.9916446295972850E-11)}; + + const T q[4 * (4 + 1)] = { + HWY_REP4(1.8932479758079768E-05), HWY_REP4(2.7312342474687321E-05), + HWY_REP4(4.3901204783327006E-06), HWY_REP4(1.0417787306920273E-07), + HWY_REP4(3.0084206762140419E-10)}; + + const T err = + RunApproximation(0.77, 255, p, q, EvalPoly(), LinearToSrgb8Direct); + EXPECT_LT(err, 0.05); +} + +void TestExp() { + const T p[4 * (2 + 1)] = {HWY_REP4(9.6266879665530902E-01), + HWY_REP4(4.8961265681586763E-01), + HWY_REP4(8.2619259189548433E-02)}; + const T q[4 * (2 + 1)] = {HWY_REP4(9.6259895571622622E-01), + HWY_REP4(-4.7272457588933831E-01), + HWY_REP4(7.4802088567547664E-02)}; + const T err = + RunApproximation(-1, 1, p, q, EvalPoly(), [](T x) { return T(exp(x)); }); + EXPECT_LT(err, 1E-4); +} + +void TestNegExp() { + // 4,3 is the min required for monotonicity; max error in 0,10: 751 ppm + // no benefit for k>50. + const T p[4 * (4 + 1)] = { + HWY_REP4(5.9580258551150123E-02), HWY_REP4(-2.5073728806886408E-02), + HWY_REP4(4.1561830213689248E-03), HWY_REP4(-3.1815408488900372E-04), + HWY_REP4(9.3866690094906802E-06)}; + const T q[4 * (3 + 1)] = { + HWY_REP4(5.9579108238812878E-02), HWY_REP4(3.4542074345478582E-02), + HWY_REP4(8.7263562483501714E-03), HWY_REP4(1.4095109143061216E-03)}; + + const T err = + RunApproximation(0, 10, p, q, EvalPoly(), [](T x) { return T(exp(-x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 2E-5 : 3E-5); +} + +void TestSin() { + const T p[4 * (6 + 1)] = { + HWY_REP4(1.5518122109203780E-05), HWY_REP4(2.3388958643675966E+00), + HWY_REP4(-8.6705520940849157E-01), HWY_REP4(-1.9702294764873535E-01), + HWY_REP4(1.2193404314472320E-01), HWY_REP4(-1.7373966109788839E-02), + HWY_REP4(7.8829435883034796E-04)}; + const T q[4 * (5 + 1)] = { + HWY_REP4(2.3394371422557279E+00), HWY_REP4(-8.7028221081288615E-01), + HWY_REP4(2.0052872219658430E-01), HWY_REP4(-3.2460335995264836E-02), + HWY_REP4(3.1546157932479282E-03), HWY_REP4(-1.6692542019380155E-04)}; + + const T err = RunApproximation(0, Pi(1) * 2, p, q, EvalPoly(), + [](T x) { return T(sin(x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 5E-4 : 7E-4); +} + +void TestLog() { + HWY_ALIGN const T p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06), + HWY_REP4(1.4287160470083755E+00), + HWY_REP4(7.4245873327820566E-01)}; + HWY_ALIGN const T q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01), + HWY_REP4(1.0096718572241148E+00), + HWY_REP4(1.7409343003366853E-01)}; + const T err = RunApproximation(1E-6, 1000, p, q, EvalLog2(), std::log2); + printf("%E\n", err); +} + +HWY_NOINLINE void TestRationalPolynomial() { + TestSimpleGamma(); + TestLinearToSrgb8Direct(); + TestExp(); + TestNegExp(); + TestSin(); + TestLog(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class RationalPolynomialTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(RationalPolynomialTest); + +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSimpleGamma); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLinearToSrgb8Direct); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestNegExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSin); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLog); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/robust_statistics_test.cc b/third_party/jpeg-xl/lib/jxl/robust_statistics_test.cc new file mode 100644 index 000000000000..6ddb32a1e605 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/robust_statistics_test.cc @@ -0,0 +1,159 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/robust_statistics.h" + +#include + +#include // partial_sum +#include + +#include "gtest/gtest.h" +#include "lib/jxl/noise_distributions.h" + +namespace jxl { +namespace { + +TEST(RobustStatisticsTest, TestMode) { + // Enough to populate bins. We have to sort this many values. + constexpr size_t kReps = 15000; + constexpr size_t kBins = 101; + + std::mt19937 rng(65537); + + // Place Poisson mean at 1/10, 2/10 .. 9/10 of the bin range. + for (int frac = 1; frac < 10; ++frac) { + printf("===========================frac %d\n", frac); + + NoisePoisson noise(frac * kBins / 10); + std::vector values; + values.reserve(kReps); + + uint32_t bins[kBins] = {0}; + + std::uniform_real_distribution jitter(-1E-3f, 1E-3f); + for (size_t rep = 0; rep < kReps; ++rep) { + // Scale back to integer, add jitter to avoid too many repeated values. + const float poisson = noise(0.0f, &rng) * 1E3f + jitter(rng); + + values.push_back(poisson); + + const int idx_bin = static_cast(poisson); + if (idx_bin < static_cast(kBins)) { + bins[idx_bin] += 1; + } // else skip instead of clamping to avoid bias + } + + // // Print histogram + // for (const uint32_t b : bins) { + // printf("%u\n", b); + // } + + // (Smoothed) argmax and median for verification + float smoothed[kBins]; + smoothed[0] = bins[0]; + smoothed[kBins - 1] = bins[kBins - 1]; + for (size_t i = 1; i < kBins - 1; ++i) { + smoothed[i] = (2 * bins[i] + bins[i - 1] + bins[i + 1]) * 0.25f; + } + const float argmax = + std::max_element(smoothed, smoothed + kBins) - smoothed; + const float median = Median(&values); + + std::sort(values.begin(), values.end()); + const float hsm = HalfSampleMode()(values.data(), values.size()); + + uint32_t cdf[kBins]; + std::partial_sum(bins, bins + kBins, cdf); + const int hrm = HalfRangeMode()(cdf, kBins); + + const auto is_near = [](const float expected, const float actual) { + return std::abs(expected - actual) <= 1.0f + 1E-5f; + }; + EXPECT_TRUE(is_near(hsm, argmax) || is_near(hsm, median)); + EXPECT_TRUE(is_near(hrm, argmax) || is_near(hrm, median)); + + printf("hsm %.1f hrm %d argmax %.1f median %f\n", hsm, hrm, argmax, median); + const int center = static_cast(argmax); + printf("%d %d %d %d %d\n", bins[center - 2], bins[center - 1], bins[center], + bins[center + 1], bins[center + 2]); + } +} + +// Ensures Median3/5 return the same results as Median. +TEST(RobustStatisticsTest, TestMedian) { + std::vector v3(3), v5(5); + + std::uniform_real_distribution dist(-100.0f, 100.0f); + std::mt19937 rng(129); + +#ifdef NDEBUG + constexpr size_t kReps = 100000; +#else + constexpr size_t kReps = 100; +#endif + for (size_t i = 0; i < kReps; ++i) { + v3[0] = dist(rng); + v3[1] = dist(rng); + v3[2] = dist(rng); + for (size_t j = 0; j < 5; ++j) { + v5[j] = dist(rng); + } + + JXL_ASSERT(Median(&v3) == Median3(v3[0], v3[1], v3[2])); + JXL_ASSERT(Median(&v5) == Median5(v5[0], v5[1], v5[2], v5[3], v5[4])); + } +} + +template +void TestLine(const Noise& noise, float max_l1_limit, float mad_limit) { + std::vector points; + Line perfect(0.6f, 2.0f); + + // Random spacing of X (must be unique) + float x = -100.0f; + std::mt19937_64 rng(129); + std::uniform_real_distribution x_dist(1E-6f, 10.0f); + for (size_t ix = 0; ix < 500; ++ix) { + x += x_dist(rng); + const float y = noise(perfect(x), &rng); + points.emplace_back(x, y); + // printf("%f,%f\n", x, y); + } + + Line est(points); + float max_l1, mad; + EvaluateQuality(est, points, &max_l1, &mad); + printf("x %f slope=%.2f b=%.2f max_l1 %f mad %f\n", x, est.slope(), + est.intercept(), max_l1, mad); + + EXPECT_LE(max_l1, max_l1_limit); + EXPECT_LE(mad, mad_limit); +} + +TEST(RobustStatisticsTest, CleanLine) { + const NoiseNone noise; + TestLine(noise, 1E-6, 1E-7); +} +TEST(RobustStatisticsTest, Uniform) { + const NoiseUniform noise(-100.0f, 100.0f); + TestLine(noise, 107, 53); +} +TEST(RobustStatisticsTest, Gauss) { + const NoiseGaussian noise(10.0f); + TestLine(noise, 37, 7); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc b/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc new file mode 100644 index 000000000000..0f73d79e4e5d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc @@ -0,0 +1,625 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "jxl/decode.h" +#include "jxl/decode_cxx.h" +#include "jxl/encode.h" +#include "jxl/encode_cxx.h" +#include "lib/extras/codec.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace { + +// Converts a test image to a CodecInOut. +// icc_profile can be empty to automatically deduce profile from the pixel +// format, or filled in to force this ICC profile +jxl::CodecInOut ConvertTestImage(const std::vector& buf, + const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format, + const jxl::PaddedBytes& icc_profile) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + + bool is_gray = + pixel_format.num_channels == 1 || pixel_format.num_channels == 2; + bool has_alpha = + pixel_format.num_channels == 2 || pixel_format.num_channels == 4; + + io.metadata.m.color_encoding.SetColorSpace(is_gray ? jxl::ColorSpace::kGray + : jxl::ColorSpace::kRGB); + if (has_alpha) { + // Note: alpha > 16 not yet supported by the C++ codec + switch (pixel_format.data_type) { + case JXL_TYPE_UINT8: + io.metadata.m.SetAlphaBits(8); + break; + case JXL_TYPE_UINT16: + case JXL_TYPE_UINT32: + case JXL_TYPE_FLOAT: + case JXL_TYPE_FLOAT16: + io.metadata.m.SetAlphaBits(16); + break; + default: + EXPECT_TRUE(false) << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + } + size_t bitdepth = 0; + switch (pixel_format.data_type) { + case JXL_TYPE_FLOAT: + bitdepth = 32; + io.metadata.m.SetFloat32Samples(); + break; + case JXL_TYPE_FLOAT16: + bitdepth = 16; + io.metadata.m.SetFloat16Samples(); + break; + case JXL_TYPE_UINT8: + bitdepth = 8; + io.metadata.m.SetUintSamples(8); + break; + case JXL_TYPE_UINT16: + bitdepth = 16; + io.metadata.m.SetUintSamples(16); + break; + default: + EXPECT_TRUE(false) << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + jxl::ColorEncoding color_encoding; + if (!icc_profile.empty()) { + jxl::PaddedBytes icc_profile_copy(icc_profile); + EXPECT_TRUE(color_encoding.SetICC(std::move(icc_profile_copy))); + } else if (pixel_format.data_type == JXL_TYPE_FLOAT) { + color_encoding = jxl::ColorEncoding::LinearSRGB(is_gray); + } else { + color_encoding = jxl::ColorEncoding::SRGB(is_gray); + } + EXPECT_TRUE( + ConvertFromExternal(jxl::Span(buf.data(), buf.size()), + xsize, ysize, color_encoding, has_alpha, + /*alpha_is_premultiplied=*/false, + /*bits_per_sample=*/bitdepth, pixel_format.endianness, + /*flipped_y=*/false, /*pool=*/nullptr, &io.Main())); + return io; +} + +// Stores a float in big endian +void StoreBEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreBE32(u, p); +} + +// Stores a float in little endian +void StoreLEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreLE32(u, p); +} + +// Loads a float in big endian +float LoadBEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadBE32(p); + memcpy(&value, &u, 4); + return value; +} + +// Loads a float in little endian +float LoadLEFloat(const uint8_t* p) { + float value; + const uint32_t u = LoadLE32(p); + memcpy(&value, &u, 4); + return value; +} + +template +T ConvertTestPixel(const float val); + +template <> +float ConvertTestPixel(const float val) { + return val; +} + +template <> +uint16_t ConvertTestPixel(const float val) { + return (uint16_t)(val * UINT16_MAX); +} + +template <> +uint8_t ConvertTestPixel(const float val) { + return (uint8_t)(val * UINT8_MAX); +} + +// Returns a test image. +template +std::vector GetTestImage(const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format) { + std::vector pixels(xsize * ysize * pixel_format.num_channels); + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + for (size_t chan = 0; chan < pixel_format.num_channels; chan++) { + float val; + switch (chan % 4) { + case 0: + val = static_cast(y) / static_cast(ysize); + break; + case 1: + val = static_cast(x) / static_cast(xsize); + break; + case 2: + val = static_cast(x + y) / static_cast(xsize + ysize); + break; + case 3: + val = static_cast(x * y) / static_cast(xsize * ysize); + break; + } + pixels[(y * xsize + x) * pixel_format.num_channels + chan] = + ConvertTestPixel(val); + } + } + } + std::vector bytes(pixels.size() * sizeof(T)); + memcpy(bytes.data(), pixels.data(), sizeof(T) * pixels.size()); + return bytes; +} + +void EncodeWithEncoder(JxlEncoder* enc, std::vector* compressed) { + compressed->resize(64); + uint8_t* next_out = compressed->data(); + size_t avail_out = compressed->size() - (next_out - compressed->data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed->data(); + compressed->resize(compressed->size() * 2); + next_out = compressed->data() + offset; + avail_out = compressed->size() - offset; + } + } + compressed->resize(next_out - compressed->data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); +} + +// Generates some pixels using using some dimensions and pixel_format, +// compresses them, and verifies that the decoded version is similar to the +// original pixels. +template +void VerifyRoundtripCompression(const size_t xsize, const size_t ysize, + const JxlPixelFormat& input_pixel_format, + const JxlPixelFormat& output_pixel_format, + const bool lossless, const bool use_container) { + const std::vector original_bytes = + GetTestImage(xsize, ysize, input_pixel_format); + jxl::CodecInOut original_io = + ConvertTestImage(original_bytes, xsize, ysize, input_pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, use_container)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &input_pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = lossless; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + if (input_pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB( + &color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderOptions* opts = JxlEncoderOptionsCreate(enc, nullptr); + JxlEncoderOptionsSetLossless(opts, lossless); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(opts, &input_pixel_format, + (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderImageOutBufferSize( + dec, &output_pixel_format, &buffer_size)); + if (&input_pixel_format == &output_pixel_format) { + EXPECT_EQ(buffer_size, original_bytes.size()); + } + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, &output_pixel_format, + JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + jxl::PaddedBytes icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &output_pixel_format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &output_pixel_format, + decoded_bytes.data(), decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = ConvertTestImage( + decoded_bytes, xsize, ysize, output_pixel_format, icc_profile); + + jxl::ButteraugliParams ba; + float butteraugli_score = ButteraugliDistance(original_io, decoded_io, ba, + /*distmap=*/nullptr, nullptr); + if (lossless) { + EXPECT_LE(butteraugli_score, 0.0f); + } else { + EXPECT_LE(butteraugli_score, 2.0f); + } +} + +} // namespace + +TEST(RoundtripTest, FloatFrameRoundtripTest) { + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + // There's no support (yet) for lossless extra float channels, so we + // don't test it. + if (num_channels % 2 != 0 || !lossless) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression(63, 129, pixel_format, pixel_format, + (bool)lossless, + (bool)use_container); + } + } + } + } +} + +TEST(RoundtripTest, Uint16FrameRoundtripTest) { + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + JxlPixelFormat pixel_format = + JxlPixelFormat{num_channels, JXL_TYPE_UINT16, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression(63, 129, pixel_format, + pixel_format, (bool)lossless, + (bool)use_container); + } + } + } +} + +TEST(RoundtripTest, Uint8FrameRoundtripTest) { + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + JxlPixelFormat pixel_format = + JxlPixelFormat{num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression(63, 129, pixel_format, pixel_format, + (bool)lossless, + (bool)use_container); + } + } + } +} + +TEST(RoundtripTest, TestNonlinearSrgbAsXybEncoded) { + for (int use_container = 0; use_container < 2; use_container++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + JxlPixelFormat pixel_format_in = + JxlPixelFormat{num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + JxlPixelFormat pixel_format_out = + JxlPixelFormat{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression( + 63, 129, pixel_format_in, pixel_format_out, + /*lossless=*/false, (bool)use_container); + } + } +} + +TEST(RoundtripTest, ExtraBoxesTest) { + JxlPixelFormat pixel_format = + JxlPixelFormat{4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + const size_t xsize = 61; + const size_t ysize = 71; + + const std::vector original_bytes = + GetTestImage(xsize, ysize, pixel_format); + jxl::CodecInOut original_io = + ConvertTestImage(original_bytes, xsize, ysize, pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, true)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderOptions* opts = JxlEncoderOptionsCreate(enc, nullptr); + JxlEncoderOptionsSetLossless(opts, false); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(opts, &pixel_format, (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + std::vector extra_data(1023); + jxl::AppendBoxHeader(jxl::MakeBoxType("crud"), extra_data.size(), false, + &compressed); + compressed.insert(compressed.end(), extra_data.begin(), extra_data.end()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &pixel_format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, &pixel_format, + JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + jxl::PaddedBytes icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile( + dec, &pixel_format, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer(dec, &pixel_format, + decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = + ConvertTestImage(decoded_bytes, xsize, ysize, pixel_format, icc_profile); + + jxl::ButteraugliParams ba; + float butteraugli_score = ButteraugliDistance(original_io, decoded_io, ba, + /*distmap=*/nullptr, nullptr); + EXPECT_LE(butteraugli_score, 2.0f); +} + +TEST(RoundtripTest, TestICCProfile) { + // This ICC profile is not a valid ICC profile, however neither the encoder + // nor the decoder parse this profile, and the bytes should be passed on + // correctly through the roundtrip. + jxl::PaddedBytes icc; + for (size_t i = 0; i < 200; i++) { + icc.push_back(i ^ 55); + } + + JxlPixelFormat format = + JxlPixelFormat{3, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + + size_t xsize = 25; + size_t ysize = 37; + const std::vector original_bytes = + GetTestImage(xsize, ysize, format); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = JXL_FALSE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetICCProfile(enc, icc.data(), icc.size())); + JxlEncoderOptions* opts = JxlEncoderOptionsCreate(enc, nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(opts, &format, (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t dec_icc_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, &format, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_icc_size)); + EXPECT_EQ(icc.size(), dec_icc_size); + jxl::PaddedBytes dec_icc(dec_icc_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile(dec, &format, + JXL_COLOR_PROFILE_TARGET_ORIGINAL, + dec_icc.data(), dec_icc.size())); + + std::vector decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(icc, dec_icc); + + JxlDecoderDestroy(dec); +} + +TEST(RoundtripTest, TestJPEGReconstruction) { + const std::string jpeg_path = + "imagecompression.info/flower_foveon.png.im_q85_420.jpg"; + const jxl::PaddedBytes orig = jxl::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE( + SetFromBytes(jxl::Span(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderOptions* options = JxlEncoderOptionsCreate(enc.get(), NULL); + + JxlBasicInfo basic_info; + basic_info.exponent_bits_per_sample = 0; + basic_info.bits_per_sample = 8; + basic_info.alpha_bits = 0; + basic_info.alpha_exponent_bits = 0; + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(options, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector compressed; + EncodeWithEncoder(enc.get(), &compressed); + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus dec_process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (dec_process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + dec_process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, dec_process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, orig.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), orig.data(), used)); +} diff --git a/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc b/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc new file mode 100644 index 000000000000..ac0ac1e49d86 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc @@ -0,0 +1,113 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/padded_bytes.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { +namespace { + +struct SpeedTierTestParams { + explicit SpeedTierTestParams(const SpeedTier speed_tier, + const bool shrink8 = false) + : speed_tier(speed_tier), shrink8(shrink8) {} + SpeedTier speed_tier; + bool shrink8; +}; + +std::ostream& operator<<(std::ostream& os, SpeedTierTestParams params) { + auto previous_flags = os.flags(); + os << std::boolalpha; + os << "SpeedTierTestParams{" << SpeedTierName(params.speed_tier) + << ", /*shrink8=*/" << params.shrink8 << "}"; + os.flags(previous_flags); + return os; +} + +class SpeedTierTest : public testing::TestWithParam {}; + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + SpeedTierTestInstantiation, SpeedTierTest, + testing::Values(SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kKitten, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kKitten, + /*shrink8=*/false}, + // Only downscaled image for Tortoise mode. + SpeedTierTestParams{SpeedTier::kTortoise, + /*shrink8=*/true})); + +TEST_P(SpeedTierTest, Roundtrip) { + const PaddedBytes orig = + ReadTestData("wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ThreadPoolInternal pool(8); + ASSERT_TRUE(SetFromBytes(Span(orig), &io, &pool)); + + const SpeedTierTestParams& params = GetParam(); + + if (params.shrink8) { + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + } + + CompressParams cparams; + cparams.speed_tier = params.speed_tier; + DecompressParams dparams; + + CodecInOut io2; + test::Roundtrip(&io, cparams, dparams, nullptr, &io2); + + // Can be 2.2 in non-hare mode. + EXPECT_LE(ButteraugliDistance(io, io2, cparams.ba_params, + /*distmap=*/nullptr, /*pool=*/nullptr), + 2.8); +} +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/splines.cc b/third_party/jpeg-xl/lib/jxl/splines.cc new file mode 100644 index 000000000000..d4df452389bf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines.cc @@ -0,0 +1,523 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/splines.h" + +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/opsin_params.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/splines.cc" +#include +#include + +#include "lib/jxl/fast_math-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// Given a set of DCT coefficients, this returns the result of performing cosine +// interpolation on the original samples. +float ContinuousIDCT(const float dct[32], float t) { + // We compute here the DCT-3 of the `dct` vector, rescaled by a factor of + // sqrt(32). This is such that an input vector vector {x, 0, ..., 0} produces + // a constant result of x. dct[0] was scaled in Dequantize() to allow uniform + // treatment of all the coefficients. + constexpr float kMultipliers[32] = { + kPi / 32 * 0, kPi / 32 * 1, kPi / 32 * 2, kPi / 32 * 3, kPi / 32 * 4, + kPi / 32 * 5, kPi / 32 * 6, kPi / 32 * 7, kPi / 32 * 8, kPi / 32 * 9, + kPi / 32 * 10, kPi / 32 * 11, kPi / 32 * 12, kPi / 32 * 13, kPi / 32 * 14, + kPi / 32 * 15, kPi / 32 * 16, kPi / 32 * 17, kPi / 32 * 18, kPi / 32 * 19, + kPi / 32 * 20, kPi / 32 * 21, kPi / 32 * 22, kPi / 32 * 23, kPi / 32 * 24, + kPi / 32 * 25, kPi / 32 * 26, kPi / 32 * 27, kPi / 32 * 28, kPi / 32 * 29, + kPi / 32 * 30, kPi / 32 * 31, + }; + HWY_CAPPED(float, 32) df; + auto result = Zero(df); + const auto tandhalf = Set(df, t + 0.5f); + for (int i = 0; i < 32; i += Lanes(df)) { + auto cos_arg = LoadU(df, kMultipliers + i) * tandhalf; + auto cos = FastCosf(df, cos_arg); + auto local_res = LoadU(df, dct + i) * cos; + result = MulAdd(Set(df, square_root<2>::value), local_res, result); + } + return GetLane(SumOfLanes(result)); +} + +// Splats a single Gaussian on the image. +void DrawGaussian(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect, const Spline::Point& center, + const float intensity, const float color[3], + const float sigma, std::vector& xs, + std::vector& ys, + std::vector& local_intensity_storage) { + constexpr float kDistanceMultiplier = 4.605170185988091f; // -2 * log(0.1) + // Distance beyond which exp(-d^2 / (2 * sigma^2)) drops below 0.1. + const float maximum_distance = sigma * sigma * kDistanceMultiplier; + const auto xbegin_s = + std::max(image_rect.x0(), center.x - maximum_distance + .5f); + const auto xend_s = + std::min(center.x + maximum_distance + .5f, + image_rect.x0() + image_rect.xsize() - 1); + const auto ybegin_s = + std::max(image_rect.y0(), center.y - maximum_distance + .5f); + const auto yend_s = + std::min(center.y + maximum_distance + .5f, + image_rect.y0() + image_rect.ysize() - 1); + if ((xend_s) <= 0 || (xend_s < xbegin_s)) return; + const size_t xbegin = xbegin_s; + const size_t xend = xend_s; + if ((yend_s <= 0) || (yend_s < ybegin_s)) return; + const size_t ybegin = ybegin_s; + const size_t yend = yend_s; + const size_t opsin_stride = opsin->PixelsPerRow(); + float* JXL_RESTRICT rows[3] = { + opsin_rect.PlaneRow(opsin, 0, ybegin - image_rect.y0()), + opsin_rect.PlaneRow(opsin, 1, ybegin - image_rect.y0()), + opsin_rect.PlaneRow(opsin, 2, ybegin - image_rect.y0()), + }; + const size_t nx = xend + 1 - xbegin; + const size_t ny = yend + 1 - ybegin; + HWY_FULL(float) df; + if (xs.size() < nx * ny) { + size_t sz = DivCeil(nx * ny, Lanes(df)) * Lanes(df); + xs.resize(sz); + ys.resize(sz); + local_intensity_storage.resize(sz); + } + for (size_t y = ybegin; y <= yend; ++y) { + for (size_t x = xbegin; x <= xend; ++x) { + xs[(y - ybegin) * nx + (x - xbegin)] = x; + ys[(y - ybegin) * nx + (x - xbegin)] = y; + } + } + Rebind di; + const auto inv_sigma = Set(df, 1.0f / sigma); + const auto half = Set(df, 0.5f); + const auto one_over_2s2 = Set(df, 0.353553391f); + const auto sigma_over_4_times_intensity = Set(df, .25f * sigma * intensity); + for (size_t i = 0; i < nx * ny; i += Lanes(df)) { + const auto x = ConvertTo(df, LoadU(di, &xs[i])); + const auto y = ConvertTo(df, LoadU(di, &ys[i])); + const auto dx = x - Set(df, center.x); + const auto dy = y - Set(df, center.y); + const auto sqd = MulAdd(dx, dx, dy * dy); + const auto distance = Sqrt(sqd); + const auto one_dimensional_factor = + FastErff(df, MulAdd(distance, half, one_over_2s2) * inv_sigma) - + FastErff(df, MulSub(distance, half, one_over_2s2) * inv_sigma); + const auto local_intensity = sigma_over_4_times_intensity * + one_dimensional_factor * + one_dimensional_factor; + StoreU(local_intensity, df, &local_intensity_storage[i]); + } + ssize_t off = -static_cast(image_rect.x0()); + for (size_t y = ybegin; y <= yend; ++y) { + HWY_CAPPED(float, 1) df; + for (size_t x = xbegin; x <= xend; ++x) { + const auto local_intensity = Load( + df, local_intensity_storage.data() + (y - ybegin) * nx + x - xbegin); + for (size_t c = 0; c < 3; ++c) { + const auto cm = Set(df, color[c]); + const auto in = LoadU(df, rows[c] + x + off); + StoreU(MulAdd(cm, local_intensity, in), df, rows[c] + x + off); + } + } + off += opsin_stride; + } +} + +void DrawFromPoints( + Image3F* const opsin, const Rect& opsin_rect, const Rect& image_rect, + const Spline& spline, bool add, + const std::vector>& points_to_draw, + float arc_length) { + float inv_arc_length = 1.0f / arc_length; + int k = 0; + std::vector xs, ys; + std::vector local_intensity_storage; + for (const auto& point_to_draw : points_to_draw) { + const Spline::Point& point = point_to_draw.first; + const float multiplier = add ? point_to_draw.second : -point_to_draw.second; + const float progress_along_arc = + std::min(1.f, (k * kDesiredRenderingDistance) * inv_arc_length); + ++k; + float color[3]; + for (size_t c = 0; c < 3; ++c) { + color[c] = + ContinuousIDCT(spline.color_dct[c], (32 - 1) * progress_along_arc); + } + const float sigma = + ContinuousIDCT(spline.sigma_dct, (32 - 1) * progress_along_arc); + DrawGaussian(opsin, opsin_rect, image_rect, point, multiplier, color, sigma, + xs, ys, local_intensity_storage); + } +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(DrawFromPoints); + +namespace { + +// Maximum number of spline control points per frame is +// std::min(kMaxNumControlPoints, xsize * ysize / 2) +constexpr size_t kMaxNumControlPoints = 1u << 20u; +constexpr size_t kMaxNumControlPointsPerPixelRatio = 2; + +// X, Y, B, sigma. +float ColorQuantizationWeight(const int32_t adjustment, const int channel, + const int i) { + const float multiplier = adjustment >= 0 ? 1.f + .125f * adjustment + : 1.f / (1.f + .125f * -adjustment); + + static constexpr float kChannelWeight[] = {0.0042f, 0.075f, 0.07f, .3333f}; + + return multiplier / kChannelWeight[channel]; +} + +Status DecodeAllStartingPoints(std::vector* const points, + BitReader* const br, ANSSymbolReader* reader, + const std::vector& context_map, + size_t num_splines) { + points->clear(); + points->reserve(num_splines); + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < num_splines; i++) { + int64_t x = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + int64_t y = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + if (i != 0) { + x = UnpackSigned(x) + last_x; + y = UnpackSigned(y) + last_y; + } + points->emplace_back(static_cast(x), static_cast(y)); + last_x = x; + last_y = y; + } + return true; +} + +struct Vector { + float x, y; + Vector operator-() const { return {-x, -y}; } + Vector operator+(const Vector& other) const { + return {x + other.x, y + other.y}; + } + float SquaredNorm() const { return x * x + y * y; } +}; +Vector operator*(const float k, const Vector& vec) { + return {k * vec.x, k * vec.y}; +} + +Spline::Point operator+(const Spline::Point& p, const Vector& vec) { + return {p.x + vec.x, p.y + vec.y}; +} +Spline::Point operator-(const Spline::Point& p, const Vector& vec) { + return p + -vec; +} +Vector operator-(const Spline::Point& a, const Spline::Point& b) { + return {a.x - b.x, a.y - b.y}; +} + +std::vector DrawCentripetalCatmullRomSpline( + std::vector points) { + if (points.size() <= 1) return points; + // Number of points to compute between each control point. + static constexpr int kNumPoints = 16; + std::vector result; + result.reserve((points.size() - 1) * kNumPoints + 1); + points.insert(points.begin(), points[0] + (points[0] - points[1])); + points.push_back(points[points.size() - 1] + + (points[points.size() - 1] - points[points.size() - 2])); + // points has at least 4 elements at this point. + for (size_t start = 0; start < points.size() - 3; ++start) { + // 4 of them are used, and we draw from p[1] to p[2]. + const Spline::Point* const p = &points[start]; + result.push_back(p[1]); + float t[4] = {0}; + for (int k = 1; k < 4; ++k) { + t[k] = std::sqrt(hypotf(p[k].x - p[k - 1].x, p[k].y - p[k - 1].y)) + + t[k - 1]; + } + for (int i = 1; i < kNumPoints; ++i) { + const float tt = + t[1] + (static_cast(i) / kNumPoints) * (t[2] - t[1]); + Spline::Point a[3]; + for (int k = 0; k < 3; ++k) { + a[k] = p[k] + ((tt - t[k]) / (t[k + 1] - t[k])) * (p[k + 1] - p[k]); + } + Spline::Point b[2]; + for (int k = 0; k < 2; ++k) { + b[k] = a[k] + ((tt - t[k]) / (t[k + 2] - t[k])) * (a[k + 1] - a[k]); + } + result.push_back(b[0] + ((tt - t[1]) / (t[2] - t[1])) * (b[1] - b[0])); + } + } + result.push_back(points[points.size() - 2]); + return result; +} + +// Move along the line segments defined by `points`, `kDesiredRenderingDistance` +// pixels at a time, and call `functor` with each point and the actual distance +// to the previous point (which will always be kDesiredRenderingDistance except +// possibly for the very last point). +template +void ForEachEquallySpacedPoint(const Points& points, const Functor& functor) { + JXL_ASSERT(!points.empty()); + Spline::Point current = points.front(); + functor(current, kDesiredRenderingDistance); + auto next = points.begin(); + while (next != points.end()) { + const Spline::Point* previous = ¤t; + float arclength_from_previous = 0.f; + for (;;) { + if (next == points.end()) { + functor(*previous, arclength_from_previous); + return; + } + const float arclength_to_next = + std::sqrt((*next - *previous).SquaredNorm()); + if (arclength_from_previous + arclength_to_next >= + kDesiredRenderingDistance) { + current = + *previous + ((kDesiredRenderingDistance - arclength_from_previous) / + arclength_to_next) * + (*next - *previous); + functor(current, kDesiredRenderingDistance); + break; + } + arclength_from_previous += arclength_to_next; + previous = &*next; + ++next; + } + } +} + +} // namespace + +QuantizedSpline::QuantizedSpline(const Spline& original, + const int32_t quantization_adjustment, + float ytox, float ytob) { + JXL_ASSERT(!original.control_points.empty()); + control_points_.reserve(original.control_points.size() - 1); + const Spline::Point& starting_point = original.control_points.front(); + int previous_x = static_cast(roundf(starting_point.x)), + previous_y = static_cast(roundf(starting_point.y)); + int previous_delta_x = 0, previous_delta_y = 0; + for (auto it = original.control_points.begin() + 1; + it != original.control_points.end(); ++it) { + const int new_x = static_cast(roundf(it->x)); + const int new_y = static_cast(roundf(it->y)); + const int new_delta_x = new_x - previous_x; + const int new_delta_y = new_y - previous_y; + control_points_.emplace_back(new_delta_x - previous_delta_x, + new_delta_y - previous_delta_y); + previous_delta_x = new_delta_x; + previous_delta_y = new_delta_y; + previous_x = new_x; + previous_y = new_y; + } + + for (int c = 0; c < 3; ++c) { + float factor = c == 0 ? ytox : c == 1 ? 0 : ytob; + for (int i = 0; i < 32; ++i) { + const float coefficient = + original.color_dct[c][i] - + factor * color_dct_[1][i] / + ColorQuantizationWeight(quantization_adjustment, 1, i); + color_dct_[c][i] = static_cast( + roundf(coefficient * + ColorQuantizationWeight(quantization_adjustment, c, i))); + } + } + for (int i = 0; i < 32; ++i) { + sigma_dct_[i] = static_cast( + roundf(original.sigma_dct[i] * + ColorQuantizationWeight(quantization_adjustment, 3, i))); + } +} + +Spline QuantizedSpline::Dequantize(const Spline::Point& starting_point, + const int32_t quantization_adjustment, + float ytox, float ytob) const { + Spline result; + + result.control_points.reserve(control_points_.size() + 1); + int current_x = static_cast(roundf(starting_point.x)), + current_y = static_cast(roundf(starting_point.y)); + result.control_points.push_back(Spline::Point{static_cast(current_x), + static_cast(current_y)}); + int current_delta_x = 0, current_delta_y = 0; + for (const auto& point : control_points_) { + current_delta_x += point.first; + current_delta_y += point.second; + current_x += current_delta_x; + current_y += current_delta_y; + result.control_points.push_back(Spline::Point{ + static_cast(current_x), static_cast(current_y)}); + } + + for (int c = 0; c < 3; ++c) { + for (int i = 0; i < 32; ++i) { + result.color_dct[c][i] = + color_dct_[c][i] * (i == 0 ? 1.0f / square_root<2>::value : 1.0f) / + ColorQuantizationWeight(quantization_adjustment, c, i); + } + } + for (int i = 0; i < 32; ++i) { + result.color_dct[0][i] += ytox * result.color_dct[1][i]; + result.color_dct[2][i] += ytob * result.color_dct[1][i]; + } + for (int i = 0; i < 32; ++i) { + result.sigma_dct[i] = + sigma_dct_[i] * (i == 0 ? 1.0f / square_root<2>::value : 1.0f) / + ColorQuantizationWeight(quantization_adjustment, 3, i); + } + + return result; +} + +Status QuantizedSpline::Decode(const std::vector& context_map, + ANSSymbolReader* const decoder, + BitReader* const br, size_t max_control_points, + size_t* total_num_control_points) { + const size_t num_control_points = + decoder->ReadHybridUint(kNumControlPointsContext, br, context_map); + *total_num_control_points += num_control_points; + if (*total_num_control_points > max_control_points) { + return JXL_FAILURE("Too many control points: %zu", + *total_num_control_points); + } + control_points_.resize(num_control_points); + for (std::pair& control_point : control_points_) { + control_point.first = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + control_point.second = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + } + + const auto decode_dct = [decoder, br, &context_map](int dct[32]) -> Status { + for (int i = 0; i < 32; ++i) { + dct[i] = + UnpackSigned(decoder->ReadHybridUint(kDCTContext, br, context_map)); + } + return true; + }; + for (int c = 0; c < 3; ++c) { + JXL_RETURN_IF_ERROR(decode_dct(color_dct_[c])); + } + JXL_RETURN_IF_ERROR(decode_dct(sigma_dct_)); + return true; +} + +Status Splines::Decode(jxl::BitReader* br, size_t num_pixels) { + std::vector context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumSplineContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + const size_t num_splines = + 1 + decoder.ReadHybridUint(kNumSplinesContext, br, context_map); + size_t max_control_points = std::min( + kMaxNumControlPoints, num_pixels / kMaxNumControlPointsPerPixelRatio); + if (num_splines > max_control_points) { + return JXL_FAILURE("Too many splines: %zu", num_splines); + } + JXL_RETURN_IF_ERROR(DecodeAllStartingPoints(&starting_points_, br, &decoder, + context_map, num_splines)); + + quantization_adjustment_ = UnpackSigned( + decoder.ReadHybridUint(kQuantizationAdjustmentContext, br, context_map)); + + splines_.clear(); + splines_.reserve(num_splines); + size_t num_control_points = num_splines; + for (size_t i = 0; i < num_splines; ++i) { + QuantizedSpline spline; + JXL_RETURN_IF_ERROR(spline.Decode(context_map, &decoder, br, + max_control_points, &num_control_points)); + splines_.push_back(std::move(spline)); + } + + JXL_RETURN_IF_ERROR(decoder.CheckANSFinalState()); + + if (!HasAny()) { + return JXL_FAILURE("Decoded splines but got none"); + } + + return true; +} + +Status Splines::AddTo(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect, + const ColorCorrelationMap& cmap) const { + return Apply(opsin, opsin_rect, image_rect, cmap); +} + +Status Splines::SubtractFrom(Image3F* const opsin, + const ColorCorrelationMap& cmap) const { + return Apply(opsin, Rect(*opsin), Rect(*opsin), cmap); +} + +template +Status Splines::Apply(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect, + const ColorCorrelationMap& cmap) const { + for (size_t i = 0; i < splines_.size(); ++i) { + const Spline spline = + splines_[i].Dequantize(starting_points_[i], quantization_adjustment_, + cmap.YtoXRatio(0), cmap.YtoBRatio(0)); + if (std::adjacent_find(spline.control_points.begin(), + spline.control_points.end()) != + spline.control_points.end()) { + return JXL_FAILURE("identical successive control points in spline %zu", + i); + } + std::vector> points_to_draw; + ForEachEquallySpacedPoint( + DrawCentripetalCatmullRomSpline(spline.control_points), + [&](const Spline::Point& point, const float multiplier) { + points_to_draw.emplace_back(point, multiplier); + }); + const float arc_length = + (points_to_draw.size() - 2) * kDesiredRenderingDistance + + points_to_draw.back().second; + if (arc_length <= 0.f) { + // This spline wouldn't have any effect. + continue; + } + HWY_DYNAMIC_DISPATCH(DrawFromPoints) + (opsin, opsin_rect, image_rect, spline, add, points_to_draw, arc_length); + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/splines.h b/third_party/jpeg-xl/lib/jxl/splines.h new file mode 100644 index 000000000000..0a7bb7135e1c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines.h @@ -0,0 +1,133 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_SPLINES_H_ +#define LIB_JXL_SPLINES_H_ + +#include +#include + +#include +#include + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/aux_out.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" + +namespace jxl { + +static constexpr float kDesiredRenderingDistance = 1.f; + +enum SplineEntropyContexts : size_t { + kQuantizationAdjustmentContext = 0, + kStartingPositionContext, + kNumSplinesContext, + kNumControlPointsContext, + kControlPointsContext, + kDCTContext, + kNumSplineContexts +}; + +struct Spline { + struct Point { + Point() : x(0.0f), y(0.0f) {} + Point(float x, float y) : x(x), y(y) {} + float x, y; + bool operator==(const Point& other) const { + return std::fabs(x - other.x) < 1e-3f && std::fabs(y - other.y) < 1e-3f; + } + }; + std::vector control_points; + // X, Y, B. + float color_dct[3][32]; + // Splines are draws by normalized Gaussian splatting. This controls the + // Gaussian's parameter along the spline. + float sigma_dct[32]; +}; + +class QuantizedSplineEncoder; + +class QuantizedSpline { + public: + QuantizedSpline() = default; + explicit QuantizedSpline(const Spline& original, + int32_t quantization_adjustment, float ytox, + float ytob); + + Spline Dequantize(const Spline::Point& starting_point, + int32_t quantization_adjustment, float ytox, + float ytob) const; + + Status Decode(const std::vector& context_map, + ANSSymbolReader* decoder, BitReader* br, + size_t max_control_points, size_t* total_num_control_points); + + private: + friend class QuantizedSplineEncoder; + + std::vector> + control_points_; // Double delta-encoded. + int color_dct_[3][32] = {}; + int sigma_dct_[32] = {}; +}; + +class Splines { + public: + Splines() = default; + explicit Splines(const int32_t quantization_adjustment, + std::vector splines, + std::vector starting_points) + : quantization_adjustment_(quantization_adjustment), + splines_(std::move(splines)), + starting_points_(std::move(starting_points)) {} + + bool HasAny() const { return !splines_.empty(); } + + Status Decode(BitReader* br, size_t num_pixels); + + Status AddTo(Image3F* opsin, const Rect& opsin_rect, const Rect& image_rect, + const ColorCorrelationMap& cmap) const; + Status SubtractFrom(Image3F* opsin, const ColorCorrelationMap& cmap) const; + + const std::vector& QuantizedSplines() const { + return splines_; + } + const std::vector& StartingPoints() const { + return starting_points_; + } + + int32_t GetQuantizationAdjustment() const { return quantization_adjustment_; } + + private: + template + Status Apply(Image3F* opsin, const Rect& opsin_rect, const Rect& image_rect, + const ColorCorrelationMap& cmap) const; + + // If positive, quantization weights are multiplied by 1 + this/8, which + // increases precision. If negative, they are divided by 1 - this/8. If 0, + // they are unchanged. + int32_t quantization_adjustment_ = 0; + std::vector splines_; + std::vector starting_points_; +}; + +} // namespace jxl + +#endif // LIB_JXL_SPLINES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/splines_gbench.cc b/third_party/jpeg-xl/lib/jxl/splines_gbench.cc new file mode 100644 index 000000000000..c0674aca8748 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines_gbench.cc @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "lib/jxl/splines.h" + +namespace jxl { +namespace { + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +void BM_Splines(benchmark::State& state) { + const size_t n = state.range(); + + std::vector spline_data = { + {/*control_points=*/{ + {9, 54}, {118, 159}, {97, 3}, {10, 40}, {150, 25}, {120, 300}}, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F drawing_area(320, 320); + ZeroFillImage(&drawing_area); + for (auto _ : state) { + for (size_t i = 0; i < n; ++i) { + JXL_CHECK(splines.AddTo(&drawing_area, Rect(drawing_area), + Rect(drawing_area), *cmap)); + } + } + + state.SetItemsProcessed(n * state.iterations()); +} + +BENCHMARK(BM_Splines)->Range(1, 1 << 10); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/splines_test.cc b/third_party/jpeg-xl/lib/jxl/splines_test.cc new file mode 100644 index 000000000000..b2d9270c6094 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines_test.cc @@ -0,0 +1,321 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/splines.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "lib/extras/codec.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testdata.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const Spline::Point& p) { + return os << "(" << p.x << ", " << p.y << ")"; +} + +std::ostream& operator<<(std::ostream& os, const Spline& spline) { + return os << "(spline with " << spline.control_points.size() + << " control points)"; +} + +namespace { + +using ::testing::AllOf; +using ::testing::Field; +using ::testing::FloatNear; +using ::testing::Pointwise; + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +constexpr float kTolerance = 0.003125; + +std::vector DequantizeSplines(const Splines& splines) { + const auto& quantized_splines = splines.QuantizedSplines(); + const auto& starting_points = splines.StartingPoints(); + JXL_ASSERT(quantized_splines.size() == starting_points.size()); + + std::vector dequantized; + for (size_t i = 0; i < quantized_splines.size(); ++i) { + dequantized.push_back(quantized_splines[i].Dequantize( + starting_points[i], kQuantizationAdjustment, kYToX, kYToB)); + } + return dequantized; +} + +MATCHER(ControlPointIs, "") { + const Spline::Point& actual = std::get<0>(arg); + const Spline::Point& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + AllOf(Field(&Spline::Point::x, FloatNear(expected.x, kTolerance)), + Field(&Spline::Point::y, FloatNear(expected.y, kTolerance))), + actual, result_listener); +} + +MATCHER(ControlPointsMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + Field(&Spline::control_points, + Pointwise(ControlPointIs(), expected.control_points)), + actual, result_listener); +} + +MATCHER(SplinesMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + if (!testing::ExplainMatchResult(ControlPointsMatch(), arg, + result_listener)) { + return false; + } + for (int i = 0; i < 3; ++i) { + size_t color_dct_size = + sizeof(expected.color_dct[i]) / sizeof(expected.color_dct[i][0]); + for (size_t j = 0; j < color_dct_size; j++) { + testing::StringMatchResultListener color_dct_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.color_dct[i][j], kTolerance), + actual.color_dct[i][j], &color_dct_listener)) { + *result_listener << ", where color_dct[" << i << "][" << j + << "] don't match, " << color_dct_listener.str(); + return false; + } + } + } + size_t sigma_dct_size = + sizeof(expected.sigma_dct) / sizeof(expected.sigma_dct[0]); + for (size_t i = 0; i < sigma_dct_size; i++) { + testing::StringMatchResultListener sigma_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.sigma_dct[i], kTolerance), actual.sigma_dct[i], + &sigma_listener)) { + *result_listener << ", where sigma_dct[" << i << "] don't match, " + << sigma_listener.str(); + return false; + } + } + return true; +} + +} // namespace + +TEST(SplinesTest, Serialization) { + std::vector spline_data = { + {/*control_points=*/{ + {109, 54}, {218, 159}, {80, 3}, {110, 274}, {94, 185}, {17, 277}}, + /*color_dct=*/ + {{36.3, 39.7, 23.2, 67.5, 4.4, 71.5, 62.3, 32.3, 92.2, 10.1, 10.8, + 9.2, 6.1, 10.5, 79.1, 7, 24.6, 90.8, 5.5, 84, 43.8, 49, + 33.5, 78.9, 54.5, 77.9, 62.1, 51.4, 36.4, 14.3, 83.7, 35.4}, + {9.4, 53.4, 9.5, 74.9, 72.7, 26.7, 7.9, 0.9, 84.9, 23.2, 26.5, + 31.1, 91, 11.7, 74.1, 39.3, 23.7, 82.5, 4.8, 2.7, 61.2, 96.4, + 13.7, 66.7, 62.9, 82.4, 5.9, 98.7, 21.5, 7.9, 51.7, 63.1}, + {48, 39.3, 6.9, 26.3, 33.3, 6.2, 1.7, 98.9, 59.9, 59.6, 95, + 61.3, 82.7, 53, 6.1, 30.4, 34.7, 96.9, 93.4, 17, 38.8, 80.8, + 63, 18.6, 43.6, 32.3, 61, 20.2, 24.3, 28.3, 69.1, 62.4}}, + /*sigma_dct=*/{32.7, 21.5, 44.4, 1.8, 45.8, 90.6, 29.3, 59.2, + 23.7, 85.2, 84.8, 27.2, 42.1, 84.1, 50.6, 17.6, + 93.7, 4.9, 2.6, 69.8, 94.9, 52, 24.3, 18.8, + 12.1, 95.7, 28.5, 81.4, 89.9, 31.4, 74.8, 52}}, + {/*control_points=*/{{172, 309}, + {196, 277}, + {42, 238}, + {114, 350}, + {307, 290}, + {316, 269}, + {124, 66}, + {233, 267}}, + /*color_dct=*/ + {{15, 28.9, 22, 6.6, 41.8, 83, 8.6, 56.8, 68.9, 9.7, 5.4, + 19.8, 70.8, 90, 52.5, 65.2, 7.8, 23.5, 26.4, 72.2, 64.7, 87.1, + 1.3, 67.5, 46, 68.4, 65.4, 35.5, 29.1, 13, 41.6, 23.9}, + {47.7, 79.4, 62.7, 29.1, 96.8, 18.5, 17.6, 15.2, 80.5, 56, 96.2, + 59.9, 26.7, 96.1, 92.3, 42.1, 35.8, 54, 23.2, 55, 76, 35.8, + 58.4, 88.7, 2.4, 78.1, 95.6, 27.5, 6.6, 78.5, 24.1, 69.8}, + {43.8, 96.5, 0.9, 95.1, 49.1, 71.2, 25.1, 33.6, 75.2, 95, 82.1, + 19.7, 10.5, 44.9, 50, 93.3, 83.5, 99.5, 64.6, 54, 3.5, 99.7, + 45.3, 82.1, 22.4, 37.9, 60, 32.2, 12.6, 4.6, 65.5, 96.4}}, + /*sigma_dct=*/{72.5, 2.6, 41.7, 2.2, 39.7, 79.1, 69.6, 19.9, + 92.3, 71.5, 41.9, 62.1, 30, 49.4, 70.3, 45.3, + 62.5, 47.2, 46.7, 41.2, 90.8, 46.8, 91.2, 55, + 8.1, 69.6, 25.4, 84.7, 61.7, 27.6, 3.7, 46.9}}, + {/*control_points=*/{{100, 186}, + {257, 97}, + {170, 49}, + {25, 169}, + {309, 104}, + {232, 237}, + {385, 101}, + {122, 168}, + {26, 300}, + {390, 88}}, + /*color_dct=*/ + {{16.9, 64.8, 4.2, 10.6, 23.5, 17, 79.3, 5.7, 60.4, 16.6, 94.9, + 63.7, 87.6, 10.5, 3.8, 61.1, 22.9, 81.9, 80.4, 40.5, 45.9, 25.4, + 39.8, 30, 50.2, 90.4, 27.9, 93.7, 65.1, 48.2, 22.3, 43.9}, + {24.9, 66, 3.5, 90.2, 97.1, 15.8, 35.6, 0.6, 68, 39.6, 24.4, + 85.9, 57.7, 77.6, 47.5, 67.9, 4.3, 5.4, 91.2, 58.5, 0.1, 52.2, + 3.5, 47.8, 63.2, 43.5, 85.8, 35.8, 50.2, 35.9, 19.2, 48.2}, + {82.8, 44.9, 76.4, 39.5, 94.1, 14.3, 89.8, 10, 10.5, 74.5, 56.3, + 65.8, 7.8, 23.3, 52.8, 99.3, 56.8, 46, 76.7, 13.5, 67, 22.4, + 29.9, 43.3, 70.3, 26, 74.3, 53.9, 62, 19.1, 49.3, 46.7}}, + /*sigma_dct=*/{83.5, 1.7, 25.1, 18.7, 46.5, 75.3, 28, 62.3, + 50.3, 23.3, 85.6, 96, 45.8, 33.1, 33.4, 52.9, + 26.3, 58.5, 19.6, 70, 92.6, 22.5, 57, 21.6, + 76.8, 87.5, 22.9, 66.3, 35.7, 35.6, 56.8, 67.2}}, + }; + + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + const std::vector quantized_spline_data = DequantizeSplines(splines); + EXPECT_THAT(quantized_spline_data, + Pointwise(ControlPointsMatch(), spline_data)); + + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, HistogramParams(), nullptr); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + + printf("Wrote %zu bits of splines.\n", bits_written); + + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + ASSERT_TRUE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + ASSERT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + ASSERT_TRUE(reader.Close()); + + const std::vector decoded_spline_data = + DequantizeSplines(decoded_splines); + EXPECT_THAT(decoded_spline_data, + Pointwise(SplinesMatch(), quantized_spline_data)); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_TooManySplinesTest) { +#else +TEST(SplinesTest, TooManySplinesTest) { +#endif + // This is more than the limit for 1000 pixels. + const size_t kNumSplines = 300; + + std::vector quantized_splines; + std::vector starting_points; + for (size_t i = 0; i < kNumSplines; i++) { + Spline spline = { + /*control_points=*/{{1.f + i, 2}, {10.f + i, 25}, {30.f + i, 300}}, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, + HistogramParams(SpeedTier::kFalcon, 1), nullptr); + writer.ZeroPadToByte(); + // Re-read splines. + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + EXPECT_FALSE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + EXPECT_TRUE(reader.Close()); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_DuplicatePoints) { +#else +TEST(SplinesTest, DuplicatePoints) { +#endif + std::vector control_points{ + {9, 54}, {118, 159}, {97, 3}, // Repeated. + {97, 3}, {10, 40}, {150, 25}, {120, 300}}; + Spline spline{control_points, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + std::vector spline_data{spline}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + EXPECT_FALSE(splines.AddTo(&image, Rect(image), Rect(image), *cmap)); +} + +TEST(SplinesTest, Drawing) { + CodecInOut io_expected; + const PaddedBytes orig = ReadTestData("jxl/splines.png"); + ASSERT_TRUE(SetFromBytes(Span(orig), &io_expected, + /*pool=*/nullptr)); + + std::vector control_points{{9, 54}, {118, 159}, {97, 3}, + {10, 40}, {150, 25}, {120, 300}}; + const Spline spline{ + control_points, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}; + std::vector spline_data = {spline}; + std::vector quantized_splines; + std::vector starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + ASSERT_TRUE(splines.AddTo(&image, Rect(image), Rect(image), *cmap)); + + OpsinParams opsin_params{}; + opsin_params.Init(kDefaultIntensityTarget); + (void)OpsinToLinearInplace(&image, /*pool=*/nullptr, opsin_params); + + CodecInOut io_actual; + io_actual.SetFromImage(CopyImage(image), ColorEncoding::LinearSRGB()); + ASSERT_TRUE(io_actual.TransformTo(io_expected.Main().c_current())); + + VerifyRelativeError(*io_expected.Main().color(), *io_actual.Main().color(), + 1e-2f, 1e-1f); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/test_utils.h b/third_party/jpeg-xl/lib/jxl/test_utils.h new file mode 100644 index 000000000000..b8fd6870c084 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/test_utils.h @@ -0,0 +1,350 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_TEST_UTILS_H_ +#define LIB_JXL_TEST_UTILS_H_ + +// Macros and functions useful for tests. + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jxl/codestream_header.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_file.h" +#include "lib/jxl/dec_params.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_file.h" +#include "lib/jxl/enc_params.h" + +#ifdef JXL_DISABLE_SLOW_TESTS +#define JXL_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_SLOW_TEST(X) X +#endif // JXL_DISABLE_SLOW_TESTS + +#ifdef THREAD_SANITIZER +#define JXL_TSAN_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_TSAN_SLOW_TEST(X) X +#endif // THREAD_SANITIZER + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +namespace jxl { +namespace test { + +void JxlBasicInfoSetFromPixelFormat(JxlBasicInfo* basic_info, + const JxlPixelFormat* pixel_format) { + switch (pixel_format->data_type) { + case JXL_TYPE_FLOAT: + basic_info->bits_per_sample = 32; + basic_info->exponent_bits_per_sample = 8; + break; + case JXL_TYPE_FLOAT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 5; + break; + case JXL_TYPE_UINT8: + basic_info->bits_per_sample = 8; + basic_info->exponent_bits_per_sample = 0; + break; + case JXL_TYPE_UINT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 0; + break; + case JXL_TYPE_UINT32: + basic_info->bits_per_sample = 32; + basic_info->exponent_bits_per_sample = 0; + break; + case JXL_TYPE_BOOLEAN: + basic_info->bits_per_sample = 1; + basic_info->exponent_bits_per_sample = 0; + break; + } + if (pixel_format->num_channels == 2 || pixel_format->num_channels == 4) { + basic_info->alpha_exponent_bits = 0; + if (basic_info->bits_per_sample == 32) { + basic_info->alpha_bits = 16; + } else { + basic_info->alpha_bits = basic_info->bits_per_sample; + } + } else { + basic_info->alpha_exponent_bits = 0; + basic_info->alpha_bits = 0; + } +} + +MATCHER_P(MatchesPrimariesAndTransferFunction, color_encoding, "") { + return arg.primaries == color_encoding.primaries && + arg.tf.IsSame(color_encoding.tf); +} + +MATCHER(MatchesPrimariesAndTransferFunction, "") { + return testing::ExplainMatchResult( + MatchesPrimariesAndTransferFunction(std::get<1>(arg)), std::get<0>(arg), + result_listener); +} + +// Returns compressed size [bytes]. +size_t Roundtrip(const CodecInOut* io, const CompressParams& cparams, + const DecompressParams& dparams, ThreadPool* pool, + CodecInOut* JXL_RESTRICT io2, AuxOut* aux_out = nullptr) { + PaddedBytes compressed; + + std::vector original_metadata_encodings; + std::vector original_current_encodings; + for (const ImageBundle& ib : io->frames) { + // Remember original encoding, will be returned by decoder. + original_metadata_encodings.push_back(ib.metadata()->color_encoding); + // c_current should not change during encoding. + original_current_encodings.push_back(ib.c_current()); + } + + std::unique_ptr enc_state = + jxl::make_unique(); + EXPECT_TRUE( + EncodeFile(cparams, io, enc_state.get(), &compressed, aux_out, pool)); + + std::vector metadata_encodings_1; + for (const ImageBundle& ib1 : io->frames) { + metadata_encodings_1.push_back(ib1.metadata()->color_encoding); + } + + // Should still be in the same color space after encoding. + EXPECT_THAT(metadata_encodings_1, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_metadata_encodings)); + + EXPECT_TRUE(DecodeFile(dparams, compressed, io2, pool)); + + std::vector metadata_encodings_2; + std::vector current_encodings_2; + for (const ImageBundle& ib2 : io2->frames) { + metadata_encodings_2.push_back(ib2.metadata()->color_encoding); + current_encodings_2.push_back(ib2.c_current()); + } + + EXPECT_THAT(io2->frames, testing::SizeIs(io->frames.size())); + // We always produce the original color encoding if a color transform hook is + // set. + EXPECT_THAT(current_encodings_2, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_current_encodings)); + + // Decoder returns the originals passed to the encoder. + EXPECT_THAT(metadata_encodings_2, + testing::Pointwise(MatchesPrimariesAndTransferFunction(), + original_metadata_encodings)); + + return compressed.size(); +} + +void CoalesceGIFAnimationWithAlpha(CodecInOut* io) { + ImageBundle canvas = io->frames[0].Copy(); + for (size_t i = 1; i < io->frames.size(); i++) { + const ImageBundle& frame = io->frames[i]; + ImageBundle rendered = canvas.Copy(); + for (size_t y = 0; y < frame.ysize(); y++) { + float* row0 = + rendered.color()->PlaneRow(0, frame.origin.y0 + y) + frame.origin.x0; + float* row1 = + rendered.color()->PlaneRow(1, frame.origin.y0 + y) + frame.origin.x0; + float* row2 = + rendered.color()->PlaneRow(2, frame.origin.y0 + y) + frame.origin.x0; + float* rowa = + rendered.alpha()->Row(frame.origin.y0 + y) + frame.origin.x0; + const float* row0f = frame.color().PlaneRow(0, y); + const float* row1f = frame.color().PlaneRow(1, y); + const float* row2f = frame.color().PlaneRow(2, y); + const float* rowaf = frame.alpha().Row(y); + for (size_t x = 0; x < frame.xsize(); x++) { + if (rowaf[x] != 0) { + row0[x] = row0f[x]; + row1[x] = row1f[x]; + row2[x] = row2f[x]; + rowa[x] = rowaf[x]; + } + } + } + if (frame.use_for_next_frame) { + canvas = rendered.Copy(); + } + io->frames[i] = std::move(rendered); + } +} + +// A POD descriptor of a ColorEncoding. Only used in tests as the return value +// of AllEncodings(). +struct ColorEncodingDescriptor { + ColorSpace color_space; + WhitePoint white_point; + Primaries primaries; + TransferFunction tf; + RenderingIntent rendering_intent; +}; + +static inline ColorEncoding ColorEncodingFromDescriptor( + const ColorEncodingDescriptor& desc) { + ColorEncoding c; + c.SetColorSpace(desc.color_space); + c.white_point = desc.white_point; + c.primaries = desc.primaries; + c.tf.SetTransferFunction(desc.tf); + c.rendering_intent = desc.rendering_intent; + return c; +} + +// Define the operator<< for tests. +static inline ::std::ostream& operator<<(::std::ostream& os, + const ColorEncodingDescriptor& c) { + return os << "ColorEncoding/" << Description(ColorEncodingFromDescriptor(c)); +} + +// Returns ColorEncodingDescriptors, which are only used in tests. To obtain a +// ColorEncoding object call ColorEncodingFromDescriptor and then call +// ColorEncoding::CreateProfile() on that object to generate a profile. +std::vector AllEncodings() { + std::vector all_encodings; + all_encodings.reserve(300); + ColorEncoding c; + + for (ColorSpace cs : Values()) { + if (cs == ColorSpace::kUnknown || cs == ColorSpace::kXYB) continue; + c.SetColorSpace(cs); + + for (WhitePoint wp : Values()) { + if (wp == WhitePoint::kCustom) continue; + if (c.ImplicitWhitePoint() && c.white_point != wp) continue; + c.white_point = wp; + + for (Primaries primaries : Values()) { + if (primaries == Primaries::kCustom) continue; + if (!c.HasPrimaries()) continue; + c.primaries = primaries; + + for (TransferFunction tf : Values()) { + if (tf == TransferFunction::kUnknown) continue; + if (c.tf.SetImplicit() && + (c.tf.IsGamma() || c.tf.GetTransferFunction() != tf)) { + continue; + } + c.tf.SetTransferFunction(tf); + + for (RenderingIntent ri : Values()) { + ColorEncodingDescriptor cdesc; + cdesc.color_space = cs; + cdesc.white_point = wp; + cdesc.primaries = primaries; + cdesc.tf = tf; + cdesc.rendering_intent = ri; + all_encodings.push_back(cdesc); + } + } + } + } + } + + return all_encodings; +} + +// Returns a test image with some autogenerated pixel content, using 16 bits per +// channel, big endian order, 1 to 4 channels +// The seed parameter allows to create images with different pixel content. +std::vector GetSomeTestImage(size_t xsize, size_t ysize, + size_t num_channels, uint16_t seed) { + // Cause more significant image difference for successive seeds. + seed = static_cast(seed * 77); + size_t num_pixels = xsize * ysize; + // 16 bits per channel, big endian, 4 channels + std::vector pixels(num_pixels * num_channels * 2); + // Create pixel content to test, actual content does not matter as long as it + // can be compared after roundtrip. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = (65535 - x * y) ^ seed; + uint16_t g = (x << 8) + y + seed; + uint16_t b = (y << 8) + x * seed; + uint16_t a = 32768 + x * 256 - y; + // put some shape in there for visual debugging + if (x * x + y * y < 1000) { + std::swap(r, g); + b = 0; + } + size_t i = (y * xsize + x) * 2 * num_channels; + pixels[i + 0] = (r >> 8); + pixels[i + 1] = (r & 255); + if (num_channels >= 2) { + // This may store what is called 'g' in the alpha channel of a 2-channel + // image, but that's ok since the content is arbitrary + pixels[i + 2] = (g >> 8); + pixels[i + 3] = (g & 255); + } + if (num_channels >= 3) { + pixels[i + 4] = (b >> 8); + pixels[i + 5] = (b & 255); + } + if (num_channels >= 4) { + pixels[i + 6] = (a >> 8); + pixels[i + 7] = (a & 255); + } + } + } + return pixels; +} + +// Returns a CodecInOut based on the buf, xsize, ysize, and the assumption +// that the buffer was created using `GetSomeTestImage`. +jxl::CodecInOut SomeTestImageToCodecInOut(const std::vector& buf, + size_t num_channels, size_t xsize, + size_t ysize) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB( + /*is_gray=*/num_channels == 1 || num_channels == 2); + EXPECT_TRUE(ConvertFromExternal( + jxl::Span(buf.data(), buf.size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/num_channels == 1 || + num_channels == 2), + /*has_alpha=*/num_channels == 2 || num_channels == 4, + /*alpha_is_premultiplied=*/false, /*bits_per_sample=*/16, JXL_BIG_ENDIAN, + /*flipped_y=*/false, /*pool=*/nullptr, + /*ib=*/&io.Main())); + return io; +} + +} // namespace test + +bool operator==(const jxl::PaddedBytes& a, const jxl::PaddedBytes& b) { + if (a.size() != b.size()) return false; + if (memcmp(a.data(), b.data(), a.size()) != 0) return false; + return true; +} + +// Allow using EXPECT_EQ on jxl::PaddedBytes +bool operator!=(const jxl::PaddedBytes& a, const jxl::PaddedBytes& b) { + return !(a == b); +} +} // namespace jxl + +#endif // LIB_JXL_TEST_UTILS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/testdata.h b/third_party/jpeg-xl/lib/jxl/testdata.h new file mode 100644 index 000000000000..cf49e81d0eaf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/testdata.h @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_TESTDATA_H_ +#define LIB_JXL_TESTDATA_H_ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "lib/jxl/base/file_io.h" + +namespace jxl { + +static inline PaddedBytes ReadTestData(const std::string& filename) { + std::string full_path = std::string(TEST_DATA_PATH "/") + filename; + PaddedBytes data; + bool ok = ReadFile(full_path, &data); +#ifdef __EMSCRIPTEN__ + // Fallback in case FS is not supported in current JS engine. + if (!ok) { + // {size_t size, uint8_t* bytes} pair. + uint32_t size_bytes[2] = {0, 0}; + EM_ASM( + { + let buffer = null; + try { + buffer = readbuffer(UTF8ToString($0)); + } catch { + } + if (!buffer) return; + let bytes = new Uint8Array(buffer); + let size = bytes.length; + let out = _malloc(size); + if (!out) return; + HEAP8.set(bytes, out); + HEAP32[$1 >> 2] = size; + HEAP32[($1 + 4) >> 2] = out; + }, + full_path.c_str(), size_bytes); + size_t size = size_bytes[0]; + uint8_t* bytes = reinterpret_cast(size_bytes[1]); + if (size) { + data.append(bytes, bytes + size); + free(reinterpret_cast(bytes)); + ok = true; + } + } +#endif + JXL_CHECK(ok); + return data; +} + +} // namespace jxl + +#endif // LIB_JXL_TESTDATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/tf_gbench.cc b/third_party/jpeg-xl/lib/jxl/tf_gbench.cc new file mode 100644 index 000000000000..e56fdb62408b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/tf_gbench.cc @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" +#include "lib/jxl/image_ops.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/tf_gbench.cc" +#include +#include + +#include "lib/jxl/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#define RUN_BENCHMARK(F) \ + constexpr size_t kNum = 1 << 12; \ + HWY_FULL(float) d; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + auto sum1 = Zero(d); \ + auto sum2 = Zero(d); \ + auto sum3 = Zero(d); \ + for (auto _ : state) { \ + auto x = Set(d, 1e-5); \ + auto v1 = Set(d, 1e-5); \ + auto v2 = Set(d, 1.1e-5); \ + auto v3 = Set(d, 1.2e-5); \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(d, v1); \ + sum2 += F(d, v2); \ + sum3 += F(d, v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * Lanes(d) * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +#define RUN_BENCHMARK_SCALAR(F) \ + constexpr size_t kNum = 1 << 12; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + float sum1 = 0, sum2 = 0, sum3 = 0; \ + for (auto _ : state) { \ + float x = 1e-5; \ + float v1 = 1e-5; \ + float v2 = 1.1e-5; \ + float v3 = 1.2e-5; \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(v1); \ + sum2 += F(v2); \ + sum3 += F(v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +HWY_NOINLINE void BM_FastSRGB(benchmark::State& state) { + RUN_BENCHMARK(FastLinearToSRGB); +} + +HWY_NOINLINE void BM_TFSRGB(benchmark::State& state) { + RUN_BENCHMARK(TF_SRGB().EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQDFE(benchmark::State& state) { + RUN_BENCHMARK(TF_PQ().DisplayFromEncoded); +} + +HWY_NOINLINE void BM_PQEFD(benchmark::State& state) { + RUN_BENCHMARK(TF_PQ().EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQSlowDFE(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ().DisplayFromEncoded); +} + +HWY_NOINLINE void BM_PQSlowEFD(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ().EncodedFromDisplay); +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(BM_FastSRGB); +HWY_EXPORT(BM_TFSRGB); +HWY_EXPORT(BM_PQDFE); +HWY_EXPORT(BM_PQEFD); +HWY_EXPORT(BM_PQSlowDFE); +HWY_EXPORT(BM_PQSlowEFD); + +float SRGB_pow(float x) { + return x < 0.0031308f ? 12.92f * x : 1.055f * powf(x, 1.0f / 2.4f) - 0.055f; +} + +void BM_FastSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_FastSRGB)(state); +} +void BM_TFSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_TFSRGB)(state); +} +void BM_PQDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQDFE)(state); +} +void BM_PQEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQEFD)(state); +} +void BM_PQSlowDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowDFE)(state); +} +void BM_PQSlowEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowEFD)(state); +} + +void BM_SRGB_pow(benchmark::State& state) { RUN_BENCHMARK_SCALAR(SRGB_pow); } + +BENCHMARK(BM_FastSRGB); +BENCHMARK(BM_TFSRGB); +BENCHMARK(BM_SRGB_pow); +BENCHMARK(BM_PQDFE); +BENCHMARK(BM_PQEFD); +BENCHMARK(BM_PQSlowDFE); +BENCHMARK(BM_PQSlowEFD); + +} // namespace +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/toc.cc b/third_party/jpeg-xl/lib/jxl/toc.cc new file mode 100644 index 000000000000..35c2b7a6b1c5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc.cc @@ -0,0 +1,106 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/toc.h" + +#include + +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" +#include "lib/jxl/fields.h" + +namespace jxl { +size_t MaxBits(const size_t num_sizes) { + const size_t entry_bits = U32Coder::MaxEncodedBits(kTocDist) * num_sizes; + // permutation bit (not its tokens!), padding, entries, padding. + return 1 + kBitsPerByte + entry_bits + kBitsPerByte; +} + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT offsets, + std::vector* JXL_RESTRICT sizes, + uint64_t* total_size) { + if (toc_entries > 65536) { + // Prevent out of memory if invalid JXL codestream causes a bogus amount + // of toc_entries such as 2720436919446 to be computed. + // TODO(lode): verify whether 65536 is a reasonable upper bound + return JXL_FAILURE("too many toc entries"); + } + + const auto check_bit_budget = [&](size_t num_entries) -> Status { + // U32Coder reads 2 bits to recognize variant and kTocDist cheapest variant + // is Bits(10), this way at least 12 bits are required per toc-entry. + size_t minimal_bit_cost = num_entries * (2 + 10); + size_t bit_budget = reader->TotalBytes() * 8; + size_t expenses = reader->TotalBitsConsumed(); + if ((expenses <= bit_budget) && + (minimal_bit_cost <= bit_budget - expenses)) { + return true; + } + return JXL_STATUS(StatusCode::kNotEnoughBytes, "Not enough bytes for TOC"); + }; + + JXL_DASSERT(offsets != nullptr && sizes != nullptr); + std::vector permutation; + if (reader->ReadFixedBits<1>() == 1 && toc_entries > 0) { + // Skip permutation description if the toc_entries is 0. + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + permutation.resize(toc_entries); + JXL_RETURN_IF_ERROR( + DecodePermutation(/*skip=*/0, toc_entries, permutation.data(), reader)); + } + + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + sizes->clear(); + sizes->reserve(toc_entries); + for (size_t i = 0; i < toc_entries; ++i) { + sizes->push_back(U32Coder::Read(kTocDist, reader)); + } + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(0)); + + // Prefix sum starting with 0 and ending with the offset of the last group + offsets->clear(); + offsets->reserve(toc_entries); + uint64_t offset = 0; + for (size_t i = 0; i < toc_entries; ++i) { + if (offset + (*sizes)[i] < offset) { + return JXL_FAILURE("group offset overflow"); + } + offsets->push_back(offset); + offset += (*sizes)[i]; + } + if (total_size) { + *total_size = offset; + } + + if (!permutation.empty()) { + std::vector permuted_offsets; + std::vector permuted_sizes; + permuted_offsets.reserve(toc_entries); + permuted_sizes.reserve(toc_entries); + for (coeff_order_t index : permutation) { + permuted_offsets.push_back((*offsets)[index]); + permuted_sizes.push_back((*sizes)[index]); + } + std::swap(*offsets, permuted_offsets); + std::swap(*sizes, permuted_sizes); + } + + return true; +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/toc.h b/third_party/jpeg-xl/lib/jxl/toc.h new file mode 100644 index 000000000000..6955e19837a9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc.h @@ -0,0 +1,59 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_TOC_H_ +#define LIB_JXL_TOC_H_ + +#include +#include + +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// (2+bits) = 2,3,4 bytes so encoders can patch TOC after encoding. +// 30 is sufficient for 4K channels of uncompressed 16-bit samples. +constexpr U32Enc kTocDist(Bits(10), BitsOffset(14, 1024), BitsOffset(22, 17408), + BitsOffset(30, 4211712)); + +size_t MaxBits(const size_t num_sizes); + +// TODO(veluca): move these to FrameDimensions. +static JXL_INLINE size_t AcGroupIndex(size_t pass, size_t group, + size_t num_groups, size_t num_dc_groups, + bool has_ac_global) { + return 1 + num_dc_groups + static_cast(has_ac_global) + + pass * num_groups + group; +} + +static JXL_INLINE size_t NumTocEntries(size_t num_groups, size_t num_dc_groups, + size_t num_passes, bool has_ac_global) { + if (num_groups == 1 && num_passes == 1) return 1; + return AcGroupIndex(0, 0, num_groups, num_dc_groups, has_ac_global) + + num_groups * num_passes; +} + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector* JXL_RESTRICT offsets, + std::vector* JXL_RESTRICT sizes, + uint64_t* total_size); + +} // namespace jxl + +#endif // LIB_JXL_TOC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/toc_test.cc b/third_party/jpeg-xl/lib/jxl/toc_test.cc new file mode 100644 index 000000000000..e235214ed895 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc_test.cc @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/toc.h" + +#include + +#include "gtest/gtest.h" +#include "lib/jxl/aux_out_fwd.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/common.h" +#include "lib/jxl/enc_toc.h" + +namespace jxl { +namespace { + +void Roundtrip(size_t num_entries, bool permute, std::mt19937* rng) { + // Generate a random permutation. + std::vector permutation(num_entries); + std::vector inv_permutation(num_entries); + for (size_t i = 0; i < num_entries; i++) { + permutation[i] = i; + inv_permutation[i] = i; + } + if (permute) { + std::shuffle(permutation.begin(), permutation.end(), *rng); + for (size_t i = 0; i < num_entries; i++) { + inv_permutation[permutation[i]] = i; + } + } + + // Generate num_entries groups of random (byte-aligned) length + std::vector group_codes(num_entries); + for (BitWriter& writer : group_codes) { + const size_t max_bits = (*rng)() & 0xFFF; + BitWriter::Allotment allotment(&writer, max_bits + kBitsPerByte); + size_t i = 0; + for (; i + BitWriter::kMaxBitsPerCall < max_bits; + i += BitWriter::kMaxBitsPerCall) { + writer.Write(BitWriter::kMaxBitsPerCall, 0); + } + for (; i < max_bits; i += 1) { + writer.Write(/*n_bits=*/1, 0); + } + writer.ZeroPadToByte(); + AuxOut aux_out; + ReclaimAndCharge(&writer, &allotment, 0, &aux_out); + } + + BitWriter writer; + AuxOut aux_out; + ASSERT_TRUE(WriteGroupOffsets(group_codes, permute ? &permutation : nullptr, + &writer, &aux_out)); + + BitReader reader(writer.GetSpan()); + std::vector group_offsets; + std::vector group_sizes; + uint64_t total_size; + ASSERT_TRUE(ReadGroupOffsets(num_entries, &reader, &group_offsets, + &group_sizes, &total_size)); + ASSERT_EQ(num_entries, group_offsets.size()); + ASSERT_EQ(num_entries, group_sizes.size()); + EXPECT_TRUE(reader.Close()); + + uint64_t prefix_sum = 0; + for (size_t i = 0; i < num_entries; ++i) { + EXPECT_EQ(prefix_sum, group_offsets[inv_permutation[i]]); + + EXPECT_EQ(0, group_codes[i].BitsWritten() % kBitsPerByte); + prefix_sum += group_codes[i].BitsWritten() / kBitsPerByte; + + if (i + 1 < num_entries) { + EXPECT_EQ( + group_offsets[inv_permutation[i]] + group_sizes[inv_permutation[i]], + group_offsets[inv_permutation[i + 1]]); + } + } + EXPECT_EQ(prefix_sum, total_size); +} + +TEST(TocTest, Test) { + std::mt19937 rng(12345); + for (size_t num_entries = 0; num_entries < 10; ++num_entries) { + for (bool permute : std::vector{false, true}) { + Roundtrip(num_entries, permute, &rng); + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/transfer_functions-inl.h b/third_party/jpeg-xl/lib/jxl/transfer_functions-inl.h new file mode 100644 index 000000000000..b2cd83aff490 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/transfer_functions-inl.h @@ -0,0 +1,364 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Transfer functions for color encodings. + +#if defined(LIB_JXL_TRANSFER_FUNCTIONS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#undef LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#else +#define LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ +#endif + +#include +#include +#include + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/rational_polynomial-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Definitions for BT.2100-2 transfer functions (used inside/outside SIMD): +// "display" is linear light (nits) normalized to [0, 1]. +// "encoded" is a nonlinear encoding (e.g. PQ) in [0, 1]. +// "scene" is a linear function of photon counts, normalized to [0, 1]. + +// Despite the stated ranges, we need unbounded transfer functions: see +// http://www.littlecms.com/CIC18_UnboundedCMM.pdf. Inputs can be negative or +// above 1 due to chromatic adaptation. To avoid severe round-trip errors caused +// by clamping, we mirror negative inputs via copysign (f(-x) = -f(x), see +// https://developer.apple.com/documentation/coregraphics/cgcolorspace/1644735-extendedsrgb) +// and extend the function domains above 1. + +// Hybrid Log-Gamma. +class TF_HLG { + public: + // EOTF. e = encoded. + JXL_INLINE double DisplayFromEncoded(const double e) const { + const double lifted = e * (1.0 - kBeta) + kBeta; + return OOTF(InvOETF(lifted)); + } + + // Inverse EOTF. d = display. + JXL_INLINE double EncodedFromDisplay(const double d) const { + const double lifted = OETF(InvOOTF(d)); + const double e = (lifted - kBeta) * (1.0 / (1.0 - kBeta)); + return e; + } + + private: + // OETF (defines the HLG approach). s = scene, returns encoded. + JXL_INLINE double OETF(double s) const { + if (s == 0.0) return 0.0; + const double original_sign = s; + s = std::abs(s); + + if (s <= kDiv12) return copysignf(std::sqrt(3.0 * s), original_sign); + + const double e = kA * std::log(12 * s - kB) + kC; + JXL_ASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + // e = encoded, returns scene. + JXL_INLINE double InvOETF(double e) const { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + if (e <= 0.5) return copysignf(e * e * (1.0 / 3), original_sign); + + const double s = (std::exp((e - kC) * kRA) + kB) * kDiv12; + JXL_ASSERT(s >= 0); + return copysignf(s, original_sign); + } + + // s = scene, returns display. + JXL_INLINE double OOTF(const double s) const { + // The actual (red channel) OOTF is RD = alpha * YS^(gamma-1) * RS, where + // YS = 0.2627 * RS + 0.6780 * GS + 0.0593 * BS. Let alpha = 1 so we return + // "display" (normalized [0, 1]) instead of nits. Our transfer function + // interface does not allow a dependency on YS. Fortunately, the system + // gamma at 334 nits is 1.0, so this reduces to RD = RS. + return s; + } + + // d = display, returns scene. + JXL_INLINE double InvOOTF(const double d) const { + return d; // see OOTF(). + } + + // Assume 1000:1 contrast @ 200 nits => gamma 0.9 + static constexpr double kBeta = 0.04; // = sqrt(3 * contrast^(1/gamma)) + + static constexpr double kA = 0.17883277; + static constexpr double kRA = 1.0 / kA; + static constexpr double kB = 1 - 4 * kA; + static constexpr double kC = 0.5599107295; + static constexpr double kDiv12 = 1.0 / 12; +}; + +// Perceptual Quantization +class TF_PQ { + public: + // EOTF (defines the PQ approach). e = encoded. + JXL_INLINE double DisplayFromEncoded(double e) const { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + const double xp = std::pow(e, 1.0 / kM2); + const double num = std::max(xp - kC1, 0.0); + const double den = kC2 - kC3 * xp; + JXL_DASSERT(den != 0.0); + const double d = std::pow(num / den, 1.0 / kM1); + JXL_DASSERT(d >= 0.0); // Equal for e ~= 1E-9 + return copysignf(d, original_sign); + } + + // Maximum error 3e-6 + template + JXL_INLINE V DisplayFromEncoded(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x+x*x. This improves + // the maximum error by about 5x over a rational polynomial for x. + auto xpxx = MulAdd(x, x, x); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(2.62975656e-04f), HWY_REP4(-6.23553089e-03f), + HWY_REP4(7.38602301e-01f), HWY_REP4(2.64553172e+00f), + HWY_REP4(5.50034862e-01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(4.21350107e+02f), HWY_REP4(-4.28736818e+02f), + HWY_REP4(1.74364667e+02f), HWY_REP4(-3.39078883e+01f), + HWY_REP4(2.67718770e+00f), + }; + auto magnitude = EvalRationalPolynomial(d, xpxx, p, q); + return Or(AndNot(kSign, magnitude), original_sign); + } + + // Inverse EOTF. d = display. + JXL_INLINE double EncodedFromDisplay(double d) const { + if (d == 0.0) return 0.0; + const double original_sign = d; + d = std::abs(d); + + const double xp = std::pow(d, kM1); + const double num = kC1 + xp * kC2; + const double den = 1.0 + xp * kC3; + const double e = std::pow(num / den, kM2); + JXL_DASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + // Maximum error 7e-7. + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x**0.25, with two + // different polynomials above and below 1e-4. + auto xto025 = Sqrt(Sqrt(x)); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(1.351392e-02f), HWY_REP4(-1.095778e+00f), + HWY_REP4(5.522776e+01f), HWY_REP4(1.492516e+02f), + HWY_REP4(4.838434e+01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(1.012416e+00f), HWY_REP4(2.016708e+01f), + HWY_REP4(9.263710e+01f), HWY_REP4(1.120607e+02f), + HWY_REP4(2.590418e+01f), + }; + + HWY_ALIGN constexpr float plo[(4 + 1) * 4] = { + HWY_REP4(9.863406e-06f), HWY_REP4(3.881234e-01f), + HWY_REP4(1.352821e+02f), HWY_REP4(6.889862e+04f), + HWY_REP4(-2.864824e+05f), + }; + HWY_ALIGN constexpr float qlo[(4 + 1) * 4] = { + HWY_REP4(3.371868e+01f), HWY_REP4(1.477719e+03f), + HWY_REP4(1.608477e+04f), HWY_REP4(-4.389884e+04f), + HWY_REP4(-2.072546e+05f), + }; + + auto magnitude = IfThenElse(x < Set(d, 1e-4f), + EvalRationalPolynomial(d, xto025, plo, qlo), + EvalRationalPolynomial(d, xto025, p, q)); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + static constexpr double kM1 = 2610.0 / 16384; + static constexpr double kM2 = (2523.0 / 4096) * 128; + static constexpr double kC1 = 3424.0 / 4096; + static constexpr double kC2 = (2413.0 / 4096) * 32; + static constexpr double kC3 = (2392.0 / 4096) * 32; +}; + +// sRGB +class TF_SRGB { + public: + template + JXL_INLINE V DisplayFromEncoded(V x) const { + const HWY_FULL(float) d; + const HWY_FULL(uint32_t) du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // TODO(janwas): range reduction + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, + 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, + 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, + 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, + 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 2.631846970e-01f, 2.631846970e-01f, 2.631846970e-01f, + 2.631846970e-01f, 1.076976492e+00f, 1.076976492e+00f, + 1.076976492e+00f, 1.076976492e+00f, 4.987528350e-01f, + 4.987528350e-01f, 4.987528350e-01f, 4.987528350e-01f, + -5.512498495e-02f, -5.512498495e-02f, -5.512498495e-02f, + -5.512498495e-02f, 6.521209011e-03f, 6.521209011e-03f, + 6.521209011e-03f, 6.521209011e-03f, + }; + const V linear = x * Set(d, kLowDivInv); + const V poly = EvalRationalPolynomial(d, x, p, q); + const V magnitude = + IfThenElse(x > Set(d, kThreshSRGBToLinear), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + // Error ~5e-07 + template + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + -5.135152395e-04f, -5.135152395e-04f, -5.135152395e-04f, + -5.135152395e-04f, 5.287254571e-03f, 5.287254571e-03f, + 5.287254571e-03f, 5.287254571e-03f, 3.903842876e-01f, + 3.903842876e-01f, 3.903842876e-01f, 3.903842876e-01f, + 1.474205315e+00f, 1.474205315e+00f, 1.474205315e+00f, + 1.474205315e+00f, 7.352629620e-01f, 7.352629620e-01f, + 7.352629620e-01f, 7.352629620e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, + 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, + 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, + 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, + 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, + }; + const V linear = x * Set(d, kLowDiv); + const V poly = EvalRationalPolynomial(d, Sqrt(x), p, q); + const V magnitude = + IfThenElse(x > Set(d, kThreshLinearToSRGB), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + static constexpr float kThreshSRGBToLinear = 0.04045f; + static constexpr float kThreshLinearToSRGB = 0.0031308f; + static constexpr float kLowDiv = 12.92f; + static constexpr float kLowDivInv = 1.0f / kLowDiv; +}; + +// Linear to sRGB conversion with error of at most 1.2e-4. +template +V FastLinearToSRGB(D d, V v) { + const hwy::HWY_NAMESPACE::Rebind du; + const hwy::HWY_NAMESPACE::Rebind di; + // Convert to 0.25 - 0.5 range. + auto v025_05 = + BitCast(d, (BitCast(du, v) | Set(du, 0x3e800000)) & Set(du, 0x3effffff)); + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) * 0.5. A degree 4 polynomial only improves + // accuracy by about 3x. + auto d1 = MulAdd(v025_05, Set(d, 0.059914046f), Set(d, -0.108894556f)); + auto d2 = MulAdd(d1, v025_05, Set(d, 0.107963754f)); + auto pow = MulAdd(d2, v025_05, Set(d, 0.018092343f)); + // Compute extra multiplier depending on exponent. Valid exponent range for + // [0.0031308f, 1.0) is 0...8 after subtracting 118. + // The next three constants contain a representation of the powers of + // 2**(1/2.4) = 2**(5/12) times two; in particular, bits from 26 to 31 are + // always the same and in k2to512powers_basebits, and the two arrays contain + // the next groups of 8 bits. This ends up being a 22-bit representation (with + // a mantissa of 13 bits). The choice of polynomial to approximate is such + // that the multiplication factor has the highest 5 bits constant, and that + // the factor for the lowest possible exponent is a power of two (thus making + // the additional bits 0, which is used to correctly merge back together the + // floats). + constexpr uint32_t k2to512powers_basebits = 0x40000000; + HWY_ALIGN constexpr uint8_t k2to512powers_25to18bits[16] = { + 0x0, 0xa, 0x19, 0x26, 0x32, 0x41, 0x4d, 0x5c, + 0x68, 0x75, 0x83, 0x8f, 0xa0, 0xaa, 0xb9, 0xc6, + }; + HWY_ALIGN constexpr uint8_t k2to512powers_17to10bits[16] = { + 0x0, 0xb7, 0x4, 0xd, 0xcb, 0xe7, 0x41, 0x68, + 0x51, 0xd1, 0xeb, 0xf2, 0x0, 0xb7, 0x4, 0xd, + }; + // Note that vld1q_s8_x2 on ARM seems to actually be slower. +#if HWY_TARGET != HWY_SCALAR + using hwy::HWY_NAMESPACE::ShiftLeft; + using hwy::HWY_NAMESPACE::ShiftRight; + // Every lane of exp is now (if cast to byte) {0, 0, 0, }. + auto exp = ShiftRight<23>(BitCast(di, v)) - Set(di, 118); + auto pow25to18bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast(k2to512powers_25to18bits)), + exp); + auto pow17to10bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast(k2to512powers_17to10bits)), + exp); + // Now, pow* contain {0, 0, 0, }. Here + // we take advantage of the fact that each table has its position 0 equal to + // 0. + // We can now just reassemble the float. + auto mul = + BitCast(d, ShiftLeft<18>(pow25to18bits) | ShiftLeft<10>(pow17to10bits) | + Set(di, k2to512powers_basebits)); +#else + // Fallback for scalar. + uint32_t exp = ((BitCast(di, v).raw >> 23) - 118) & 0xf; + auto mul = BitCast(d, Set(di, (k2to512powers_25to18bits[exp] << 18) | + (k2to512powers_17to10bits[exp] << 10) | + k2to512powers_basebits)); +#endif + return IfThenElse(v < Set(d, 0.0031308f), v * Set(d, 12.92f), + MulAdd(pow, mul, Set(d, -0.055))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_TRANSFER_FUNCTIONS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/transpose-inl.h b/third_party/jpeg-xl/lib/jxl/transpose-inl.h new file mode 100644 index 000000000000..72119c628ffa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/transpose-inl.h @@ -0,0 +1,210 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Block transpose for DCT/IDCT + +#if defined(LIB_JXL_TRANSPOSE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_TRANSPOSE_INL_H_ +#undef LIB_JXL_TRANSPOSE_INL_H_ +#else +#define LIB_JXL_TRANSPOSE_INL_H_ +#endif + +#include + +#include +#include + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_block-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#ifndef JXL_INLINE_TRANSPOSE +// Workaround for issue #42 - (excessive?) inlining causes invalid codegen. +#if defined(__arm__) +#define JXL_INLINE_TRANSPOSE HWY_NOINLINE +#else +#define JXL_INLINE_TRANSPOSE HWY_INLINE +#endif +#endif // JXL_INLINE_TRANSPOSE + +// Simple wrapper that ensures that a function will not be inlined. +template +JXL_NOINLINE void NoInlineWrapper(const T& f, const Args&... args) { + return f(args...); +} + +template +struct TransposeSimdTag {}; + +// TODO(veluca): it's not super useful to have this in the SIMD namespace. +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + for (size_t n = 0; n < ROWS; ++n) { + for (size_t m = 0; m < COLS; ++m) { + to.Write(from.Read(n, m), m, n); + } + } +} + +// TODO(veluca): AVX3? +#if HWY_CAP_GE256 +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 8 == 0 && COLS % 8 == 0; +} + +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<8>()) == 8, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 8 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 8 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 8) { + for (size_t m = 0; m < COLS; m += 8) { + auto i0 = from.LoadPart(BlockDesc<8>(), n + 0, m + 0); + auto i1 = from.LoadPart(BlockDesc<8>(), n + 1, m + 0); + auto i2 = from.LoadPart(BlockDesc<8>(), n + 2, m + 0); + auto i3 = from.LoadPart(BlockDesc<8>(), n + 3, m + 0); + auto i4 = from.LoadPart(BlockDesc<8>(), n + 4, m + 0); + auto i5 = from.LoadPart(BlockDesc<8>(), n + 5, m + 0); + auto i6 = from.LoadPart(BlockDesc<8>(), n + 6, m + 0); + auto i7 = from.LoadPart(BlockDesc<8>(), n + 7, m + 0); + // Surprisingly, this straightforward implementation (24 cycles on port5) + // is faster than load128+insert and LoadDup128+ConcatUpperLower+blend. + const auto q0 = InterleaveLower(i0, i2); + const auto q1 = InterleaveLower(i1, i3); + const auto q2 = InterleaveUpper(i0, i2); + const auto q3 = InterleaveUpper(i1, i3); + const auto q4 = InterleaveLower(i4, i6); + const auto q5 = InterleaveLower(i5, i7); + const auto q6 = InterleaveUpper(i4, i6); + const auto q7 = InterleaveUpper(i5, i7); + + const auto r0 = InterleaveLower(q0, q1); + const auto r1 = InterleaveUpper(q0, q1); + const auto r2 = InterleaveLower(q2, q3); + const auto r3 = InterleaveUpper(q2, q3); + const auto r4 = InterleaveLower(q4, q5); + const auto r5 = InterleaveUpper(q4, q5); + const auto r6 = InterleaveLower(q6, q7); + const auto r7 = InterleaveUpper(q6, q7); + + i0 = ConcatLowerLower(r4, r0); + i1 = ConcatLowerLower(r5, r1); + i2 = ConcatLowerLower(r6, r2); + i3 = ConcatLowerLower(r7, r3); + i4 = ConcatUpperUpper(r4, r0); + i5 = ConcatUpperUpper(r5, r1); + i6 = ConcatUpperUpper(r6, r2); + i7 = ConcatUpperUpper(r7, r3); + to.StorePart(BlockDesc<8>(), i0, m + 0, n + 0); + to.StorePart(BlockDesc<8>(), i1, m + 1, n + 0); + to.StorePart(BlockDesc<8>(), i2, m + 2, n + 0); + to.StorePart(BlockDesc<8>(), i3, m + 3, n + 0); + to.StorePart(BlockDesc<8>(), i4, m + 4, n + 0); + to.StorePart(BlockDesc<8>(), i5, m + 5, n + 0); + to.StorePart(BlockDesc<8>(), i6, m + 6, n + 0); + to.StorePart(BlockDesc<8>(), i7, m + 7, n + 0); + } + } +} +#elif HWY_TARGET != HWY_SCALAR +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 4 == 0 && COLS % 4 == 0; +} + +template +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<4>()) == 4, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 4 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 4 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 4) { + for (size_t m = 0; m < COLS; m += 4) { + const auto p0 = from.LoadPart(BlockDesc<4>(), n + 0, m + 0); + const auto p1 = from.LoadPart(BlockDesc<4>(), n + 1, m + 0); + const auto p2 = from.LoadPart(BlockDesc<4>(), n + 2, m + 0); + const auto p3 = from.LoadPart(BlockDesc<4>(), n + 3, m + 0); + + const auto q0 = InterleaveLower(p0, p2); + const auto q1 = InterleaveLower(p1, p3); + const auto q2 = InterleaveUpper(p0, p2); + const auto q3 = InterleaveUpper(p1, p3); + + const auto r0 = InterleaveLower(q0, q1); + const auto r1 = InterleaveUpper(q0, q1); + const auto r2 = InterleaveLower(q2, q3); + const auto r3 = InterleaveUpper(q2, q3); + + to.StorePart(BlockDesc<4>(), r0, m + 0, n + 0); + to.StorePart(BlockDesc<4>(), r1, m + 1, n + 0); + to.StorePart(BlockDesc<4>(), r2, m + 2, n + 0); + to.StorePart(BlockDesc<4>(), r3, m + 3, n + 0); + } + } +} +#else +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { return false; } +#endif + +template +struct Transpose { + template + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag tag; + GenericTransposeBlock(tag, from, to, N, M); + } +}; + +// Avoid inlining and unrolling transposes for large blocks. +template +struct Transpose< + N, M, typename std::enable_if<(N >= 8 && M >= 8 && N * M >= 512)>::type> { + template + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag tag; + constexpr void (*transpose)(TransposeSimdTag, + const From&, const To&, size_t, size_t) = + GenericTransposeBlock<0, 0, From, To>; + NoInlineWrapper(transpose, tag, from, to, N, M); + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_TRANSPOSE_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h b/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h new file mode 100644 index 000000000000..0536a685357a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h @@ -0,0 +1,97 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Fast but weak random generator. + +#if defined(LIB_JXL_XORSHIFT128PLUS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#undef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#else +#define LIB_JXL_XORSHIFT128PLUS_INL_H_ +#endif + +#include + +#include +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; + +// Adapted from https://github.com/vpxyz/xorshift/blob/master/xorshift128plus/ +// (MIT-license) +class Xorshift128Plus { + public: + // 8 independent generators (= single iteration for AVX-512) + enum { N = 8 }; + + explicit HWY_MAYBE_UNUSED Xorshift128Plus(const uint64_t seed) { + // Init state using SplitMix64 generator + s0_[0] = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_[0] = SplitMix64(s0_[0]); + for (size_t i = 1; i < N; ++i) { + s0_[i] = SplitMix64(s1_[i - 1]); + s1_[i] = SplitMix64(s0_[i]); + } + } + + HWY_INLINE HWY_MAYBE_UNUSED void Fill(uint64_t* HWY_RESTRICT random_bits) { +#if HWY_CAP_INTEGER64 + const HWY_FULL(uint64_t) d; + for (size_t i = 0; i < N; i += Lanes(d)) { + auto s1 = Load(d, s0_ + i); + const auto s0 = Load(d, s1_ + i); + const auto bits = s1 + s0; // b, c + Store(s0, d, s0_ + i); + s1 ^= ShiftLeft<23>(s1); + Store(bits, d, random_bits + i); + s1 ^= s0 ^ ShiftRight<18>(s1) ^ ShiftRight<5>(s0); + Store(s1, d, s1_ + i); + } +#else + for (size_t i = 0; i < N; ++i) { + auto s1 = s0_[i]; + const auto s0 = s1_[i]; + const auto bits = s1 + s0; // b, c + s0_[i] = s0; + s1 ^= s1 << 23; + random_bits[i] = bits; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_[i] = s1; + } +#endif + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + HWY_ALIGN uint64_t s0_[N]; + HWY_ALIGN uint64_t s1_[N]; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_XORSHIFT128PLUS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc b/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc new file mode 100644 index 000000000000..c66bbf96c39a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc @@ -0,0 +1,381 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/xorshift128plus_test.cc" +#include +#include +#include + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/xorshift128plus-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::ShiftRight; + +// Define to nonzero in order to print the (new) golden outputs. +#define PRINT_RESULTS 0 + +const size_t kVectors = 64; + +#if PRINT_RESULTS + +template +void Print(const uint64_t (&result)[kNumLanes]) { + printf("{ "); + for (int i = 0; i < kNumLanes; ++i) { + if (i != 0) { + printf(", "); + } + printf("0x%016llXull", result[i]); + } + printf("},\n"); +} + +#else // PRINT_RESULTS + +const uint64_t kExpected[kVectors][Xorshift128Plus::N] = { + {0x6E901576D477CBB1ull, 0xE9E53789195DA2A2ull, 0xB681F6DDA5E0AE99ull, + 0x8EFD18CE21FD6896ull, 0xA898A80DF75CF532ull, 0x50CEB2C9E2DE7E32ull, + 0x3CA7C2FEB25C0DD0ull, 0xA4D0866B80B4D836ull}, + {0x8CD6A1E6233D3A26ull, 0x3D4603ADE98B112Dull, 0xDC427AF674019E36ull, + 0xE28B4D230705AC53ull, 0x7297E9BBA88783DDull, 0x34D3D23CFCD9B41Aull, + 0x5A223615ADBE96B8ull, 0xE5EB529027CFBD01ull}, + {0xC1894CF00DFAC6A2ull, 0x18EDF8AE9085E404ull, 0x8E936625296B4CCDull, + 0x31971EF3A14A899Bull, 0xBE87535FCE0BF26Aull, 0x576F7A752BC6649Full, + 0xA44CBADCE0C6B937ull, 0x3DBA819BB17A353Aull}, + {0x27CE38DFCC1C5EB6ull, 0x920BEB5606340256ull, 0x3986CBC40C9AFC2Cull, + 0xE22BCB3EEB1E191Eull, 0x6E1FCDD3602A8FBAull, 0x052CB044E5415A29ull, + 0x46266646EFB9ECD7ull, 0x8F44914618D29335ull}, + {0xDD30AEDF72A362C5ull, 0xBC1D824E16BB98F4ull, 0x9EA6009C2AA3D2F1ull, + 0xF65C0FBBE17AF081ull, 0x22424D06A8738991ull, 0x8A62763F2B7611D2ull, + 0x2F3E89F722637939ull, 0x84D338BEF50AFD50ull}, + {0x00F46494898E2B0Bull, 0x81239DC4FB8E8003ull, 0x414AD93EC5773FE7ull, + 0x791473C450E4110Full, 0x87F127BF68C959ACull, 0x6429282D695EF67Bull, + 0x661082E11546CBA8ull, 0x5815D53FA5436BFDull}, + {0xB3DEADAB9BE6E0F9ull, 0xAA1B7B8F7CED0202ull, 0x4C5ED437699D279Eull, + 0xA4471727F1CB39D3ull, 0xE439DA193F802F70ull, 0xF89401BB04FA6493ull, + 0x3B08045A4FE898BAull, 0x32137BFE98227950ull}, + {0xFBAE4A092897FEF3ull, 0x0639F6CE56E71C8Eull, 0xF0AD6465C07F0C1Eull, + 0xFF8E28563361DCE5ull, 0xC2013DB7F86BC6B9ull, 0x8EFCC0503330102Full, + 0x3F6B767EA5C4DA40ull, 0xB9864B950B2232E1ull}, + {0x76EB58DE8E5EC22Aull, 0x9BBBF49A18B32F4Full, 0xC8405F02B2B2FAB9ull, + 0xC3E122A5F146BC34ull, 0xC90BB046660F5765ull, 0xB933981310DBECCFull, + 0x5A2A7BFC9126FD1Cull, 0x8BB388C94DF87901ull}, + {0x753EB89AD63EF3C3ull, 0xF24AAF40C89D65ADull, 0x23F68931C1A6AA6Dull, + 0xF47E79BF702C6DD0ull, 0xA3AD113244EE7EAEull, 0xD42CBEA28F793DC3ull, + 0xD896FCF1820F497Cull, 0x042B86D2818948C1ull}, + {0x8F2A4FC5A4265763ull, 0xEC499E6F95EAA10Cull, 0xE3786D4ECCD0DEB5ull, + 0xC725C53D3AC4CC43ull, 0x065A4ACBBF83610Eull, 0x35C61C9FEF167129ull, + 0x7B720AEAA7D70048ull, 0x14206B841377D039ull}, + {0xAD27D78BF96055F6ull, 0x5F43B20FF47ADCD4ull, 0xE184C2401E2BF71Eull, + 0x30B263D78990045Dull, 0xC22F00EBFF9BA201ull, 0xAE7F86522B53A562ull, + 0x2853312BC039F0A4ull, 0x868D619E6549C3C8ull}, + {0xFD5493D8AE9A8371ull, 0x773D5E224DF61B3Bull, 0x5377C54FBB1A8280ull, + 0xCAD4DE3B8265CAFAull, 0xCDF3F19C91EBD5F6ull, 0xC8EA0F182D73BD78ull, + 0x220502D593433FF1ull, 0xB81205E612DC31B1ull}, + {0x8F32A39EAEDA4C70ull, 0x1D4B0914AA4DAC7Full, 0x56EF1570F3A8B405ull, + 0x29812CB17404A592ull, 0x97A2AAF69CAE90F2ull, 0x12BF5E02778BBFE5ull, + 0x9D4B55AD42A05FD2ull, 0x06C2BAB5E6086620ull}, + {0x8DB4B9648302B253ull, 0xD756AD9E3AEA12C7ull, 0x68709B7F11D4B188ull, + 0x7CC299DDCD707A4Bull, 0x97B860C370A7661Dull, 0xCECD314FC20E64F5ull, + 0x55F412CDFB4C7EC3ull, 0x55EE97591193B525ull}, + {0xCF70F3ACA96E6254ull, 0x022FEDECA2E09F46ull, 0x686823DB60AE1ECFull, + 0xFD36190D3739830Eull, 0x74E1C09027F68120ull, 0xB5883A835C093842ull, + 0x93E1EFB927E9E4E3ull, 0xB2721E249D7E5EBEull}, + {0x69B6E21C44188CB8ull, 0x5D6CFB853655A7AAull, 0x3E001A0B425A66DCull, + 0x8C57451103A5138Full, 0x7BF8B4BE18EAB402ull, 0x494102EB8761A365ull, + 0xB33796A9F6A81F0Eull, 0x10005AB3BCCFD960ull}, + {0xB2CF25740AE965DCull, 0x6F7C1DF7EF53D670ull, 0x648DD6087AC2251Eull, + 0x040955D9851D487Dull, 0xBD550FC7E21A7F66ull, 0x57408F484DEB3AB5ull, + 0x481E24C150B506C1ull, 0x72C0C3EAF91A40D6ull}, + {0x1997A481858A5D39ull, 0x539718F4BEF50DC1ull, 0x2EC4DC4787E7E368ull, + 0xFF1CE78879419845ull, 0xE219A93DD6F6DD30ull, 0x85328618D02FEC1Aull, + 0xC86E02D969181B20ull, 0xEBEC8CD8BBA34E6Eull}, + {0x28B55088A16CE947ull, 0xDD25AC11E6350195ull, 0xBD1F176694257B1Cull, + 0x09459CCF9FCC9402ull, 0xF8047341E386C4E4ull, 0x7E8E9A9AD984C6C0ull, + 0xA4661E95062AA092ull, 0x70A9947005ED1152ull}, + {0x4C01CF75DBE98CCDull, 0x0BA076CDFC7373B9ull, 0x6C5E7A004B57FB59ull, + 0x336B82297FD3BC56ull, 0x7990C0BE74E8D60Full, 0xF0275CC00EC5C8C8ull, + 0x6CF29E682DFAD2E9ull, 0xFA4361524BD95D72ull}, + {0x631D2A19FF62F018ull, 0x41C43863B985B3FAull, 0xE052B2267038EFD9ull, + 0xE2A535FAC575F430ull, 0xE004EEA90B1FF5B8ull, 0x42DFE2CA692A1F26ull, + 0x90FB0BFC9A189ECCull, 0x4484102BD3536BD0ull}, + {0xD027134E9ACCA5A5ull, 0xBBAB4F966D476A9Bull, 0x713794A96E03D693ull, + 0x9F6335E6B94CD44Aull, 0xC5090C80E7471617ull, 0x6D9C1B0C87B58E33ull, + 0x1969CE82E31185A5ull, 0x2099B97E87754EBEull}, + {0x60EBAF4ED934350Full, 0xC26FBF0BA5E6ECFFull, 0x9E54150F0312EC57ull, + 0x0973B48364ED0041ull, 0x800A523241426CFCull, 0x03AB5EC055F75989ull, + 0x8CF315935DEEB40Aull, 0x83D3FC0190BD1409ull}, + {0x26D35394CF720A51ull, 0xCE9EAA15243CBAFEull, 0xE2B45FBAF21B29E0ull, + 0xDB92E98EDE73F9E0ull, 0x79B16F5101C26387ull, 0x1AC15959DE88C86Full, + 0x387633AEC6D6A580ull, 0xA6FC05807BFC5EB8ull}, + {0x2D26C8E47C6BADA9ull, 0x820E6EC832D52D73ull, 0xB8432C3E0ED0EE5Bull, + 0x0F84B3C4063AAA87ull, 0xF393E4366854F651ull, 0x749E1B4D2366A567ull, + 0x805EACA43480D004ull, 0x244EBF3AA54400A5ull}, + {0xBFDC3763AA79F75Aull, 0x9E3A74CC751F41DBull, 0xF401302A149DBC55ull, + 0x6B25F7973D7BF7BCull, 0x13371D34FDBC3DAEull, 0xC5E1998C8F484DCDull, + 0x7031B8AE5C364464ull, 0x3847F0C4F3DA2C25ull}, + {0x24C6387D2C0F1225ull, 0x77CCE960255C67A4ull, 0x21A0947E497B10EBull, + 0xBB5DB73A825A9D7Eull, 0x26294A41999E553Dull, 0x3953E0089F87D925ull, + 0x3DAE6E5D4E5EAAFEull, 0x74B545460341A7AAull}, + {0x710E5EB08A7DB820ull, 0x7E43C4E77CAEA025ull, 0xD4C91529C8B060C1ull, + 0x09AE26D8A7B0CA29ull, 0xAB9F356BB360A772ull, 0xB68834A25F19F6E9ull, + 0x79B8D9894C5734E2ull, 0xC6847E7C8FFD265Full}, + {0x10C4BCB06A5111E6ull, 0x57CB50955B6A2516ull, 0xEF53C87798B6995Full, + 0xAB38E15BBD8D0197ull, 0xA51C6106EFF73C93ull, 0x83D7F0E2270A7134ull, + 0x0923FD330397FCE5ull, 0xF9DE54EDFE58FB45ull}, + {0x07D44833ACCD1A94ull, 0xAAD3C9E945E2F9F3ull, 0xABF4C879B876AA37ull, + 0xF29C69A21B301619ull, 0x2DDCE959111C788Bull, 0x7CEDB48F8AC1729Bull, + 0x93F3BA9A02B659BEull, 0xF20A87FF17933CBEull}, + {0x8E96EBE93180CFE6ull, 0x94CAA12873937079ull, 0x05F613D9380D4189ull, + 0xBCAB40C1DC79F38Aull, 0x0AD8907B7C61D19Eull, 0x88534E189D103910ull, + 0x2DB2FAABA160AB8Full, 0xA070E7506B06F15Cull}, + {0x6FB1FCDAFFEF87A9ull, 0xE735CF25337A090Dull, 0x172C6EDCEFEF1825ull, + 0x76957EA49EF0542Dull, 0x819BF4CD250F7C49ull, 0xD6FF23E4AD00C4D4ull, + 0xE79673C1EC358FF0ull, 0xAC9C048144337938ull}, + {0x4C5387FF258B3AF4ull, 0xEDB68FAEC2CB1AA3ull, 0x02A624E67B4E1DA4ull, + 0x5C44797A38E08AF2ull, 0x36546A70E9411B4Bull, 0x47C17B24D2FD9675ull, + 0x101957AAA020CA26ull, 0x47A1619D4779F122ull}, + {0xF84B8BCDC92D9A3Cull, 0x951D7D2C74B3066Bull, 0x7AC287C06EDDD9B2ull, + 0x4C38FC476608D38Full, 0x224D793B19CB4BCDull, 0x835A255899BF1A41ull, + 0x4AD250E9F62DB4ABull, 0xD9B44F4B58781096ull}, + {0xABBAF99A8EB5C6B8ull, 0xFB568E900D3A9F56ull, 0x11EDF63D23C5DF11ull, + 0xA9C3011D3FA7C5A8ull, 0xAEDD3CF11AFFF725ull, 0xABCA472B5F1EDD6Bull, + 0x0600B6BB5D879804ull, 0xDB4DE007F22191A0ull}, + {0xD76CC9EFF0CE9392ull, 0xF5E0A772B59BA49Aull, 0x7D1AE1ED0C1261B5ull, + 0x79224A33B5EA4F4Aull, 0x6DD825D80C40EA60ull, 0x47FC8E747E51C953ull, + 0x695C05F72888BF98ull, 0x1A012428440B9015ull}, + {0xD754DD61F9B772BFull, 0xC4A2FCF4C0F9D4EBull, 0x461167CDF67A24A2ull, + 0x434748490EBCB9D4ull, 0x274DD9CDCA5781DEull, 0x36BAC63BA9A85209ull, + 0x30324DAFDA36B70Full, 0x337570DB4FE6DAB3ull}, + {0xF46CBDD57C551546ull, 0x8E02507E676DA3E3ull, 0xD826245A8C15406Dull, + 0xDFB38A5B71113B72ull, 0x5EA38454C95B16B5ull, 0x28C054FB87ABF3E1ull, + 0xAA2724C0BA1A8096ull, 0xECA83EC980304F2Full}, + {0x6AA76EC294EB3303ull, 0x42D4CDB2A8032E3Bull, 0x7999EDF75DCD8735ull, + 0xB422BFFE696CCDCCull, 0x8F721461FD7CCDFEull, 0x148E1A5814FDE253ull, + 0x4DC941F4375EF8FFull, 0x27B2A9E0EB5B49CFull}, + {0xCEA592EF9343EBE1ull, 0xF7D38B5FA7698903ull, 0x6CCBF352203FEAB6ull, + 0x830F3095FCCDA9C5ull, 0xDBEEF4B81B81C8F4ull, 0x6D7EB9BCEECA5CF9ull, + 0xC58ABB0FBE436C69ull, 0xE4B97E6DB2041A4Bull}, + {0x7E40FC772978AF14ull, 0xCDDA4BBAE28354A1ull, 0xE4F993B832C32613ull, + 0xD3608093C68A4B35ull, 0x9A3B60E01BEE3699ull, 0x03BEF248F3288713ull, + 0x70B9294318F3E9B4ull, 0x8D2ABB913B8610DEull}, + {0x37F209128E7D8B2Cull, 0x81D2AB375BD874BCull, 0xA716A1B7373F7408ull, + 0x0CEE97BEC4706540ull, 0xA40C5FD9CDBC1512ull, 0x73CAF6C8918409E7ull, + 0x45E11BCEDF0BBAA1ull, 0x612C612BFF6E6605ull}, + {0xF8ECB14A12D0F649ull, 0xDA683CD7C01BA1ACull, 0xA2203F7510E124C1ull, + 0x7F83E52E162F3C78ull, 0x77D2BB73456ACADBull, 0x37FC34FC840BBA6Full, + 0x3076BC7D4C6EBC1Full, 0x4F514123632B5FA9ull}, + {0x44D789DED935E884ull, 0xF8291591E09FEC9Full, 0xD9CED2CF32A2E4B7ull, + 0x95F70E1EB604904Aull, 0xDE438FE43C14F6ABull, 0x4C8D23E4FAFCF8D8ull, + 0xC716910A3067EB86ull, 0x3D6B7915315095D3ull}, + {0x3170FDBADAB92095ull, 0x8F1963933FC5650Bull, 0x72F94F00ABECFEABull, + 0x6E3AE826C6AAB4CEull, 0xA677A2BF31068258ull, 0x9660CDC4F363AF10ull, + 0xD81A15A152379EF1ull, 0x5D7D285E1080A3F9ull}, + {0xDAD5DDFF9A2249B3ull, 0x6F9721D926103FAEull, 0x1418CBB83FFA349Aull, + 0xE71A30AD48C012B2ull, 0xBE76376C63751132ull, 0x3496467ACA713AE6ull, + 0x8D7EC01369F991A3ull, 0xD8C73A88B96B154Eull}, + {0x8B5D9C74AEB4833Aull, 0xF914FB3F867B912Full, 0xB894EA034936B1DCull, + 0x8A16D21BE51C4F5Bull, 0x31FF048ED582D98Eull, 0xB95AB2F4DC65B820ull, + 0x04082B9170561AF7ull, 0xA215610A5DC836FAull}, + {0xB2ADE592C092FAACull, 0x7A1E683BCBF13294ull, 0xC7A4DBF86858C096ull, + 0x3A49940F97BFF316ull, 0xCAE5C06B82C46703ull, 0xC7F413A0F951E2BDull, + 0x6665E7BB10EB5916ull, 0x86F84A5A94EDE319ull}, + {0x4EA199D8FAA79CA3ull, 0xDFA26E5BF1981704ull, 0x0F5E081D37FA4E01ull, + 0x9CB632F89CD675CDull, 0x4A09DB89D48C0304ull, 0x88142742EA3C7672ull, + 0xAC4F149E6D2E9BDBull, 0x6D9E1C23F8B1C6C6ull}, + {0xD58BE47B92DEC0E9ull, 0x8E57573645E34328ull, 0x4CC094CCB5FB5126ull, + 0x5F1D66AF6FB40E3Cull, 0x2BA15509132D3B00ull, 0x0D6545646120E567ull, + 0x3CF680C45C223666ull, 0x96B28E32930179DAull}, + {0x5900C45853AC7990ull, 0x61881E3E8B7FF169ull, 0x4DE5F835DF2230FFull, + 0x4427A9E7932F73FFull, 0x9B641BAD379A8C8Dull, 0xDF271E5BF98F4E5Cull, + 0xDFDA16DB830FF5EEull, 0x371C7E7CFB89C0E9ull}, + {0x4410A8576247A250ull, 0x6AD2DA12B45AC0D9ull, 0x18DFC72AAC85EECCull, + 0x06FC8BB2A0EF25C8ull, 0xEB287619C85E6118ull, 0x19553ECA67F25A2Cull, + 0x3B9557F1DCEC5BAAull, 0x7BAD9E8B710D1079ull}, + {0x34F365D66BD22B28ull, 0xE6E124B9F10F835Dull, 0x0573C38ABF2B24DCull, + 0xD32E6AF10A0125AEull, 0x383590ACEA979519ull, 0x8376ED7A39E28205ull, + 0xF0B7F184DCBDA435ull, 0x062A203390E31794ull}, + {0xA2AFFD7E41918760ull, 0x7F90FC1BD0819C86ull, 0x5033C08E5A969533ull, + 0x2707AF5C6D039590ull, 0x57BBD5980F17DF9Cull, 0xD3FE6E61D763268Aull, + 0x9E0A0AE40F335A3Bull, 0x43CF4EB0A99613C5ull}, + {0xD4D2A397CE1A7C2Eull, 0x3DF7CE7CC3212DADull, 0x0880F0D5D356C75Aull, + 0xA8AFC44DD03B1346ull, 0x79263B46C13A29E0ull, 0x11071B3C0ED58E7Aull, + 0xED46DC9F538406BFull, 0x2C94974F2B94843Dull}, + {0xE246E13C39AB5D5Eull, 0xAC1018489D955B20ull, 0x8601B558771852B8ull, + 0x110BD4C06DB40173ull, 0x738FC8A18CCA0EBBull, 0x6673E09BE0EA76E5ull, + 0x024BC7A0C7527877ull, 0x45E6B4652E2EC34Eull}, + {0xD1ED26A1A375CDC8ull, 0xAABC4E896A617CB8ull, 0x0A9C9E8E57D753C6ull, + 0xA3774A75FEB4C30Eull, 0x30B816C01C93E49Eull, 0xF405BABC06D2408Cull, + 0xCC0CE6B4CE788ABCull, 0x75E7922D0447956Cull}, + {0xD07C1676A698BC95ull, 0x5F9AEA4840E2D860ull, 0xD5FC10D58BDF6F02ull, + 0xF190A2AD4BC2EEA7ull, 0x0C24D11F51726931ull, 0xDB646899A16B6512ull, + 0x7BC10670047B1DD8ull, 0x2413A5ABCD45F092ull}, + {0x4E66892190CFD923ull, 0xF10162440365EC8Eull, 0x158ACA5A6A2280AEull, + 0x0D60ED11C0224166ull, 0x7CD2E9A71B9D7488ull, 0x450D7289706AB2A3ull, + 0x88FAE34EC9A0D7DCull, 0x96FF9103575A97DAull}, + {0x77990FAC6046C446ull, 0xB174B5FB30C76676ull, 0xE352CE3EB56CF82Aull, + 0xC6039B6873A9A082ull, 0xE3F80F3AE333148Aull, 0xB853BA24BA3539B9ull, + 0xE8863E52ECCB0C74ull, 0x309B4CC1092CC245ull}, + {0xBC2B70BEE8388D9Full, 0xE48D92AE22216DCEull, 0xF15F3BF3E2C15D8Full, + 0x1DD964D4812D8B24ull, 0xD56AF02FB4665E4Cull, 0x98002200595BD9A3ull, + 0x049246D50BB8FA12ull, 0x1B542DF485B579B9ull}, + {0x2347409ADFA8E497ull, 0x36015C2211D62498ull, 0xE9F141F32EB82690ull, + 0x1F839912D0449FB9ull, 0x4E4DCFFF2D02D97Cull, 0xF8A03AB4C0F625C9ull, + 0x0605F575795DAC5Cull, 0x4746C9BEA0DDA6B1ull}, + {0xCA5BB519ECE7481Bull, 0xFD496155E55CA945ull, 0xF753B9DBB1515F81ull, + 0x50549E8BAC0F70E7ull, 0x8614FB0271E21C60ull, 0x60C72947EB0F0070ull, + 0xA6511C10AEE742B6ull, 0x48FB48F2CACCB43Eull}}; + +#endif // PRINT_RESULTS + +// Ensures Xorshift128+ returns consistent and unchanging values. +void TestGolden() { + HWY_ALIGN Xorshift128Plus rng(12345); + for (uint64_t vector = 0; vector < kVectors; ++vector) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + rng.Fill(lanes); +#if PRINT_RESULTS + Print(lanes); +#else + for (size_t i = 0; i < Xorshift128Plus::N; ++i) { + ASSERT_EQ(kExpected[vector][i], lanes[i]) + << "Where vector=" << vector << " i=" << i; + } +#endif + } +} + +// Output changes when given different seeds +void TestSeedChanges() { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + std::vector first; + constexpr size_t kNumSeeds = 16384; + first.reserve(kNumSeeds); + + // All 14-bit seeds + for (size_t seed = 0; seed < kNumSeeds; ++seed) { + HWY_ALIGN Xorshift128Plus rng(seed); + + rng.Fill(lanes); + first.push_back(lanes[0]); + } + + // All outputs are unique + ASSERT_EQ(kNumSeeds, first.size()); + std::sort(first.begin(), first.end()); + first.erase(std::unique(first.begin(), first.end()), first.end()); + EXPECT_EQ(kNumSeeds, first.size()); +} + +void TestFloat() { + ThreadPoolInternal pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 2048; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 16384; // All 14-bit seeds +#endif // JXL_DISABLE_SLOW_TESTS + pool.Run(0, kMaxSeed, ThreadPool::SkipInit(), + [](const int seed, const int /*thread*/) { + HWY_ALIGN Xorshift128Plus rng(seed); + + const HWY_FULL(uint32_t) du; + const HWY_FULL(float) df; + HWY_ALIGN uint64_t batch[Xorshift128Plus::N]; + HWY_ALIGN float lanes[MaxLanes(df)]; + double sum = 0.0; + size_t count = 0; + const size_t kReps = 2000; + for (size_t reps = 0; reps < kReps; ++reps) { + rng.Fill(batch); + for (size_t i = 0; i < Xorshift128Plus::N * 2; i += Lanes(df)) { + const auto bits = + Load(du, reinterpret_cast(batch) + i); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = + BitCast(df, ShiftRight<9>(bits) | Set(du, 0x3F800000)); + const auto rand01 = rand12 - Set(df, 1.0f); + Store(rand01, df, lanes); + for (float lane : lanes) { + sum += lane; + count += 1; + EXPECT_LE(lane, 1.0f); + EXPECT_GE(lane, 0.0f); + } + } + } + + // Verify average (uniform distribution) + EXPECT_NEAR(0.5, sum / count, 0.00702); + }); +} + +// Not more than one 64-bit zero +void TestNotZero() { + ThreadPoolInternal pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 500; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 2000; +#endif // JXL_DISABLE_SLOW_TESTS + pool.Run(0, kMaxSeed, ThreadPool::SkipInit(), + [](const int task, const int /*thread*/) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + HWY_ALIGN Xorshift128Plus rng(task); + size_t num_zero = 0; + for (size_t vectors = 0; vectors < 10000; ++vectors) { + rng.Fill(lanes); + for (uint64_t lane : lanes) { + num_zero += static_cast(lane == 0); + } + } + EXPECT_LE(num_zero, 1); + }); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class Xorshift128Test : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(Xorshift128Test); + +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestNotZero); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestGolden); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestSeedChanges); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestFloat); + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl_benchmark.cmake b/third_party/jpeg-xl/lib/jxl_benchmark.cmake new file mode 100644 index 000000000000..0f33c55c6d95 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl_benchmark.cmake @@ -0,0 +1,47 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# All files ending in "_gbench.cc" are considered Google benchmark files and +# should be listed here. +set(JPEGXL_INTERNAL_SOURCES_GBENCH + extras/tone_mapping_gbench.cc + jxl/dec_external_image_gbench.cc + jxl/enc_external_image_gbench.cc + jxl/splines_gbench.cc + jxl/tf_gbench.cc +) + +# benchmark.h doesn't work in our MINGW set up since it ends up including the +# wrong stdlib header. We don't run gbench on MINGW targets anyway. +if(NOT MINGW) + +# This is the Google benchmark project (https://github.com/google/benchmark). +find_package(benchmark QUIET) + +if(benchmark_FOUND) + # Compiles all the benchmark files into a single binary. Individual benchmarks + # can be run with --benchmark_filter. + add_executable(jxl_gbench "${JPEGXL_INTERNAL_SOURCES_GBENCH}") + + target_compile_definitions(jxl_gbench PRIVATE + -DTEST_DATA_PATH="${PROJECT_SOURCE_DIR}/third_party/testdata") + target_link_libraries(jxl_gbench + jxl_extras-static + jxl-static + benchmark::benchmark + benchmark::benchmark_main + ) +endif() # benchmark_FOUND + +endif() # MINGW diff --git a/third_party/jpeg-xl/lib/jxl_extras.cmake b/third_party/jpeg-xl/lib/jxl_extras.cmake new file mode 100644 index 000000000000..c81e0ebe3c5a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl_extras.cmake @@ -0,0 +1,118 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(JPEGXL_EXTRAS_SOURCES + extras/codec.cc + extras/codec.h + # codec_jpg is included always for loading of lossless reconstruction but + # decoding to pixels is only supported if libjpeg is found and + # JPEGXL_ENABLE_JPEG=1. + extras/codec_jpg.cc + extras/codec_jpg.h + extras/codec_pgx.cc + extras/codec_pgx.h + extras/codec_png.cc + extras/codec_png.h + extras/codec_pnm.cc + extras/codec_pnm.h + extras/codec_psd.cc + extras/codec_psd.h + extras/tone_mapping.cc + extras/tone_mapping.h +) + +# We only define a static library for jxl_extras since it uses internal parts +# of jxl library which are not accessible from outside the library in the +# shared library case. +add_library(jxl_extras-static STATIC "${JPEGXL_EXTRAS_SOURCES}") +target_compile_options(jxl_extras-static PRIVATE "${JPEGXL_INTERNAL_FLAGS}") +set_property(TARGET jxl_extras-static PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(jxl_extras-static PUBLIC "${PROJECT_SOURCE_DIR}") +target_link_libraries(jxl_extras-static PUBLIC + jxl-static + lodepng +) + +find_package(GIF 5) +if(GIF_FOUND) + target_sources(jxl_extras-static PRIVATE + extras/codec_gif.cc + extras/codec_gif.h + ) + target_include_directories(jxl_extras-static PUBLIC "${GIF_INCLUDE_DIRS}") + target_link_libraries(jxl_extras-static PUBLIC ${GIF_LIBRARIES}) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_GIF=1) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libgif-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libgif COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +find_package(JPEG) +if(JPEG_FOUND) + target_include_directories(jxl_extras-static PUBLIC "${JPEG_INCLUDE_DIRS}") + target_link_libraries(jxl_extras-static PUBLIC ${JPEG_LIBRARIES}) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_JPEG=1) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libjpeg-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libjpeg COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +find_package(ZLIB) # dependency of PNG +find_package(PNG) +if(PNG_FOUND AND ZLIB_FOUND) + target_sources(jxl_extras-static PRIVATE + extras/codec_apng.cc + extras/codec_apng.h + ) + target_include_directories(jxl_extras-static PUBLIC "${PNG_INCLUDE_DIRS}") + target_link_libraries(jxl_extras-static PUBLIC ${PNG_LIBRARIES}) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_APNG=1) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/zlib1g-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.zlib COPYONLY) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libpng-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libpng COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +if (JPEGXL_ENABLE_SJPEG) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_SJPEG=1) + target_link_libraries(jxl_extras-static PUBLIC sjpeg) +endif () + +if (JPEGXL_ENABLE_OPENEXR) +pkg_check_modules(OpenEXR IMPORTED_TARGET OpenEXR) +if (OpenEXR_FOUND) + target_sources(jxl_extras-static PRIVATE + extras/codec_exr.cc + extras/codec_exr.h + ) + target_compile_definitions(jxl_extras-static PUBLIC -DJPEGXL_ENABLE_EXR=1) + target_link_libraries(jxl_extras-static PUBLIC PkgConfig::OpenEXR) + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libopenexr-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.libopenexr COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR + # OpenEXR generates exceptions, so we need exception support to catch them. + # Actully those flags counteract the ones set in JPEGXL_INTERNAL_FLAGS. + if (NOT WIN32) + set_source_files_properties(extras/codec_exr.cc PROPERTIES COMPILE_FLAGS -fexceptions) + if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + set_source_files_properties(extras/codec_exr.cc PROPERTIES COMPILE_FLAGS -fcxx-exceptions) + endif() + endif() +endif() # OpenEXR_FOUND +endif() # JPEGXL_ENABLE_OPENEXR diff --git a/third_party/jpeg-xl/lib/jxl_profiler.cmake b/third_party/jpeg-xl/lib/jxl_profiler.cmake new file mode 100644 index 000000000000..f9154e390d76 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl_profiler.cmake @@ -0,0 +1,40 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(JPEGXL_PROFILER_SOURCES + profiler/profiler.cc + profiler/profiler.h + profiler/tsc_timer.h +) + +### Static library. +add_library(jxl_profiler STATIC ${JPEGXL_PROFILER_SOURCES}) +target_link_libraries(jxl_profiler hwy) + +target_compile_options(jxl_profiler PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(jxl_profiler PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET jxl_profiler PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(jxl_profiler + PRIVATE "${PROJECT_SOURCE_DIR}") + +set_target_properties(jxl_profiler PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 +) + +# Make every library linking against the jxl_profiler define this macro to +# enable the profiler. +target_compile_definitions(jxl_profiler + PUBLIC -DPROFILER_ENABLED) diff --git a/third_party/jpeg-xl/lib/jxl_tests.cmake b/third_party/jpeg-xl/lib/jxl_tests.cmake new file mode 100644 index 000000000000..c2a4d841b4cf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl_tests.cmake @@ -0,0 +1,142 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(TEST_FILES + extras/codec_test.cc + jxl/ac_strategy_test.cc + jxl/adaptive_reconstruction_test.cc + jxl/alpha_test.cc + jxl/ans_common_test.cc + jxl/ans_test.cc + jxl/bit_reader_test.cc + jxl/bits_test.cc + jxl/blending_test.cc + jxl/butteraugli_test.cc + jxl/byte_order_test.cc + jxl/coeff_order_test.cc + jxl/color_encoding_internal_test.cc + jxl/color_management_test.cc + jxl/compressed_image_test.cc + jxl/convolve_test.cc + jxl/data_parallel_test.cc + jxl/dct_test.cc + jxl/decode_test.cc + jxl/descriptive_statistics_test.cc + jxl/enc_external_image_test.cc + jxl/encode_test.cc + jxl/entropy_coder_test.cc + jxl/fast_math_test.cc + jxl/fields_test.cc + jxl/filters_internal_test.cc + jxl/gaborish_test.cc + jxl/gamma_correct_test.cc + jxl/gauss_blur_test.cc + jxl/gradient_test.cc + jxl/iaca_test.cc + jxl/icc_codec_test.cc + jxl/image_bundle_test.cc + jxl/image_ops_test.cc + jxl/jxl_test.cc + jxl/lehmer_code_test.cc + jxl/linalg_test.cc + jxl/modular_test.cc + jxl/opsin_image_test.cc + jxl/opsin_inverse_test.cc + jxl/optimize_test.cc + jxl/padded_bytes_test.cc + jxl/passes_test.cc + jxl/patch_dictionary_test.cc + jxl/preview_test.cc + jxl/quant_weights_test.cc + jxl/quantizer_test.cc + jxl/rational_polynomial_test.cc + jxl/robust_statistics_test.cc + jxl/roundtrip_test.cc + jxl/speed_tier_test.cc + jxl/splines_test.cc + jxl/toc_test.cc + jxl/xorshift128plus_test.cc + threads/thread_parallel_runner_test.cc + ### Files before this line are handled by build_cleaner.py + # TODO(deymo): Move this to tools/ + ../tools/box/box_test.cc +) + +# Test-only library code. +set(TESTLIB_FILES + jxl/dct_for_test.h + jxl/dec_transforms_testonly.cc + jxl/dec_transforms_testonly.h + jxl/image_test_utils.h + jxl/test_utils.h + jxl/testdata.h +) + +find_package(GTest) + +# Library with test-only code shared between all tests. +add_library(jxl_testlib-static STATIC ${TESTLIB_FILES}) + target_compile_options(jxl_testlib-static PRIVATE + ${JPEGXL_INTERNAL_FLAGS} + ${JPEGXL_COVERAGE_FLAGS} + ) +target_compile_definitions(jxl_testlib-static PUBLIC + -DTEST_DATA_PATH="${PROJECT_SOURCE_DIR}/third_party/testdata") +target_include_directories(jxl_testlib-static PUBLIC + "${PROJECT_SOURCE_DIR}" +) +target_link_libraries(jxl_testlib-static hwy) + +# Individual test binaries: +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +foreach (TESTFILE IN LISTS TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + add_executable(${TESTNAME} ${TESTFILE}) + if(JPEGXL_EMSCRIPTEN) + # The emscripten linking step takes too much memory and crashes during the + # wasm-opt step when using -O2 optimization level + set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "\ + -O1 \ + -s TOTAL_MEMORY=1536MB \ + -s SINGLE_FILE=1 \ + ") + endif() + target_compile_options(${TESTNAME} PRIVATE + ${JPEGXL_INTERNAL_FLAGS} + # Add coverage flags to the test binary so code in the private headers of + # the library is also instrumented when running tests that execute it. + ${JPEGXL_COVERAGE_FLAGS} + ) + target_link_libraries(${TESTNAME} + box + jxl-static + jxl_threads-static + jxl_extras-static + jxl_testlib-static + gmock + GTest::GTest + GTest::Main + ) + # Output test targets in the test directory. + set_target_properties(${TESTNAME} PROPERTIES PREFIX "tests/") + if (WIN32 AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set_target_properties(${TESTNAME} PROPERTIES COMPILE_FLAGS "-Wno-error") + endif () + if(${CMAKE_VERSION} VERSION_LESS "3.10.3") + gtest_discover_tests(${TESTNAME} TIMEOUT 240) + else () + gtest_discover_tests(${TESTNAME} DISCOVERY_TIMEOUT 240) + endif () +endforeach () diff --git a/third_party/jpeg-xl/lib/jxl_threads.cmake b/third_party/jpeg-xl/lib/jxl_threads.cmake new file mode 100644 index 000000000000..0cea13982faa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl_threads.cmake @@ -0,0 +1,108 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_package(Threads REQUIRED) + +set(JPEGXL_THREADS_SOURCES + threads/thread_parallel_runner.cc + threads/thread_parallel_runner_internal.cc + threads/thread_parallel_runner_internal.h +) + +### Define the jxl_threads shared or static target library. The ${target} +# parameter should already be created with add_library(), but this function +# sets all the remaining common properties. +function(_set_jxl_threads _target) + +target_compile_options(${_target} PRIVATE ${JPEGXL_INTERNAL_FLAGS}) +target_compile_options(${_target} PUBLIC ${JPEGXL_COVERAGE_FLAGS}) +set_property(TARGET ${_target} PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(${_target} + PRIVATE + "${PROJECT_SOURCE_DIR}" + PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_BINARY_DIR}/include") + +target_link_libraries(${_target} + PUBLIC ${JPEGXL_COVERAGE_FLAGS} Threads::Threads +) + +set_target_properties(${_target} PROPERTIES + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1 + DEFINE_SYMBOL JXL_THREADS_INTERNAL_LIBRARY_BUILD +) + +# Always install the library as jxl_threads.{a,so} file without the "-static" +# suffix, except in Windows. +if (NOT WIN32) + set_target_properties(${_target} PROPERTIES OUTPUT_NAME "jxl_threads") +endif() +install(TARGETS ${_target} DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +endfunction() + + +### Static library. +add_library(jxl_threads-static STATIC ${JPEGXL_THREADS_SOURCES}) +_set_jxl_threads(jxl_threads-static) + +# Make jxl_threads symbols neither imported nor exported when using the static +# library. These will have hidden visibility anyway in the static library case +# in unix. +target_compile_definitions(jxl_threads-static + PUBLIC -DJXL_THREADS_STATIC_DEFINE) + + +### Public shared library. +if (((NOT DEFINED "${TARGET_SUPPORTS_SHARED_LIBS}") OR + TARGET_SUPPORTS_SHARED_LIBS) AND NOT JPEGXL_STATIC) +add_library(jxl_threads SHARED ${JPEGXL_THREADS_SOURCES}) +_set_jxl_threads(jxl_threads) + +set_target_properties(jxl_threads PROPERTIES + VERSION ${JPEGXL_LIBRARY_VERSION} + SOVERSION ${JPEGXL_LIBRARY_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + +# Compile the shared library such that the JXL_THREADS_EXPORT symbols are +# exported. Users of the library will not set this flag and therefore import +# those symbols. +target_compile_definitions(jxl_threads + PRIVATE -DJXL_THREADS_INTERNAL_LIBRARY_BUILD) + +# Generate the jxl/jxl_threads_export.h header, we only need to generate it once +# but we can use it from both libraries. +generate_export_header(jxl_threads + BASE_NAME JXL_THREADS + EXPORT_FILE_NAME include/jxl/jxl_threads_export.h) +else() +add_library(jxl_threads ALIAS jxl_threads-static) +# When not building the shared library generate the jxl_threads_export.h header +# only based on the static target. +generate_export_header(jxl_threads-static + BASE_NAME JXL_THREADS + EXPORT_FILE_NAME include/jxl/jxl_threads_export.h) +endif() # TARGET_SUPPORTS_SHARED_LIBS AND NOT JPEGXL_STATIC + + +### Add a pkg-config file for libjxl_threads. +set(JPEGXL_THREADS_LIBRARY_REQUIRES "") +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/threads/libjxl_threads.pc.in" + "libjxl_threads.pc" @ONLY) +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/libjxl_threads.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") diff --git a/third_party/jpeg-xl/lib/lib.gni b/third_party/jpeg-xl/lib/lib.gni new file mode 100644 index 000000000000..99857ebf6fb5 --- /dev/null +++ b/third_party/jpeg-xl/lib/lib.gni @@ -0,0 +1,383 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Source files definitions for GN-based build systems. + +# Library version macros +libjxl_version_defines = [ + "JPEGXL_MAJOR_VERSION=0", + "JPEGXL_MINOR_VERSION=3", + "JPEGXL_PATCH_VERSION=7", +] + +libjxl_dec_sources = [ + "jxl/ac_context.h", + "jxl/ac_strategy.cc", + "jxl/ac_strategy.h", + "jxl/alpha.cc", + "jxl/alpha.h", + "jxl/ans_common.cc", + "jxl/ans_common.h", + "jxl/ans_params.h", + "jxl/aux_out.cc", + "jxl/aux_out.h", + "jxl/aux_out_fwd.h", + "jxl/base/arch_macros.h", + "jxl/base/bits.h", + "jxl/base/byte_order.h", + "jxl/base/cache_aligned.cc", + "jxl/base/cache_aligned.h", + "jxl/base/compiler_specific.h", + "jxl/base/data_parallel.cc", + "jxl/base/data_parallel.h", + "jxl/base/descriptive_statistics.cc", + "jxl/base/descriptive_statistics.h", + "jxl/base/file_io.h", + "jxl/base/iaca.h", + "jxl/base/os_macros.h", + "jxl/base/override.h", + "jxl/base/padded_bytes.cc", + "jxl/base/padded_bytes.h", + "jxl/base/profiler.h", + "jxl/base/robust_statistics.h", + "jxl/base/span.h", + "jxl/base/status.cc", + "jxl/base/status.h", + "jxl/base/thread_pool_internal.h", + "jxl/base/time.cc", + "jxl/base/time.h", + "jxl/blending.cc", + "jxl/blending.h", + "jxl/chroma_from_luma.cc", + "jxl/chroma_from_luma.h", + "jxl/codec_in_out.h", + "jxl/coeff_order.cc", + "jxl/coeff_order.h", + "jxl/coeff_order_fwd.h", + "jxl/color_encoding_internal.cc", + "jxl/color_encoding_internal.h", + "jxl/color_management.cc", + "jxl/color_management.h", + "jxl/common.h", + "jxl/compressed_dc.cc", + "jxl/compressed_dc.h", + "jxl/convolve-inl.h", + "jxl/convolve.cc", + "jxl/convolve.h", + "jxl/dct-inl.h", + "jxl/dct_block-inl.h", + "jxl/dct_scales.cc", + "jxl/dct_scales.h", + "jxl/dct_util.h", + "jxl/dec_ans.cc", + "jxl/dec_ans.h", + "jxl/dec_bit_reader.h", + "jxl/dec_cache.h", + "jxl/dec_context_map.cc", + "jxl/dec_context_map.h", + "jxl/dec_external_image.cc", + "jxl/dec_external_image.h", + "jxl/dec_frame.cc", + "jxl/dec_frame.h", + "jxl/dec_group.cc", + "jxl/dec_group.h", + "jxl/dec_group_border.cc", + "jxl/dec_group_border.h", + "jxl/dec_huffman.cc", + "jxl/dec_huffman.h", + "jxl/dec_modular.cc", + "jxl/dec_modular.h", + "jxl/dec_noise.cc", + "jxl/dec_noise.h", + "jxl/dec_params.h", + "jxl/dec_patch_dictionary.cc", + "jxl/dec_patch_dictionary.h", + "jxl/dec_reconstruct.cc", + "jxl/dec_reconstruct.h", + "jxl/dec_transforms-inl.h", + "jxl/dec_upsample.cc", + "jxl/dec_upsample.h", + "jxl/dec_xyb-inl.h", + "jxl/dec_xyb.cc", + "jxl/dec_xyb.h", + "jxl/decode.cc", + "jxl/enc_bit_writer.cc", + "jxl/enc_bit_writer.h", + "jxl/entropy_coder.cc", + "jxl/entropy_coder.h", + "jxl/epf.cc", + "jxl/epf.h", + "jxl/fast_math-inl.h", + "jxl/field_encodings.h", + "jxl/fields.cc", + "jxl/fields.h", + "jxl/filters.cc", + "jxl/filters.h", + "jxl/filters_internal.h", + "jxl/frame_header.cc", + "jxl/frame_header.h", + "jxl/gauss_blur.cc", + "jxl/gauss_blur.h", + "jxl/headers.cc", + "jxl/headers.h", + "jxl/huffman_table.cc", + "jxl/huffman_table.h", + "jxl/icc_codec.cc", + "jxl/icc_codec.h", + "jxl/icc_codec_common.cc", + "jxl/icc_codec_common.h", + "jxl/image.cc", + "jxl/image.h", + "jxl/image_bundle.cc", + "jxl/image_bundle.h", + "jxl/image_metadata.cc", + "jxl/image_metadata.h", + "jxl/image_ops.h", + "jxl/jpeg/dec_jpeg_data.cc", + "jxl/jpeg/dec_jpeg_data.h", + "jxl/jpeg/dec_jpeg_data_writer.cc", + "jxl/jpeg/dec_jpeg_data_writer.h", + "jxl/jpeg/dec_jpeg_output_chunk.h", + "jxl/jpeg/dec_jpeg_serialization_state.h", + "jxl/jpeg/jpeg_data.cc", + "jxl/jpeg/jpeg_data.h", + "jxl/jxl_inspection.h", + "jxl/lehmer_code.h", + "jxl/linalg.h", + "jxl/loop_filter.cc", + "jxl/loop_filter.h", + "jxl/luminance.cc", + "jxl/luminance.h", + "jxl/memory_manager_internal.cc", + "jxl/memory_manager_internal.h", + "jxl/modular/encoding/context_predict.h", + "jxl/modular/encoding/dec_ma.cc", + "jxl/modular/encoding/dec_ma.h", + "jxl/modular/encoding/encoding.cc", + "jxl/modular/encoding/encoding.h", + "jxl/modular/encoding/ma_common.h", + "jxl/modular/modular_image.cc", + "jxl/modular/modular_image.h", + "jxl/modular/options.h", + "jxl/modular/transform/near-lossless.h", + "jxl/modular/transform/palette.h", + "jxl/modular/transform/squeeze.h", + "jxl/modular/transform/subtractgreen.h", + "jxl/modular/transform/transform.cc", + "jxl/modular/transform/transform.h", + "jxl/noise.h", + "jxl/noise_distributions.h", + "jxl/opsin_params.cc", + "jxl/opsin_params.h", + "jxl/passes_state.cc", + "jxl/passes_state.h", + "jxl/patch_dictionary_internal.h", + "jxl/quant_weights.cc", + "jxl/quant_weights.h", + "jxl/quantizer-inl.h", + "jxl/quantizer.cc", + "jxl/quantizer.h", + "jxl/rational_polynomial-inl.h", + "jxl/splines.cc", + "jxl/splines.h", + "jxl/toc.cc", + "jxl/toc.h", + "jxl/transfer_functions-inl.h", + "jxl/transpose-inl.h", + "jxl/xorshift128plus-inl.h", +] + +libjxl_enc_sources = [ + "jxl/butteraugli/butteraugli.cc", + "jxl/butteraugli/butteraugli.h", + "jxl/butteraugli_wrapper.cc", + "jxl/dec_file.cc", + "jxl/dec_file.h", + "jxl/enc_ac_strategy.cc", + "jxl/enc_ac_strategy.h", + "jxl/enc_adaptive_quantization.cc", + "jxl/enc_adaptive_quantization.h", + "jxl/enc_ans.cc", + "jxl/enc_ans.h", + "jxl/enc_ans_params.h", + "jxl/enc_ar_control_field.cc", + "jxl/enc_ar_control_field.h", + "jxl/enc_butteraugli_comparator.cc", + "jxl/enc_butteraugli_comparator.h", + "jxl/enc_butteraugli_pnorm.cc", + "jxl/enc_butteraugli_pnorm.h", + "jxl/enc_cache.cc", + "jxl/enc_cache.h", + "jxl/enc_chroma_from_luma.cc", + "jxl/enc_chroma_from_luma.h", + "jxl/enc_cluster.cc", + "jxl/enc_cluster.h", + "jxl/enc_coeff_order.cc", + "jxl/enc_coeff_order.h", + "jxl/enc_color_management.cc", + "jxl/enc_color_management.h", + "jxl/enc_comparator.cc", + "jxl/enc_comparator.h", + "jxl/enc_context_map.cc", + "jxl/enc_context_map.h", + "jxl/enc_detect_dots.cc", + "jxl/enc_detect_dots.h", + "jxl/enc_dot_dictionary.cc", + "jxl/enc_dot_dictionary.h", + "jxl/enc_entropy_coder.cc", + "jxl/enc_entropy_coder.h", + "jxl/enc_external_image.cc", + "jxl/enc_external_image.h", + "jxl/enc_fast_heuristics.cc", + "jxl/enc_file.cc", + "jxl/enc_file.h", + "jxl/enc_frame.cc", + "jxl/enc_frame.h", + "jxl/enc_gamma_correct.h", + "jxl/enc_group.cc", + "jxl/enc_group.h", + "jxl/enc_heuristics.cc", + "jxl/enc_heuristics.h", + "jxl/enc_huffman.cc", + "jxl/enc_huffman.h", + "jxl/enc_icc_codec.cc", + "jxl/enc_icc_codec.h", + "jxl/enc_image_bundle.cc", + "jxl/enc_image_bundle.h", + "jxl/enc_modular.cc", + "jxl/enc_modular.h", + "jxl/enc_noise.cc", + "jxl/enc_noise.h", + "jxl/enc_params.h", + "jxl/enc_patch_dictionary.cc", + "jxl/enc_patch_dictionary.h", + "jxl/enc_quant_weights.cc", + "jxl/enc_quant_weights.h", + "jxl/enc_splines.cc", + "jxl/enc_splines.h", + "jxl/enc_toc.cc", + "jxl/enc_toc.h", + "jxl/enc_transforms-inl.h", + "jxl/enc_transforms.cc", + "jxl/enc_transforms.h", + "jxl/enc_xyb.cc", + "jxl/enc_xyb.h", + "jxl/encode.cc", + "jxl/encode_internal.h", + "jxl/gaborish.cc", + "jxl/gaborish.h", + "jxl/huffman_tree.cc", + "jxl/huffman_tree.h", + "jxl/jpeg/enc_jpeg_data.cc", + "jxl/jpeg/enc_jpeg_data.h", + "jxl/jpeg/enc_jpeg_data_reader.cc", + "jxl/jpeg/enc_jpeg_data_reader.h", + "jxl/jpeg/enc_jpeg_huffman_decode.cc", + "jxl/jpeg/enc_jpeg_huffman_decode.h", + "jxl/linalg.cc", + "jxl/modular/encoding/enc_encoding.cc", + "jxl/modular/encoding/enc_encoding.h", + "jxl/modular/encoding/enc_ma.cc", + "jxl/modular/encoding/enc_ma.h", + "jxl/optimize.cc", + "jxl/optimize.h", + "jxl/progressive_split.cc", + "jxl/progressive_split.h", +] + +libjxl_gbench_sources = [ + "extras/tone_mapping_gbench.cc", + "jxl/dec_external_image_gbench.cc", + "jxl/enc_external_image_gbench.cc", + "jxl/splines_gbench.cc", + "jxl/tf_gbench.cc", +] + +libjxl_tests_sources = [ + "jxl/ac_strategy_test.cc", + "jxl/adaptive_reconstruction_test.cc", + "jxl/alpha_test.cc", + "jxl/ans_common_test.cc", + "jxl/ans_test.cc", + "jxl/bit_reader_test.cc", + "jxl/bits_test.cc", + "jxl/blending_test.cc", + "jxl/butteraugli_test.cc", + "jxl/byte_order_test.cc", + "jxl/coeff_order_test.cc", + "jxl/color_encoding_internal_test.cc", + "jxl/color_management_test.cc", + "jxl/compressed_image_test.cc", + "jxl/convolve_test.cc", + "jxl/data_parallel_test.cc", + "jxl/dct_test.cc", + "jxl/decode_test.cc", + "jxl/descriptive_statistics_test.cc", + "jxl/enc_external_image_test.cc", + "jxl/encode_test.cc", + "jxl/entropy_coder_test.cc", + "jxl/fast_math_test.cc", + "jxl/fields_test.cc", + "jxl/filters_internal_test.cc", + "jxl/gaborish_test.cc", + "jxl/gamma_correct_test.cc", + "jxl/gauss_blur_test.cc", + "jxl/gradient_test.cc", + "jxl/iaca_test.cc", + "jxl/icc_codec_test.cc", + "jxl/image_bundle_test.cc", + "jxl/image_ops_test.cc", + "jxl/jxl_test.cc", + "jxl/lehmer_code_test.cc", + "jxl/linalg_test.cc", + "jxl/modular_test.cc", + "jxl/opsin_image_test.cc", + "jxl/opsin_inverse_test.cc", + "jxl/optimize_test.cc", + "jxl/padded_bytes_test.cc", + "jxl/passes_test.cc", + "jxl/patch_dictionary_test.cc", + "jxl/preview_test.cc", + "jxl/quant_weights_test.cc", + "jxl/quantizer_test.cc", + "jxl/rational_polynomial_test.cc", + "jxl/robust_statistics_test.cc", + "jxl/roundtrip_test.cc", + "jxl/speed_tier_test.cc", + "jxl/splines_test.cc", + "jxl/toc_test.cc", + "jxl/xorshift128plus_test.cc", +] + +# Test-only library code. +libjxl_testlib_sources = [ + "jxl/dct_for_test.h", + "jxl/dec_transforms_testonly.cc", + "jxl/dec_transforms_testonly.h", + "jxl/image_test_utils.h", + "jxl/test_utils.h", + "jxl/testdata.h", +] + +libjxl_threads_sources = [ + "threads/thread_parallel_runner.cc", + "threads/thread_parallel_runner_internal.cc", + "threads/thread_parallel_runner_internal.h", +] + +libjxl_profiler_sources = [ + "profiler/profiler.cc", + "profiler/profiler.h", + "profiler/tsc_timer.h", +] diff --git a/third_party/jpeg-xl/lib/profiler/profiler.cc b/third_party/jpeg-xl/lib/profiler/profiler.cc new file mode 100644 index 000000000000..bf9ca3a6c9e8 --- /dev/null +++ b/third_party/jpeg-xl/lib/profiler/profiler.cc @@ -0,0 +1,606 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/jxl/base/profiler.h" + +#if PROFILER_ENABLED + +#include +#include +#include // memcpy + +#include // sort +#include +#include // PRIu64 +#include +#include + +#include "lib/jxl/base/robust_statistics.h" + +// Non-portable aspects: +// - 128-bit load/store (write-combining, UpdateOrAdd) +// - RDTSCP timestamps (serializing, high-resolution) +// - assumes string literals are stored within an 8 MiB range +// - compiler-specific annotations (restrict, alignment, fences) + +// How many mebibytes to allocate (if PROFILER_ENABLED) per thread that +// enters at least one zone. Once this buffer is full, the thread will analyze +// and discard packets, thus temporarily adding some observer overhead. +// Each zone occupies 16 bytes. +#ifndef PROFILER_THREAD_STORAGE +#define PROFILER_THREAD_STORAGE 32ULL +#endif + +#define PROFILER_PRINT_OVERHEAD 0 + +#if PROFILER_BUFFER + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Overwrites `to` without loading it into cache (read-for-ownership). +// Copies kCacheLineSize bytes from/to naturally aligned addresses. +void StreamCacheLine(const Packet* JXL_RESTRICT from, Packet* JXL_RESTRICT to) { + constexpr size_t kLanes = 16 / sizeof(Packet); + static_assert(kLanes == 2, "Update descriptor type"); + const HWY_CAPPED(uint64_t, kLanes) d; + JXL_COMPILER_FENCE; + const uint64_t* JXL_RESTRICT from64 = reinterpret_cast(from); + const auto v0 = Load(d, from64 + 0 * kLanes); + const auto v1 = Load(d, from64 + 1 * kLanes); + const auto v2 = Load(d, from64 + 2 * kLanes); + const auto v3 = Load(d, from64 + 3 * kLanes); + // Fences prevent the compiler from reordering loads/stores, which may + // interfere with write-combining. + JXL_COMPILER_FENCE; + uint64_t* JXL_RESTRICT to64 = reinterpret_cast(to); + Stream(v0, d, to64 + 0 * kLanes); + Stream(v1, d, to64 + 1 * kLanes); + Stream(v2, d, to64 + 2 * kLanes); + Stream(v3, d, to64 + 3 * kLanes); + JXL_COMPILER_FENCE; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // PROFILER_BUFFER + +namespace jxl { +namespace { + +// Upper bounds for various fixed-size data structures (guarded via JXL_ASSERT): + +// How many unique threads can enter a zone (those that don't do not count). +// Memory use is about kMaxThreads * PROFILER_THREAD_STORAGE MiB. +// WARNING: fiber libraries and multiple ThreadPool can spawn >100 threads. +constexpr size_t kMaxThreads = 1024; + +// Maximum nesting of zones. +constexpr size_t kMaxDepth = 64; + +// Total number of zones. +constexpr size_t kMaxZones = 256; + +// Returns the address of a string literal. Assuming zone names are also +// literals and stored nearby, we can represent them as offsets, which are +// faster to compute than hashes or even a static index. +// +// This function must not be static - each call (even from other translation +// units) must return the same value. +uintptr_t StringOrigin() { + // Chosen such that no zone name is a prefix nor suffix of this string + // to ensure they aren't merged (offset 0 identifies zone-exit packets). + static const char* string_origin = "__#Origin#__"; + + return reinterpret_cast(string_origin) - Packet::kOffsetBias; +} + +// Representation of an active zone, stored in a stack. Used to deduct +// child duration from the parent's self time. POD. +struct ProfilerNode { + Packet packet; + uint64_t child_total; +}; + +// Holds statistics for all zones with the same name. POD. +struct Accumulator { + static constexpr size_t kNumCallBits = 64 - Packet::kOffsetBits; + + uintptr_t BiasedOffset() const { return num_calls >> kNumCallBits; } + uint64_t NumCalls() const { return num_calls & ((1ULL << kNumCallBits) - 1); } + + // UpdateOrAdd relies upon this layout. + uint64_t num_calls = 0; // upper bits = biased_offset. + uint64_t total_duration = 0; +}; +#if JXL_ARCH_X64 +static_assert(sizeof(Accumulator) == 2 * sizeof(uint64_t), "Accumulator size"); +#endif + +template +inline T ClampedSubtract(const T minuend, const T subtrahend) { + if (subtrahend > minuend) { + return 0; + } + return minuend - subtrahend; +} + +} // namespace + +// Per-thread call graph (stack) and Accumulator for each zone. +class Results { + public: + Results() { + // Zero-initialize first accumulator to avoid a check for num_zones_ == 0. + memset(zones_, 0, sizeof(Accumulator)); + } + + // Used for computing overhead when this thread encounters its first Zone. + // This has no observable effect apart from increasing "analyze_elapsed_". + uint64_t ZoneDuration(const Packet* packets) { + JXL_CHECK(depth_ == 0); + JXL_CHECK(num_zones_ == 0); + AnalyzePackets(packets, 2); + const uint64_t duration = zones_[0].total_duration; + zones_[0].num_calls = 0; + zones_[0].total_duration = 0; + JXL_CHECK(depth_ == 0); + num_zones_ = 0; + return duration; + } + + void SetSelfOverhead(const uint64_t self_overhead) { + self_overhead_ = self_overhead; + } + + void SetChildOverhead(const uint64_t child_overhead) { + child_overhead_ = child_overhead; + } + + // Draw all required information from the packets, which can be discarded + // afterwards. Called whenever this thread's storage is full. + void AnalyzePackets(const Packet* packets, const size_t num_packets) { + const uint64_t t0 = TicksBefore(); + + for (size_t i = 0; i < num_packets; ++i) { + const Packet p = packets[i]; + // Entering a zone + if (p.BiasedOffset() != Packet::kOffsetBias) { + JXL_ASSERT(depth_ < kMaxDepth); + nodes_[depth_].packet = p; + nodes_[depth_].child_total = 0; + ++depth_; + continue; + } + + JXL_ASSERT(depth_ != 0); + const ProfilerNode& node = nodes_[depth_ - 1]; + // Masking correctly handles unsigned wraparound. + const uint64_t duration = + (p.Timestamp() - node.packet.Timestamp()) & Packet::kTimestampMask; + const uint64_t self_duration = ClampedSubtract( + duration, self_overhead_ + child_overhead_ + node.child_total); + + UpdateOrAdd(node.packet.BiasedOffset(), 1, self_duration); + --depth_; + + // Deduct this nested node's time from its parent's self_duration. + if (depth_ != 0) { + nodes_[depth_ - 1].child_total += duration + child_overhead_; + } + } + + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0; + } + + // Incorporates results from another thread. Call after all threads have + // exited any zones. + void Assimilate(const Results& other) { + const uint64_t t0 = TicksBefore(); + JXL_ASSERT(depth_ == 0); + JXL_ASSERT(other.depth_ == 0); + + for (size_t i = 0; i < other.num_zones_; ++i) { + const Accumulator& zone = other.zones_[i]; + UpdateOrAdd(zone.BiasedOffset(), zone.NumCalls(), zone.total_duration); + } + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0 + other.analyze_elapsed_; + } + + // Single-threaded. + void Print() { + const uint64_t t0 = TicksBefore(); + MergeDuplicates(); + + // Sort by decreasing total (self) cost. + std::sort(zones_, zones_ + num_zones_, + [](const Accumulator& r1, const Accumulator& r2) { + return r1.total_duration > r2.total_duration; + }); + + const uintptr_t string_origin = StringOrigin(); + uint64_t total_visible_duration = 0; + for (size_t i = 0; i < num_zones_; ++i) { + const Accumulator& r = zones_[i]; + const uint64_t num_calls = r.NumCalls(); + const char* name = + reinterpret_cast(string_origin + r.BiasedOffset()); + if (name[0] != '@') { + total_visible_duration += r.total_duration; + printf("%-40s: %10" PRIu64 " x %15" PRIu64 "= %15" PRIu64 "\n", name, + num_calls, r.total_duration / num_calls, r.total_duration); + } + } + + const uint64_t t1 = TicksAfter(); + analyze_elapsed_ += t1 - t0; + printf("Total clocks during analysis: %" PRIu64 "\n", analyze_elapsed_); + printf("Total clocks measured: %" PRIu64 "\n", total_visible_duration); + } + + // Single-threaded. Clears all results as if no zones had been recorded. + void Reset() { + analyze_elapsed_ = 0; + JXL_CHECK(depth_ == 0); + num_zones_ = 0; + memset(nodes_, 0, sizeof(nodes_)); + memset(zones_, 0, sizeof(zones_)); + } + + private: +#if JXL_ARCH_X64 + static bool SameOffset(const __m128i zone, const uint64_t biased_offset) { + const uint64_t num_calls = _mm_cvtsi128_si64(zone); + return (num_calls >> Accumulator::kNumCallBits) == biased_offset; + } +#endif + + // Updates an existing Accumulator (uniquely identified by biased_offset) or + // adds one if this is the first time this thread analyzed that zone. + // Uses a self-organizing list data structure, which avoids dynamic memory + // allocations and is far faster than unordered_map. Loads, updates and + // stores the entire Accumulator with vector instructions. + void UpdateOrAdd(const uint64_t biased_offset, const uint64_t num_calls, + const uint64_t duration) { + JXL_ASSERT(biased_offset < (1ULL << Packet::kOffsetBits)); + +#if JXL_ARCH_X64 + const __m128i num_calls_64 = _mm_cvtsi64_si128(num_calls); + const __m128i duration_64 = _mm_cvtsi64_si128(duration); + const __m128i add_duration_call = + _mm_unpacklo_epi64(num_calls_64, duration_64); + + __m128i* const JXL_RESTRICT zones = reinterpret_cast<__m128i*>(zones_); + + // Special case for first zone: (maybe) update, without swapping. + __m128i prev = _mm_load_si128(zones); + if (SameOffset(prev, biased_offset)) { + prev = _mm_add_epi64(prev, add_duration_call); + JXL_ASSERT(SameOffset(prev, biased_offset)); + _mm_store_si128(zones, prev); + return; + } + + // Look for a zone with the same offset. + for (size_t i = 1; i < num_zones_; ++i) { + __m128i zone = _mm_load_si128(zones + i); + if (SameOffset(zone, biased_offset)) { + zone = _mm_add_epi64(zone, add_duration_call); + JXL_ASSERT(SameOffset(zone, biased_offset)); + // Swap with predecessor (more conservative than move to front, + // but at least as successful). + _mm_store_si128(zones + i - 1, zone); + _mm_store_si128(zones + i, prev); + return; + } + prev = zone; + } + + // Not found; create a new Accumulator. + const __m128i biased_offset_64 = _mm_slli_epi64( + _mm_cvtsi64_si128(biased_offset), Accumulator::kNumCallBits); + const __m128i zone = _mm_add_epi64(biased_offset_64, add_duration_call); + JXL_ASSERT(SameOffset(zone, biased_offset)); + + JXL_ASSERT(num_zones_ < kMaxZones); + _mm_store_si128(zones + num_zones_, zone); + ++num_zones_; +#else + // Special case for first zone: (maybe) update, without swapping. + if (zones_[0].BiasedOffset() == biased_offset) { + zones_[0].total_duration += duration; + zones_[0].num_calls += num_calls; + JXL_ASSERT(zones_[0].BiasedOffset() == biased_offset); + return; + } + + // Look for a zone with the same offset. + for (size_t i = 1; i < num_zones_; ++i) { + if (zones_[i].BiasedOffset() == biased_offset) { + zones_[i].total_duration += duration; + zones_[i].num_calls += num_calls; + JXL_ASSERT(zones_[i].BiasedOffset() == biased_offset); + // Swap with predecessor (more conservative than move to front, + // but at least as successful). + const Accumulator prev = zones_[i - 1]; + zones_[i - 1] = zones_[i]; + zones_[i] = prev; + return; + } + } + + // Not found; create a new Accumulator. + JXL_ASSERT(num_zones_ < kMaxZones); + Accumulator* JXL_RESTRICT zone = zones_ + num_zones_; + zone->num_calls = (biased_offset << Accumulator::kNumCallBits) + num_calls; + zone->total_duration = duration; + JXL_ASSERT(zone->BiasedOffset() == biased_offset); + ++num_zones_; +#endif + } + + // Each instantiation of a function template seems to get its own copy of + // __func__ and GCC doesn't merge them. An N^2 search for duplicates is + // acceptable because we only expect a few dozen zones. + void MergeDuplicates() { + const uintptr_t string_origin = StringOrigin(); + for (size_t i = 0; i < num_zones_; ++i) { + const uint64_t biased_offset = zones_[i].BiasedOffset(); + const char* name = + reinterpret_cast(string_origin + biased_offset); + // Separate num_calls from biased_offset so we can add them together. + uint64_t num_calls = zones_[i].NumCalls(); + + // Add any subsequent duplicates to num_calls and total_duration. + for (size_t j = i + 1; j < num_zones_;) { + if (!strcmp(name, reinterpret_cast( + string_origin + zones_[j].BiasedOffset()))) { + num_calls += zones_[j].NumCalls(); + zones_[i].total_duration += zones_[j].total_duration; + // Fill hole with last item. + zones_[j] = zones_[--num_zones_]; + } else { // Name differed, try next Accumulator. + ++j; + } + } + + JXL_ASSERT(num_calls < (1ULL << Accumulator::kNumCallBits)); + + // Re-pack regardless of whether any duplicates were found. + zones_[i].num_calls = + (biased_offset << Accumulator::kNumCallBits) + num_calls; + } + } + + uint64_t analyze_elapsed_ = 0; + uint64_t self_overhead_ = 0; + uint64_t child_overhead_ = 0; + + size_t depth_ = 0; // Number of active zones. + size_t num_zones_ = 0; // Number of retired zones. + + // After other members to avoid large pointer offsets. + alignas(64) ProfilerNode nodes_[kMaxDepth]; // Stack + alignas(64) Accumulator zones_[kMaxZones]; // Self-organizing list +}; + +// `zone_name` is used to sanity-check offsets fit in kOffsetBits. +ThreadSpecific::ThreadSpecific(const char* zone_name) + : packets_(static_cast( + CacheAligned::Allocate(PROFILER_THREAD_STORAGE << 20))), + num_packets_(0), + max_packets_(PROFILER_THREAD_STORAGE << 17), + string_origin_(StringOrigin()), + results_(static_cast(CacheAligned::Allocate(sizeof(Results)))) { + new (results_) Results(); + // Even in optimized builds (with NDEBUG), verify that this zone's name + // offset fits within the allotted space. If not, UpdateOrAdd is likely to + // overrun zones_[]. We also JXL_ASSERT(), but users often do not run debug + // builds. Checking here on the cold path (only reached once per thread) + // is cheap, but it only covers one zone. + const uint64_t biased_offset = + reinterpret_cast(zone_name) - string_origin_; + JXL_CHECK(biased_offset <= (1ULL << Packet::kOffsetBits)); +} + +ThreadSpecific::~ThreadSpecific() { + results_->~Results(); + CacheAligned::Free(packets_); + CacheAligned::Free(results_); +} + +void ThreadSpecific::FlushStorage() { + results_->AnalyzePackets(packets_, num_packets_); + num_packets_ = 0; +} + +#if PROFILER_BUFFER +void ThreadSpecific::FlushBuffer() { + if (num_packets_ + kBufferCapacity > max_packets_) { + FlushStorage(); + } + // This buffering halves observer overhead and decreases the overall + // runtime by about 3%. + HWY_STATIC_DISPATCH(StreamCacheLine)(buffer_, packets_ + num_packets_); + num_packets_ += kBufferCapacity; + buffer_size_ = 0; +} +#endif // PROFILER_BUFFER + +void ThreadSpecific::AnalyzeRemainingPackets() { +#if PROFILER_BUFFER + // Ensures prior weakly-ordered streaming stores are globally visible. + hwy::StoreFence(); + + // Storage full => empty it. + if (num_packets_ + buffer_size_ > max_packets_) { + results_->AnalyzePackets(packets_, num_packets_); + num_packets_ = 0; + } + memcpy(packets_ + num_packets_, buffer_, buffer_size_ * sizeof(Packet)); + num_packets_ += buffer_size_; + buffer_size_ = 0; +#endif // PROFILER_BUFFER + + results_->AnalyzePackets(packets_, num_packets_); + num_packets_ = 0; +} + +void ThreadSpecific::ComputeOverhead() { + // Delay after capturing timestamps before/after the actual zone runs. Even + // with frequency throttling disabled, this has a multimodal distribution, + // including 32, 34, 48, 52, 59, 62. + uint64_t self_overhead; + { + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 1024; + uint32_t durations[kNumDurations]; + + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + { + PROFILER_ZONE("Dummy Zone (never shown)"); + } +#if PROFILER_BUFFER + const uint64_t duration = results_->ZoneDuration(buffer_); + buffer_size_ = 0; +#else + const uint64_t duration = results_->ZoneDuration(packets_); + num_packets_ = 0; +#endif + durations[idx_duration] = static_cast(duration); + JXL_CHECK(num_packets_ == 0); + } + CountingSort(durations, durations + kNumDurations); + samples[idx_sample] = HalfSampleMode()(durations, kNumDurations); + } + // Median. + CountingSort(samples, samples + kNumSamples); + self_overhead = samples[kNumSamples / 2]; +#if PROFILER_PRINT_OVERHEAD + printf("Overhead: %zu\n", self_overhead); +#endif + results_->SetSelfOverhead(self_overhead); + } + + // Delay before capturing start timestamp / after end timestamp. + const size_t kNumSamples = 32; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + const size_t kNumDurations = 16; + uint32_t durations[kNumDurations]; + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + const size_t kReps = 10000; + // Analysis time should not be included => must fit within buffer. + JXL_CHECK(kReps * 2 < max_packets_); +#if JXL_ARCH_X64 + _mm_mfence(); +#endif + const uint64_t t0 = TicksBefore(); + for (size_t i = 0; i < kReps; ++i) { + PROFILER_ZONE("Dummy"); + } + hwy::StoreFence(); + const uint64_t t1 = TicksAfter(); +#if PROFILER_BUFFER + JXL_CHECK(num_packets_ + buffer_size_ == kReps * 2); + buffer_size_ = 0; +#else + JXL_CHECK(num_packets_ == kReps * 2); +#endif + num_packets_ = 0; + const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps; + durations[idx_duration] = + static_cast(ClampedSubtract(avg_duration, self_overhead)); + } + CountingSort(durations, durations + kNumDurations); + samples[idx_sample] = HalfSampleMode()(durations, kNumDurations); + } + CountingSort(samples, samples + kNumSamples); + const uint64_t child_overhead = samples[9 * kNumSamples / 10]; +#if PROFILER_PRINT_OVERHEAD + printf("Child overhead: %zu\n", child_overhead); +#endif + results_->SetChildOverhead(child_overhead); +} + +namespace { + +class ThreadList { + public: + // Thread-safe. + void Add(ThreadSpecific* const ts) { + const uint32_t index = num_threads_.fetch_add(1, std::memory_order_relaxed); + JXL_CHECK(index < kMaxThreads); + threads_[index] = ts; + } + + // Single-threaded. + void PrintResults() { + const uint32_t num_threads = num_threads_.load(std::memory_order_relaxed); + for (uint32_t i = 0; i < num_threads; ++i) { + threads_[i]->AnalyzeRemainingPackets(); + } + + // Combine all threads into a single Result. + for (uint32_t i = 1; i < num_threads; ++i) { + threads_[0]->GetResults().Assimilate(threads_[i]->GetResults()); + } + + if (num_threads != 0) { + threads_[0]->GetResults().Print(); + + for (uint32_t i = 0; i < num_threads; ++i) { + threads_[i]->GetResults().Reset(); + } + } + } + + private: + // Owning pointers. + alignas(64) ThreadSpecific* threads_[kMaxThreads]; + std::atomic num_threads_{0}; +}; + +ThreadList& GetThreadList() { + static ThreadList threads_; + return threads_; +} + +} // namespace + +ThreadSpecific* Zone::InitThreadSpecific(const char* zone_name) { + void* mem = CacheAligned::Allocate(sizeof(ThreadSpecific)); + ThreadSpecific* thread_specific = new (mem) ThreadSpecific(zone_name); + // Must happen before ComputeOverhead, which re-enters this ctor. + GetThreadList().Add(thread_specific); + GetThreadSpecific() = thread_specific; + thread_specific->ComputeOverhead(); + return thread_specific; +} + +/*static*/ void Zone::PrintResults() { GetThreadList().PrintResults(); } + +} // namespace jxl + +#endif // PROFILER_ENABLED diff --git a/third_party/jpeg-xl/lib/profiler/profiler.h b/third_party/jpeg-xl/lib/profiler/profiler.h new file mode 100644 index 000000000000..ff97f1e2a72e --- /dev/null +++ b/third_party/jpeg-xl/lib/profiler/profiler.h @@ -0,0 +1,222 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_PROFILER_PROFILER_H_ +#define LIB_PROFILER_PROFILER_H_ + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// Usage: add this header to BUILD srcs; instrument regions of interest: +// { PROFILER_ZONE("name"); /*code*/ } or +// void FuncToMeasure() { PROFILER_FUNC; /*code*/ }. +// After all threads have exited any zones, invoke PROFILER_PRINT_RESULTS() to +// print call counts and average durations [CPU cycles] to stdout, sorted in +// descending order of total duration. + +// If zero, this file has no effect and no measurements will be recorded. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif +#if PROFILER_ENABLED + +#include +#include + +#include + +#include "lib/jxl/base/arch_macros.h" // for JXL_ARCH_* +#include "lib/jxl/base/cache_aligned.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/profiler/tsc_timer.h" + +#if JXL_ARCH_X64 && HWY_TARGET != HWY_SCALAR +#define PROFILER_BUFFER 1 +#else +#define PROFILER_BUFFER 0 +#endif + +namespace jxl { + +// Represents zone entry/exit events. Stores a full-resolution timestamp plus +// an offset (representing zone name or identifying exit packets). POD. +class Packet { + public: + // If offsets do not fit, UpdateOrAdd will overrun our heap allocation + // (governed by kMaxZones). We have seen ~100 MiB static binaries. + static constexpr size_t kOffsetBits = 27; + static constexpr uintptr_t kOffsetBias = 1ULL << (kOffsetBits - 1); + + // We need full-resolution timestamps; at an effective rate of 4 GHz, + // this permits 34 second zone durations (for longer durations, split into + // multiple zones). Wraparound is handled by masking. + static constexpr size_t kTimestampBits = 64 - kOffsetBits; + static constexpr uint64_t kTimestampMask = (1ULL << kTimestampBits) - 1; + + static Packet Make(const uint64_t biased_offset, const uint64_t timestamp) { + JXL_DASSERT(biased_offset < (1ULL << kOffsetBits)); + + Packet packet; + packet.bits_ = + (biased_offset << kTimestampBits) + (timestamp & kTimestampMask); + return packet; + } + + uint64_t Timestamp() const { return bits_ & kTimestampMask; } + + uintptr_t BiasedOffset() const { return (bits_ >> kTimestampBits); } + + private: + uint64_t bits_; +}; +static_assert(sizeof(Packet) == 8, "Wrong Packet size"); + +class Results; + +// Per-thread packet storage, allocated via CacheAligned. +class ThreadSpecific { + static constexpr size_t kBufferCapacity = + CacheAligned::kCacheLineSize / sizeof(Packet); + + public: + // `zone_name` is used to sanity-check offsets fit in kOffsetBits. + explicit ThreadSpecific(const char* zone_name); + ~ThreadSpecific(); + + // Depends on Zone => defined out of line. + void ComputeOverhead(); + + void WriteEntry(const char* name, const uint64_t timestamp) { + const uint64_t biased_offset = + reinterpret_cast(name) - string_origin_; + Write(Packet::Make(biased_offset, timestamp)); + } + + void WriteExit(const uint64_t timestamp) { + const uint64_t biased_offset = Packet::kOffsetBias; + Write(Packet::Make(biased_offset, timestamp)); + } + + void AnalyzeRemainingPackets(); + + Results& GetResults() { return *results_; } + + private: + void FlushStorage(); +#if PROFILER_BUFFER + void FlushBuffer(); +#endif + + // Write packet to buffer/storage, emptying them as needed. + void Write(const Packet packet) { +#if PROFILER_BUFFER + if (buffer_size_ == kBufferCapacity) { // Full + FlushBuffer(); + } + buffer_[buffer_size_] = packet; + ++buffer_size_; +#else + if (num_packets_ >= max_packets_) { // Full + FlushStorage(); + } + packets_[num_packets_] = packet; + ++num_packets_; +#endif // PROFILER_BUFFER + } + + // Write-combining buffer to avoid cache pollution. Must be the first + // non-static member to ensure cache-line alignment. +#if PROFILER_BUFFER + Packet buffer_[kBufferCapacity]; + size_t buffer_size_ = 0; +#endif + + // Contiguous storage for zone enter/exit packets. + Packet* const JXL_RESTRICT packets_; + size_t num_packets_; + const size_t max_packets_; + + // Cached here because we already read this cache line on zone entry/exit. + uintptr_t string_origin_; + + Results* results_; +}; + +// RAII zone enter/exit recorder constructed by the ZONE macro; also +// responsible for initializing ThreadSpecific. +class Zone { + public: + // "name" must be a string literal (see StringOrigin). + JXL_NOINLINE explicit Zone(const char* name) { + JXL_COMPILER_FENCE; + ThreadSpecific* JXL_RESTRICT thread_specific = GetThreadSpecific(); + if (JXL_UNLIKELY(thread_specific == nullptr)) { + thread_specific = InitThreadSpecific(name); + } + + // (Capture timestamp ASAP, not inside WriteEntry.) + JXL_COMPILER_FENCE; + const uint64_t timestamp = TicksBefore(); + thread_specific->WriteEntry(name, timestamp); + } + + JXL_NOINLINE ~Zone() { + JXL_COMPILER_FENCE; + const uint64_t timestamp = TicksAfter(); + GetThreadSpecific()->WriteExit(timestamp); + JXL_COMPILER_FENCE; + } + + // Call exactly once after all threads have exited all zones. + static void PrintResults(); + + private: + // Returns reference to the thread's ThreadSpecific pointer (initially null). + // Function-local static avoids needing a separate definition. + static ThreadSpecific*& GetThreadSpecific() { + static thread_local ThreadSpecific* thread_specific; + return thread_specific; + } + + // Non time-critical. + ThreadSpecific* InitThreadSpecific(const char* zone_name); +}; + +// Creates a zone starting from here until the end of the current scope. +// Timestamps will be recorded when entering and exiting the zone. +// "name" must be a string literal, which is ensured by merging with "". +#define PROFILER_ZONE(name) \ + JXL_COMPILER_FENCE; \ + const ::jxl::Zone zone("" name); \ + JXL_COMPILER_FENCE + +// Creates a zone for an entire function (when placed at its beginning). +// Shorter/more convenient than ZONE. +#define PROFILER_FUNC \ + JXL_COMPILER_FENCE; \ + const ::jxl::Zone zone(__func__); \ + JXL_COMPILER_FENCE + +#define PROFILER_PRINT_RESULTS ::jxl::Zone::PrintResults + +} // namespace jxl + +#else // !PROFILER_ENABLED +#define PROFILER_ZONE(name) +#define PROFILER_FUNC +#define PROFILER_PRINT_RESULTS() +#endif + +#endif // LIB_PROFILER_PROFILER_H_ diff --git a/third_party/jpeg-xl/lib/profiler/tsc_timer.h b/third_party/jpeg-xl/lib/profiler/tsc_timer.h new file mode 100644 index 000000000000..e28511d23417 --- /dev/null +++ b/third_party/jpeg-xl/lib/profiler/tsc_timer.h @@ -0,0 +1,142 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_PROFILER_TSC_TIMER_H_ +#define LIB_PROFILER_TSC_TIMER_H_ + +// High-resolution (~10 ns) timestamps, using fences to prevent reordering and +// ensure exactly the desired regions are measured. + +#include + +#include +#include // LoadFence + +#include "lib/jxl/base/arch_macros.h" // for JXL_ARCH_* +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// TicksBefore/After return absolute timestamps and must be placed immediately +// before and after the region to measure. The functions are distinct because +// they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Before = LFENCE/RDTSC/LFENCE; After = RDTSCP/LFENCE. +// +// Using Before+Before leads to higher variance and overhead than After+After. +// However, After+After includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Before+After +// is faster than Before+Before and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. Although 32-bit ticks are faster to read, +// they overflow too quickly to measure long regions. +static inline uint64_t TicksBefore() { + uint64_t t; +#if JXL_ARCH_PPC + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif JXL_ARCH_X64 && JXL_COMPILER_MSVC + hwy::LoadFence(); + JXL_COMPILER_FENCE; + t = __rdtsc(); + hwy::LoadFence(); + JXL_COMPILER_FENCE; +#elif JXL_ARCH_X64 && (JXL_COMPILER_CLANG || JXL_COMPILER_GCC) + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#else + // Fall back to OS - unsure how to reliably query cntvct_el0 frequency. + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = ts.tv_sec * 1000000000LL + ts.tv_nsec; +#endif + return t; +} + +static inline uint64_t TicksAfter() { + uint64_t t; +#if JXL_ARCH_X64 && JXL_COMPILER_MSVC + JXL_COMPILER_FENCE; + unsigned aux; + t = __rdtscp(&aux); + hwy::LoadFence(); + JXL_COMPILER_FENCE; +#elif JXL_ARCH_X64 && (JXL_COMPILER_CLANG || JXL_COMPILER_GCC) + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = TicksBefore(); // no difference on other platforms. +#endif + return t; +} + +} // namespace jxl + +#endif // LIB_PROFILER_TSC_TIMER_H_ diff --git a/third_party/jpeg-xl/lib/threads/libjxl_threads.pc.in b/third_party/jpeg-xl/lib/threads/libjxl_threads.pc.in new file mode 100644 index 000000000000..8a3275cf1c1b --- /dev/null +++ b/third_party/jpeg-xl/lib/threads/libjxl_threads.pc.in @@ -0,0 +1,12 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${exec_prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: libjxl_threads +Description: JPEG XL multi-thread runner using std::threads. +Version: @JPEGXL_LIBRARY_VERSION@ +Requires.private: @JPEGXL_THREADS_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl_threads +Libs.private: -lm +Cflags: -I${includedir} diff --git a/third_party/jpeg-xl/lib/threads/thread_parallel_runner.cc b/third_party/jpeg-xl/lib/threads/thread_parallel_runner.cc new file mode 100644 index 000000000000..3982a90a8669 --- /dev/null +++ b/third_party/jpeg-xl/lib/threads/thread_parallel_runner.cc @@ -0,0 +1,110 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jxl/thread_parallel_runner.h" + +#include + +#include "lib/threads/thread_parallel_runner_internal.h" + +namespace { + +// Default JxlMemoryManager using malloc and free for the jpegxl_threads +// library. Same as the default JxlMemoryManager for the jpegxl library +// itself. + +// Default alloc and free functions. +void* ThreadMemoryManagerDefaultAlloc(void* opaque, size_t size) { + return malloc(size); +} + +void ThreadMemoryManagerDefaultFree(void* opaque, void* address) { + free(address); +} + +// Initializes the memory manager instance with the passed one. The +// MemoryManager passed in |memory_manager| may be NULL or contain NULL +// functions which will be initialized with the default ones. If either alloc +// or free are NULL, then both must be NULL, otherwise this function returns an +// error. +bool ThreadMemoryManagerInit(JxlMemoryManager* self, + const JxlMemoryManager* memory_manager) { + if (memory_manager) { + *self = *memory_manager; + } else { + memset(self, 0, sizeof(*self)); + } + if (!self->alloc != !self->free) { + return false; + } + if (!self->alloc) self->alloc = ThreadMemoryManagerDefaultAlloc; + if (!self->free) self->free = ThreadMemoryManagerDefaultFree; + + return true; +} + +void* ThreadMemoryManagerAlloc(const JxlMemoryManager* memory_manager, + size_t size) { + return memory_manager->alloc(memory_manager->opaque, size); +} + +void ThreadMemoryManagerFree(const JxlMemoryManager* memory_manager, + void* address) { + return memory_manager->free(memory_manager->opaque, address); +} + +} // namespace + +JxlParallelRetCode JxlThreadParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + return jpegxl::ThreadParallelRunner::Runner( + runner_opaque, jpegxl_opaque, init, func, start_range, end_range); +} + +/// Starts the given number of worker threads and blocks until they are ready. +/// "num_worker_threads" defaults to one per hyperthread. If zero, all tasks +/// run on the main thread. +void* JxlThreadParallelRunnerCreate(const JxlMemoryManager* memory_manager, + size_t num_worker_threads) { + JxlMemoryManager local_memory_manager; + if (!ThreadMemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = ThreadMemoryManagerAlloc(&local_memory_manager, + sizeof(jpegxl::ThreadParallelRunner)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + jpegxl::ThreadParallelRunner* runner = + new (alloc) jpegxl::ThreadParallelRunner(num_worker_threads); + runner->memory_manager = local_memory_manager; + + return runner; +} + +void JxlThreadParallelRunnerDestroy(void* runner_opaque) { + jpegxl::ThreadParallelRunner* runner = + reinterpret_cast(runner_opaque); + if (runner) { + // Call destructor directly since custom free function is used. + runner->~ThreadParallelRunner(); + ThreadMemoryManagerFree(&runner->memory_manager, runner); + } +} + +// Get default value for num_worker_threads parameter of +// InitJxlThreadParallelRunner. +size_t JxlThreadParallelRunnerDefaultNumWorkerThreads() { + return std::thread::hardware_concurrency(); +} diff --git a/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.cc b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.cc new file mode 100644 index 000000000000..e03813e524db --- /dev/null +++ b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.cc @@ -0,0 +1,226 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lib/threads/thread_parallel_runner_internal.h" + +#include + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +#include "jxl/thread_parallel_runner.h" +#include "lib/jxl/base/profiler.h" + +namespace { + +// Exits the program after printing a stack trace when possible. +bool Abort() { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // defined(*_SANITIZER) + +#ifdef _MSC_VER + __debugbreak(); + abort(); +#else + __builtin_trap(); +#endif +} + +// Does not guarantee running the code, use only for debug mode checks. +#if JXL_ENABLE_ASSERT +#define JXL_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + Abort(); \ + } \ + } while (0) +#else +#define JXL_ASSERT(condition) \ + do { \ + } while (0) +#endif +} // namespace + +namespace jpegxl { + +// static +JxlParallelRetCode ThreadParallelRunner::Runner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + ThreadParallelRunner* self = + static_cast(runner_opaque); + if (start_range > end_range) return -1; + if (start_range == end_range) return 0; + + int ret = init(jpegxl_opaque, std::max(self->num_worker_threads_, 1)); + if (ret != 0) return ret; + + // Use a sequential run when num_worker_threads_ is zero since we have no + // worker threads. + if (self->num_worker_threads_ == 0) { + const size_t thread = 0; + for (uint32_t task = start_range; task < end_range; ++task) { + func(jpegxl_opaque, task, thread); + } + return 0; + } + + if (self->depth_.fetch_add(1, std::memory_order_acq_rel) != 0) { + return -1; // Must not re-enter. + } + + const WorkerCommand worker_command = + (static_cast(start_range) << 32) + end_range; + // Ensure the inputs do not result in a reserved command. + JXL_ASSERT(worker_command != kWorkerWait); + JXL_ASSERT(worker_command != kWorkerOnce); + JXL_ASSERT(worker_command != kWorkerExit); + + self->data_func_ = func; + self->jpegxl_opaque_ = jpegxl_opaque; + self->num_reserved_.store(0, std::memory_order_relaxed); + + self->StartWorkers(worker_command); + self->WorkersReadyBarrier(); + + if (self->depth_.fetch_add(-1, std::memory_order_acq_rel) != 1) { + return -1; + } + return 0; +} + +// static +void ThreadParallelRunner::RunRange(ThreadParallelRunner* self, + const WorkerCommand command, + const int thread) { + const uint32_t begin = command >> 32; + const uint32_t end = command & 0xFFFFFFFF; + const uint32_t num_tasks = end - begin; + const uint32_t num_worker_threads = self->num_worker_threads_; + + // OpenMP introduced several "schedule" strategies: + // "single" (static assignment of exactly one chunk per thread): slower. + // "dynamic" (allocates k tasks at a time): competitive for well-chosen k. + // "guided" (allocates k tasks, decreases k): computing k = remaining/n + // is faster than halving k each iteration. We prefer this strategy + // because it avoids user-specified parameters. + + for (;;) { +#if 0 + // dynamic + const uint32_t my_size = std::max(num_tasks / (num_worker_threads * 4), 1); +#else + // guided + const uint32_t num_reserved = + self->num_reserved_.load(std::memory_order_relaxed); + const uint32_t num_remaining = num_tasks - num_reserved; + const uint32_t my_size = + std::max(num_remaining / (num_worker_threads * 4), 1u); +#endif + const uint32_t my_begin = begin + self->num_reserved_.fetch_add( + my_size, std::memory_order_relaxed); + const uint32_t my_end = std::min(my_begin + my_size, begin + num_tasks); + // Another thread already reserved the last task. + if (my_begin >= my_end) { + break; + } + for (uint32_t task = my_begin; task < my_end; ++task) { + self->data_func_(self->jpegxl_opaque_, task, thread); + } + } +} + +// static +void ThreadParallelRunner::ThreadFunc(ThreadParallelRunner* self, + const int thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerOnce: + lock.unlock(); + self->data_func_(self->jpegxl_opaque_, thread, thread); + break; + case kWorkerExit: + return; // exits thread + default: + lock.unlock(); + RunRange(self, command, thread); + break; + } + } +} + +ThreadParallelRunner::ThreadParallelRunner(const int num_worker_threads) +#if defined(__EMSCRIPTEN__) + : num_worker_threads_(0), num_threads_(1) { + // TODO(eustas): find out if pthreads would work for us. + (void)num_worker_threads; +#else + : num_worker_threads_(num_worker_threads), + num_threads_(std::max(num_worker_threads, 1)) { +#endif + PROFILER_ZONE("ThreadParallelRunner ctor"); + + threads_.reserve(num_worker_threads_); + + // Suppress "unused-private-field" warning. + (void)padding1; + (void)padding2; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + + for (uint32_t i = 0; i < num_worker_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + if (num_worker_threads_ != 0) { + WorkersReadyBarrier(); + } + + // Warm up profiler on worker threads so its expensive initialization + // doesn't count towards other timer measurements. + RunOnEachThread( + [](const int task, const int thread) { PROFILER_ZONE("@InitWorkers"); }); +} + +ThreadParallelRunner::~ThreadParallelRunner() { + if (num_worker_threads_ != 0) { + StartWorkers(kWorkerExit); + } + + for (std::thread& thread : threads_) { + JXL_ASSERT(thread.joinable()); + thread.join(); + } +} +} // namespace jpegxl diff --git a/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.h b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.h new file mode 100644 index 000000000000..a38e160ba68a --- /dev/null +++ b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_internal.h @@ -0,0 +1,181 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// C++ implementation using std::thread of a ::JxlParallelRunner. + +// The main class in this module, ThreadParallelRunner, implements a static +// method ThreadParallelRunner::Runner than can be passed as a +// JxlParallelRunner when using the JPEG XL library. This uses std::thread +// internally and related synchronization functions. The number of threads +// created is fixed at construction time and the threads are re-used for every +// ThreadParallelRunner::Runner call. Only one concurrent Runner() call per +// instance is allowed at a time. +// +// This is a scalable, lower-overhead thread pool runner, especially suitable +// for data-parallel computations in the fork-join model, where clients need to +// know when all tasks have completed. +// +// This thread pool can efficiently load-balance millions of tasks using an +// atomic counter, thus avoiding per-task virtual or system calls. With 48 +// hyperthreads and 1M tasks that add to an atomic counter, overall runtime is +// 10-20x higher when using std::async, and ~200x for a queue-based thread +// pool. +// +// Usage: +// ThreadParallelRunner runner; +// JxlDecode( +// ... , &ThreadParallelRunner::Runner, static_cast(&runner)); + +#ifndef LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ +#define LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ + +#include +#include +#include + +#include +#include //NOLINT +#include //NOLINT +#include //NOLINT +#include + +#include "jxl/memory_manager.h" +#include "jxl/parallel_runner.h" + +namespace jpegxl { + +// Main helper class implementing the ::JxlParallelRunner interface. +class ThreadParallelRunner { + public: + // ::JxlParallelRunner interface. + static JxlParallelRetCode Runner(void* runner_opaque, void* jpegxl_opaque, + JxlParallelRunInit init, + JxlParallelRunFunction func, + uint32_t start_range, uint32_t end_range); + + // Starts the given number of worker threads and blocks until they are ready. + // "num_worker_threads" defaults to one per hyperthread. If zero, all tasks + // run on the main thread. + explicit ThreadParallelRunner( + int num_worker_threads = std::thread::hardware_concurrency()); + + // Waits for all threads to exit. + ~ThreadParallelRunner(); + + // Returns number of worker threads created (some may be sleeping and never + // wake up in time to participate in Run). Useful for characterizing + // performance; 0 means "run on main thread". + size_t NumWorkerThreads() const { return num_worker_threads_; } + + // Returns maximum number of main/worker threads that may call Func. Useful + // for allocating per-thread storage. + size_t NumThreads() const { return num_threads_; } + + // Runs func(thread, thread) on all thread(s) that may participate in Run. + // If NumThreads() == 0, runs on the main thread with thread == 0, otherwise + // concurrently called by each worker thread in [0, NumThreads()). + template + void RunOnEachThread(const Func& func) { + if (num_worker_threads_ == 0) { + const int thread = 0; + func(thread, thread); + return; + } + + data_func_ = reinterpret_cast(&CallClosure); + jpegxl_opaque_ = const_cast(static_cast(&func)); + StartWorkers(kWorkerOnce); + WorkersReadyBarrier(); + } + + JxlMemoryManager memory_manager; + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + // Special values; all others encode the begin/end parameters. Note that all + // these are no-op ranges (begin >= end) and therefore never used to encode + // ranges. + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerOnce = ~2ULL; + static constexpr WorkerCommand kWorkerExit = ~3ULL; + + // Calls f(task, thread). Used for type erasure of Func arguments. The + // signature must match JxlParallelRunFunction, hence a void* argument. + template + static void CallClosure(void* f, const uint32_t task, const size_t thread) { + (*reinterpret_cast(f))(task, thread); + } + + void WorkersReadyBarrier() { + std::unique_lock lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + mutex_.lock(); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + mutex_.unlock(); + worker_start_cv_.notify_all(); + } + + // Attempts to reserve and perform some work from the global range of tasks, + // which is encoded within "command". Returns after all tasks are reserved. + static void RunRange(ThreadParallelRunner* self, const WorkerCommand command, + const int thread); + + static void ThreadFunc(ThreadParallelRunner* self, int thread); + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + const uint32_t num_worker_threads_; // == threads_.size() + const uint32_t num_threads_; + + std::atomic depth_{0}; // detects if Run is re-entered (not supported). + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + uint32_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + JxlParallelRunFunction data_func_; + void* jpegxl_opaque_; + + // Updated by workers; padding avoids false sharing. + uint8_t padding1[64]; + std::atomic num_reserved_{0}; + uint8_t padding2[64]; +}; + +} // namespace jpegxl + +#endif // LIB_THREADS_THREAD_PARALLEL_RUNNER_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/threads/thread_parallel_runner_test.cc b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_test.cc new file mode 100644 index 000000000000..507edf50cb4e --- /dev/null +++ b/third_party/jpeg-xl/lib/threads/thread_parallel_runner_test.cc @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/thread_pool_internal.h" + +namespace jpegxl { +namespace { + +int PopulationCount(uint64_t bits) { + int num_set = 0; + while (bits != 0) { + num_set += bits & 1; + bits >>= 1; + } + return num_set; +} + +// Ensures task parameter is in bounds, every parameter is reached, +// pool can be reused (multiple consecutive Run calls), pool can be destroyed +// (joining with its threads), num_threads=0 works (runs on current thread). +TEST(ThreadParallelRunnerTest, TestPool) { + for (int num_threads = 0; num_threads <= 18; ++num_threads) { + jxl::ThreadPoolInternal pool(num_threads); + for (int num_tasks = 0; num_tasks < 32; ++num_tasks) { + std::vector mementos(num_tasks); + for (int begin = 0; begin < 32; ++begin) { + std::fill(mementos.begin(), mementos.end(), 0); + pool.Run( + begin, begin + num_tasks, jxl::ThreadPool::SkipInit(), + [begin, num_tasks, &mementos](const int task, const int thread) { + // Parameter is in the given range + EXPECT_GE(task, begin); + EXPECT_LT(task, begin + num_tasks); + + // Store mementos to be sure we visited each task. + mementos.at(task - begin) = 1000 + task; + }); + for (int task = begin; task < begin + num_tasks; ++task) { + EXPECT_EQ(1000 + task, mementos.at(task - begin)); + } + } + } + } +} + +// Verify "thread" parameter when processing few tasks. +TEST(ThreadParallelRunnerTest, TestSmallAssignments) { + // WARNING: cumulative total threads must not exceed profiler.h kMaxThreads. + const int kMaxThreads = 8; + for (int num_threads = 1; num_threads <= kMaxThreads; ++num_threads) { + jxl::ThreadPoolInternal pool(num_threads); + + // (Avoid mutex because it may perturb the worker thread scheduling) + std::atomic id_bits{0}; + std::atomic num_calls{0}; + + pool.Run( + 0, num_threads, jxl::ThreadPool::SkipInit(), + [&num_calls, num_threads, &id_bits](const int task, const int thread) { + num_calls.fetch_add(1, std::memory_order_relaxed); + + EXPECT_LT(thread, num_threads); + uint64_t bits = id_bits.load(std::memory_order_relaxed); + while ( + !id_bits.compare_exchange_weak(bits, bits | (1ULL << thread))) { + } + }); + + // Correct number of tasks. + EXPECT_EQ(num_threads, num_calls.load()); + + const int num_participants = PopulationCount(id_bits.load()); + // Can't expect equality because other workers may have woken up too late. + EXPECT_LE(num_participants, num_threads); + } +} + +struct Counter { + Counter() { + // Suppress "unused-field" warning. + (void)padding; + } + void Assimilate(const Counter& victim) { counter += victim.counter; } + int counter = 0; + int padding[31]; +}; + +TEST(ThreadParallelRunnerTest, TestCounter) { + const int kNumThreads = 12; + jxl::ThreadPoolInternal pool(kNumThreads); + alignas(128) Counter counters[kNumThreads]; + + const int kNumTasks = kNumThreads * 19; + pool.Run(0, kNumTasks, jxl::ThreadPool::SkipInit(), + [&counters](const int task, const int thread) { + counters[thread].counter += task; + }); + + int expected = 0; + for (int i = 0; i < kNumTasks; ++i) { + expected += i; + } + + for (int i = 1; i < kNumThreads; ++i) { + counters[0].Assimilate(counters[i]); + } + EXPECT_EQ(expected, counters[0].counter); +} + +} // namespace +} // namespace jpegxl diff --git a/third_party/jpeg-xl/plugins/CMakeLists.txt b/third_party/jpeg-xl/plugins/CMakeLists.txt new file mode 100644 index 000000000000..c7bc5db74305 --- /dev/null +++ b/third_party/jpeg-xl/plugins/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT WIN32) + add_subdirectory(gdk-pixbuf) +endif() + +add_subdirectory(gimp) + +add_subdirectory(mime) diff --git a/third_party/jpeg-xl/plugins/gdk-pixbuf/CMakeLists.txt b/third_party/jpeg-xl/plugins/gdk-pixbuf/CMakeLists.txt new file mode 100644 index 000000000000..baabf65d9dce --- /dev/null +++ b/third_party/jpeg-xl/plugins/gdk-pixbuf/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_package(PkgConfig) +pkg_check_modules(Gdk-Pixbuf IMPORTED_TARGET gdk-pixbuf-2.0>=2.38) + +if (NOT Gdk-Pixbuf_FOUND) + message(WARNING "GDK Pixbuf development libraries not found, \ + the Gdk-Pixbuf plugin will not be built") + return () +endif () + +add_library(pixbufloader-jxl SHARED pixbufloader-jxl.c) +target_link_libraries(pixbufloader-jxl jxl-static jxl_threads-static PkgConfig::Gdk-Pixbuf) + +pkg_get_variable(GDK_PIXBUF_MODULEDIR gdk-pixbuf-2.0 gdk_pixbuf_moduledir) +install(TARGETS pixbufloader-jxl LIBRARY DESTINATION "${GDK_PIXBUF_MODULEDIR}") + +# Instead of the following, we might instead add the +# mime type image/jxl to +# /usr/share/thumbnailers/gdk-pixbuf-thumbnailer.thumbnailer +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/jxl.thumbnailer DESTINATION share/thumbnailers/) diff --git a/third_party/jpeg-xl/plugins/gdk-pixbuf/README.md b/third_party/jpeg-xl/plugins/gdk-pixbuf/README.md new file mode 100644 index 000000000000..ffba6d3554e1 --- /dev/null +++ b/third_party/jpeg-xl/plugins/gdk-pixbuf/README.md @@ -0,0 +1,36 @@ +## JPEG XL GDK Pixbuf + + +If already installed by the [Installing section of README.md](../../README.md#installing), then the plugin should be in the correct place, e.g. + +```/usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/2.10.0/loaders/libpixbufloader-jxl.so``` + +Otherwise we can link them manually: + +```bash +sudo cp $your_build_directory/plugins/gdk-pixbuf/libpixbufloader-jxl.so /usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/2.10.0/loaders/libpixbufloader-jxl.so +``` + + +Then we update the cache, for example with: + +```bash +sudo /usr/lib/x86_64-linux-gnu/gdk-pixbuf-2.0/gdk-pixbuf-query-loaders --update-cache +``` + +In order to get thumbnails with this, first one has to add the jxl MIME type, see [../mime/README.md](../mime/README.md). + +Update the Mime database with +```bash +update-mime --local +``` +or +```bash +sudo update-desktop-database +``` + +Then possibly delete the thumbnail cache with +```bash +rm -r ~/.cache/thumbnails +``` +and restart the application displaying thumbnails, e.g. `nautilus -q` to display thumbnails. \ No newline at end of file diff --git a/third_party/jpeg-xl/plugins/gdk-pixbuf/jxl.thumbnailer b/third_party/jpeg-xl/plugins/gdk-pixbuf/jxl.thumbnailer new file mode 100644 index 000000000000..1bcaab61fcbd --- /dev/null +++ b/third_party/jpeg-xl/plugins/gdk-pixbuf/jxl.thumbnailer @@ -0,0 +1,4 @@ +[Thumbnailer Entry] +TryExec=/usr/bin/gdk-pixbuf-thumbnailer +Exec=/usr/bin/gdk-pixbuf-thumbnailer -s %s %u %o +MimeType=image/jxl; diff --git a/third_party/jpeg-xl/plugins/gdk-pixbuf/pixbufloader-jxl.c b/third_party/jpeg-xl/plugins/gdk-pixbuf/pixbufloader-jxl.c new file mode 100644 index 000000000000..d61f7fede578 --- /dev/null +++ b/third_party/jpeg-xl/plugins/gdk-pixbuf/pixbufloader-jxl.c @@ -0,0 +1,283 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define GDK_PIXBUF_ENABLE_BACKEND +#include +#undef GDK_PIXBUF_ENABLE_BACKEND + +#include "jxl/decode.h" + +static void DestroyPixels(guchar *pixels, gpointer data) { free(pixels); } + +typedef struct { + GdkPixbufModuleSizeFunc size_func; + GdkPixbufModuleUpdatedFunc update_func; + GdkPixbufModulePreparedFunc prepare_func; + gpointer user_data; + GdkPixbuf *pixbuf; + GError **error; + + FILE *increment_buffer; + char *increment_buffer_ptr; + size_t increment_buffer_size; + +} JxlContext; + +uint8_t *JxlMemoryToPixels(const uint8_t *next_in, size_t size, size_t *stride, + size_t *xsize, size_t *ysize, int *has_alpha) { + JxlDecoder *dec = JxlDecoderCreate(NULL); + *has_alpha = 1; + uint8_t *pixels = NULL; + if (!dec) { + fprintf(stderr, "JxlDecoderCreate failed\n"); + return 0; + } + if (JXL_DEC_SUCCESS != + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)) { + fprintf(stderr, "JxlDecoderSubscribeEvents failed\n"); + JxlDecoderDestroy(dec); + return 0; + } + + JxlBasicInfo info; + int success = 0; + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + JxlDecoderSetInput(dec, next_in, size); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + + if (status == JXL_DEC_ERROR) { + fprintf(stderr, "Decoder error\n"); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + fprintf(stderr, "Error, already provided all input\n"); + break; + } else if (status == JXL_DEC_BASIC_INFO) { + if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec, &info)) { + fprintf(stderr, "JxlDecoderGetBasicInfo failed\n"); + break; + } + *xsize = info.xsize; + *ysize = info.ysize; + *stride = info.xsize * 4; + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + size_t buffer_size; + if (JXL_DEC_SUCCESS != + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)) { + fprintf(stderr, "JxlDecoderImageOutBufferSize failed\n"); + break; + } + if (buffer_size != *stride * *ysize) { + fprintf(stderr, "Invalid out buffer size %zu %zu\n", buffer_size, + *stride * *ysize); + break; + } + size_t pixels_buffer_size = buffer_size * sizeof(uint8_t); + pixels = malloc(pixels_buffer_size); + void *pixels_buffer = (void *)pixels; + if (JXL_DEC_SUCCESS != JxlDecoderSetImageOutBuffer(dec, &format, + pixels_buffer, + pixels_buffer_size)) { + fprintf(stderr, "JxlDecoderSetImageOutBuffer failed\n"); + break; + } + } else if (status == JXL_DEC_FULL_IMAGE) { + // This means the decoder has decoded all pixels into the buffer. + success = 1; + break; + } else if (status == JXL_DEC_SUCCESS) { + fprintf(stderr, "Decoding finished before receiving pixel data\n"); + break; + } else { + fprintf(stderr, "Unexpected decoder status: %d\n", status); + break; + } + } + if (success){ + return pixels; + } else { + free(pixels); + return NULL; + } +} + +static GdkPixbuf *gdk_pixbuf__jxl_image_load(FILE *f, GError **error) { + size_t data_size; + int status; + gpointer data; + + // Get data size + status = fseek(f, 0, SEEK_END); + if (status) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to find end of file"); + } + data_size = ftell(f); + fseek(f, 0, SEEK_SET); + status = fseek(f, 0, SEEK_SET); + if (status) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to set pointer to beginning of file"); + } + + // Get data + data = g_malloc(data_size); + status = (fread(data, data_size, 1, f) == 1); + if (!status) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to read file"); + g_free(data); + return NULL; + } + size_t xsize, ysize, stride; + int has_alpha; + uint8_t *decoded = + JxlMemoryToPixels(data, data_size, &stride, &xsize, &ysize, &has_alpha); + g_free(data); + if (!decoded) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to decode data"); + return NULL; + } + + GdkPixbuf *pixbuf = + gdk_pixbuf_new_from_data(decoded, GDK_COLORSPACE_RGB, has_alpha, 8, xsize, + ysize, stride, &DestroyPixels, NULL); + + if (!pixbuf) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Failed to create output pixbuf"); + free(decoded); + return NULL; + } + + return pixbuf; +} + +static gpointer gdk_pixbuf__jxl_image_begin_load( + GdkPixbufModuleSizeFunc size_func, GdkPixbufModulePreparedFunc prepare_func, + GdkPixbufModuleUpdatedFunc update_func, gpointer user_data, + GError **error) { + JxlContext *context = g_new(JxlContext, 1); + context->size_func = size_func; + context->prepare_func = prepare_func; + context->update_func = update_func; + context->user_data = user_data; + context->error = error; + + context->increment_buffer = open_memstream(&context->increment_buffer_ptr, + &context->increment_buffer_size); + + if (!context->increment_buffer) { + perror("Cannot create increment buffer."); + g_free(context); + return NULL; + } + + return context; +} + +static gboolean gdk_pixbuf__jxl_image_stop_load(gpointer user_context, + GError **error) { + JxlContext *context = (JxlContext *)user_context; + + int status = fflush(context->increment_buffer); + status |= fseek(context->increment_buffer, 0L, SEEK_SET); + + if (status != 0) { + perror("Cannot flush and rewind increment buffer."); + fclose(context->increment_buffer); + free(context->increment_buffer_ptr); + g_free(context); + return FALSE; + } + + context->pixbuf = + gdk_pixbuf__jxl_image_load(context->increment_buffer, error); + + gint width = gdk_pixbuf_get_width(context->pixbuf); + gint height = gdk_pixbuf_get_height(context->pixbuf); + if (context->size_func) { + context->size_func(&width, &height, context->user_data); + } + + if (context->prepare_func) { + (*context->prepare_func)(context->pixbuf, NULL, context->user_data); + } + + if (context->update_func) { + (*context->update_func)( + context->pixbuf, 0, 0, gdk_pixbuf_get_width(context->pixbuf), + gdk_pixbuf_get_height(context->pixbuf), context->user_data); + } + + fclose(context->increment_buffer); + + free(context->increment_buffer_ptr); + + g_object_unref(context->pixbuf); + g_free(context); + + return TRUE; +} + +static gboolean gdk_pixbuf__jxl_image_load_increment(gpointer user_context, + const guchar *buf, + guint size, + GError **error) { + JxlContext *context = (JxlContext *)user_context; + + int status = fwrite(buf, size, sizeof(guchar), context->increment_buffer); + + if (status != sizeof(guchar)) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Can't write to increment buffer."); + return FALSE; + } + + status = fflush(context->increment_buffer); + + if (status != 0) { + g_set_error(error, GDK_PIXBUF_ERROR, GDK_PIXBUF_ERROR_FAILED, + "Can't flush the increment buffer."); + return FALSE; + } + + return TRUE; +} + +void fill_vtable(GdkPixbufModule *module) { + module->load = gdk_pixbuf__jxl_image_load; + module->begin_load = gdk_pixbuf__jxl_image_begin_load; + module->stop_load = gdk_pixbuf__jxl_image_stop_load; + module->load_increment = gdk_pixbuf__jxl_image_load_increment; +} + +void fill_info(GdkPixbufFormat *info) { + static GdkPixbufModulePattern signature[] = { + {"\xd7\x4c\x4d\x0a", " ", 100}, {NULL, NULL, 0}}; + + static gchar *mime_types[] = {"image/jxl", NULL}; + + static gchar *extensions[] = {"jxl", NULL}; + + info->name = "JPEG XL"; + info->signature = signature; + info->description = "JPEG XL image"; + info->mime_types = mime_types; + info->extensions = extensions; + info->flags = GDK_PIXBUF_FORMAT_THREADSAFE; + info->license = "Apache 2"; +} diff --git a/third_party/jpeg-xl/plugins/gimp/CMakeLists.txt b/third_party/jpeg-xl/plugins/gimp/CMakeLists.txt new file mode 100644 index 000000000000..e33ae783dfab --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/CMakeLists.txt @@ -0,0 +1,38 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_package(PkgConfig) +pkg_check_modules(Gimp IMPORTED_TARGET gimp-2.0>=2.10 gimpui-2.0>=2.10) + +if (NOT Gimp_FOUND) + message(WARNING "Gimp development libraries not found, the Gimp plugin will not be built") + return () +endif () + +option(JPEGXL_ENABLE_GIMP_SAVING "Enable saving to JPEG XL from GIMP" OFF) + +add_executable(file-jxl WIN32 + common.h + file-jxl-load.cc + file-jxl-load.h + file-jxl-save.cc + file-jxl-save.h + file-jxl.cc) +if (JPEGXL_ENABLE_GIMP_SAVING) + target_compile_definitions(file-jxl PRIVATE -DJPEGXL_ENABLE_GIMP_SAVING=1) +endif () +target_link_libraries(file-jxl jxl-static jxl_threads-static PkgConfig::Gimp) + +pkg_get_variable(GIMP_LIB_DIR gimp-2.0 gimplibdir) +install(TARGETS file-jxl RUNTIME DESTINATION "${GIMP_LIB_DIR}/plug-ins/file-jxl/") diff --git a/third_party/jpeg-xl/plugins/gimp/common.h b/third_party/jpeg-xl/plugins/gimp/common.h new file mode 100644 index 000000000000..acd8503181fa --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/common.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PLUGINS_GIMP_COMMON_H_ +#define PLUGINS_GIMP_COMMON_H_ + +#include + +namespace jxl { + +// FromFloat expects a value between 0 and 1, and ToFloat returns such values +// from GIMP values. +template +struct BufferFormat; +template <> +struct BufferFormat { + using Sample = uint8_t; + static Sample FromFloat(const float x) { + return static_cast(std::round(x * 255.f)); + } + static float ToFloat(const Sample s) { return s; } +}; +template <> +struct BufferFormat { + using Sample = uint16_t; + static Sample FromFloat(const float x) { + return static_cast(std::round(x * 65535.f)); + } + static float ToFloat(const Sample s) { return s * (1.f / 65535.f); } +}; +template <> +struct BufferFormat { + using Sample = uint32_t; + static Sample FromFloat(const float x) { + return static_cast(std::round(x * 4294967295.f)); + } + static float ToFloat(const Sample s) { return s * (1.f / 4294967295.f); } +}; +template <> +struct BufferFormat { + using Sample = float; + static Sample FromFloat(const float x) { return x; } + static float ToFloat(const Sample s) { return s; } +}; +template <> +struct BufferFormat { + using Sample = float; + static Sample FromFloat(const float x) { return x; } + static float ToFloat(const Sample s) { return s; } +}; + +} // namespace jxl + +#endif // PLUGINS_GIMP_COMMON_H_ diff --git a/third_party/jpeg-xl/plugins/gimp/file-jxl-load.cc b/third_party/jpeg-xl/plugins/gimp/file-jxl-load.cc new file mode 100644 index 000000000000..183ccdf5216b --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/file-jxl-load.cc @@ -0,0 +1,182 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "plugins/gimp/file-jxl-load.h" + +// Defined by both FUIF and glib. +#undef MAX +#undef MIN +#undef CLAMP + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/dec_file.h" +#include "plugins/gimp/common.h" + +namespace jxl { + +namespace { + +template +void FillBuffer( + const CodecInOut& io, + std::vector::Sample>* const pixel_data) { + pixel_data->reserve(io.xsize() * io.ysize() * (num_channels + has_alpha)); + for (size_t y = 0; y < io.ysize(); ++y) { + const float* rows[num_channels]; + for (size_t c = 0; c < num_channels; ++c) { + rows[c] = io.Main().color().ConstPlaneRow(c, y); + } + const float* const alpha_row = + has_alpha ? io.Main().alpha().ConstRow(y) : nullptr; + for (size_t x = 0; x < io.xsize(); ++x) { + const float alpha = has_alpha ? alpha_row[x] : 1.f; + const float alpha_multiplier = + has_alpha && io.Main().AlphaIsPremultiplied() + ? 1.f / std::max(kSmallAlpha, alpha) + : 1.f; + for (const float* const row : rows) { + pixel_data->push_back(BufferFormat::FromFloat( + std::max(0.f, std::min(1.f, alpha_multiplier * row[x])))); + } + if (has_alpha) { + pixel_data->push_back( + BufferFormat::FromFloat(255.f * alpha)); + } + } + } +} + +template +Status FillGimpLayer(const gint32 layer, const CodecInOut& io, + GimpImageType layer_type) { + std::vector::Sample> pixel_data; + switch (layer_type) { + case GIMP_GRAY_IMAGE: + FillBuffer( + io, &pixel_data); + break; + case GIMP_GRAYA_IMAGE: + FillBuffer( + io, &pixel_data); + break; + case GIMP_RGB_IMAGE: + FillBuffer( + io, &pixel_data); + break; + case GIMP_RGBA_IMAGE: + FillBuffer( + io, &pixel_data); + break; + default: + return false; + } + + GeglBuffer* buffer = gimp_drawable_get_buffer(layer); + gegl_buffer_set(buffer, GEGL_RECTANGLE(0, 0, io.xsize(), io.ysize()), 0, + nullptr, pixel_data.data(), GEGL_AUTO_ROWSTRIDE); + g_clear_object(&buffer); + return true; +} + +} // namespace + +Status LoadJpegXlImage(const gchar* const filename, gint32* const image_id) { + PaddedBytes compressed; + JXL_RETURN_IF_ERROR(ReadFile(filename, &compressed)); + + // TODO(deymo): Use C API instead of the ThreadPoolInternal. + ThreadPoolInternal pool; + DecompressParams dparams; + CodecInOut io; + JXL_RETURN_IF_ERROR(DecodeFile(dparams, compressed, &io, &pool)); + + const ColorEncoding& color_encoding = io.metadata.m.color_encoding; + JXL_RETURN_IF_ERROR(io.TransformTo(color_encoding, &pool)); + + GimpColorProfile* profile = nullptr; + if (color_encoding.IsSRGB()) { + profile = gimp_color_profile_new_rgb_srgb(); + } else if (color_encoding.IsLinearSRGB()) { + profile = gimp_color_profile_new_rgb_srgb_linear(); + } else { + profile = gimp_color_profile_new_from_icc_profile( + color_encoding.ICC().data(), color_encoding.ICC().size(), + /*error=*/nullptr); + } + if (profile == nullptr) { + return JXL_FAILURE( + "Failed to create GIMP color profile from %zu bytes of ICC data", + color_encoding.ICC().size()); + } + + GimpImageBaseType image_type; + GimpImageType layer_type; + + if (io.Main().IsGray()) { + image_type = GIMP_GRAY; + if (io.Main().HasAlpha()) { + layer_type = GIMP_GRAYA_IMAGE; + } else { + layer_type = GIMP_GRAY_IMAGE; + } + } else { + image_type = GIMP_RGB; + if (io.Main().HasAlpha()) { + layer_type = GIMP_RGBA_IMAGE; + } else { + layer_type = GIMP_RGB_IMAGE; + } + } + + GimpPrecision precision; + Status (*fill_layer)(gint32 layer, const CodecInOut& io, GimpImageType); + if (io.metadata.m.bit_depth.floating_point_sample) { + if (io.metadata.m.bit_depth.bits_per_sample <= 16) { + precision = GIMP_PRECISION_HALF_GAMMA; + fill_layer = &FillGimpLayer; + } else { + precision = GIMP_PRECISION_FLOAT_GAMMA; + fill_layer = &FillGimpLayer; + } + } else { + if (io.metadata.m.bit_depth.bits_per_sample <= 8) { + precision = GIMP_PRECISION_U8_GAMMA; + fill_layer = &FillGimpLayer; + } else if (io.metadata.m.bit_depth.bits_per_sample <= 16) { + precision = GIMP_PRECISION_U16_GAMMA; + fill_layer = &FillGimpLayer; + } else { + precision = GIMP_PRECISION_U32_GAMMA; + fill_layer = &FillGimpLayer; + } + } + + *image_id = gimp_image_new_with_precision(io.xsize(), io.ysize(), image_type, + precision); + gimp_image_set_color_profile(*image_id, profile); + g_clear_object(&profile); + const gint32 layer = gimp_layer_new( + *image_id, "image", io.xsize(), io.ysize(), layer_type, /*opacity=*/100, + gimp_image_get_default_new_layer_mode(*image_id)); + gimp_image_set_filename(*image_id, filename); + gimp_image_insert_layer(*image_id, layer, /*parent_id=*/-1, /*position=*/0); + + JXL_RETURN_IF_ERROR(fill_layer(layer, io, layer_type)); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/plugins/gimp/file-jxl-load.h b/third_party/jpeg-xl/plugins/gimp/file-jxl-load.h new file mode 100644 index 000000000000..f8bd68a13e2d --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/file-jxl-load.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PLUGINS_GIMP_FILE_JXL_LOAD_H_ +#define PLUGINS_GIMP_FILE_JXL_LOAD_H_ + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +Status LoadJpegXlImage(const gchar* filename, gint32* image_id); + +} + +#endif // PLUGINS_GIMP_FILE_JXL_LOAD_H_ diff --git a/third_party/jpeg-xl/plugins/gimp/file-jxl-save.cc b/third_party/jpeg-xl/plugins/gimp/file-jxl-save.cc new file mode 100644 index 000000000000..82a61d5349b5 --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/file-jxl-save.cc @@ -0,0 +1,141 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "plugins/gimp/file-jxl-save.h" + +// Defined by both FUIF and glib. +#undef MAX +#undef MIN +#undef CLAMP + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/file_io.h" +#include "lib/jxl/base/thread_pool_internal.h" +#include "lib/jxl/enc_file.h" +#include "plugins/gimp/common.h" + +namespace jxl { + +namespace { + +template +Status ReadBuffer(const size_t xsize, const size_t ysize, + const std::vector& pixel_data, PaddedBytes icc, + CodecInOut* const io) { + constexpr float alpha_multiplier = has_alpha ? 1.f / 255 : 0.f; + Image3F image(xsize, ysize); + ImageF alpha; + if (has_alpha) { + alpha = ImageF(xsize, ysize); + } + const float* current_sample = pixel_data.data(); + for (size_t y = 0; y < ysize; ++y) { + float* rows[3]; + for (size_t c = 0; c < 3; ++c) { + rows[c] = image.PlaneRow(c, y); + } + float* const alpha_row = has_alpha ? alpha.Row(y) : nullptr; + for (size_t x = 0; x < xsize; ++x) { + for (float* const row : rows) { + row[x] = BufferFormat::ToFloat( + *current_sample++); + } + if (has_alpha) { + alpha_row[x] = alpha_multiplier * + BufferFormat::ToFloat( + *current_sample++); + } + } + } + + ColorEncoding color_encoding; + JXL_RETURN_IF_ERROR(color_encoding.SetICC(std::move(icc))); + io->metadata.m.color_encoding = color_encoding; + io->SetFromImage(std::move(image), color_encoding); + if (has_alpha) { + io->metadata.m.SetAlphaBits(alpha_bits); + io->Main().SetAlpha(std::move(alpha), /*alpha_is_premultiplied=*/false); + } + return true; +} + +} // namespace + +Status SaveJpegXlImage(const gint32 image_id, const gint32 drawable_id, + const gint32 orig_image_id, + const gchar* const filename) { + GimpColorProfile* profile = + gimp_image_get_effective_color_profile(orig_image_id); + gsize icc_size; + const guint8* const icc_bytes = + gimp_color_profile_get_icc_profile(profile, &icc_size); + PaddedBytes icc; + icc.assign(icc_bytes, icc_bytes + icc_size); + g_clear_object(&profile); + + const Babl* format; + Status (*read_buffer)(size_t, size_t, const std::vector&, PaddedBytes, + CodecInOut*); + const bool has_alpha = gimp_drawable_has_alpha(drawable_id); + if (has_alpha) { + format = babl_format("R'G'B'A float"); + read_buffer = &ReadBuffer; + } else { + format = babl_format("R'G'B' float"); + read_buffer = &ReadBuffer; + } + + CodecInOut io; + + GeglBuffer* gegl_buffer = gimp_drawable_get_buffer(drawable_id); + + // TODO(lode): is there a way to query whether the data type if float or int + // from gegl_buffer_get_format instead? + GimpPrecision precision = gimp_image_get_precision(image_id); + if (precision == GIMP_PRECISION_HALF_GAMMA) { + io.metadata.m.bit_depth.bits_per_sample = 16; + io.metadata.m.bit_depth.exponent_bits_per_sample = 5; + } else if (precision == GIMP_PRECISION_FLOAT_GAMMA) { + io.metadata.m.SetFloat32Samples(); + } else { // unsigned integer + // TODO(lode): handle GIMP_PRECISION_DOUBLE_GAMMA. 64-bit per channel is not + // supported by io.metadata.m. + const Babl* native_format = gegl_buffer_get_format(gegl_buffer); + uint32_t bits_per_sample = 8 * + babl_format_get_bytes_per_pixel(native_format) / + babl_format_get_n_components(native_format); + io.metadata.m.SetUintSamples(bits_per_sample); + } + + const GeglRectangle rect = *gegl_buffer_get_extent(gegl_buffer); + std::vector pixel_data(rect.width * rect.height * (3 + has_alpha)); + gegl_buffer_get(gegl_buffer, &rect, 1., format, pixel_data.data(), + GEGL_AUTO_ROWSTRIDE, GEGL_ABYSS_NONE); + g_clear_object(&gegl_buffer); + + JXL_RETURN_IF_ERROR( + read_buffer(rect.width, rect.height, pixel_data, std::move(icc), &io)); + CompressParams params; + PassesEncoderState encoder_state; + PaddedBytes compressed; + ThreadPoolInternal pool; + params.butteraugli_distance = 1.f; + JXL_RETURN_IF_ERROR(EncodeFile(params, &io, &encoder_state, &compressed, + /*aux_out=*/nullptr, &pool)); + JXL_RETURN_IF_ERROR(WriteFile(compressed, filename)); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/plugins/gimp/file-jxl-save.h b/third_party/jpeg-xl/plugins/gimp/file-jxl-save.h new file mode 100644 index 000000000000..0da21ccf09dc --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/file-jxl-save.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PLUGINS_GIMP_FILE_JXL_SAVE_H_ +#define PLUGINS_GIMP_FILE_JXL_SAVE_H_ + +#include + +#include "lib/jxl/base/status.h" + +namespace jxl { + +Status SaveJpegXlImage(gint32 image_id, gint32 drawable_id, + gint32 orig_image_id, const gchar* filename); + +} + +#endif // PLUGINS_GIMP_FILE_JXL_SAVE_H_ diff --git a/third_party/jpeg-xl/plugins/gimp/file-jxl.cc b/third_party/jpeg-xl/plugins/gimp/file-jxl.cc new file mode 100644 index 000000000000..51a9594ca6fc --- /dev/null +++ b/third_party/jpeg-xl/plugins/gimp/file-jxl.cc @@ -0,0 +1,179 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +#include "plugins/gimp/file-jxl-load.h" +#include "plugins/gimp/file-jxl-save.h" + +namespace jxl { +namespace { + +constexpr char kLoadProc[] = "file-jxl-load"; +constexpr char kSaveProc[] = "file-jxl-save"; + +void Query() { + { + static char run_mode_name[] = "run-mode"; + static char run_mode_description[] = "Run mode"; + static char filename_name[] = "filename"; + static char filename_description[] = "The name of the file to load"; + static char raw_filename_name[] = "raw-filename"; + static char raw_filename_description[] = + "The name of the file, as entered by the user"; + static const GimpParamDef load_args[] = { + {GIMP_PDB_INT32, run_mode_name, run_mode_description}, + {GIMP_PDB_STRING, filename_name, filename_description}, + {GIMP_PDB_STRING, raw_filename_name, raw_filename_description}, + }; + static char image_name[] = "image"; + static char image_description[] = "Loaded image"; + static const GimpParamDef load_return_vals[] = { + {GIMP_PDB_IMAGE, image_name, image_description}, + }; + + gimp_install_procedure( + /*name=*/kLoadProc, /*blurb=*/"Loads JPEG XL image files", + /*help=*/"Loads JPEG XL image files", /*author=*/"JPEG XL Project", + /*copyright=*/"JPEG XL Project", /*date=*/"2019", + /*menu_label=*/"JPEG XL image", /*image_types=*/nullptr, + /*type=*/GIMP_PLUGIN, /*n_params=*/G_N_ELEMENTS(load_args), + /*n_return_vals=*/G_N_ELEMENTS(load_return_vals), /*params=*/load_args, + /*return_vals=*/load_return_vals); + gimp_register_file_handler_mime(kLoadProc, "image/jxl"); + gimp_register_magic_load_handler(kLoadProc, "jxl", "", + "0,string,\x0A\x04\x42\xD2\xD5\x4E\x12,0," + "string,\xFF\x0A,0,string,\xFF\x58"); + } + + { + static char run_mode_name[] = "run-mode"; + static char run_mode_description[] = "Run mode"; + static char image_name[] = "image"; + static char image_description[] = "Input image"; + static char drawable_name[] = "drawable"; + static char drawable_description[] = "Drawable to save"; + static char filename_name[] = "filename"; + static char filename_description[] = "The name of the file to save"; + static char raw_filename_name[] = "raw-filename"; + static char raw_filename_description[] = "The name of the file to save"; + static const GimpParamDef save_args[] = { + {GIMP_PDB_INT32, run_mode_name, run_mode_description}, + {GIMP_PDB_IMAGE, image_name, image_description}, + {GIMP_PDB_DRAWABLE, drawable_name, drawable_description}, + {GIMP_PDB_STRING, filename_name, filename_description}, + {GIMP_PDB_STRING, raw_filename_name, raw_filename_description}, + }; + + gimp_install_procedure( + /*name=*/kSaveProc, /*blurb=*/"Saves JPEG XL image files", + /*help=*/"Saves JPEG XL image files", /*author=*/"JPEG XL Project", + /*copyright=*/"JPEG XL Project", /*date=*/"2019", + /*menu_label=*/"JPEG XL image", /*image_types=*/"RGB*, GRAY*", + /*type=*/GIMP_PLUGIN, /*n_params=*/G_N_ELEMENTS(save_args), + /*n_return_vals=*/0, /*params=*/save_args, + /*return_vals=*/nullptr); + gimp_register_file_handler_mime(kSaveProc, "image/jxl"); + gimp_register_save_handler(kSaveProc, "jxl", ""); + } +} + +void Run(const gchar* const name, const gint nparams, + const GimpParam* const params, gint* const nreturn_vals, + GimpParam** const return_vals) { + gegl_init(nullptr, nullptr); + + static GimpParam values[2]; + + *nreturn_vals = 1; + *return_vals = values; + + values[0].type = GIMP_PDB_STATUS; + values[0].data.d_status = GIMP_PDB_EXECUTION_ERROR; + + if (strcmp(name, kLoadProc) == 0) { + if (nparams != 3) { + values[0].data.d_status = GIMP_PDB_CALLING_ERROR; + return; + } + + const gchar* const filename = params[1].data.d_string; + gint32 image_id; + if (!LoadJpegXlImage(filename, &image_id)) { + values[0].data.d_status = GIMP_PDB_EXECUTION_ERROR; + return; + } + + *nreturn_vals = 2; + values[0].data.d_status = GIMP_PDB_SUCCESS; + values[1].type = GIMP_PDB_IMAGE; + values[1].data.d_image = image_id; + } else if (strcmp(name, kSaveProc) == 0) { +#if !JPEGXL_ENABLE_GIMP_SAVING + *nreturn_vals = 2; + values[0].data.d_status = GIMP_PDB_EXECUTION_ERROR; + values[1].type = GIMP_PDB_STRING; + static gchar compatibility_message[] = + "Saving is disabled in this build of the JPEG XL plugin. Rebuild it " + "with -DJPEGXL_ENABLE_GIMP_SAVING=1 to enable the functionality, but " + "be aware that files created in this fashion may not work in future " + "versions of the decoder."; + values[1].data.d_string = compatibility_message; + return; +#endif + if (nparams != 5) { + values[0].data.d_status = GIMP_PDB_CALLING_ERROR; + return; + } + + gint32 image_id = params[1].data.d_image; + gint32 drawable_id = params[2].data.d_drawable; + const gchar* const filename = params[3].data.d_string; + const gint32 orig_image_id = image_id; + const GimpExportReturn export_result = gimp_export_image( + &image_id, &drawable_id, "JPEG XL", + static_cast(GIMP_EXPORT_CAN_HANDLE_RGB | + GIMP_EXPORT_CAN_HANDLE_ALPHA)); + switch (export_result) { + case GIMP_EXPORT_CANCEL: + values[0].data.d_status = GIMP_PDB_CANCEL; + return; + case GIMP_EXPORT_IGNORE: + break; + case GIMP_EXPORT_EXPORT: + break; + } + gimp_progress_init_printf(_("Saving JPEG XL file \"%s\""), filename); + if (!SaveJpegXlImage(image_id, drawable_id, orig_image_id, filename)) { + return; + } + if (image_id != orig_image_id) { + gimp_image_delete(image_id); + } + values[0].data.d_status = GIMP_PDB_SUCCESS; + } +} + +} // namespace +} // namespace jxl + +static const GimpPlugInInfo PLUG_IN_INFO = {nullptr, nullptr, &jxl::Query, + &jxl::Run}; + +MAIN() diff --git a/third_party/jpeg-xl/plugins/mime/CMakeLists.txt b/third_party/jpeg-xl/plugins/mime/CMakeLists.txt new file mode 100644 index 000000000000..831f3d692d3a --- /dev/null +++ b/third_party/jpeg-xl/plugins/mime/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +install(FILES image-jxl.xml DESTINATION share/mime/packages/) diff --git a/third_party/jpeg-xl/plugins/mime/README.md b/third_party/jpeg-xl/plugins/mime/README.md new file mode 100644 index 000000000000..4b5373ce501e --- /dev/null +++ b/third_party/jpeg-xl/plugins/mime/README.md @@ -0,0 +1,20 @@ +## JPEG XL MIME type + +If not already installed by the [Installing section of README.md](../../README.md#installing), then it can be done manually: + +### Install +```bash +sudo xdg-mime install --novendor image-jxl.xml +``` + +Then run: +``` +update-mime --local +``` + + +### Uninstall +```bash +sudo xdg-mime uninstall image-jxl.xml +``` + diff --git a/third_party/jpeg-xl/plugins/mime/image-jxl.xml b/third_party/jpeg-xl/plugins/mime/image-jxl.xml new file mode 100644 index 000000000000..cab9018c7da4 --- /dev/null +++ b/third_party/jpeg-xl/plugins/mime/image-jxl.xml @@ -0,0 +1,13 @@ + + + + JPEG XL image + image JPEG XL + JPEG XL afbeelding + + + + + + + diff --git a/third_party/jpeg-xl/third_party/CMakeLists.txt b/third_party/jpeg-xl/third_party/CMakeLists.txt new file mode 100644 index 000000000000..ec789ff084a7 --- /dev/null +++ b/third_party/jpeg-xl/third_party/CMakeLists.txt @@ -0,0 +1,223 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Enable tests in third_party/ as well. +enable_testing() +include(CTest) + +if(BUILD_TESTING) +# Add GTest from source and alias it to what the find_package(GTest) workflow +# defines. Omitting googletest/ directory would require it to be available in +# the base system instead, but it would work just fine. This makes packages +# using GTest and calling find_package(GTest) actually work. +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/googletest/CMakeLists.txt" AND + NOT JPEGXL_FORCE_SYSTEM_GTEST) + add_subdirectory(googletest EXCLUDE_FROM_ALL) + + set(GTEST_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/googletest/googletest") + set(GTEST_INCLUDE_DIR "$" + CACHE STRING "") + set(GMOCK_INCLUDE_DIR "$") + set(GTEST_LIBRARY "$") + set(GTEST_MAIN_LIBRARY "$") + add_library(GTest::GTest ALIAS gtest) + add_library(GTest::Main ALIAS gtest_main) + + set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE TRUE) + + # googletest doesn't compile clean with clang-cl (-Wundef) + if (WIN32 AND ${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set_target_properties(gtest PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gmock PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gtest_main PROPERTIES COMPILE_FLAGS "-Wno-error") + set_target_properties(gmock_main PROPERTIES COMPILE_FLAGS "-Wno-error") + endif () + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/googletest/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.googletest COPYONLY) +else() + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/googletest/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.googletest COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() +find_package(GTest) +if (NOT GTEST_FOUND) + set(BUILD_TESTING OFF CACHE BOOL "Build tests" FORCE) + message(SEND_ERROR "GTest not found. Install googletest package " + "(libgtest-dev) in the system or download googletest to " + "third_party/googletest from https://github.com/google/googletest ." + "To disable tests instead re-run cmake with -DBUILD_TESTING=OFF.") +endif() # NOT GTEST_FOUND + +# Look for gmock in the system too. +if (NOT DEFINED GMOCK_INCLUDE_DIR) + find_path( + GMOCK_INCLUDE_DIR "gmock/gmock.h" + HINTS ${GTEST_INCLUDE_DIRS}) + if ("${GMOCK_INCLUDE_DIR}" STREQUAL "GMOCK_INCLUDE_DIR-NOTFOUND") + set(BUILD_TESTING OFF CACHE BOOL "Build tests" FORCE) + message(SEND_ERROR "GMock not found. Install googletest package " + "(libgmock-dev) in the system or download googletest to " + "third_party/googletest from https://github.com/google/googletest ." + "To disable tests instead re-run cmake with -DBUILD_TESTING=OFF.") + else() + message(STATUS "Found GMock: ${GMOCK_INCLUDE_DIR}") + endif() # GMOCK_INCLUDE_DIR-NOTFOUND +endif() # NOT DEFINED GMOCK_INCLUDE_DIR +endif() # BUILD_TESTING + +# Highway +set(HWY_SYSTEM_GTEST ON CACHE INTERNAL "") +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/highway/CMakeLists.txt" AND + NOT JPEGXL_FORCE_SYSTEM_HWY) + add_subdirectory(highway) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/highway/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.highway COPYONLY) +else() + pkg_check_modules(HWY libhwy) + if (NOT HWY_FOUND) + message(FATAL_ERROR + "Highway library (hwy) not found. Install libhwy-dev or download it " + "to third_party/highway from https://github.com/google/highway . " + "Highway is required to build JPEG XL. You can run " + "${PROJECT_SOURCE_DIR}/deps.sh to download this dependency.") + endif() + add_library(hwy INTERFACE IMPORTED GLOBAL) + if(${CMAKE_VERSION} VERSION_LESS "3.13.5") + set_property(TARGET hwy PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${HWY_INCLUDE_DIR}) + target_link_libraries(hwy INTERFACE ${HWY_LDFLAGS}) + set_property(TARGET hwy PROPERTY INTERFACE_COMPILE_OPTIONS ${HWY_CFLAGS_OTHER}) + else() + target_include_directories(hwy INTERFACE ${HWY_INCLUDE_DIRS}) + target_link_libraries(hwy INTERFACE ${HWY_LINK_LIBRARIES}) + target_link_options(hwy INTERFACE ${HWY_LDFLAGS_OTHER}) + target_compile_options(hwy INTERFACE ${HWY_CFLAGS_OTHER}) + endif() + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libhwy-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.highway COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +endif() + +# lodepng +if( NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/lodepng/lodepng.h" ) + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") +endif() +include(lodepng.cmake) +configure_file("${CMAKE_CURRENT_SOURCE_DIR}/lodepng/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.lodepng COPYONLY) + +# brotli +if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/brotli/c/include/brotli/decode.h" OR + JPEGXL_FORCE_SYSTEM_BROTLI) + # Create the libbrotli* and libbrotli*-static targets. + foreach(brlib IN ITEMS brotlienc brotlidec brotlicommon) + # Use uppercase like "BROTLIENC" for the cmake variables + string(TOUPPER "${brlib}" BRPREFIX) + pkg_check_modules(${BRPREFIX} lib${brlib}) + if (${BRPREFIX}_FOUND) + if(${CMAKE_VERSION} VERSION_LESS "3.13.5") + add_library(${brlib} INTERFACE IMPORTED GLOBAL) + set_property(TARGET ${brlib} PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${${BRPREFIX}_INCLUDE_DIR}) + target_link_libraries(${brlib} INTERFACE ${${BRPREFIX}_LDFLAGS}) + set_property(TARGET ${brlib} PROPERTY INTERFACE_COMPILE_OPTIONS ${${BRPREFIX}_CFLAGS_OTHER}) + + add_library(${brlib}-static INTERFACE IMPORTED GLOBAL) + set_property(TARGET ${brlib}-static PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${${BRPREFIX}_INCLUDE_DIR}) + target_link_libraries(${brlib}-static INTERFACE ${${BRPREFIX}_LDFLAGS}) + set_property(TARGET ${brlib}-static PROPERTY INTERFACE_COMPILE_OPTIONS ${${BRPREFIX}_CFLAGS_OTHER}) + else() + add_library(${brlib} INTERFACE IMPORTED GLOBAL) + target_include_directories(${brlib} + INTERFACE ${${BRPREFIX}_INCLUDE_DIRS}) + target_link_libraries(${brlib} + INTERFACE ${${BRPREFIX}_LINK_LIBRARIES}) + target_link_options(${brlib} + INTERFACE ${${BRPREFIX}_LDFLAGS_OTHER}) + target_compile_options(${brlib} + INTERFACE ${${BRPREFIX}_CFLAGS_OTHER}) + + # TODO(deymo): Remove the -static library versions, this target is + # currently needed by brunsli.cmake. When importing it this way, the + # brotli*-static target is just an alias. + add_library(${brlib}-static ALIAS ${brlib}) + endif() + endif() + unset(BRPREFIX) + endforeach() + + if (BROTLIENC_FOUND AND BROTLIDEC_FOUND AND BROTLICOMMON_FOUND) + set(BROTLI_FOUND 1) + else() + set(BROTLI_FOUND 0) + endif() + + if (NOT BROTLI_FOUND) + message(FATAL_ERROR + "Brotli not found, install brotli-dev or download brotli source code to" + " third_party/brotli from https://github.com/google/brotli. You can use" + " ${PROJECT_SOURCE_DIR}/deps.sh to download this dependency.") + endif () + if(JPEGXL_DEP_LICENSE_DIR) + configure_file("${JPEGXL_DEP_LICENSE_DIR}/libbrotli-dev/copyright" + ${PROJECT_BINARY_DIR}/LICENSE.brotli COPYONLY) + endif() # JPEGXL_DEP_LICENSE_DIR +else() + # Compile brotli from sources. + set(BROTLI_DISABLE_TESTS ON CACHE STRING "Disable Brotli tests") + add_subdirectory(brotli EXCLUDE_FROM_ALL) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/brotli/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.brotli COPYONLY) + if(BROTLI_EMSCRIPTEN) + # Brotli only defines the -static targets when using emscripten. + foreach(brlib IN ITEMS brotlienc brotlidec brotlicommon) + add_library(${brlib} ALIAS ${brlib}-static) + endforeach() + endif() # BROTLI_EMSCRIPTEN +endif() + +# *cms +if (JPEGXL_ENABLE_SKCMS) + if( NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/skcms/skcms.h" ) + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + include(skcms.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/skcms/LICENSE" + ${PROJECT_BINARY_DIR}/LICENSE.skcms COPYONLY) +endif () +if (JPEGXL_ENABLE_VIEWERS OR NOT JPEGXL_ENABLE_SKCMS) + if( NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/lcms/.git" ) + message(SEND_ERROR "Please run git submodule update --init") + endif() + include(lcms2.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/lcms/COPYING" + ${PROJECT_BINARY_DIR}/LICENSE.lcms COPYONLY) +endif() + +# sjpeg +if (JPEGXL_ENABLE_SJPEG) + if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/sjpeg/CMakeLists.txt") + message(FATAL_ERROR "Please run ${PROJECT_SOURCE_DIR}/deps.sh to fetch the " + "build dependencies.") + endif() + include(sjpeg.cmake) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/sjpeg/COPYING" + ${PROJECT_BINARY_DIR}/LICENSE.sjpeg COPYONLY) +endif () + diff --git a/third_party/jpeg-xl/third_party/HEVCSoftware/README.md b/third_party/jpeg-xl/third_party/HEVCSoftware/README.md new file mode 100644 index 000000000000..70ebaeba33f6 --- /dev/null +++ b/third_party/jpeg-xl/third_party/HEVCSoftware/README.md @@ -0,0 +1,2 @@ +This directory contains modified configuration files from the reference HEVC +encoder, the source code of which can be found at: https://hevc.hhi.fraunhofer.de/svn/svn_HEVCSoftware/ diff --git a/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/LICENSE b/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/LICENSE new file mode 100644 index 000000000000..a9d8844e4239 --- /dev/null +++ b/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/LICENSE @@ -0,0 +1,31 @@ +The copyright in this software is being made available under the BSD +License, included below. This software may be subject to other third party +and contributor rights, including patent rights, and no such rights are +granted under this license.   + +Copyright (c) 2010-2017, ITU/ISO/IEC +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the ITU/ISO/IEC nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg b/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg new file mode 100644 index 000000000000..6247fdf48477 --- /dev/null +++ b/third_party/jpeg-xl/third_party/HEVCSoftware/cfg/encoder_intra_main_scc_10.cfg @@ -0,0 +1,136 @@ +#======== File I/O ===================== +BitstreamFile : str.bin +ReconFile : rec.yuv + +#======== Profile definition ============== +Profile : main-SCC # Profile name to use for encoding. Use main (for FDIS main), main10 (for FDIS main10), main-still-picture, main-RExt, high-throughput-RExt, main-SCC +Tier : main # Tier to use for interpretation of --Level (main or high only)" + +#======== Unit definition ================ +MaxCUWidth : 64 # Maximum coding unit width in pixel +MaxCUHeight : 64 # Maximum coding unit height in pixel +MaxPartitionDepth : 4 # Maximum coding unit depth +QuadtreeTULog2MaxSize : 5 # Log2 of maximum transform size for + # quadtree-based TU coding (2...6) +QuadtreeTULog2MinSize : 2 # Log2 of minimum transform size for + # quadtree-based TU coding (2...6) +QuadtreeTUMaxDepthInter : 3 +QuadtreeTUMaxDepthIntra : 3 + +#======== Coding Structure ============= +IntraPeriod : 1 # Period of I-Frame ( -1 = only first) +DecodingRefreshType : 1 # Random Accesss 0:none, 1:CRA, 2:IDR, 3:Recovery Point SEI +GOPSize : 1 # GOP Size (number of B slice = GOPSize-1) +ReWriteParamSetsFlag : 1 # Write parameter sets with every IRAP +# Type POC QPoffset QPfactor tcOffsetDiv2 betaOffsetDiv2 temporal_id #ref_pics_active #ref_pics reference pictures + +#=========== Motion Search ============= +FastSearch : 1 # 0:Full search 1:TZ search +SearchRange : 64 # (0: Search range is a Full frame) +HadamardME : 1 # Use of hadamard measure for fractional ME +FEN : 1 # Fast encoder decision +FDM : 1 # Fast Decision for Merge RD cost + +#======== Quantization ============= +QP : 32 # Quantization parameter(0-51) +MaxDeltaQP : 0 # CU-based multi-QP optimization +MaxCuDQPDepth : 0 # Max depth of a minimum CuDQP for sub-LCU-level delta QP +DeltaQpRD : 0 # Slice-based multi-QP optimization +RDOQ : 1 # RDOQ +RDOQTS : 1 # RDOQ for transform skip +CbQpOffset : 6 +CrQpOffset : 6 + +#=========== Deblock Filter ============ +LoopFilterOffsetInPPS : 1 # Dbl params: 0=varying params in SliceHeader, param = base_param + GOP_offset_param; 1 (default) =constant params in PPS, param = base_param) +LoopFilterDisable : 0 # Disable deblocking filter (0=Filter, 1=No Filter) +LoopFilterBetaOffset_div2 : 0 # base_param: -6 ~ 6 +LoopFilterTcOffset_div2 : 0 # base_param: -6 ~ 6 +DeblockingFilterMetric : 0 # blockiness metric (automatically configures deblocking parameters in bitstream). Applies slice-level loop filter offsets (LoopFilterOffsetInPPS and LoopFilterDisable must be 0) + +#=========== Misc. ============ +InternalBitDepth : 10 # codec operating bit-depth + +#=========== Coding Tools ================= +SAO : 1 # Sample adaptive offset (0: OFF, 1: ON) +AMP : 1 # Asymmetric motion partitions (0: OFF, 1: ON) +TransformSkip : 1 # Transform skipping (0: OFF, 1: ON) +TransformSkipFast : 1 # Fast Transform skipping (0: OFF, 1: ON) +SAOLcuBoundary : 0 # SAOLcuBoundary using non-deblocked pixels (0: OFF, 1: ON) + +#============ Slices ================ +SliceMode : 0 # 0: Disable all slice options. + # 1: Enforce maximum number of LCU in an slice, + # 2: Enforce maximum number of bytes in an 'slice' + # 3: Enforce maximum number of tiles in a slice +SliceArgument : 1500 # Argument for 'SliceMode'. + # If SliceMode==1 it represents max. SliceGranularity-sized blocks per slice. + # If SliceMode==2 it represents max. bytes per slice. + # If SliceMode==3 it represents max. tiles per slice. + +LFCrossSliceBoundaryFlag : 1 # In-loop filtering, including ALF and DB, is across or not across slice boundary. + # 0:not across, 1: across + +#============ PCM ================ +PCMEnabledFlag : 0 # 0: No PCM mode +PCMLog2MaxSize : 5 # Log2 of maximum PCM block size. +PCMLog2MinSize : 3 # Log2 of minimum PCM block size. +PCMInputBitDepthFlag : 1 # 0: PCM bit-depth is internal bit-depth. 1: PCM bit-depth is input bit-depth. +PCMFilterDisableFlag : 0 # 0: Enable loop filtering on I_PCM samples. 1: Disable loop filtering on I_PCM samples. + +#============ Tiles ================ +TileUniformSpacing : 0 # 0: the column boundaries are indicated by TileColumnWidth array, the row boundaries are indicated by TileRowHeight array + # 1: the column and row boundaries are distributed uniformly +NumTileColumnsMinus1 : 0 # Number of tile columns in a picture minus 1 +TileColumnWidthArray : 2 3 # Array containing tile column width values in units of CTU (from left to right in picture) +NumTileRowsMinus1 : 0 # Number of tile rows in a picture minus 1 +TileRowHeightArray : 2 # Array containing tile row height values in units of CTU (from top to bottom in picture) + +LFCrossTileBoundaryFlag : 1 # In-loop filtering is across or not across tile boundary. + # 0:not across, 1: across + +#============ WaveFront ================ +WaveFrontSynchro : 0 # 0: No WaveFront synchronisation (WaveFrontSubstreams must be 1 in this case). + # >0: WaveFront synchronises with the LCU above and to the right by this many LCUs. + +#=========== Quantization Matrix ================= +ScalingList : 0 # ScalingList 0 : off, 1 : default, 2 : file read +ScalingListFile : scaling_list.txt # Scaling List file name. If file is not exist, use Default Matrix. + +#============ Lossless ================ +TransquantBypassEnable : 0 # Value of PPS flag. +CUTransquantBypassFlagForce: 0 # Force transquant bypass mode, when transquant_bypass_enable_flag is enabled + +#=========== RExt ============ +ExtendedPrecision : 0 # Increased internal accuracies to support high bit depths (not valid in V1 profiles) +TransformSkipLog2MaxSize : 2 # Specify transform-skip maximum size. Minimum 2. (not valid in V1 profiles) +ImplicitResidualDPCM : 1 # Enable implicitly signalled residual DPCM for intra (also known as sample-adaptive intra predict) (not valid in V1 profiles) +ExplicitResidualDPCM : 1 # Enable explicitly signalled residual DPCM for inter and intra-block-copy (not valid in V1 profiles) +ResidualRotation : 1 # Enable rotation of transform-skipped and transquant-bypassed TUs through 180 degrees prior to entropy coding (not valid in V1 profiles) +SingleSignificanceMapContext : 1 # Enable, for transform-skipped and transquant-bypassed TUs, the selection of a single significance map context variable for all coefficients (not valid in V1 profiles) +IntraReferenceSmoothing : 1 # 0: Disable use of intra reference smoothing (not valid in V1 profiles). 1: Enable use of intra reference smoothing (same as V1) +GolombRiceParameterAdaptation : 1 # Enable the partial retention of the Golomb-Rice parameter value from one coefficient group to the next +HighPrecisionPredictionWeighting : 1 # Use high precision option for weighted prediction (not valid in V1 profiles) +CrossComponentPrediction : 1 # Enable the use of cross-component prediction (not valid in V1 profiles) + +#=========== SCC ============ +IntraBlockCopyEnabled : 1 # Enable the use of intra block copying +HashBasedIntraBlockCopySearchEnabled : 1 # Use hash based search for intra block copying on 8x8 blocks +IntraBlockCopySearchWidthInCTUs : -1 # Search range for IBC (-1: full frame search) +IntraBlockCopyNonHashSearchWidthInCTUs : 3 # Search range for IBC non-hash search method (i.e., fast/full search) +MSEBasedSequencePSNR : 1 # 0:Emit sequence PSNR only as a linear average of the frame PSNRs, 1: also emit a sequence PSNR based on an average of the frame MSEs +PrintClippedPSNR : 1 # 0:Print lossless PSNR values as 999.99 dB, 1: clip lossless PSNR according to resolution +PrintFrameMSE : 1 # 0:emit only bit count and PSNRs for each frame, 1: also emit MSE values +PrintSequenceMSE : 1 # 0:emit only bit rate and PSNRs for the whole sequence, 1 = also emit MSE values +ColourTransform : 1 # Enable the use of color transform(not valid in V1 profiles) +PaletteMode : 1 # Enable the use of palette mode(not valid in V1 profiles) +PaletteMaxSize : 63 # Supported maximum palette size (not valid in V1 profiles) +PaletteMaxPredSize : 128 # Supported maximum palette predictor size (not valid in V1 profiles) +IntraBoundaryFilterDisabled : 1 # Disable the use of intra boundary filtering (not valid in V1 profiles) +TransquantBypassInferTUSplit : 1 # Infer TU splitting for transquant bypass CUs +PalettePredInSPSEnabled : 0 # Transmit palette predictor initializer in SPS (not valid in V1 profiles) +PalettePredInPPSEnabled : 0 # Transmit palette predictor initializer in PPS (not valid in V1 profiles) +SelectiveRDOQ : 1 # Selective RDOQ + +### DO NOT ADD ANYTHING BELOW THIS LINE ### +### DO NOT DELETE THE EMPTY LINE BELOW ### diff --git a/third_party/jpeg-xl/third_party/dirent.cc b/third_party/jpeg-xl/third_party/dirent.cc new file mode 100644 index 000000000000..81015ed0fb27 --- /dev/null +++ b/third_party/jpeg-xl/third_party/dirent.cc @@ -0,0 +1,142 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(_WIN32) || defined(_WIN64) +#include "third_party/dirent.h" + +#include "lib/jxl/base/status.h" + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include + +#include +#include + +int mkdir(const char* path, mode_t /*mode*/) { + const LPSECURITY_ATTRIBUTES sec = nullptr; + if (!CreateDirectory(path, sec)) { + JXL_NOTIFY_ERROR("Failed to create directory %s", path); + return -1; + } + return 0; +} + +// Modified from code bearing the following notice: +// https://trac.wildfiregames.com/browser/ps/trunk/source/lib/sysdep/os/ +/* Copyright (C) 2010 Wildfire Games. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +struct DIR { + HANDLE hFind; + + WIN32_FIND_DATA findData; // indeterminate if hFind == INVALID_HANDLE_VALUE + + // readdir will return the address of this member. + // (must be stored in DIR to allow multiple independent + // opendir/readdir sequences). + dirent ent; + + // used by readdir to skip the first FindNextFile. + size_t numCalls = 0; +}; + +static bool IsValidDirectory(const char* path) { + const DWORD fileAttributes = GetFileAttributes(path); + + // path not found + if (fileAttributes == INVALID_FILE_ATTRIBUTES) return false; + + // not a directory + if ((fileAttributes & FILE_ATTRIBUTE_DIRECTORY) == 0) return false; + + return true; +} + +DIR* opendir(const char* path) { + if (!IsValidDirectory(path)) { + errno = ENOENT; + return nullptr; + } + + std::unique_ptr d(new DIR); + + // NB: "c:\\path" only returns information about that directory; + // trailing slashes aren't allowed. append "\\*" to retrieve its entries. + std::string searchPath(path); + if (searchPath.back() != '/' && searchPath.back() != '\\') { + searchPath += '\\'; + } + searchPath += '*'; + + // (we don't defer FindFirstFile until readdir because callers + // expect us to return 0 if directory reading will/did fail.) + d->hFind = FindFirstFile(searchPath.c_str(), &d->findData); + if (d->hFind != INVALID_HANDLE_VALUE) return d.release(); + if (GetLastError() == ERROR_NO_MORE_FILES) return d.release(); // empty + + JXL_NOTIFY_ERROR("Failed to open directory %s", searchPath.c_str()); + return nullptr; +} + +int closedir(DIR* dir) { + delete dir; + return 0; +} + +dirent* readdir(DIR* d) { + // "empty" case from opendir + if (d->hFind == INVALID_HANDLE_VALUE) return nullptr; + + // until end of directory or a valid entry was found: + for (;;) { + if (d->numCalls++ != 0) // (skip first call to FindNextFile - see opendir) + { + if (!FindNextFile(d->hFind, &d->findData)) { + JXL_ASSERT(GetLastError() == ERROR_NO_MORE_FILES); + SetLastError(0); + return nullptr; // end of directory or error + } + } + + // only return non-hidden and non-system entries + if ((d->findData.dwFileAttributes & + (FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM)) == 0) { + d->ent.d_name = d->findData.cFileName; + return &d->ent; + } + } +} + +#endif // #if defined(_WIN32) || defined(_WIN64) diff --git a/third_party/jpeg-xl/third_party/dirent.h b/third_party/jpeg-xl/third_party/dirent.h new file mode 100644 index 000000000000..37a08f425be9 --- /dev/null +++ b/third_party/jpeg-xl/third_party/dirent.h @@ -0,0 +1,49 @@ +// Copyright (c) the JPEG XL Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIB_JXL_THIRD_PARTY_DIRENT_H_ +#define LIB_JXL_THIRD_PARTY_DIRENT_H_ + +// Emulates POSIX readdir for Windows + +#if defined(_WIN32) || defined(_WIN64) + +#include // S_IFREG + +#ifndef _MODE_T_ +typedef unsigned int mode_t; +#endif // _MODE_T_ +int mkdir(const char* path, mode_t mode); + +struct dirent { + char* d_name; // no path +}; + +#define stat _stat64 + +#ifndef S_ISDIR +#define S_ISDIR(m) (m & S_IFDIR) +#endif // S_ISDIR + +#ifndef S_ISREG +#define S_ISREG(m) (m & S_IFREG) +#endif // S_ISREG + +struct DIR; +DIR* opendir(const char* path); +int closedir(DIR* dir); +dirent* readdir(DIR* d); + +#endif // #if defined(_WIN32) || defined(_WIN64) +#endif // LIB_JXL_THIRD_PARTY_DIRENT_H_ diff --git a/third_party/jpeg-xl/third_party/lcms2.cmake b/third_party/jpeg-xl/third_party/lcms2.cmake new file mode 100644 index 000000000000..783697c3b5d3 --- /dev/null +++ b/third_party/jpeg-xl/third_party/lcms2.cmake @@ -0,0 +1,63 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(lcms2 STATIC + lcms/src/cmsalpha.c + lcms/src/cmscam02.c + lcms/src/cmscgats.c + lcms/src/cmscnvrt.c + lcms/src/cmserr.c + lcms/src/cmsgamma.c + lcms/src/cmsgmt.c + lcms/src/cmshalf.c + lcms/src/cmsintrp.c + lcms/src/cmsio0.c + lcms/src/cmsio1.c + lcms/src/cmslut.c + lcms/src/cmsmd5.c + lcms/src/cmsmtrx.c + lcms/src/cmsnamed.c + lcms/src/cmsopt.c + lcms/src/cmspack.c + lcms/src/cmspcs.c + lcms/src/cmsplugin.c + lcms/src/cmsps2.c + lcms/src/cmssamp.c + lcms/src/cmssm.c + lcms/src/cmstypes.c + lcms/src/cmsvirt.c + lcms/src/cmswtpnt.c + lcms/src/cmsxform.c + lcms/src/lcms2_internal.h +) +target_include_directories(lcms2 + PUBLIC "${CMAKE_CURRENT_LIST_DIR}/lcms/include") +# This warning triggers with gcc-8. +if (${CMAKE_C_COMPILER_ID} MATCHES "GNU") +target_compile_options(lcms2 + PRIVATE + # gcc-only flags. + -Wno-stringop-truncation + -Wno-strict-aliasing +) +endif() +# By default LCMS uses sizeof(void*) for memory alignment, but in arm 32-bits we +# can't access doubles not aligned to 8 bytes. This forces the alignment to 8 +# bytes. +target_compile_definitions(lcms2 + PRIVATE "-DCMS_PTR_ALIGNMENT=8") +target_compile_definitions(lcms2 + PUBLIC "-DCMS_NO_REGISTER_KEYWORD=1") + +set_property(TARGET lcms2 PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/third_party/jpeg-xl/third_party/lodepng.cmake b/third_party/jpeg-xl/third_party/lodepng.cmake new file mode 100644 index 000000000000..1d3850eb5884 --- /dev/null +++ b/third_party/jpeg-xl/third_party/lodepng.cmake @@ -0,0 +1,22 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(lodepng STATIC + lodepng/lodepng.cpp + lodepng/lodepng.h +) +# This library can be included into position independent binaries. +set_target_properties(lodepng PROPERTIES POSITION_INDEPENDENT_CODE TRUE) +target_include_directories(lodepng + PUBLIC "${CMAKE_CURRENT_LIST_DIR}/lodepng") diff --git a/third_party/jpeg-xl/third_party/sjpeg.cmake b/third_party/jpeg-xl/third_party/sjpeg.cmake new file mode 100644 index 000000000000..152a86ebe719 --- /dev/null +++ b/third_party/jpeg-xl/third_party/sjpeg.cmake @@ -0,0 +1,23 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# We need to CACHE the SJPEG_BUILD_EXAMPLES to not be removed by the option() +# inside SJPEG. +set(SJPEG_BUILD_EXAMPLES NO CACHE BOOL "Examples") +# SJPEG uses OpenGL which throws a warning if multiple options are installed. +# This setting makes it prefer the new version. +set(OpenGL_GL_PREFERENCE GLVND) + +add_subdirectory(sjpeg EXCLUDE_FROM_ALL) +target_include_directories(sjpeg PUBLIC "${CMAKE_CURRENT_LIST_DIR}/sjpeg/src/") diff --git a/third_party/jpeg-xl/third_party/skcms.cmake b/third_party/jpeg-xl/third_party/skcms.cmake new file mode 100644 index 000000000000..78e4028b80e5 --- /dev/null +++ b/third_party/jpeg-xl/third_party/skcms.cmake @@ -0,0 +1,24 @@ +# Copyright (c) the JPEG XL Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(skcms STATIC skcms/skcms.cc) +target_include_directories(skcms PUBLIC "${CMAKE_CURRENT_LIST_DIR}/skcms/") + +include(CheckCXXCompilerFlag) +check_cxx_compiler_flag("-Wno-psabi" CXX_WPSABI_SUPPORTED) +if(CXX_WPSABI_SUPPORTED) + target_compile_options(skcms PRIVATE -Wno-psabi) +endif() + +set_property(TARGET skcms PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/toolkit/moz.configure b/toolkit/moz.configure index 4ab794481fce..33e09d889d30 100644 --- a/toolkit/moz.configure +++ b/toolkit/moz.configure @@ -555,6 +555,20 @@ set_define("MOZ_DAV1D_ASM", dav1d_asm) set_config("MOZ_AV1", av1) set_define("MOZ_AV1", av1) +# JXL Image Codec Support +# ============================================================== +option("--disable-jxl", help="Disable jxl image support") + + +@depends("--disable-jxl") +def jxl(value): + if value: + return True + + +set_config("MOZ_JXL", jxl) +set_define("MOZ_JXL", jxl) + # Built-in fragmented MP4 support. # ============================================================== option(