[C#, CPP] Introduce Float16/BFloat16 support and tests for C#, C++ (#16506)
### Description Introduce `Float16/BFloat16` support for C# and C++ APIs. User should be able to perform conversions from `float` to/from `Float16/BFloat16`, compare values and tests for `NaN, Inifnity, and whether the number is denormalized.` ### Motivation and Context User filed issues such as: https://github.com/microsoft/onnxruntime/issues/14303
This commit is contained in:
Родитель
77b45c6503
Коммит
853c4ff0a5
|
@ -21,6 +21,7 @@ endif()
|
||||||
set(ONNXRUNTIME_PUBLIC_HEADERS
|
set(ONNXRUNTIME_PUBLIC_HEADERS
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
|
||||||
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h"
|
||||||
|
@ -38,6 +39,7 @@ macro(get_mobile_api_headers _HEADERS)
|
||||||
set(${_HEADERS}
|
set(${_HEADERS}
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
|
||||||
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h"
|
||||||
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
|
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,15 @@ if(NOT onnxruntime_DISABLE_ABSEIL)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
set(EIGEN_NATVIS_FILE ${eigen_SOURCE_DIR}/debug/msvc/eigen.natvis)
|
||||||
|
if (EXISTS ${EIGEN_NATVIS_FILE})
|
||||||
|
target_sources(
|
||||||
|
onnxruntime_common
|
||||||
|
INTERFACE $<BUILD_INTERFACE:${EIGEN_NATVIS_FILE}>)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
onnxruntime_add_include_to_target(onnxruntime_common date_interface WIL::WIL)
|
onnxruntime_add_include_to_target(onnxruntime_common date_interface WIL::WIL)
|
||||||
target_include_directories(onnxruntime_common
|
target_include_directories(onnxruntime_common
|
||||||
PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}
|
PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}
|
||||||
|
|
|
@ -56,6 +56,13 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_framework_srcs})
|
||||||
|
|
||||||
onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_srcs})
|
onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_srcs})
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
set(ORT_FRAMEWORK_NATVIS_FILE "onnxruntime_framework.natvis")
|
||||||
|
target_sources(
|
||||||
|
onnxruntime_framework
|
||||||
|
INTERFACE $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/${ORT_FRAMEWORK_NATVIS_FILE}>)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(onnxruntime_ENABLE_INSTRUMENT)
|
if(onnxruntime_ENABLE_INSTRUMENT)
|
||||||
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
|
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
|
||||||
|
<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
|
||||||
|
<Type Name="onnxruntime::MLFloat16">
|
||||||
|
<Intrinsic Name="_negative" Expression="(val & 0x8000) == 1"/>
|
||||||
|
<Intrinsic Name="_strip_sign" Expression="(val & ~0x8000)"/>
|
||||||
|
<Intrinsic Name="_is_nan" Expression="(_strip_sign() > 0x7C00)"/>
|
||||||
|
<Intrinsic Name="_is_finite" Expression="(_strip_sign() < 0x7C00)"/>
|
||||||
|
<Intrinsic Name="_is_normal" Expression="(_is_finite() && (val != 0)) && ((val & 0x7C00) != 0)"/>
|
||||||
|
<Intrinsic Name="_biased_exponent" Expression="(val >> 10) & (0x7C00 >> 10)"/>
|
||||||
|
<Intrinsic Name="_exponent" Expression="(int16_t)(_biased_exponent() - 15)"/>
|
||||||
|
<Intrinsic Name="_significand" Expression="(val & 0x03FF)"/>
|
||||||
|
<DisplayString>{{val={ val }}}</DisplayString>
|
||||||
|
<Expand>
|
||||||
|
<Item Name="[Negative]" ExcludeView="simple">_negative()</Item>
|
||||||
|
<Item Name="[IsNan]" ExcludeView="simple" Condition="_is_nan()">true</Item>
|
||||||
|
<Item Name="[IsFinite]" ExcludeView="simple">_is_finite()</Item>
|
||||||
|
<Item Name="[IsNormal]" ExcludeView="simple">_is_normal()</Item>
|
||||||
|
<Item Name="[uint16_t]" ExcludeView="simple">val</Item>
|
||||||
|
<Item Name="[Exponent]" ExcludeView="simple">_exponent()</Item>
|
||||||
|
<Item Name="[Biased Exponent]" ExcludeView="simple">_biased_exponent()</Item>
|
||||||
|
<Item Name="[Significand]" ExcludeView="simple">_significand()</Item>
|
||||||
|
</Expand>
|
||||||
|
</Type>
|
||||||
|
|
||||||
|
<Type Name="onnxruntime::BFloat16">
|
||||||
|
<Intrinsic Name="_negative" Expression="(val & 0x8000) == 1"/>
|
||||||
|
<Intrinsic Name="_strip_sign" Expression="(val & ~0x8000)"/>
|
||||||
|
<Intrinsic Name="_is_nan" Expression="(_strip_sign() > 0x7F80)"/>
|
||||||
|
<Intrinsic Name="_is_finite" Expression="(_strip_sign() < 0x7F80)"/>
|
||||||
|
<Intrinsic Name="_is_normal" Expression="(_is_finite() && (val != 0)) && ((val & 0x7F80) != 0)"/>
|
||||||
|
<Intrinsic Name="_biased_exponent" Expression="(val >> 7) & (0x7F80 >> 7)"/>
|
||||||
|
<Intrinsic Name="_exponent" Expression="(int16_t)(_biased_exponent() - 127)"/>
|
||||||
|
<Intrinsic Name="_significand" Expression="(val & 0x007F)"/>
|
||||||
|
<DisplayString>{{val={ val }}}</DisplayString>
|
||||||
|
<Expand>
|
||||||
|
<Item Name="[Negative]" ExcludeView="simple">_negative()</Item>
|
||||||
|
<Item Name="[IsNormal]" ExcludeView="simple">_is_normal()</Item>
|
||||||
|
<Item Name="[IsNan]" ExcludeView="simple" Condition="_is_nan()">true</Item>
|
||||||
|
<Item Name="[IsFinite]" ExcludeView="simple">_is_finite()</Item>
|
||||||
|
<Item Name="[uint16_t]" ExcludeView="simple">val</Item>
|
||||||
|
<Item Name="[Exponent]" ExcludeView="simple">_exponent()</Item>
|
||||||
|
<Item Name="[Biased Exponent]" ExcludeView="simple">_biased_exponent()</Item>
|
||||||
|
<Item Name="[Significand]" ExcludeView="simple">_significand()</Item>
|
||||||
|
</Expand>
|
||||||
|
</Type>
|
||||||
|
</AutoVisualizer>
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -44,174 +44,6 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
|
||||||
DataTypeMax = 17
|
DataTypeMax = 17
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// This value type represents A Float16 value
|
|
||||||
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
|
||||||
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
|
||||||
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
|
|
||||||
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
|
||||||
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
|
||||||
/// </summary>
|
|
||||||
public struct Float16
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// float16 representation bits
|
|
||||||
/// </summary>
|
|
||||||
public ushort value;
|
|
||||||
/// <summary>
|
|
||||||
/// Ctor
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="v"></param>
|
|
||||||
public Float16(ushort v)
|
|
||||||
{
|
|
||||||
value = v;
|
|
||||||
}
|
|
||||||
/// <summary>
|
|
||||||
/// Converts to ushort
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="f">instance of Float16</param>
|
|
||||||
/// <returns>value member</returns>
|
|
||||||
public static implicit operator ushort(Float16 f) { return f.value; }
|
|
||||||
/// <summary>
|
|
||||||
/// Converts a 16-bit unsigned integer to a Float16.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="value">A 16-bit unsigned integer.</param>
|
|
||||||
/// <returns>A Float16 that represents the converted 16-bit unsigned integer.</returns>
|
|
||||||
public static implicit operator Float16(ushort value) { return new Float16(value); }
|
|
||||||
/// <summary>
|
|
||||||
/// Compares values of two Float16 for binary equality
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="lhs"></param>
|
|
||||||
/// <param name="rhs"></param>
|
|
||||||
/// <returns>result of value comparisons</returns>
|
|
||||||
public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
|
|
||||||
/// <summary>
|
|
||||||
/// Compares values of two Float16 for binary inequality
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="lhs"></param>
|
|
||||||
/// <param name="rhs"></param>
|
|
||||||
/// <returns>result of value comparisons</returns>
|
|
||||||
public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
|
|
||||||
/// <summary>
|
|
||||||
/// Returns a value indicating whether this instance and other Float16 represent the same value.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="other">A Float16 object to compare to this instance.</param>
|
|
||||||
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
|
||||||
public bool Equals(Float16 other)
|
|
||||||
{
|
|
||||||
return (other == this);
|
|
||||||
}
|
|
||||||
/// <summary>
|
|
||||||
/// Returns a value indicating whether this instance and a specified System.Object
|
|
||||||
/// represent the same type and value.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="obj">An System.Object.</param>
|
|
||||||
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
|
|
||||||
public override bool Equals(object obj)
|
|
||||||
{
|
|
||||||
bool result = false;
|
|
||||||
if (obj is Float16)
|
|
||||||
{
|
|
||||||
Float16 fl16 = (Float16)obj;
|
|
||||||
result = (fl16 == this);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/// <summary>
|
|
||||||
/// Returns the hash code for this instance.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>A 32-bit signed integer hash code.</returns>
|
|
||||||
public override int GetHashCode()
|
|
||||||
{
|
|
||||||
return value.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// This value type represents A BFloat16 value
|
|
||||||
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
|
|
||||||
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
|
|
||||||
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
|
|
||||||
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
|
|
||||||
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
|
|
||||||
/// </summary>
|
|
||||||
public struct BFloat16
|
|
||||||
{
|
|
||||||
/// <summary>
|
|
||||||
/// bfloat16 representation bits
|
|
||||||
/// </summary>
|
|
||||||
public ushort value;
|
|
||||||
/// <summary>
|
|
||||||
/// Ctor
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="v"></param>
|
|
||||||
public BFloat16(ushort v)
|
|
||||||
{
|
|
||||||
value = v;
|
|
||||||
}
|
|
||||||
/// <summary>
|
|
||||||
/// Converts to ushort
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="bf">instance of BFloat16</param>
|
|
||||||
/// <returns>value member</returns>
|
|
||||||
public static implicit operator ushort(BFloat16 bf) { return bf.value; }
|
|
||||||
/// <summary>
|
|
||||||
/// Converts a 16-bit unsigned integer to a BFloat16.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="value">A 16-bit unsigned integer.</param>
|
|
||||||
/// <returns>A BFloat16 that represents the converted 16-bit unsigned integer.</returns>
|
|
||||||
public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
|
|
||||||
/// <summary>
|
|
||||||
/// Compares values of two BFloat16 for binary equality
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="lhs"></param>
|
|
||||||
/// <param name="rhs"></param>
|
|
||||||
/// <returns>result of value comparisons</returns>
|
|
||||||
public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
|
|
||||||
/// <summary>
|
|
||||||
/// Compares values of two BFloat16 for binary inequality
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="lhs"></param>
|
|
||||||
/// <param name="rhs"></param>
|
|
||||||
/// <returns>result of value comparisons</returns>
|
|
||||||
public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="other">A BFloat16 object to compare to this instance.</param>
|
|
||||||
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
|
|
||||||
public bool Equals(BFloat16 other)
|
|
||||||
{
|
|
||||||
return (other == this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Returns a value indicating whether this instance and a specified System.Object
|
|
||||||
/// represent the same type and value.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="obj">An System.Object.</param>
|
|
||||||
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
|
|
||||||
public override bool Equals(object obj)
|
|
||||||
{
|
|
||||||
bool result = false;
|
|
||||||
if (obj is BFloat16)
|
|
||||||
{
|
|
||||||
BFloat16 bfl16 = (BFloat16)obj;
|
|
||||||
result = (bfl16 == this);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
/// <summary>
|
|
||||||
/// Returns the hash code for this instance.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>A 32-bit signed integer hash code.</returns>
|
|
||||||
public override int GetHashCode()
|
|
||||||
{
|
|
||||||
return value.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Helps typecasting. Holds Tensor element type traits.
|
/// Helps typecasting. Holds Tensor element type traits.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -1319,12 +1319,13 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
private void TestModelInputFLOAT16()
|
private void TestModelInputFLOAT16()
|
||||||
{
|
{
|
||||||
// model takes 1x5 input of fixed type, echoes back
|
// model takes 1x5 input of fixed type, echoes back
|
||||||
|
Float16[] modelInput = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
|
||||||
|
int[] inputShape = { 1, 5 };
|
||||||
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
|
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
|
||||||
using (var session = new InferenceSession(model))
|
using (var session = new InferenceSession(model))
|
||||||
{
|
{
|
||||||
var container = new List<NamedOnnxValue>();
|
var container = new List<NamedOnnxValue>();
|
||||||
var tensorIn = new DenseTensor<Float16>(
|
var tensorIn = new DenseTensor<Float16>(modelInput, inputShape);
|
||||||
new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 });
|
|
||||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||||
container.Add(nov);
|
container.Add(nov);
|
||||||
using (var res = session.Run(container))
|
using (var res = session.Run(container))
|
||||||
|
@ -1341,13 +1342,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
[Fact(DisplayName = "TestModelInputBFLOAT16")]
|
[Fact(DisplayName = "TestModelInputBFLOAT16")]
|
||||||
private void TestModelInputBFLOAT16()
|
private void TestModelInputBFLOAT16()
|
||||||
{
|
{
|
||||||
|
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
|
||||||
|
new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) };
|
||||||
|
int[] inputShape = { 1, 5 };
|
||||||
// model takes 1x5 input of fixed type, echoes back
|
// model takes 1x5 input of fixed type, echoes back
|
||||||
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_BFLOAT16.onnx");
|
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_BFLOAT16.onnx");
|
||||||
using (var session = new InferenceSession(model))
|
using (var session = new InferenceSession(model))
|
||||||
{
|
{
|
||||||
var container = new List<NamedOnnxValue>();
|
var container = new List<NamedOnnxValue>();
|
||||||
var tensorIn = new DenseTensor<BFloat16>(
|
var tensorIn = new DenseTensor<BFloat16>(modelInput, inputShape);
|
||||||
new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 });
|
|
||||||
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
|
||||||
container.Add(nov);
|
container.Add(nov);
|
||||||
using (var res = session.Run(container))
|
using (var res = session.Run(container))
|
||||||
|
@ -1999,80 +2002,6 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class FloatComparer : IEqualityComparer<float>
|
|
||||||
{
|
|
||||||
private float atol = 1e-3f;
|
|
||||||
private float rtol = 1.7e-2f;
|
|
||||||
|
|
||||||
public bool Equals(float x, float y)
|
|
||||||
{
|
|
||||||
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
|
|
||||||
}
|
|
||||||
public int GetHashCode(float x)
|
|
||||||
{
|
|
||||||
return x.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class DoubleComparer : IEqualityComparer<double>
|
|
||||||
{
|
|
||||||
private double atol = 1e-3;
|
|
||||||
private double rtol = 1.7e-2;
|
|
||||||
|
|
||||||
public bool Equals(double x, double y)
|
|
||||||
{
|
|
||||||
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
|
|
||||||
}
|
|
||||||
public int GetHashCode(double x)
|
|
||||||
{
|
|
||||||
return x.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class ExactComparer<T> : IEqualityComparer<T>
|
|
||||||
{
|
|
||||||
public bool Equals(T x, T y)
|
|
||||||
{
|
|
||||||
return x.Equals(y);
|
|
||||||
}
|
|
||||||
public int GetHashCode(T x)
|
|
||||||
{
|
|
||||||
return x.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Use it to compare Float16
|
|
||||||
/// </summary>
|
|
||||||
internal class Float16Comparer : IEqualityComparer<Float16>
|
|
||||||
{
|
|
||||||
public ushort tolerance = 0;
|
|
||||||
public bool Equals(Float16 x, Float16 y)
|
|
||||||
{
|
|
||||||
return Math.Abs(x.value - y.value) <= (tolerance + y);
|
|
||||||
}
|
|
||||||
public int GetHashCode(Float16 x)
|
|
||||||
{
|
|
||||||
return x.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Use it to compare Bloat16
|
|
||||||
/// </summary>
|
|
||||||
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
|
|
||||||
{
|
|
||||||
public ushort tolerance = 0;
|
|
||||||
public bool Equals(BFloat16 x, BFloat16 y)
|
|
||||||
{
|
|
||||||
return Math.Abs(x.value - y.value) <= (tolerance + y);
|
|
||||||
}
|
|
||||||
public int GetHashCode(BFloat16 x)
|
|
||||||
{
|
|
||||||
return x.GetHashCode();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private class GpuFact : FactAttribute
|
private class GpuFact : FactAttribute
|
||||||
{
|
{
|
||||||
public GpuFact()
|
public GpuFact()
|
||||||
|
|
|
@ -50,9 +50,10 @@
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Compile Remove="InferenceTest.cs" />
|
<Compile Remove="InferenceTest.cs" />
|
||||||
<Compile Remove="OrtEnvTests.cs" />
|
|
||||||
<Compile Remove="OrtIoBindingAllocationTest.cs" />
|
<Compile Remove="OrtIoBindingAllocationTest.cs" />
|
||||||
|
<Compile Remove="OrtEnvTests.cs" />
|
||||||
<Compile Remove="OrtValueTests.cs" />
|
<Compile Remove="OrtValueTests.cs" />
|
||||||
|
<Compile Remove="OrtFloat16Tests.cs" />
|
||||||
<Compile Remove="Tensors\TensorTests.cs" />
|
<Compile Remove="Tensors\TensorTests.cs" />
|
||||||
<Compile Remove="TrainingTest.cs" />
|
<Compile Remove="TrainingTest.cs" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
@ -80,10 +81,10 @@
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<!-- include common files for visibility, however they're compiled directly by the target specific test projects -->
|
<!-- include common files for visibility, however they're compiled directly by the target specific test projects -->
|
||||||
<None Include="InferenceTest.cs" />
|
<None Include="InferenceTest.cs" />
|
||||||
<None Include="OrtEnvTests.cs" />
|
|
||||||
<None Include="OnnxData.cs" />
|
<None Include="OnnxData.cs" />
|
||||||
<None Include="OrtIoBindingAllocationTest.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
<None Include="OrtIoBindingAllocationTest.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
||||||
<None Include="OrtValueTests.cs" />
|
<None Include="OrtValueTests.cs" />
|
||||||
|
<None Include="OrtFloat16Tests.cs" />
|
||||||
<None Include="Tensors\TensorTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
<None Include="Tensors\TensorTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
||||||
<None Include="Tensors\ArrayTensorExtensionTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
<None Include="Tensors\ArrayTensorExtensionTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
|
||||||
<None Include="TrainingTest.cs" />
|
<None Include="TrainingTest.cs" />
|
||||||
|
|
|
@ -0,0 +1,533 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
|
{
|
||||||
|
[Collection("Ort Float16 tests")]
|
||||||
|
public class OrtFloat16Tests
|
||||||
|
{
|
||||||
|
const float oneThird = 1 / 3.0f;
|
||||||
|
const float oneSeventh = 1 / 7.0f;
|
||||||
|
const float oneTenth = 1 / 10.0f;
|
||||||
|
|
||||||
|
[Fact(DisplayName = "ConvertFloatToFloat16")]
|
||||||
|
public void ConvertFloatToFloat16()
|
||||||
|
{
|
||||||
|
// Generate integer floats and insert between them
|
||||||
|
// fractions. This will test the rounding logic.
|
||||||
|
float start = -10;
|
||||||
|
|
||||||
|
var floatValues = new float[21 * 4];
|
||||||
|
for (int i = 0; i < floatValues.Length; i += 4)
|
||||||
|
{
|
||||||
|
floatValues[i] = start;
|
||||||
|
floatValues[i + 1] = start + oneThird;
|
||||||
|
floatValues[i + 2] = start + oneSeventh;
|
||||||
|
floatValues[i + 3] = start + oneTenth;
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
var f16Converted = Array.ConvertAll(floatValues, f => (Float16)f);
|
||||||
|
var backConverted = Array.ConvertAll(f16Converted, f16 => (float)f16);
|
||||||
|
Assert.Equal(floatValues, backConverted, new FloatComparer());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestZeros")]
|
||||||
|
public void TestZeros()
|
||||||
|
{
|
||||||
|
var positiveZero = new Float16(0);
|
||||||
|
Assert.False(Float16.IsNegative(positiveZero));
|
||||||
|
Assert.True(Float16.IsNaNOrZero(positiveZero));
|
||||||
|
|
||||||
|
float singlePositiveZero = (float)positiveZero;
|
||||||
|
Assert.Equal(+0.0f, singlePositiveZero);
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
Assert.False(float.IsNegative(singlePositiveZero));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
var negativeZero = Float16.Negate(positiveZero);
|
||||||
|
Assert.True(Float16.IsNegative(negativeZero));
|
||||||
|
Assert.True(Float16.IsNaNOrZero(negativeZero));
|
||||||
|
|
||||||
|
float singleNegativeZero = (float)negativeZero;
|
||||||
|
Assert.Equal(-0.0f, singleNegativeZero);
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
Assert.True(float.IsNegative(singleNegativeZero));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestComparisonOperators")]
|
||||||
|
public void TestComparisonOperators()
|
||||||
|
{
|
||||||
|
Float16 left = (Float16)(float)-33.33f;
|
||||||
|
Float16 leftSame = (Float16)(float)-33.33f;
|
||||||
|
Float16 right = (Float16)(float)66.66f;
|
||||||
|
Float16 rightSame = (Float16)(float)66.66f;
|
||||||
|
|
||||||
|
Assert.False(Float16.IsNaNOrZero(left));
|
||||||
|
Assert.False(Float16.IsNaNOrZero(right));
|
||||||
|
|
||||||
|
Assert.True(right > Float16.Epsilon);
|
||||||
|
|
||||||
|
Assert.True(left == leftSame);
|
||||||
|
Assert.False(left == Float16.Negate(leftSame));
|
||||||
|
|
||||||
|
Assert.True(right == rightSame);
|
||||||
|
Assert.False(right == Float16.Negate(rightSame));
|
||||||
|
|
||||||
|
Assert.True(left < right);
|
||||||
|
Assert.True(left > Float16.Negate(right));
|
||||||
|
Assert.True(Float16.Negate(left) < right);
|
||||||
|
|
||||||
|
Assert.True(left <= right);
|
||||||
|
Assert.True(left >= Float16.Negate(right));
|
||||||
|
Assert.False(left > right);
|
||||||
|
Assert.False(left >= right);
|
||||||
|
Assert.True(Float16.Negate(left) <= right);
|
||||||
|
Assert.False(left == right);
|
||||||
|
Assert.False(right == left);
|
||||||
|
Assert.True(left != right);
|
||||||
|
Assert.True(right != left);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNAN")]
|
||||||
|
public void TestNAN()
|
||||||
|
{
|
||||||
|
Float16 fp16NANFromSingle = (Float16)float.NaN;
|
||||||
|
Assert.True(Float16.IsNaN(fp16NANFromSingle));
|
||||||
|
Assert.Equal(Float16.NaN, fp16NANFromSingle);
|
||||||
|
Assert.True(Float16.IsNaNOrZero(fp16NANFromSingle));
|
||||||
|
|
||||||
|
float NanFromFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
Assert.True(float.IsNaN(NanFromFloat16));
|
||||||
|
|
||||||
|
// IEqualityComparable returns true, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.Equal(fp16NANFromSingle, Float16.NaN);
|
||||||
|
|
||||||
|
Assert.Equal(Float16.NaN, Float16.Negate(Float16.NaN));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNANComparision")]
|
||||||
|
public void TestNANComparisionOperators()
|
||||||
|
{
|
||||||
|
// NaN is not ordered with respect to anything
|
||||||
|
// including itself
|
||||||
|
|
||||||
|
// IEqualityComparable returns true, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.Equal(Float16.NaN, Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN < Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN > Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN <= Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN >= Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN == Float16.NaN);
|
||||||
|
|
||||||
|
// IEqualityComparable returns false, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.NotEqual(Float16.NaN, Float16.MaxValue);
|
||||||
|
|
||||||
|
Assert.False(Float16.NaN < Float16.MaxValue);
|
||||||
|
Assert.False(Float16.MaxValue < Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN == Float16.MaxValue);
|
||||||
|
Assert.False(Float16.MaxValue == Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN > Float16.MinValue);
|
||||||
|
Assert.False(Float16.MaxValue > Float16.NaN);
|
||||||
|
Assert.False(Float16.NaN == Float16.MinValue);
|
||||||
|
Assert.False(Float16.MaxValue == Float16.NaN);
|
||||||
|
Assert.True(Float16.MinValue < Float16.MaxValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestInfinity")]
|
||||||
|
public void TestInfinity()
|
||||||
|
{
|
||||||
|
Assert.False(Float16.IsInfinity(Float16.MinValue));
|
||||||
|
Assert.False(Float16.IsInfinity(Float16.MaxValue));
|
||||||
|
|
||||||
|
Float16 posInfinityFromSingle = (Float16)float.PositiveInfinity;
|
||||||
|
Assert.True(Float16.IsPositiveInfinity(posInfinityFromSingle));
|
||||||
|
Assert.Equal(Float16.PositiveInfinity, posInfinityFromSingle);
|
||||||
|
Assert.False(Float16.IsFinite(posInfinityFromSingle));
|
||||||
|
Assert.True(Float16.IsInfinity(posInfinityFromSingle));
|
||||||
|
Assert.True(Float16.IsPositiveInfinity(posInfinityFromSingle));
|
||||||
|
Assert.False(Float16.IsNegativeInfinity(posInfinityFromSingle));
|
||||||
|
|
||||||
|
Assert.False(Float16.IsPositiveInfinity(Float16.MinValue));
|
||||||
|
Assert.False(Float16.IsPositiveInfinity(Float16.MaxValue));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.Equal(float.PositiveInfinity < 0, Float16.IsNegative(posInfinityFromSingle));
|
||||||
|
|
||||||
|
Float16 negInfinityFromSingle = (Float16)float.NegativeInfinity;
|
||||||
|
Assert.True(Float16.IsNegativeInfinity(negInfinityFromSingle));
|
||||||
|
Assert.Equal(Float16.NegativeInfinity, negInfinityFromSingle);
|
||||||
|
Assert.False(Float16.IsFinite(negInfinityFromSingle));
|
||||||
|
Assert.True(Float16.IsInfinity(negInfinityFromSingle));
|
||||||
|
Assert.True(Float16.IsNegativeInfinity(negInfinityFromSingle));
|
||||||
|
Assert.False(Float16.IsPositiveInfinity(negInfinityFromSingle));
|
||||||
|
|
||||||
|
Assert.False(Float16.IsNegativeInfinity(Float16.MinValue));
|
||||||
|
Assert.False(Float16.IsNegativeInfinity(Float16.MaxValue));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.Equal(float.NegativeInfinity < 0, Float16.IsNegative(negInfinityFromSingle));
|
||||||
|
|
||||||
|
// Convert infinity to float and test the fact
|
||||||
|
float infFromFloat16 = (float)Float16.PositiveInfinity;
|
||||||
|
Assert.True(float.IsInfinity(infFromFloat16));
|
||||||
|
Assert.True(float.IsPositiveInfinity(infFromFloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNormalSubnormal")]
|
||||||
|
public void TestNormalSubnormal()
|
||||||
|
{
|
||||||
|
Float16 fp16FromSingleMaxValue = (Float16)float.MaxValue;
|
||||||
|
|
||||||
|
// Float MaxValue is outside Float16 range. This is different
|
||||||
|
// from BFloat16 that retains sufficient range.
|
||||||
|
Assert.True(Float16.IsInfinity(fp16FromSingleMaxValue));
|
||||||
|
Assert.False(Float16.IsNormal(fp16FromSingleMaxValue));
|
||||||
|
|
||||||
|
Assert.False(Float16.IsNormal(Float16.PositiveInfinity));
|
||||||
|
Assert.True(Float16.IsNormal((Float16)45.6f));
|
||||||
|
Assert.False(Float16.IsSubnormal((Float16)45.6f));
|
||||||
|
|
||||||
|
Assert.False(Float16.IsSubnormal(fp16FromSingleMaxValue));
|
||||||
|
Assert.False(Float16.IsSubnormal(Float16.PositiveInfinity));
|
||||||
|
|
||||||
|
// 0b0_00000_0000000001 => 5.9604645E-08
|
||||||
|
const ushort minSubnormalBits = 0x0001;
|
||||||
|
const float smallestF16Subnormal = 5.9604645E-08f;
|
||||||
|
Float16 smallestSubnormal = new Float16(minSubnormalBits);
|
||||||
|
Assert.True(Float16.IsSubnormal(smallestSubnormal));
|
||||||
|
Assert.False(Float16.IsNormal(smallestSubnormal));
|
||||||
|
|
||||||
|
// 0b0_00000_1111111111 => 6.09755516E-05
|
||||||
|
const float largestF16Subnormal = 6.09755516E-05f;
|
||||||
|
const ushort maxSubnormalBits = 0x03FF;
|
||||||
|
Float16 largestSubnormal = new Float16(maxSubnormalBits);
|
||||||
|
Assert.True(Float16.IsSubnormal(largestSubnormal));
|
||||||
|
Assert.False(Float16.IsNormal(largestSubnormal));
|
||||||
|
|
||||||
|
// Convert subnormal to float and see if we match
|
||||||
|
float convertedFromSmallestSubnormal = (float)smallestSubnormal;
|
||||||
|
Assert.Equal(smallestF16Subnormal, convertedFromSmallestSubnormal, 6);
|
||||||
|
|
||||||
|
float convertedFromLargestSubnormal = (float)largestSubnormal;
|
||||||
|
Assert.Equal(largestF16Subnormal, convertedFromLargestSubnormal, 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestEqual")]
|
||||||
|
public void TestEqual()
|
||||||
|
{
|
||||||
|
// Box it
|
||||||
|
object obj_1 = Float16.MaxValue;
|
||||||
|
object obj_2 = new Float16(Float16.MaxValue.value);
|
||||||
|
Assert.True(obj_1.Equals(obj_2));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.NotEqual(0, obj_1.GetHashCode());
|
||||||
|
Assert.Equal(obj_1.GetHashCode(), obj_2.GetHashCode());
|
||||||
|
Assert.True(Float16.NaN.Equals(Float16.NaN));
|
||||||
|
|
||||||
|
Float16 fp16Zero = (Float16)0.0f;
|
||||||
|
const ushort ushortZero = 0;
|
||||||
|
Float16 fp16FromUshortZero = (Float16)ushortZero;
|
||||||
|
|
||||||
|
Assert.True(fp16Zero.Equals(fp16FromUshortZero));
|
||||||
|
|
||||||
|
// Should have the same hash code constant
|
||||||
|
Assert.Equal(fp16Zero.GetHashCode(), fp16FromUshortZero.GetHashCode());
|
||||||
|
Assert.Equal(Float16.NaN.GetHashCode(), Float16.NaN.GetHashCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestCompare")]
|
||||||
|
public void TestCompare()
|
||||||
|
{
|
||||||
|
object objMaxValue = new Float16(Float16.MaxValue.value);
|
||||||
|
Assert.Equal(0, Float16.MaxValue.CompareTo(objMaxValue));
|
||||||
|
|
||||||
|
Float16 one = (Float16)1.0f;
|
||||||
|
Assert.Equal(-1, Float16.MinValue.CompareTo(one));
|
||||||
|
Assert.Equal(1, Float16.MaxValue.CompareTo(one));
|
||||||
|
|
||||||
|
// one is bigger than NaN
|
||||||
|
Assert.Equal(-1, Float16.NaN.CompareTo(one));
|
||||||
|
// Two NaNs are equal according to CompareTo()
|
||||||
|
Assert.Equal(0, Float16.NaN.CompareTo((Float16)float.NaN));
|
||||||
|
Assert.Equal(1, one.CompareTo(Float16.NaN));
|
||||||
|
|
||||||
|
// Compare to null
|
||||||
|
Assert.Equal(1, one.CompareTo(null));
|
||||||
|
|
||||||
|
// Make sure it throws
|
||||||
|
var obj = new object();
|
||||||
|
Assert.Throws<ArgumentException>(() => one.CompareTo(obj));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Collection("Ort BFloat16 tests")]
|
||||||
|
public class OrtBFloat16Tests
|
||||||
|
{
|
||||||
|
const float oneThird = 1 / 3.0f;
|
||||||
|
const float oneSeventh = 1 / 7.0f;
|
||||||
|
const float oneTenth = 1 / 10.0f;
|
||||||
|
|
||||||
|
[Fact(DisplayName = "ConvertFloatToBFloat16")]
|
||||||
|
public void ConvertFloatToBFloat16()
|
||||||
|
{
|
||||||
|
// Generate integer floats and insert between them
|
||||||
|
// fractions. This will test the rounding logic.
|
||||||
|
float start = -10;
|
||||||
|
|
||||||
|
var floatValues = new float[21 * 4];
|
||||||
|
for (int i = 0; i < floatValues.Length; i += 4)
|
||||||
|
{
|
||||||
|
floatValues[i] = start;
|
||||||
|
floatValues[i + 1] = start + oneThird;
|
||||||
|
floatValues[i + 2] = start + oneSeventh;
|
||||||
|
floatValues[i + 3] = start + oneTenth;
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
var f16Converted = Array.ConvertAll(floatValues, f => (BFloat16)f);
|
||||||
|
var backConverted = Array.ConvertAll(f16Converted, f16 => (float)f16);
|
||||||
|
Assert.Equal(floatValues, backConverted, new FloatComparer());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestZeros")]
|
||||||
|
public void TestZeros()
|
||||||
|
{
|
||||||
|
var positiveZero = new BFloat16(0);
|
||||||
|
Assert.False(BFloat16.IsNegative(positiveZero));
|
||||||
|
Assert.True(BFloat16.IsNaNOrZero(positiveZero));
|
||||||
|
float singlePositiveZero = (float)positiveZero;
|
||||||
|
Assert.Equal(+0.0f, singlePositiveZero);
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
Assert.False(float.IsNegative(singlePositiveZero));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
var negativeZero = BFloat16.Negate(positiveZero);
|
||||||
|
Assert.True(BFloat16.IsNegative(negativeZero));
|
||||||
|
Assert.True(BFloat16.IsNaNOrZero(negativeZero));
|
||||||
|
|
||||||
|
float singleNegativeZero = (float)negativeZero;
|
||||||
|
Assert.Equal(-0.0f, singleNegativeZero);
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
Assert.True(float.IsNegative(singleNegativeZero));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestComparisonOperators")]
|
||||||
|
public void TestComparisionOperators()
|
||||||
|
{
|
||||||
|
BFloat16 left = (BFloat16)(float)-33.33f;
|
||||||
|
BFloat16 leftSame = (BFloat16)(float)-33.33f;
|
||||||
|
BFloat16 right = (BFloat16)(float)66.66f;
|
||||||
|
BFloat16 rightSame = (BFloat16)(float)66.66f;
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsNaNOrZero(left));
|
||||||
|
Assert.False(BFloat16.IsNaNOrZero(right));
|
||||||
|
|
||||||
|
Assert.True(right > BFloat16.Epsilon);
|
||||||
|
|
||||||
|
Assert.True(left == leftSame);
|
||||||
|
Assert.False(left == BFloat16.Negate(leftSame));
|
||||||
|
|
||||||
|
Assert.True(right == rightSame);
|
||||||
|
Assert.False(right == BFloat16.Negate(rightSame));
|
||||||
|
|
||||||
|
Assert.True(left < right);
|
||||||
|
Assert.True(left > BFloat16.Negate(right));
|
||||||
|
Assert.True(BFloat16.Negate(left) < right);
|
||||||
|
|
||||||
|
Assert.True(left <= right);
|
||||||
|
Assert.True(left >= BFloat16.Negate(right));
|
||||||
|
Assert.False(left > right);
|
||||||
|
Assert.False(left >= right);
|
||||||
|
Assert.True(BFloat16.Negate(left) <= right);
|
||||||
|
Assert.False(left == right);
|
||||||
|
Assert.False(right == left);
|
||||||
|
Assert.True(left != right);
|
||||||
|
Assert.True(right != left);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNAN")]
|
||||||
|
public void TestNAN()
|
||||||
|
{
|
||||||
|
BFloat16 fp16NANFromSingle = (BFloat16)float.NaN;
|
||||||
|
Assert.True(BFloat16.IsNaN(fp16NANFromSingle));
|
||||||
|
Assert.Equal(BFloat16.NaN, fp16NANFromSingle);
|
||||||
|
Assert.True(BFloat16.IsNaNOrZero(fp16NANFromSingle));
|
||||||
|
|
||||||
|
float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
Assert.True(float.IsNaN(NanFromBFloat16));
|
||||||
|
|
||||||
|
// IEqualityComparable returns true, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.Equal(fp16NANFromSingle, BFloat16.NaN);
|
||||||
|
Assert.Equal(BFloat16.NaN, BFloat16.Negate(BFloat16.NaN));
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsNaN(BFloat16.MaxValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNANComparision")]
|
||||||
|
public void TestNANComparisionOperators()
|
||||||
|
{
|
||||||
|
// NaN is not ordered with respect to anything
|
||||||
|
// including itself
|
||||||
|
|
||||||
|
// IEqualityComparable returns true, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.Equal(BFloat16.NaN, BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN < BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN > BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN <= BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN >= BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN == BFloat16.NaN);
|
||||||
|
|
||||||
|
// IEqualityComparable returns false, because it tests
|
||||||
|
// objects, not numbers.
|
||||||
|
Assert.NotEqual(BFloat16.NaN, BFloat16.MaxValue);
|
||||||
|
|
||||||
|
Assert.False(BFloat16.NaN < BFloat16.MaxValue);
|
||||||
|
Assert.False(BFloat16.MaxValue < BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN == BFloat16.MaxValue);
|
||||||
|
Assert.False(BFloat16.MaxValue == BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN > BFloat16.MinValue);
|
||||||
|
Assert.False(BFloat16.MaxValue > BFloat16.NaN);
|
||||||
|
Assert.False(BFloat16.NaN == BFloat16.MinValue);
|
||||||
|
Assert.False(BFloat16.MaxValue == BFloat16.NaN);
|
||||||
|
Assert.True(BFloat16.MinValue < BFloat16.MaxValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestInfinity")]
|
||||||
|
public void TestInfinity()
|
||||||
|
{
|
||||||
|
Assert.False(BFloat16.IsInfinity(BFloat16.MinValue));
|
||||||
|
Assert.False(BFloat16.IsInfinity(BFloat16.MaxValue));
|
||||||
|
|
||||||
|
BFloat16 posInfinityFromSingle = (BFloat16)float.PositiveInfinity;
|
||||||
|
Assert.True(BFloat16.IsPositiveInfinity(posInfinityFromSingle));
|
||||||
|
Assert.Equal(BFloat16.PositiveInfinity, posInfinityFromSingle);
|
||||||
|
Assert.False(BFloat16.IsFinite(posInfinityFromSingle));
|
||||||
|
Assert.True(BFloat16.IsInfinity(posInfinityFromSingle));
|
||||||
|
Assert.True(BFloat16.IsPositiveInfinity(posInfinityFromSingle));
|
||||||
|
Assert.False(BFloat16.IsNegativeInfinity(posInfinityFromSingle));
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsPositiveInfinity(BFloat16.MinValue));
|
||||||
|
Assert.False(BFloat16.IsPositiveInfinity(BFloat16.MaxValue));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.Equal(float.PositiveInfinity < 0, BFloat16.IsNegative(posInfinityFromSingle));
|
||||||
|
|
||||||
|
BFloat16 negInfinityFromSingle = (BFloat16)float.NegativeInfinity;
|
||||||
|
Assert.True(BFloat16.IsNegativeInfinity(negInfinityFromSingle));
|
||||||
|
Assert.Equal(BFloat16.NegativeInfinity, negInfinityFromSingle);
|
||||||
|
Assert.False(BFloat16.IsFinite(negInfinityFromSingle));
|
||||||
|
Assert.True(BFloat16.IsInfinity(negInfinityFromSingle));
|
||||||
|
Assert.True(BFloat16.IsNegativeInfinity(negInfinityFromSingle));
|
||||||
|
Assert.False(BFloat16.IsPositiveInfinity(negInfinityFromSingle));
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsNegativeInfinity(BFloat16.MinValue));
|
||||||
|
Assert.False(BFloat16.IsNegativeInfinity(BFloat16.MaxValue));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.True(BFloat16.IsNegative(negInfinityFromSingle));
|
||||||
|
|
||||||
|
// Convert infinity to float and test the fact
|
||||||
|
float infFromBFloat16 = (float)BFloat16.PositiveInfinity;
|
||||||
|
Assert.True(float.IsInfinity(infFromBFloat16));
|
||||||
|
Assert.True(float.IsPositiveInfinity(infFromBFloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestNormalSubnormal")]
|
||||||
|
public void TestNormalSubnormal()
|
||||||
|
{
|
||||||
|
BFloat16 fp16FromSingleMaxValue = (BFloat16)float.MaxValue;
|
||||||
|
|
||||||
|
Assert.True(BFloat16.IsInfinity(fp16FromSingleMaxValue));
|
||||||
|
Assert.False(BFloat16.IsNormal(fp16FromSingleMaxValue));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsNormal(BFloat16.PositiveInfinity));
|
||||||
|
Assert.True(BFloat16.IsNormal((BFloat16)45.6f));
|
||||||
|
Assert.False(BFloat16.IsSubnormal((BFloat16)45.6f));
|
||||||
|
|
||||||
|
Assert.False(BFloat16.IsSubnormal(fp16FromSingleMaxValue));
|
||||||
|
Assert.False(BFloat16.IsSubnormal(BFloat16.PositiveInfinity));
|
||||||
|
|
||||||
|
// 0b0_0000_0000_000_0001
|
||||||
|
const ushort minSubnormalBits = 0x0001;
|
||||||
|
BFloat16 smallestSubnormal = new BFloat16(minSubnormalBits);
|
||||||
|
Assert.True(BFloat16.IsSubnormal(smallestSubnormal));
|
||||||
|
Assert.False(BFloat16.IsNormal(smallestSubnormal));
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
float singleSmallestSubnormal = (float)smallestSubnormal;
|
||||||
|
Assert.True(float.IsSubnormal(singleSmallestSubnormal));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const ushort maxSubnormalBits = 0x007F; // 0b0_0000_0000_111_1111;
|
||||||
|
BFloat16 largestSubnormal = new BFloat16(maxSubnormalBits);
|
||||||
|
Assert.True(BFloat16.IsSubnormal(largestSubnormal));
|
||||||
|
Assert.False(BFloat16.IsNormal(largestSubnormal));
|
||||||
|
#if NET6_0_OR_GREATER
|
||||||
|
float singleLargestSubnornal = (float)largestSubnormal;
|
||||||
|
Assert.True(float.IsSubnormal(singleLargestSubnornal));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestEqual")]
|
||||||
|
public void TestEqual()
|
||||||
|
{
|
||||||
|
// Box it
|
||||||
|
object obj_1 = BFloat16.MaxValue;
|
||||||
|
object obj_2 = new BFloat16(BFloat16.MaxValue.value);
|
||||||
|
Assert.True(obj_1.Equals(obj_2));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.NotEqual(0, obj_1.GetHashCode());
|
||||||
|
Assert.Equal(obj_1.GetHashCode(), obj_2.GetHashCode());
|
||||||
|
Assert.True(BFloat16.NaN.Equals(BFloat16.NaN));
|
||||||
|
|
||||||
|
BFloat16 fp16Zero = (BFloat16)0.0f;
|
||||||
|
const ushort ushortZero = 0;
|
||||||
|
BFloat16 fp16FromUshortZero = (BFloat16)ushortZero;
|
||||||
|
|
||||||
|
Assert.True(fp16Zero.Equals(fp16FromUshortZero));
|
||||||
|
|
||||||
|
// Should have the same hash code constant
|
||||||
|
Assert.Equal(fp16Zero.GetHashCode(), fp16FromUshortZero.GetHashCode());
|
||||||
|
Assert.Equal(BFloat16.NaN.GetHashCode(), BFloat16.NaN.GetHashCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact(DisplayName = "TestCompare")]
|
||||||
|
public void TestCompare()
|
||||||
|
{
|
||||||
|
object objMaxValue = new BFloat16(BFloat16.MaxValue.value);
|
||||||
|
Assert.Equal(0, BFloat16.MaxValue.CompareTo(objMaxValue));
|
||||||
|
|
||||||
|
BFloat16 one = (BFloat16)1.0f;
|
||||||
|
Assert.Equal(-1, BFloat16.MinValue.CompareTo(one));
|
||||||
|
Assert.Equal(1, BFloat16.MaxValue.CompareTo(one));
|
||||||
|
|
||||||
|
// one is bigger than NaN
|
||||||
|
Assert.Equal(-1, BFloat16.NaN.CompareTo(one));
|
||||||
|
// Two NaNs are equal according to CompareTo()
|
||||||
|
Assert.Equal(0, BFloat16.NaN.CompareTo((BFloat16)float.NaN));
|
||||||
|
Assert.Equal(1, one.CompareTo(BFloat16.NaN));
|
||||||
|
|
||||||
|
// Compare to null
|
||||||
|
Assert.Equal(1, one.CompareTo(null));
|
||||||
|
|
||||||
|
// Make sure it throws
|
||||||
|
var obj = new object();
|
||||||
|
Assert.Throws<ArgumentException>(() => one.CompareTo(obj));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -242,13 +242,13 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
int[] int_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
int[] int_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||||
ushort[] ushort_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
ushort[] ushort_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||||
double[] dbl_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
double[] dbl_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||||
Float16[] fl16_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
|
var fp16_data = Array.ConvertAll(ushort_data, sh => new Float16(sh));
|
||||||
|
|
||||||
PopulateAndCheck(float_data);
|
PopulateAndCheck(float_data);
|
||||||
PopulateAndCheck(int_data);
|
PopulateAndCheck(int_data);
|
||||||
PopulateAndCheck(ushort_data);
|
PopulateAndCheck(ushort_data);
|
||||||
PopulateAndCheck(dbl_data);
|
PopulateAndCheck(dbl_data);
|
||||||
PopulateAndCheck(fl16_data);
|
PopulateAndCheck(fp16_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static readonly long[] ml_data_1 = { 1, 2 };
|
private static readonly long[] ml_data_1 = { 1, 2 };
|
||||||
|
|
|
@ -619,4 +619,78 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
||||||
return tensorData.ToArray();
|
return tensorData.ToArray();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal class FloatComparer : IEqualityComparer<float>
|
||||||
|
{
|
||||||
|
private float atol = 1e-3f;
|
||||||
|
private float rtol = 1.7e-2f;
|
||||||
|
|
||||||
|
public bool Equals(float x, float y)
|
||||||
|
{
|
||||||
|
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
|
||||||
|
}
|
||||||
|
public int GetHashCode(float x)
|
||||||
|
{
|
||||||
|
return x.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class DoubleComparer : IEqualityComparer<double>
|
||||||
|
{
|
||||||
|
private double atol = 1e-3;
|
||||||
|
private double rtol = 1.7e-2;
|
||||||
|
|
||||||
|
public bool Equals(double x, double y)
|
||||||
|
{
|
||||||
|
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
|
||||||
|
}
|
||||||
|
public int GetHashCode(double x)
|
||||||
|
{
|
||||||
|
return x.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ExactComparer<T> : IEqualityComparer<T>
|
||||||
|
{
|
||||||
|
public bool Equals(T x, T y)
|
||||||
|
{
|
||||||
|
return x.Equals(y);
|
||||||
|
}
|
||||||
|
public int GetHashCode(T x)
|
||||||
|
{
|
||||||
|
return x.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Use it to compare Float16
|
||||||
|
/// </summary>
|
||||||
|
internal class Float16Comparer : IEqualityComparer<Float16>
|
||||||
|
{
|
||||||
|
public ushort tolerance = 0;
|
||||||
|
public bool Equals(Float16 x, Float16 y)
|
||||||
|
{
|
||||||
|
return Math.Abs(x.value - y.value) <= (tolerance + y.value);
|
||||||
|
}
|
||||||
|
public int GetHashCode(Float16 x)
|
||||||
|
{
|
||||||
|
return x.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Use it to compare Bloat16
|
||||||
|
/// </summary>
|
||||||
|
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
|
||||||
|
{
|
||||||
|
public ushort tolerance = 0;
|
||||||
|
public bool Equals(BFloat16 x, BFloat16 y)
|
||||||
|
{
|
||||||
|
return Math.Abs(x.value - y.value) <= (tolerance + y.value);
|
||||||
|
}
|
||||||
|
public int GetHashCode(BFloat16 x)
|
||||||
|
{
|
||||||
|
return x.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -119,6 +119,9 @@
|
||||||
|
|
||||||
<!-- NOTE: The xUnit framework doesn't pickup the tests defined within the referenced Microsoft.ML.OnnxRuntime.Tests.Common project -->
|
<!-- NOTE: The xUnit framework doesn't pickup the tests defined within the referenced Microsoft.ML.OnnxRuntime.Tests.Common project -->
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtFloat16Tests.cs">
|
||||||
|
<Link>OrtFloat16Tests.cs</Link>
|
||||||
|
</Compile>
|
||||||
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtEnvTests.cs">
|
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtEnvTests.cs">
|
||||||
<Link>OrtEnvTests.cs</Link>
|
<Link>OrtEnvTests.cs</Link>
|
||||||
</Compile>
|
</Compile>
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
#include "endian.h"
|
#include "endian.h"
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||||
#include "cuda_bf16.h"
|
#include "cuda_bf16.h"
|
||||||
|
@ -13,6 +15,8 @@
|
||||||
|
|
||||||
#include "core/common/common.h"
|
#include "core/common/common.h"
|
||||||
|
|
||||||
|
#include "core/session/onnxruntime_float16.h"
|
||||||
|
|
||||||
namespace onnxruntime {
|
namespace onnxruntime {
|
||||||
|
|
||||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||||
|
@ -22,25 +26,69 @@ namespace onnxruntime {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// MLFloat16
|
// MLFloat16
|
||||||
struct MLFloat16 {
|
struct MLFloat16 : onnxruntime_float16::Float16Impl<MLFloat16> {
|
||||||
uint16_t val{0};
|
private:
|
||||||
|
explicit constexpr MLFloat16(uint16_t x) noexcept { val = x; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
using Base = onnxruntime_float16::Float16Impl<MLFloat16>;
|
||||||
|
|
||||||
MLFloat16() = default;
|
MLFloat16() = default;
|
||||||
explicit constexpr MLFloat16(uint16_t x) : val(x) {}
|
|
||||||
explicit MLFloat16(float f);
|
|
||||||
|
|
||||||
float ToFloat() const;
|
constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); }
|
||||||
|
|
||||||
operator float() const { return ToFloat(); }
|
// Using inherited implementation instead of math floatToHalf allows us to use this
|
||||||
|
// in other shared providers without having to implement the bridge
|
||||||
|
explicit MLFloat16(float v) noexcept { val = Base::ToUint16Impl(v); }
|
||||||
|
|
||||||
|
static const MLFloat16 NaN;
|
||||||
|
static const MLFloat16 NegativeNaN;
|
||||||
|
static const MLFloat16 Infinity;
|
||||||
|
static const MLFloat16 NegativeInfinity;
|
||||||
|
static const MLFloat16 Epsilon;
|
||||||
|
static const MLFloat16 MinValue;
|
||||||
|
static const MLFloat16 MaxValue;
|
||||||
|
static const MLFloat16 Zero;
|
||||||
|
static const MLFloat16 One;
|
||||||
|
static const MLFloat16 MinusOne;
|
||||||
|
|
||||||
|
// Using inherited implementation instead of math halfToFloat allows us to use this
|
||||||
|
// in other shared providers without having to implement the bridge
|
||||||
|
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
|
||||||
|
|
||||||
|
using Base::IsNegative;
|
||||||
|
|
||||||
|
using Base::IsNaN;
|
||||||
|
|
||||||
|
using Base::IsFinite;
|
||||||
|
|
||||||
|
using Base::IsPositiveInfinity;
|
||||||
|
|
||||||
|
using Base::IsNegativeInfinity;
|
||||||
|
|
||||||
|
using Base::IsInfinity;
|
||||||
|
|
||||||
|
using Base::IsNaNOrZero;
|
||||||
|
|
||||||
|
using Base::IsNormal;
|
||||||
|
|
||||||
|
using Base::IsSubnormal;
|
||||||
|
|
||||||
|
using Base::Abs;
|
||||||
|
|
||||||
|
using Base::Negate;
|
||||||
|
|
||||||
|
operator float() const noexcept { return ToFloat(); }
|
||||||
|
|
||||||
|
using Base::operator==;
|
||||||
|
using Base::operator!=;
|
||||||
|
using Base::operator<;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
|
|
||||||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
|
|
||||||
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }
|
|
||||||
|
|
||||||
// BFloat16
|
// BFloat16
|
||||||
struct BFloat16 {
|
struct BFloat16 : onnxruntime_float16::BFloat16Impl<BFloat16> {
|
||||||
uint16_t val{0};
|
using Base = onnxruntime_float16::BFloat16Impl<BFloat16>;
|
||||||
|
|
||||||
#if defined(__HIP__)
|
#if defined(__HIP__)
|
||||||
ORT_HOST_DEVICE BFloat16() = default;
|
ORT_HOST_DEVICE BFloat16() = default;
|
||||||
#else
|
#else
|
||||||
|
@ -48,10 +96,14 @@ struct BFloat16 {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct FromBitsT {};
|
struct FromBitsT {};
|
||||||
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
|
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); }
|
||||||
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits) {}
|
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) noexcept { val = bits; }
|
||||||
|
|
||||||
inline ORT_HOST_DEVICE BFloat16(float v) {
|
static constexpr ORT_HOST_DEVICE BFloat16 FromBits(uint16_t bits) noexcept {
|
||||||
|
return BFloat16(bits, FromBits());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ORT_HOST_DEVICE BFloat16(float v) noexcept {
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
val = __bfloat16_as_ushort(__float2bfloat16(v));
|
val = __bfloat16_as_ushort(__float2bfloat16(v));
|
||||||
#elif defined(__HIP__)
|
#elif defined(__HIP__)
|
||||||
|
@ -69,15 +121,34 @@ struct BFloat16 {
|
||||||
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if constexpr (endian::native == endian::little) {
|
|
||||||
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
|
// Use C isnan to work both in host and device
|
||||||
|
if (::isnan(v)) {
|
||||||
|
val = kPositiveQNaNBits;
|
||||||
} else {
|
} else {
|
||||||
std::memcpy(&val, &v, sizeof(uint16_t));
|
auto get_msb_half = [](float fl) {
|
||||||
|
uint16_t result;
|
||||||
|
if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) {
|
||||||
|
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
|
||||||
|
} else {
|
||||||
|
std::memcpy(&result, &fl, sizeof(uint16_t));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint16_t upper_bits = get_msb_half(v);
|
||||||
|
union {
|
||||||
|
uint32_t U32;
|
||||||
|
float F32;
|
||||||
|
};
|
||||||
|
F32 = v;
|
||||||
|
U32 += (upper_bits & 1) + kRoundToNearest;
|
||||||
|
val = get_msb_half(F32);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ORT_HOST_DEVICE float ToFloat() const {
|
inline ORT_HOST_DEVICE float ToFloat() const noexcept {
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||||
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
|
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
|
||||||
#elif defined(__HIP__)
|
#elif defined(__HIP__)
|
||||||
|
@ -89,55 +160,129 @@ struct BFloat16 {
|
||||||
result = *tempRes;
|
result = *tempRes;
|
||||||
return result;
|
return result;
|
||||||
#else
|
#else
|
||||||
float result;
|
|
||||||
|
if (IsNaNHostDevice()) {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
float result = 0;
|
||||||
char* const first = reinterpret_cast<char*>(&result);
|
char* const first = reinterpret_cast<char*>(&result);
|
||||||
char* const second = first + sizeof(uint16_t);
|
|
||||||
if constexpr (endian::native == endian::little) {
|
if constexpr (endian::native == endian::little) {
|
||||||
std::memset(first, 0, sizeof(uint16_t));
|
char* const second = first + sizeof(uint16_t);
|
||||||
std::memcpy(second, &val, sizeof(uint16_t));
|
std::memcpy(second, &val, sizeof(uint16_t));
|
||||||
} else {
|
} else {
|
||||||
std::memcpy(first, &val, sizeof(uint16_t));
|
std::memcpy(first, &val, sizeof(uint16_t));
|
||||||
std::memset(second, 0, sizeof(uint16_t));
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
|
static const BFloat16 NaN;
|
||||||
|
static const BFloat16 NegativeNaN;
|
||||||
|
static const BFloat16 Infinity;
|
||||||
|
static const BFloat16 NegativeInfinity;
|
||||||
|
static const BFloat16 Epsilon;
|
||||||
|
static const BFloat16 MinValue;
|
||||||
|
static const BFloat16 MaxValue;
|
||||||
|
static const BFloat16 Zero;
|
||||||
|
static const BFloat16 One;
|
||||||
|
static const BFloat16 MinusOne;
|
||||||
|
|
||||||
|
using Base::IsNegative;
|
||||||
|
|
||||||
|
using Base::IsNaN;
|
||||||
|
|
||||||
|
using Base::IsFinite;
|
||||||
|
|
||||||
|
using Base::IsPositiveInfinity;
|
||||||
|
|
||||||
|
using Base::IsNegativeInfinity;
|
||||||
|
|
||||||
|
using Base::IsInfinity;
|
||||||
|
|
||||||
|
using Base::IsNaNOrZero;
|
||||||
|
|
||||||
|
using Base::IsNormal;
|
||||||
|
|
||||||
|
using Base::IsSubnormal;
|
||||||
|
|
||||||
|
using Base::Abs;
|
||||||
|
|
||||||
|
using Base::Negate;
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE operator float() const noexcept { return ToFloat(); }
|
||||||
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||||
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
|
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
|
||||||
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
|
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
|
||||||
#endif
|
#endif
|
||||||
};
|
|
||||||
|
|
||||||
inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& right) { return left.val == right.val; }
|
ORT_HOST_DEVICE bool operator==(const BFloat16& rhs) const noexcept {
|
||||||
inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
|
if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
|
||||||
inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
|
// IEEE defines that NaN is not equal to anything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return val == rhs.val;
|
||||||
|
}
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE bool operator!=(const BFloat16& rhs) const noexcept {
|
||||||
|
return !(*this == rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE bool operator<(const BFloat16& rhs) const noexcept {
|
||||||
|
if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
|
||||||
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool left_is_negative = IsNegativeHostDevice();
|
||||||
|
if (left_is_negative != rhs.IsNegativeHostDevice()) {
|
||||||
|
// When the signs of left and right differ, we know that left is less than right if it is
|
||||||
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
||||||
|
// says they should be equal, even if the signs differ.
|
||||||
|
return left_is_negative && !AreZeroHostDevice(*this, rhs);
|
||||||
|
}
|
||||||
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
||||||
|
}
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE bool IsNegativeHostDevice() const noexcept {
|
||||||
|
return (val & kSignMask) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE bool IsNaNHostDevice() const noexcept {
|
||||||
|
return static_cast<uint16_t>(val & ~kSignMask) > kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
ORT_HOST_DEVICE static bool AreZeroHostDevice(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
|
||||||
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// User defined suffixes to make it easier to declare
|
// User defined suffixes to make it easier to declare
|
||||||
// initializers with MLFloat16 and BFloat16 from unsigned short
|
// initializers with MLFloat16 and BFloat16 from unsigned short
|
||||||
// E.g 10_f16 or 10_b16
|
// E.g 10_f16 or 10_b16
|
||||||
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
||||||
inline MLFloat16 operator"" _f16(unsigned long long int v) {
|
inline MLFloat16 operator"" _f16(unsigned long long int v) noexcept {
|
||||||
return MLFloat16(narrow<uint16_t>(v));
|
return MLFloat16::FromBits(narrow<uint16_t>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline MLFloat16 operator"" _fp16(long double v) {
|
inline MLFloat16 operator"" _fp16(long double v) noexcept {
|
||||||
return MLFloat16(static_cast<float>(v));
|
return MLFloat16(static_cast<float>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BFloat16 operator"" _b16(unsigned long long int v) {
|
inline BFloat16 operator"" _b16(unsigned long long int v) noexcept {
|
||||||
return BFloat16(narrow<uint16_t>(v), BFloat16::FromBits());
|
return BFloat16::FromBits((narrow<uint16_t>(v)));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BFloat16 operator"" _bfp16(long double v) {
|
inline BFloat16 operator"" _bfp16(long double v) noexcept {
|
||||||
return BFloat16(static_cast<float>(v));
|
return BFloat16(static_cast<float>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
|
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) noexcept {
|
||||||
auto src = blf;
|
auto src = blf;
|
||||||
auto d = flt;
|
auto d = flt;
|
||||||
for (; size != 0; ++src, ++d, --size) {
|
for (; size != 0; ++src, ++d, --size) {
|
||||||
|
@ -149,7 +294,7 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
|
||||||
auto src = flt;
|
auto src = flt;
|
||||||
auto d = blf;
|
auto d = blf;
|
||||||
for (; size != 0; ++src, ++d, --size) {
|
for (; size != 0; ++src, ++d, --size) {
|
||||||
new (d) BFloat16(*src);
|
*d = BFloat16(*src);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "onnxruntime_c_api.h"
|
#include "onnxruntime_c_api.h"
|
||||||
|
#include "onnxruntime_float16.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
@ -142,70 +144,286 @@ std::string GetBuildInfoString();
|
||||||
std::vector<std::string> GetAvailableProviders();
|
std::vector<std::string> GetAvailableProviders();
|
||||||
|
|
||||||
/** \brief IEEE 754 half-precision floating point data type
|
/** \brief IEEE 754 half-precision floating point data type
|
||||||
* \details It is necessary for type dispatching to make use of C++ API
|
*
|
||||||
* The type is implicitly convertible to/from uint16_t.
|
* \details This struct is used for converting float to float16 and back
|
||||||
|
* so the user could feed inputs and fetch outputs using these type.
|
||||||
|
*
|
||||||
* The size of the structure should align with uint16_t and one can freely cast
|
* The size of the structure should align with uint16_t and one can freely cast
|
||||||
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
|
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
|
||||||
*
|
*
|
||||||
* Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
|
|
||||||
* on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
|
|
||||||
* And you can also feed a array of uint16_t elements directly. For example,
|
|
||||||
*
|
|
||||||
* \code{.unparsed}
|
* \code{.unparsed}
|
||||||
* uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
|
* // This example demonstrates converion from float to float16
|
||||||
* constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
|
* constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
|
||||||
* std::vector<int64_t> dims = {values_length}; // one dimensional example
|
* std::vector<Ort::Float16_t> fp16_values;
|
||||||
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
* fp16_values.reserve(std::size(values));
|
||||||
* // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
|
* std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
|
||||||
* auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
|
* [](float value) { return Ort::Float16_t(value); });
|
||||||
* dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
*
|
||||||
* \endcode
|
* \endcode
|
||||||
*
|
|
||||||
* Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
|
|
||||||
* a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
|
|
||||||
* template specialization.
|
|
||||||
*
|
|
||||||
* \code{.unparsed}
|
|
||||||
* namespace yours { struct half {}; } // assume this is your type, define this:
|
|
||||||
* namespace Ort {
|
|
||||||
* template<>
|
|
||||||
* struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
|
||||||
* } //namespace Ort
|
|
||||||
*
|
|
||||||
* std::vector<yours::half> values;
|
|
||||||
* std::vector<int64_t> dims = {values.size()}; // one dimensional example
|
|
||||||
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
|
||||||
* // Here we are passing element count -> values.size()
|
|
||||||
* auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
|
|
||||||
*
|
|
||||||
* \endcode
|
|
||||||
*/
|
*/
|
||||||
struct Float16_t {
|
struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
|
||||||
uint16_t value;
|
private:
|
||||||
constexpr Float16_t() noexcept : value(0) {}
|
/// <summary>
|
||||||
constexpr Float16_t(uint16_t v) noexcept : value(v) {}
|
/// Constructor from a 16-bit representation of a float16 value
|
||||||
constexpr operator uint16_t() const noexcept { return value; }
|
/// No conversion is done here.
|
||||||
constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
|
/// </summary>
|
||||||
constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
|
/// <param name="v">16-bit representation</param>
|
||||||
|
constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
using Base = onnxruntime_float16::Float16Impl<Float16_t>;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Default constructor
|
||||||
|
/// </summary>
|
||||||
|
Float16_t() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Explicit conversion to uint16_t representation of bfloat16.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v">uint16_t bit representation of bfloat16</param>
|
||||||
|
/// <returns>new instance of Float16_t</returns>
|
||||||
|
constexpr static Float16_t FromBits(uint16_t x) noexcept { return Float16_t(x); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// __ctor from float. Float is converted into float16 16-bit representation.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v">float value</param>
|
||||||
|
explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts bfloat16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of bfloat16 value</returns>
|
||||||
|
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
using Base::IsNegative;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
using Base::IsNaN;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
using Base::IsFinite;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
using Base::IsPositiveInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
using Base::IsNegativeInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
using Base::IsInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
using Base::IsNaNOrZero;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
using Base::IsNormal;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
using Base::IsSubnormal;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
using Base::Abs;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
using Base::Negate;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
using Base::AreZero;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// User defined conversion operator. Converts Float16_t to float.
|
||||||
|
/// </summary>
|
||||||
|
explicit operator float() const noexcept { return ToFloat(); }
|
||||||
|
|
||||||
|
using Base::operator==;
|
||||||
|
using Base::operator!=;
|
||||||
|
using Base::operator<;
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
|
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
|
||||||
|
|
||||||
/** \brief bfloat16 (Brain Floating Point) data type
|
/** \brief bfloat16 (Brain Floating Point) data type
|
||||||
* \details It is necessary for type dispatching to make use of C++ API
|
*
|
||||||
* The type is implicitly convertible to/from uint16_t.
|
* \details This struct is used for converting float to bfloat16 and back
|
||||||
|
* so the user could feed inputs and fetch outputs using these type.
|
||||||
|
*
|
||||||
* The size of the structure should align with uint16_t and one can freely cast
|
* The size of the structure should align with uint16_t and one can freely cast
|
||||||
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
|
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
|
||||||
*
|
*
|
||||||
* See also code examples for Float16_t above.
|
* \code{.unparsed}
|
||||||
|
* // This example demonstrates converion from float to float16
|
||||||
|
* constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
|
||||||
|
* std::vector<Ort::BFloat16_t> bfp16_values;
|
||||||
|
* bfp16_values.reserve(std::size(values));
|
||||||
|
* std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
|
||||||
|
* [](float value) { return Ort::BFloat16_t(value); });
|
||||||
|
*
|
||||||
|
* \endcode
|
||||||
*/
|
*/
|
||||||
struct BFloat16_t {
|
struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
|
||||||
uint16_t value;
|
private:
|
||||||
constexpr BFloat16_t() noexcept : value(0) {}
|
/// <summary>
|
||||||
constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
|
/// Constructor from a uint16_t representation of bfloat16
|
||||||
constexpr operator uint16_t() const noexcept { return value; }
|
/// used in FromBits() to escape overload resolution issue with
|
||||||
constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
|
/// constructor from float.
|
||||||
constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
|
/// No conversion is done.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v">16-bit bfloat16 value</param>
|
||||||
|
constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
|
||||||
|
|
||||||
|
BFloat16_t() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Explicit conversion to uint16_t representation of bfloat16.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v">uint16_t bit representation of bfloat16</param>
|
||||||
|
/// <returns>new instance of BFloat16_t</returns>
|
||||||
|
static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// __ctor from float. Float is converted into bfloat16 16-bit representation.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v">float value</param>
|
||||||
|
explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts bfloat16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of bfloat16 value</returns>
|
||||||
|
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
using Base::IsNegative;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
using Base::IsNaN;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
using Base::IsFinite;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
using Base::IsPositiveInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
using Base::IsNegativeInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
using Base::IsInfinity;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
using Base::IsNaNOrZero;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
using Base::IsNormal;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
using Base::IsSubnormal;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
using Base::Abs;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
using Base::Negate;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
using Base::AreZero;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// User defined conversion operator. Converts BFloat16_t to float.
|
||||||
|
/// </summary>
|
||||||
|
explicit operator float() const noexcept { return ToFloat(); }
|
||||||
|
|
||||||
|
// We do not have an inherited impl for the below operators
|
||||||
|
// as the internal class implements them a little differently
|
||||||
|
bool operator==(const BFloat16_t& rhs) const noexcept;
|
||||||
|
bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
|
||||||
|
bool operator<(const BFloat16_t& rhs) const noexcept;
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
|
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
||||||
// the main C++ file with implementation details.
|
// the main C++ file with implementation details.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
namespace Ort {
|
namespace Ort {
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
@ -131,6 +133,30 @@ struct TypeToTensorType<Float8E5M2FNUZ_t> {
|
||||||
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
|
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is not equal to anything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return val == rhs.val;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool left_is_negative = IsNegative();
|
||||||
|
if (left_is_negative != rhs.IsNegative()) {
|
||||||
|
// When the signs of left and right differ, we know that left is less than right if it is
|
||||||
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
||||||
|
// says they should be equal, even if the signs differ.
|
||||||
|
return left_is_negative && !AreZero(*this, rhs);
|
||||||
|
}
|
||||||
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
||||||
|
}
|
||||||
|
|
||||||
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
||||||
: allocator_(allocator), p_(p), size_(size) {
|
: allocator_(allocator), p_(p), size_(size) {
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,531 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace onnxruntime_float16 {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
enum class endian {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
little = 0,
|
||||||
|
big = 1,
|
||||||
|
native = little,
|
||||||
|
#elif defined(__GNUC__) || defined(__clang__)
|
||||||
|
little = __ORDER_LITTLE_ENDIAN__,
|
||||||
|
big = __ORDER_BIG_ENDIAN__,
|
||||||
|
native = __BYTE_ORDER__,
|
||||||
|
#else
|
||||||
|
#error onnxruntime_float16::detail::endian is not implemented in this environment.
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
endian::native == endian::little || endian::native == endian::big,
|
||||||
|
"Only little-endian or big-endian native byte orders are supported.");
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||||
|
/// </summary>
|
||||||
|
template <class Derived>
|
||||||
|
struct Float16Impl {
|
||||||
|
protected:
|
||||||
|
/// <summary>
|
||||||
|
/// Converts from float to uint16_t float16 representation
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
constexpr static uint16_t ToUint16Impl(float v) noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts float16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of float16 value</returns>
|
||||||
|
float ToFloatImpl() const noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
uint16_t AbsImpl() const noexcept {
|
||||||
|
return static_cast<uint16_t>(val & ~kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
uint16_t NegateImpl() const noexcept {
|
||||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// uint16_t special values
|
||||||
|
static constexpr uint16_t kSignMask = 0x8000U;
|
||||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
|
||||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
|
||||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
|
||||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
|
||||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
|
||||||
|
static constexpr uint16_t kEpsilonBits = 0x4170U;
|
||||||
|
static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
|
||||||
|
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
|
||||||
|
static constexpr uint16_t kOneBits = 0x3C00U;
|
||||||
|
static constexpr uint16_t kMinusOneBits = 0xBC00U;
|
||||||
|
|
||||||
|
uint16_t val{0};
|
||||||
|
|
||||||
|
Float16Impl() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
bool IsNegative() const noexcept {
|
||||||
|
return static_cast<int16_t>(val) < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
bool IsNaN() const noexcept {
|
||||||
|
return AbsImpl() > kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
bool IsFinite() const noexcept {
|
||||||
|
return AbsImpl() < kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
bool IsPositiveInfinity() const noexcept {
|
||||||
|
return val == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
bool IsNegativeInfinity() const noexcept {
|
||||||
|
return val == kNegativeInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
bool IsInfinity() const noexcept {
|
||||||
|
return AbsImpl() == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
bool IsNaNOrZero() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs == 0 || abs > kPositiveInfinityBits);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsNormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsSubnormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
|
||||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const Float16Impl& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is not equal to anything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return val == rhs.val;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
|
||||||
|
|
||||||
|
bool operator<(const Float16Impl& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool left_is_negative = IsNegative();
|
||||||
|
if (left_is_negative != rhs.IsNegative()) {
|
||||||
|
// When the signs of left and right differ, we know that left is less than right if it is
|
||||||
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
||||||
|
// says they should be equal, even if the signs differ.
|
||||||
|
return left_is_negative && !AreZero(*this, rhs);
|
||||||
|
}
|
||||||
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// The following Float16_t conversions are based on the code from
|
||||||
|
// Eigen library.
|
||||||
|
|
||||||
|
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
|
||||||
|
// The original license follows:
|
||||||
|
//
|
||||||
|
// Copyright (c) Fabian Giesen, 2016
|
||||||
|
// All rights reserved.
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted.
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
union float32_bits {
|
||||||
|
unsigned int u;
|
||||||
|
float f;
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
||||||
|
detail::float32_bits f;
|
||||||
|
f.f = v;
|
||||||
|
|
||||||
|
constexpr detail::float32_bits f32infty = {255 << 23};
|
||||||
|
constexpr detail::float32_bits f16max = {(127 + 16) << 23};
|
||||||
|
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
|
||||||
|
constexpr unsigned int sign_mask = 0x80000000u;
|
||||||
|
uint16_t val = static_cast<uint16_t>(0x0u);
|
||||||
|
|
||||||
|
unsigned int sign = f.u & sign_mask;
|
||||||
|
f.u ^= sign;
|
||||||
|
|
||||||
|
// NOTE all the integer compares in this function can be safely
|
||||||
|
// compiled into signed compares since all operands are below
|
||||||
|
// 0x80000000. Important if you want fast straight SSE2 code
|
||||||
|
// (since there's no unsigned PCMPGTD).
|
||||||
|
|
||||||
|
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
|
||||||
|
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
||||||
|
} else { // (De)normalized number or zero
|
||||||
|
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
|
||||||
|
// use a magic value to align our 10 mantissa bits at the bottom of
|
||||||
|
// the float. as long as FP addition is round-to-nearest-even this
|
||||||
|
// just works.
|
||||||
|
f.f += denorm_magic.f;
|
||||||
|
|
||||||
|
// and one integer subtract of the bias later, we have our final float!
|
||||||
|
val = static_cast<uint16_t>(f.u - denorm_magic.u);
|
||||||
|
} else {
|
||||||
|
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
||||||
|
|
||||||
|
// update exponent, rounding bias part 1
|
||||||
|
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
||||||
|
// without arithmetic overflow.
|
||||||
|
f.u += 0xc8000fffU;
|
||||||
|
// rounding bias part 2
|
||||||
|
f.u += mant_odd;
|
||||||
|
// take the bits!
|
||||||
|
val = static_cast<uint16_t>(f.u >> 13);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val |= static_cast<uint16_t>(sign >> 16);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
|
||||||
|
constexpr detail::float32_bits magic = {113 << 23};
|
||||||
|
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||||
|
detail::float32_bits o;
|
||||||
|
|
||||||
|
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
|
||||||
|
unsigned int exp = shifted_exp & o.u; // just the exponent
|
||||||
|
o.u += (127 - 15) << 23; // exponent adjust
|
||||||
|
|
||||||
|
// handle exponent special cases
|
||||||
|
if (exp == shifted_exp) { // Inf/NaN?
|
||||||
|
o.u += (128 - 16) << 23; // extra exp adjust
|
||||||
|
} else if (exp == 0) { // Zero/Denormal?
|
||||||
|
o.u += 1 << 23; // extra exp adjust
|
||||||
|
o.f -= magic.f; // re-normalize
|
||||||
|
}
|
||||||
|
|
||||||
|
o.u |= (val & 0x8000) << 16; // sign bit
|
||||||
|
return o.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||||
|
template <class Derived>
|
||||||
|
struct BFloat16Impl {
|
||||||
|
protected:
|
||||||
|
/// <summary>
|
||||||
|
/// Converts from float to uint16_t float16 representation
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
static uint16_t ToUint16Impl(float v) noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts bfloat16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of bfloat16 value</returns>
|
||||||
|
float ToFloatImpl() const noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
uint16_t AbsImpl() const noexcept {
|
||||||
|
return static_cast<uint16_t>(val & ~kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
uint16_t NegateImpl() const noexcept {
|
||||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// uint16_t special values
|
||||||
|
static constexpr uint16_t kSignMask = 0x8000U;
|
||||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
|
||||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
|
||||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
|
||||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
|
||||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
|
||||||
|
static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
|
||||||
|
static constexpr uint16_t kEpsilonBits = 0x0080U;
|
||||||
|
static constexpr uint16_t kMinValueBits = 0xFF7FU;
|
||||||
|
static constexpr uint16_t kMaxValueBits = 0x7F7FU;
|
||||||
|
static constexpr uint16_t kRoundToNearest = 0x7FFFU;
|
||||||
|
static constexpr uint16_t kOneBits = 0x3F80U;
|
||||||
|
static constexpr uint16_t kMinusOneBits = 0xBF80U;
|
||||||
|
|
||||||
|
uint16_t val{0};
|
||||||
|
|
||||||
|
BFloat16Impl() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
bool IsNegative() const noexcept {
|
||||||
|
return static_cast<int16_t>(val) < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
bool IsNaN() const noexcept {
|
||||||
|
return AbsImpl() > kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
bool IsFinite() const noexcept {
|
||||||
|
return AbsImpl() < kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
bool IsPositiveInfinity() const noexcept {
|
||||||
|
return val == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
bool IsNegativeInfinity() const noexcept {
|
||||||
|
return val == kNegativeInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
bool IsInfinity() const noexcept {
|
||||||
|
return AbsImpl() == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
bool IsNaNOrZero() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs == 0 || abs > kPositiveInfinityBits);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsNormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsSubnormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
|
||||||
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
||||||
|
uint16_t result;
|
||||||
|
if (std::isnan(v)) {
|
||||||
|
result = kPositiveQNaNBits;
|
||||||
|
} else {
|
||||||
|
auto get_msb_half = [](float fl) {
|
||||||
|
uint16_t result;
|
||||||
|
#ifdef __cpp_if_constexpr
|
||||||
|
if constexpr (detail::endian::native == detail::endian::little) {
|
||||||
|
#else
|
||||||
|
if (detail::endian::native == detail::endian::little) {
|
||||||
|
#endif
|
||||||
|
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
|
||||||
|
} else {
|
||||||
|
std::memcpy(&result, &fl, sizeof(uint16_t));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint16_t upper_bits = get_msb_half(v);
|
||||||
|
union {
|
||||||
|
uint32_t U32;
|
||||||
|
float F32;
|
||||||
|
};
|
||||||
|
F32 = v;
|
||||||
|
U32 += (upper_bits & 1) + kRoundToNearest;
|
||||||
|
result = get_msb_half(F32);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
|
||||||
|
if (IsNaN()) {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
float result;
|
||||||
|
char* const first = reinterpret_cast<char*>(&result);
|
||||||
|
char* const second = first + sizeof(uint16_t);
|
||||||
|
#ifdef __cpp_if_constexpr
|
||||||
|
if constexpr (detail::endian::native == detail::endian::little) {
|
||||||
|
#else
|
||||||
|
if (detail::endian::native == detail::endian::little) {
|
||||||
|
#endif
|
||||||
|
std::memset(first, 0, sizeof(uint16_t));
|
||||||
|
std::memcpy(second, &val, sizeof(uint16_t));
|
||||||
|
} else {
|
||||||
|
std::memcpy(first, &val, sizeof(uint16_t));
|
||||||
|
std::memset(second, 0, sizeof(uint16_t));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnxruntime_float16
|
|
@ -81,10 +81,10 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
||||||
|
|
||||||
auto* length_penalty_tensor = context->Input<Tensor>(5);
|
auto* length_penalty_tensor = context->Input<Tensor>(5);
|
||||||
if (length_penalty_tensor) {
|
if (length_penalty_tensor) {
|
||||||
if (length_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
|
if (length_penalty_tensor->IsDataType<float>()) {
|
||||||
length_penalty = static_cast<float>(*length_penalty_tensor->Data<float>());
|
length_penalty = *length_penalty_tensor->Data<float>();
|
||||||
} else {
|
} else {
|
||||||
length_penalty = static_cast<MLFloat16>(*length_penalty_tensor->Data<MLFloat16>());
|
length_penalty = static_cast<float>(*length_penalty_tensor->Data<MLFloat16>());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
length_penalty = 1.0f;
|
length_penalty = 1.0f;
|
||||||
|
@ -92,10 +92,10 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
||||||
|
|
||||||
auto* repetition_penalty_tensor = context->Input<Tensor>(6);
|
auto* repetition_penalty_tensor = context->Input<Tensor>(6);
|
||||||
if (repetition_penalty_tensor) {
|
if (repetition_penalty_tensor) {
|
||||||
if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
|
if (repetition_penalty_tensor->IsDataType<float>()) {
|
||||||
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<float>());
|
repetition_penalty = *repetition_penalty_tensor->Data<float>();
|
||||||
} else {
|
} else {
|
||||||
repetition_penalty = static_cast<MLFloat16>(*repetition_penalty_tensor->Data<MLFloat16>());
|
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<MLFloat16>());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
repetition_penalty = 1.0f;
|
repetition_penalty = 1.0f;
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
#include "core/framework/tensor.h"
|
#include "core/framework/tensor.h"
|
||||||
#include "core/framework/TensorSeq.h"
|
#include "core/framework/TensorSeq.h"
|
||||||
#include "core/graph/onnx_protobuf.h"
|
#include "core/graph/onnx_protobuf.h"
|
||||||
#include "core/util/math.h"
|
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
|
@ -27,18 +26,34 @@ using namespace ONNX_NAMESPACE;
|
||||||
|
|
||||||
namespace onnxruntime {
|
namespace onnxruntime {
|
||||||
|
|
||||||
MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {}
|
|
||||||
|
|
||||||
float MLFloat16::ToFloat() const {
|
|
||||||
return math::halfToFloat(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the MLDataType used for a generic Tensor
|
// Return the MLDataType used for a generic Tensor
|
||||||
template <>
|
template <>
|
||||||
MLDataType DataTypeImpl::GetType<Tensor>() {
|
MLDataType DataTypeImpl::GetType<Tensor>() {
|
||||||
return TensorTypeBase::Type();
|
return TensorTypeBase::Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MLFloat16 MLFloat16::NaN(MLFloat16::FromBits(MLFloat16::kPositiveQNaNBits));
|
||||||
|
const MLFloat16 MLFloat16::NegativeNaN(MLFloat16::FromBits(MLFloat16::kNegativeQNaNBits));
|
||||||
|
const MLFloat16 MLFloat16::Infinity(MLFloat16::FromBits(MLFloat16::kPositiveInfinityBits));
|
||||||
|
const MLFloat16 MLFloat16::NegativeInfinity(MLFloat16::FromBits(MLFloat16::kNegativeInfinityBits));
|
||||||
|
const MLFloat16 MLFloat16::Epsilon(MLFloat16::FromBits(MLFloat16::kEpsilonBits));
|
||||||
|
const MLFloat16 MLFloat16::MinValue(MLFloat16::FromBits(MLFloat16::kMinValueBits));
|
||||||
|
const MLFloat16 MLFloat16::MaxValue(MLFloat16::FromBits(MLFloat16::kMaxValueBits));
|
||||||
|
const MLFloat16 MLFloat16::Zero(MLFloat16::FromBits(0));
|
||||||
|
const MLFloat16 MLFloat16::One(MLFloat16::FromBits(MLFloat16::kOneBits));
|
||||||
|
const MLFloat16 MLFloat16::MinusOne(MLFloat16::FromBits(MLFloat16::kMinusOneBits));
|
||||||
|
|
||||||
|
const BFloat16 BFloat16::NaN(BFloat16::FromBits(BFloat16::kPositiveQNaNBits));
|
||||||
|
const BFloat16 BFloat16::NegativeNaN(BFloat16::FromBits(BFloat16::kNegativeQNaNBits));
|
||||||
|
const BFloat16 BFloat16::Infinity(BFloat16::FromBits(BFloat16::kPositiveInfinityBits));
|
||||||
|
const BFloat16 BFloat16::NegativeInfinity(BFloat16::FromBits(BFloat16::kNegativeInfinityBits));
|
||||||
|
const BFloat16 BFloat16::Epsilon(BFloat16::FromBits(BFloat16::kEpsilonBits));
|
||||||
|
const BFloat16 BFloat16::MinValue(BFloat16::FromBits(BFloat16::kMinValueBits));
|
||||||
|
const BFloat16 BFloat16::MaxValue(BFloat16::FromBits(BFloat16::kMaxValueBits));
|
||||||
|
const BFloat16 BFloat16::Zero(BFloat16::FromBits(0));
|
||||||
|
const BFloat16 BFloat16::One(BFloat16::FromBits(BFloat16::kOneBits));
|
||||||
|
const BFloat16 BFloat16::MinusOne(BFloat16::FromBits(BFloat16::kMinusOneBits));
|
||||||
|
|
||||||
} // namespace onnxruntime
|
} // namespace onnxruntime
|
||||||
|
|
||||||
// This conflicts with the above GetType<>() specialization
|
// This conflicts with the above GetType<>() specialization
|
||||||
|
|
|
@ -370,7 +370,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
|
||||||
if (v < 0 || v > max_value) {
|
if (v < 0 || v > max_value) {
|
||||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow");
|
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow");
|
||||||
}
|
}
|
||||||
p_data[i] = MLFloat16(static_cast<uint16_t>(v));
|
p_data[i] = MLFloat16::FromBits(static_cast<uint16_t>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -1849,7 +1849,7 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
|
||||||
|
|
||||||
std::transform(Y.begin(), Y.end(), output.begin(),
|
std::transform(Y.begin(), Y.end(), output.begin(),
|
||||||
[X_fl = math::halfToFloat(X.val)](const MLFloat16& y) {
|
[X_fl = math::halfToFloat(X.val)](const MLFloat16& y) {
|
||||||
return MLFloat16(math::floatToHalf(std::fmod(X_fl, math::halfToFloat(y.val))));
|
return MLFloat16(std::fmod(X_fl, y.ToFloat()));
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[](BroadcastHelper& per_iter_bh) {
|
[](BroadcastHelper& per_iter_bh) {
|
||||||
|
@ -1859,7 +1859,7 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
|
||||||
|
|
||||||
std::transform(X.begin(), X.end(), output.begin(),
|
std::transform(X.begin(), X.end(), output.begin(),
|
||||||
[Y_fl = math::halfToFloat(Y.val)](const MLFloat16& x) {
|
[Y_fl = math::halfToFloat(Y.val)](const MLFloat16& x) {
|
||||||
return MLFloat16(math::floatToHalf(std::fmod(math::halfToFloat(x.val), Y_fl)));
|
return MLFloat16(std::fmod(x.ToFloat(), Y_fl));
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[](BroadcastHelper& per_iter_bh) {
|
[](BroadcastHelper& per_iter_bh) {
|
||||||
|
@ -1869,9 +1869,9 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
|
||||||
|
|
||||||
std::transform(X.begin(), X.end(), Y.begin(), output.begin(),
|
std::transform(X.begin(), X.end(), Y.begin(), output.begin(),
|
||||||
[](const MLFloat16& x, const MLFloat16& y) {
|
[](const MLFloat16& x, const MLFloat16& y) {
|
||||||
auto x_fl = math::halfToFloat(x.val);
|
auto x_fl = x.ToFloat();
|
||||||
auto y_fl = math::halfToFloat(y.val);
|
auto y_fl = y.ToFloat();
|
||||||
return MLFloat16(math::floatToHalf(std::fmod(x_fl, y_fl)));
|
return MLFloat16(std::fmod(x_fl, y_fl));
|
||||||
});
|
});
|
||||||
}};
|
}};
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ Status Round<MLFloat16>::Compute(OpKernelContext* ctx) const {
|
||||||
auto* output = Y.MutableData<MLFloat16>();
|
auto* output = Y.MutableData<MLFloat16>();
|
||||||
const auto size = X.Shape().Size();
|
const auto size = X.Shape().Size();
|
||||||
for (int64_t i = 0; i < size; ++i, ++output, ++input) {
|
for (int64_t i = 0; i < size; ++i, ++output, ++input) {
|
||||||
*output = MLFloat16(math::floatToHalf(::rint(math::halfToFloat(input->val))));
|
*output = MLFloat16(static_cast<float>(::rint(input->ToFloat())));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,41 +55,27 @@ struct CallSignImpl {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// The spec does not specify how NaN is
|
|
||||||
// treated but we have to treat it somehow. We choose
|
|
||||||
// to return 0 for NaN as TF does.
|
|
||||||
template <class T>
|
|
||||||
inline T FloatingImpl(T val) {
|
|
||||||
if (std::isnan(val) || val == T(0)) {
|
|
||||||
return T(0);
|
|
||||||
}
|
|
||||||
if (val > T(0)) {
|
|
||||||
return T(1);
|
|
||||||
} else {
|
|
||||||
return T(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct CallSignImpl<MLFloat16> {
|
struct CallSignImpl<MLFloat16> {
|
||||||
void operator()(const Tensor* input, Tensor* output) const {
|
void operator()(const Tensor* input, Tensor* output) const {
|
||||||
auto span = gsl::make_span(input->Data<MLFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size()));
|
ConstEigenVectorMap<Eigen::half> input_data(
|
||||||
auto output_data = output->MutableData<MLFloat16>();
|
reinterpret_cast<const Eigen::half*>(input->Data<MLFloat16>()),
|
||||||
std::transform(span.begin(), span.end(), output_data, [](const MLFloat16& val) {
|
narrow<ptrdiff_t>(input->Shape().Size()));
|
||||||
float fl = math::halfToFloat(val.val);
|
|
||||||
return MLFloat16(math::floatToHalf(FloatingImpl(fl)));
|
EigenVectorMap<Eigen::half>(reinterpret_cast<Eigen::half*>(output->MutableData<MLFloat16>()),
|
||||||
});
|
narrow<ptrdiff_t>(output->Shape().Size())) = input_data.array().cwiseSign();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct CallSignImpl<BFloat16> {
|
struct CallSignImpl<BFloat16> {
|
||||||
void operator()(const Tensor* input, Tensor* output) const {
|
void operator()(const Tensor* input, Tensor* output) const {
|
||||||
auto span = gsl::make_span(input->Data<BFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size()));
|
auto span = input->DataAsSpan<BFloat16>();
|
||||||
auto output_data = output->MutableData<BFloat16>();
|
auto output_data = output->MutableData<BFloat16>();
|
||||||
std::transform(span.begin(), span.end(), output_data, [](const BFloat16& val) {
|
std::transform(span.begin(), span.end(), output_data, [](const BFloat16& val) {
|
||||||
float fl = val.ToFloat();
|
// Return 0 as TF does for NaN.
|
||||||
return BFloat16(FloatingImpl(fl));
|
if (val.IsNaNOrZero()) return BFloat16::Zero;
|
||||||
|
return (val.IsNegative()) ? BFloat16::MinusOne : BFloat16::One;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -55,18 +55,18 @@ Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
||||||
const auto& span = gsl::make_span(input->Data<MLFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size()));
|
const auto span = input->DataAsSpan<MLFloat16>();
|
||||||
auto* output_data = output->MutableData<MLFloat16>();
|
auto* output_data = output->MutableData<MLFloat16>();
|
||||||
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const MLFloat16& val) {
|
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const MLFloat16& val) {
|
||||||
float fl = math::halfToFloat(val.val);
|
float fl = val.ToFloat();
|
||||||
return MLFloat16(math::floatToHalf(ShrinkCore<float>(fl, bias, lambd)));
|
return MLFloat16(ShrinkCore<float>(fl, bias, lambd));
|
||||||
});
|
});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
|
||||||
const auto& span = gsl::make_span(input->Data<BFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size()));
|
const auto span = input->DataAsSpan<BFloat16>();
|
||||||
auto* output_data = output->MutableData<BFloat16>();
|
auto* output_data = output->MutableData<BFloat16>();
|
||||||
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const BFloat16& val) {
|
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const BFloat16& val) {
|
||||||
float fl = val.ToFloat();
|
float fl = val.ToFloat();
|
||||||
|
|
|
@ -146,20 +146,20 @@ __device__ __forceinline__ BFloat16& operator/=(BFloat16& a, const BFloat16& b)
|
||||||
|
|
||||||
/// Arithmetic with floats
|
/// Arithmetic with floats
|
||||||
|
|
||||||
__device__ __forceinline__ float operator+(BFloat16 a, float b) { return static_cast<float>(a) + b; }
|
__device__ __forceinline__ float operator+(BFloat16 a, float b) { return a + b; }
|
||||||
__device__ __forceinline__ float operator-(BFloat16 a, float b) { return static_cast<float>(a) - b; }
|
__device__ __forceinline__ float operator-(BFloat16 a, float b) { return a - b; }
|
||||||
__device__ __forceinline__ float operator*(BFloat16 a, float b) { return static_cast<float>(a) * b; }
|
__device__ __forceinline__ float operator*(BFloat16 a, float b) { return a * b; }
|
||||||
__device__ __forceinline__ float operator/(BFloat16 a, float b) { return static_cast<float>(a) / b; }
|
__device__ __forceinline__ float operator/(BFloat16 a, float b) { return a / b; }
|
||||||
|
|
||||||
__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + static_cast<float>(b); }
|
__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + b; }
|
||||||
__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - static_cast<float>(b); }
|
__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - b; }
|
||||||
__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * static_cast<float>(b); }
|
__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * b; }
|
||||||
__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / static_cast<float>(b); }
|
__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / b; }
|
||||||
|
|
||||||
__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += static_cast<float>(b); }
|
__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += b; }
|
||||||
__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= static_cast<float>(b); }
|
__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= b; }
|
||||||
__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= static_cast<float>(b); }
|
__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= b; }
|
||||||
__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= static_cast<float>(b); }
|
__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= b; }
|
||||||
|
|
||||||
/// Arithmetic with doubles
|
/// Arithmetic with doubles
|
||||||
|
|
||||||
|
|
|
@ -73,10 +73,10 @@ struct LowMax {
|
||||||
template <>
|
template <>
|
||||||
struct LowMax<MLFloat16> {
|
struct LowMax<MLFloat16> {
|
||||||
static MLFloat16 low() {
|
static MLFloat16 low() {
|
||||||
return MLFloat16(math::floatToHalf(std::numeric_limits<float>::lowest()));
|
return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::lowest()));
|
||||||
}
|
}
|
||||||
static MLFloat16 max() {
|
static MLFloat16 max() {
|
||||||
return MLFloat16(math::floatToHalf(std::numeric_limits<float>::max()));
|
return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::max()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace clip_internal
|
} // namespace clip_internal
|
||||||
|
|
|
@ -460,10 +460,6 @@ Status DenseTensorToSparseCoo(const DataTransferManager& data_manager, const Ten
|
||||||
|
|
||||||
} // namespace sparse_utils
|
} // namespace sparse_utils
|
||||||
|
|
||||||
float MLFloat16::ToFloat() const {
|
|
||||||
return math::halfToFloat(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> GetStackTrace() { return g_host->GetStackTrace(); }
|
std::vector<std::string> GetStackTrace() { return g_host->GetStackTrace(); }
|
||||||
|
|
||||||
void LogRuntimeError(uint32_t session_id, const common::Status& status,
|
void LogRuntimeError(uint32_t session_id, const common::Status& status,
|
||||||
|
|
|
@ -76,7 +76,7 @@ class RandomValueGenerator {
|
||||||
std::vector<TFloat16> val(detail::SizeFromDims(dims));
|
std::vector<TFloat16> val(detail::SizeFromDims(dims));
|
||||||
std::uniform_real_distribution<float> distribution(min, max);
|
std::uniform_real_distribution<float> distribution(min, max);
|
||||||
for (size_t i = 0; i < val.size(); ++i) {
|
for (size_t i = 0; i < val.size(); ++i) {
|
||||||
val[i] = TFloat16(math::floatToHalf(distribution(generator_)));
|
val[i] = TFloat16(static_cast<float>(distribution(generator_)));
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,7 +151,7 @@ inline std::vector<MLFloat16> ToFloat16(const std::vector<float>& data) {
|
||||||
std::vector<MLFloat16> result;
|
std::vector<MLFloat16> result;
|
||||||
result.reserve(data.size());
|
result.reserve(data.size());
|
||||||
for (size_t i = 0; i < data.size(); i++) {
|
for (size_t i = 0; i < data.size(); i++) {
|
||||||
result.push_back(MLFloat16(math::floatToHalf(data[i])));
|
result.push_back(MLFloat16(data[i]));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -245,7 +245,7 @@ class ParallelRandomValueGenerator {
|
||||||
RandomEngine generator{seed};
|
RandomEngine generator{seed};
|
||||||
std::uniform_real_distribution<float> distribution(min, max);
|
std::uniform_real_distribution<float> distribution(min, max);
|
||||||
for (std::ptrdiff_t di = begin; di != end; ++di) {
|
for (std::ptrdiff_t di = begin; di != end; ++di) {
|
||||||
val[di] = TFloat16(math::floatToHalf(distribution(generator)));
|
val[di] = TFloat16(static_cast<float>(distribution(generator)));
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -109,8 +109,8 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector<int64_t>& input_s
|
||||||
ratio = 0.5f;
|
ratio = 0.5f;
|
||||||
} else {
|
} else {
|
||||||
if (use_float16_ratio) {
|
if (use_float16_ratio) {
|
||||||
t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))});
|
t.AddInput("ratio", {}, {MLFloat16(ratio)});
|
||||||
t_bitmask.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))});
|
t_bitmask.AddInput("ratio", {}, {MLFloat16(ratio)});
|
||||||
} else {
|
} else {
|
||||||
t.AddInput("ratio", {}, {ratio});
|
t.AddInput("ratio", {}, {ratio});
|
||||||
t_bitmask.AddInput("ratio", {}, {ratio});
|
t_bitmask.AddInput("ratio", {}, {ratio});
|
||||||
|
|
|
@ -253,17 +253,17 @@ TEST(MathOpTest, ComplexMul_fp16) {
|
||||||
if (DefaultCudaExecutionProvider() == nullptr) return;
|
if (DefaultCudaExecutionProvider() == nullptr) return;
|
||||||
|
|
||||||
std::vector<MLFloat16> input_a_data = {
|
std::vector<MLFloat16> input_a_data = {
|
||||||
MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.6f))};
|
MLFloat16(-0.5f), MLFloat16(0.6f)};
|
||||||
|
|
||||||
std::vector<MLFloat16> input_b_data = {
|
std::vector<MLFloat16> input_b_data = {
|
||||||
MLFloat16(math::floatToHalf(0.8f)), MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.0f)), MLFloat16(math::floatToHalf(1.f)),
|
MLFloat16(0.8f), MLFloat16(-0.5f), MLFloat16(0.0f), MLFloat16(1.f),
|
||||||
MLFloat16(math::floatToHalf(0.5f)), MLFloat16(math::floatToHalf(0.2f)), MLFloat16(math::floatToHalf(0.3f)), MLFloat16(math::floatToHalf(-0.6f))};
|
MLFloat16(0.5f), MLFloat16(0.2f), MLFloat16(0.3f), MLFloat16(-0.6f)};
|
||||||
|
|
||||||
std::vector<MLFloat16> output_data = {
|
std::vector<MLFloat16> output_data = {
|
||||||
MLFloat16(math::floatToHalf(-0.10f)), MLFloat16(math::floatToHalf(0.73f)),
|
MLFloat16(-0.10f), MLFloat16(0.73f),
|
||||||
MLFloat16(math::floatToHalf(-0.60f)), MLFloat16(math::floatToHalf(-0.50f)),
|
MLFloat16(-0.60f), MLFloat16(-0.50f),
|
||||||
MLFloat16(math::floatToHalf(-0.37f)), MLFloat16(math::floatToHalf(0.20f)),
|
MLFloat16(-0.37f), MLFloat16(0.20f),
|
||||||
MLFloat16(math::floatToHalf(0.21f)), MLFloat16(math::floatToHalf(0.48f))};
|
MLFloat16(0.21f), MLFloat16(0.48f)};
|
||||||
|
|
||||||
OpTester tester("ComplexMul", 1, onnxruntime::kMSDomain);
|
OpTester tester("ComplexMul", 1, onnxruntime::kMSDomain);
|
||||||
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);
|
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);
|
||||||
|
@ -279,17 +279,17 @@ TEST(MathOpTest, ComplexMulConj_fp16) {
|
||||||
if (DefaultCudaExecutionProvider() == nullptr) return;
|
if (DefaultCudaExecutionProvider() == nullptr) return;
|
||||||
|
|
||||||
std::vector<MLFloat16> input_a_data = {
|
std::vector<MLFloat16> input_a_data = {
|
||||||
MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.6f))};
|
MLFloat16(-0.5f), MLFloat16(0.6f)};
|
||||||
|
|
||||||
std::vector<MLFloat16> input_b_data = {
|
std::vector<MLFloat16> input_b_data = {
|
||||||
MLFloat16(math::floatToHalf(0.8f)), MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.0f)), MLFloat16(math::floatToHalf(1.f)),
|
MLFloat16(0.8f), MLFloat16(-0.5f), MLFloat16(0.0f), MLFloat16(1.f),
|
||||||
MLFloat16(math::floatToHalf(0.5f)), MLFloat16(math::floatToHalf(0.2f)), MLFloat16(math::floatToHalf(0.3f)), MLFloat16(math::floatToHalf(-0.6f))};
|
MLFloat16(0.5f), MLFloat16(0.2f), MLFloat16(0.3f), MLFloat16(-0.6f)};
|
||||||
|
|
||||||
std::vector<MLFloat16> output_data = {
|
std::vector<MLFloat16> output_data = {
|
||||||
MLFloat16(math::floatToHalf(-0.70f)), MLFloat16(math::floatToHalf(0.23f)),
|
MLFloat16(-0.70f), MLFloat16(0.23f),
|
||||||
MLFloat16(math::floatToHalf(0.60f)), MLFloat16(math::floatToHalf(0.50f)),
|
MLFloat16(0.60f), MLFloat16(0.50f),
|
||||||
MLFloat16(math::floatToHalf(-0.13f)), MLFloat16(math::floatToHalf(0.40f)),
|
MLFloat16(-0.13f), MLFloat16(0.40f),
|
||||||
MLFloat16(math::floatToHalf(-0.51f)), MLFloat16(math::floatToHalf(-0.12f))};
|
MLFloat16(-0.51f), MLFloat16(-0.12f)};
|
||||||
|
|
||||||
OpTester tester("ComplexMulConj", 1, onnxruntime::kMSDomain);
|
OpTester tester("ComplexMulConj", 1, onnxruntime::kMSDomain);
|
||||||
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);
|
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);
|
||||||
|
|
|
@ -30,14 +30,14 @@ TEST(InverseContribOpTest, two_by_two_float16) {
|
||||||
std::transform(
|
std::transform(
|
||||||
input_float.begin(), input_float.end(), std::back_inserter(input),
|
input_float.begin(), input_float.end(), std::back_inserter(input),
|
||||||
[](float v) {
|
[](float v) {
|
||||||
return MLFloat16(math::floatToHalf(v));
|
return MLFloat16(v);
|
||||||
});
|
});
|
||||||
|
|
||||||
auto output_float = {0.6f, -0.7f, -0.2f, 0.4f};
|
auto output_float = {0.6f, -0.7f, -0.2f, 0.4f};
|
||||||
std::vector<MLFloat16> output;
|
std::vector<MLFloat16> output;
|
||||||
std::transform(
|
std::transform(
|
||||||
output_float.begin(), output_float.end(), std::back_inserter(output), [](float v) {
|
output_float.begin(), output_float.end(), std::back_inserter(output), [](float v) {
|
||||||
return MLFloat16(math::floatToHalf(v));
|
return MLFloat16(v);
|
||||||
});
|
});
|
||||||
|
|
||||||
test.AddInput<MLFloat16>("X", {2, 2}, input);
|
test.AddInput<MLFloat16>("X", {2, 2}, input);
|
||||||
|
|
|
@ -420,13 +420,159 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) {
|
||||||
}
|
}
|
||||||
#endif // !defined(DISABLE_ML_OPS)
|
#endif // !defined(DISABLE_ML_OPS)
|
||||||
|
|
||||||
TEST_F(DataTypeTest, BFloat16Test) {
|
TEST_F(DataTypeTest, MlFloat16ConvertFloatToMLFloat16) {
|
||||||
// Test data type
|
// Test data type
|
||||||
{
|
{
|
||||||
constexpr float sample = 1.0f;
|
constexpr float sample = 1.0f;
|
||||||
BFloat16 flt16(sample);
|
const MLFloat16 flt16(sample);
|
||||||
auto int_rep = flt16.val;
|
auto int_rep = flt16.val;
|
||||||
BFloat16 flt_from_int(int_rep, BFloat16::FromBits());
|
const auto flt_from_int = MLFloat16::FromBits(int_rep);
|
||||||
|
const double diff = std::fabs(sample - flt_from_int.ToFloat());
|
||||||
|
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test bulk conversion
|
||||||
|
{
|
||||||
|
float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||||
|
std::vector<MLFloat16> converted;
|
||||||
|
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
|
||||||
|
[](float fl) { return MLFloat16(fl); });
|
||||||
|
for (size_t i = 0; i < sizeof(sample) / sizeof(float); ++i) {
|
||||||
|
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
FAIL();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> back_converted;
|
||||||
|
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
|
||||||
|
[](const MLFloat16 ml) { return (float)ml; });
|
||||||
|
for (size_t i = 0; i < sizeof(sample) / sizeof(float); ++i) {
|
||||||
|
const double diff = std::fabs(sample[i] - back_converted[i]);
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
FAIL();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16Zeros) {
|
||||||
|
const auto positive_zero = MLFloat16::FromBits(0U);
|
||||||
|
EXPECT_FALSE(positive_zero.IsNegative());
|
||||||
|
const float float_positive_zero = static_cast<float>(positive_zero);
|
||||||
|
EXPECT_EQ(+0.0f, float_positive_zero);
|
||||||
|
EXPECT_FALSE(std::signbit(float_positive_zero));
|
||||||
|
|
||||||
|
const auto negative_zero = positive_zero.Negate();
|
||||||
|
EXPECT_TRUE(negative_zero.IsNegative());
|
||||||
|
const float float_positive_negzero = static_cast<float>(negative_zero);
|
||||||
|
EXPECT_EQ(-0.0f, float_positive_negzero);
|
||||||
|
EXPECT_TRUE(std::signbit(float_positive_negzero));
|
||||||
|
|
||||||
|
EXPECT_TRUE(positive_zero.IsNaNOrZero());
|
||||||
|
EXPECT_TRUE(negative_zero.IsNaNOrZero());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16Comparision) {
|
||||||
|
const MLFloat16 left = MLFloat16(-33.33f);
|
||||||
|
const MLFloat16 left_same = MLFloat16(-33.33f);
|
||||||
|
const MLFloat16 right = MLFloat16(66.66f);
|
||||||
|
const MLFloat16 right_same = MLFloat16(66.66f);
|
||||||
|
|
||||||
|
EXPECT_TRUE(MLFloat16::Epsilon < right);
|
||||||
|
|
||||||
|
EXPECT_EQ(left, left_same);
|
||||||
|
EXPECT_NE(left, left_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_EQ(right, right_same);
|
||||||
|
EXPECT_NE(right, right_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_LT(left, right);
|
||||||
|
EXPECT_LT(right.Negate(), left);
|
||||||
|
EXPECT_LT(left.Negate(), right);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16TestNAN) {
|
||||||
|
const MLFloat16 fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
|
||||||
|
|
||||||
|
// NaN are not equal to each other
|
||||||
|
EXPECT_NE(MLFloat16::NaN, fp16NANFromSingle);
|
||||||
|
|
||||||
|
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
EXPECT_TRUE(std::isnan(NanFromBFloat16));
|
||||||
|
|
||||||
|
EXPECT_FALSE(MLFloat16::FromBits(MLFloat16::kMaxValueBits).IsNaN());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16NaNComparision) {
|
||||||
|
EXPECT_FALSE(MLFloat16::NaN < MLFloat16::NaN);
|
||||||
|
EXPECT_FALSE(MLFloat16::NaN == MLFloat16::NaN);
|
||||||
|
|
||||||
|
EXPECT_FALSE(MLFloat16::MaxValue < MLFloat16::NaN);
|
||||||
|
EXPECT_FALSE(MLFloat16::MaxValue == MLFloat16::NaN);
|
||||||
|
EXPECT_FALSE(MLFloat16::MinValue < MLFloat16::NaN);
|
||||||
|
EXPECT_FALSE(MLFloat16::NaN < MLFloat16::MaxValue);
|
||||||
|
|
||||||
|
EXPECT_TRUE(MLFloat16::MinValue < MLFloat16::MaxValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16Infinity) {
|
||||||
|
EXPECT_FALSE(MLFloat16::MinValue.IsInfinity());
|
||||||
|
EXPECT_FALSE(MLFloat16::MaxValue.IsInfinity());
|
||||||
|
EXPECT_TRUE(MLFloat16::MaxValue.IsFinite());
|
||||||
|
|
||||||
|
const MLFloat16 pos_infinity_from_float(std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const MLFloat16 neg_infinity_from_float(-std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const float pos_infinity_from_bfloat16 = static_cast<float>(MLFloat16::Infinity);
|
||||||
|
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
|
||||||
|
EXPECT_TRUE(!std::signbit(pos_infinity_from_bfloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, MLFloat16NormalSubnormal) {
|
||||||
|
EXPECT_FALSE(MLFloat16::Infinity.IsNormal());
|
||||||
|
EXPECT_TRUE(MLFloat16(45.6f).IsNormal());
|
||||||
|
EXPECT_FALSE(MLFloat16(45.6f).IsSubnormal());
|
||||||
|
|
||||||
|
// 0b0_0000_0000_000_0001 ~0.000000059604645
|
||||||
|
constexpr uint16_t min_subnormal_bits = 0x0001;
|
||||||
|
const MLFloat16 smallest_subnormal = MLFloat16::FromBits(min_subnormal_bits);
|
||||||
|
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(smallest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
// float smallest positive subnormal is ~1.40129846432481707092E-45, and
|
||||||
|
// in float the same number above would be normal
|
||||||
|
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
|
||||||
|
EXPECT_TRUE(std::isnormal(float_from_smallest_subnormal));
|
||||||
|
|
||||||
|
// 0b0_0000_0000_111_1111; ~0.000060975552
|
||||||
|
constexpr uint16_t max_subnormal_bits = 0x007F;
|
||||||
|
const MLFloat16 largest_subnormal = MLFloat16::FromBits(max_subnormal_bits);
|
||||||
|
EXPECT_TRUE(largest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(largest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
// However, in float the same number above would be normal
|
||||||
|
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
|
||||||
|
EXPECT_TRUE(std::isnormal(float_from_largest_subnormal));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16ConvertFloatToBFloat16) {
|
||||||
|
// Test data type
|
||||||
|
{
|
||||||
|
constexpr float sample = 1.0f;
|
||||||
|
const BFloat16 flt16(sample);
|
||||||
|
auto int_rep = flt16.val;
|
||||||
|
const auto flt_from_int = BFloat16::FromBits(int_rep);
|
||||||
const double diff = std::fabs(sample - flt_from_int.ToFloat());
|
const double diff = std::fabs(sample - flt_from_int.ToFloat());
|
||||||
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
|
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
|
||||||
EXPECT_TRUE(false);
|
EXPECT_TRUE(false);
|
||||||
|
@ -456,6 +602,112 @@ TEST_F(DataTypeTest, BFloat16Test) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16Zeros) {
|
||||||
|
const auto positive_zero = BFloat16::FromBits(0U);
|
||||||
|
EXPECT_FALSE(positive_zero.IsNegative());
|
||||||
|
const float float_positive_zero = static_cast<float>(positive_zero);
|
||||||
|
EXPECT_EQ(+0.0f, float_positive_zero);
|
||||||
|
EXPECT_FALSE(std::signbit(float_positive_zero));
|
||||||
|
|
||||||
|
const auto negative_zero = positive_zero.Negate();
|
||||||
|
EXPECT_TRUE(negative_zero.IsNegative());
|
||||||
|
const float float_positive_negzero = static_cast<float>(negative_zero);
|
||||||
|
EXPECT_EQ(-0.0f, float_positive_negzero);
|
||||||
|
EXPECT_TRUE(std::signbit(float_positive_negzero));
|
||||||
|
|
||||||
|
EXPECT_TRUE(positive_zero.IsNaNOrZero());
|
||||||
|
EXPECT_TRUE(negative_zero.IsNaNOrZero());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16Comparision) {
|
||||||
|
const BFloat16 left = BFloat16(-33.33f);
|
||||||
|
const BFloat16 left_same = BFloat16(-33.33f);
|
||||||
|
const BFloat16 right = BFloat16(66.66f);
|
||||||
|
const BFloat16 right_same = BFloat16(66.66f);
|
||||||
|
|
||||||
|
EXPECT_TRUE(BFloat16::Epsilon < right);
|
||||||
|
|
||||||
|
EXPECT_EQ(left, left_same);
|
||||||
|
EXPECT_NE(left, left_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_EQ(right, right_same);
|
||||||
|
EXPECT_NE(right, right_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_LT(left, right);
|
||||||
|
EXPECT_LT(right.Negate(), left);
|
||||||
|
EXPECT_LT(left.Negate(), right);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16TestNAN) {
|
||||||
|
const BFloat16 fp16NANFromSingle = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
|
||||||
|
// NaN are not equal to each other
|
||||||
|
EXPECT_NE(BFloat16::NaN, fp16NANFromSingle);
|
||||||
|
|
||||||
|
float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
EXPECT_TRUE(std::isnan(NanFromBFloat16));
|
||||||
|
|
||||||
|
EXPECT_FALSE(BFloat16::FromBits(BFloat16::kMaxValueBits).IsNaN());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16NaNComparision) {
|
||||||
|
EXPECT_FALSE(BFloat16::NaN < BFloat16::NaN);
|
||||||
|
EXPECT_FALSE(BFloat16::NaN == BFloat16::NaN);
|
||||||
|
|
||||||
|
EXPECT_FALSE(BFloat16::MaxValue < BFloat16::NaN);
|
||||||
|
EXPECT_FALSE(BFloat16::MaxValue == BFloat16::NaN);
|
||||||
|
EXPECT_FALSE(BFloat16::MinValue < BFloat16::NaN);
|
||||||
|
EXPECT_FALSE(BFloat16::NaN < BFloat16::MaxValue);
|
||||||
|
|
||||||
|
EXPECT_TRUE(BFloat16::MinValue < BFloat16::MaxValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16Infinity) {
|
||||||
|
EXPECT_FALSE(BFloat16::MinValue.IsInfinity());
|
||||||
|
EXPECT_FALSE(BFloat16::MaxValue.IsInfinity());
|
||||||
|
EXPECT_TRUE(BFloat16::MaxValue.IsFinite());
|
||||||
|
|
||||||
|
const BFloat16 pos_infinity_from_float = std::numeric_limits<float>::infinity();
|
||||||
|
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const BFloat16 neg_infinity_from_float = -std::numeric_limits<float>::infinity();
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
|
||||||
|
EXPECT_TRUE(std::signbit(neg_infinity_from_float.ToFloat()));
|
||||||
|
|
||||||
|
const float pos_infinity_from_bfloat16 = static_cast<float>(BFloat16::Infinity);
|
||||||
|
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
|
||||||
|
EXPECT_TRUE(!std::signbit(pos_infinity_from_bfloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypeTest, BFloat16NormalSubnormal) {
|
||||||
|
EXPECT_FALSE(BFloat16::Infinity.IsNormal());
|
||||||
|
EXPECT_TRUE(BFloat16(45.6f).IsNormal());
|
||||||
|
EXPECT_FALSE(BFloat16(45.6f).IsSubnormal());
|
||||||
|
|
||||||
|
// 0b0_0000_0000_000_0001
|
||||||
|
constexpr uint16_t min_subnormal_bits = 0x0001;
|
||||||
|
const BFloat16 smallest_subnormal = BFloat16::FromBits(min_subnormal_bits);
|
||||||
|
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(smallest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
const float float_from_smallest_subnormal = (float)smallest_subnormal;
|
||||||
|
EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));
|
||||||
|
|
||||||
|
// 0b0_0000_0000_111_1111;
|
||||||
|
constexpr uint16_t max_subnormal_bits = 0x007F;
|
||||||
|
const BFloat16 largest_subnormal = BFloat16::FromBits(max_subnormal_bits);
|
||||||
|
EXPECT_TRUE(largest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(largest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
const float float_from_largest_subnormal = (float)largest_subnormal;
|
||||||
|
EXPECT_FALSE(std::isnormal(float_from_largest_subnormal));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DataTypeTest, DataUtilsTest) {
|
TEST_F(DataTypeTest, DataUtilsTest) {
|
||||||
using namespace ONNX_NAMESPACE::Utils;
|
using namespace ONNX_NAMESPACE::Utils;
|
||||||
// Test simple seq
|
// Test simple seq
|
||||||
|
@ -697,7 +949,7 @@ TEST(InlinedVectorTests, TestDefaultInlinedCapacity) {
|
||||||
TEST(TypeLiterals, Tests) {
|
TEST(TypeLiterals, Tests) {
|
||||||
{
|
{
|
||||||
// uint16_t test
|
// uint16_t test
|
||||||
MLFloat16 mlfloat{static_cast<uint16_t>(16)};
|
MLFloat16 mlfloat = MLFloat16::FromBits(static_cast<uint16_t>(16));
|
||||||
auto mlfloat_literal = 16_f16;
|
auto mlfloat_literal = 16_f16;
|
||||||
ASSERT_EQ(mlfloat, mlfloat_literal);
|
ASSERT_EQ(mlfloat, mlfloat_literal);
|
||||||
|
|
||||||
|
|
|
@ -159,7 +159,7 @@ TEST(Float16_Tests, Mul_16_Test) {
|
||||||
std::vector<float> values_x_32 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
std::vector<float> values_x_32 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||||
std::vector<MLFloat16> values_x;
|
std::vector<MLFloat16> values_x;
|
||||||
for (float i : values_x_32) {
|
for (float i : values_x_32) {
|
||||||
values_x.push_back(MLFloat16(math::floatToHalf(i)));
|
values_x.push_back(MLFloat16(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare expected inputs and outputs
|
// prepare expected inputs and outputs
|
||||||
|
@ -168,7 +168,7 @@ TEST(Float16_Tests, Mul_16_Test) {
|
||||||
std::vector<float> expected_values_y_32 = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
|
std::vector<float> expected_values_y_32 = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
|
||||||
std::vector<MLFloat16> expected_values_y;
|
std::vector<MLFloat16> expected_values_y;
|
||||||
for (float i : expected_values_y_32) {
|
for (float i : expected_values_y_32) {
|
||||||
expected_values_y.push_back(MLFloat16(math::floatToHalf(i)));
|
expected_values_y.push_back(MLFloat16(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now run
|
// Now run
|
||||||
|
|
|
@ -189,7 +189,7 @@ void UnpackTensor(const onnx::TensorProto& tensor, const void* raw_data, size_t
|
||||||
ORT_CXX_API_THROW(
|
ORT_CXX_API_THROW(
|
||||||
"data overflow", OrtErrorCode::ORT_FAIL);
|
"data overflow", OrtErrorCode::ORT_FAIL);
|
||||||
}
|
}
|
||||||
p_data[i] = MLFloat16(static_cast<uint16_t>(v));
|
p_data[i] = MLFloat16::FromBits(static_cast<uint16_t>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -2348,7 +2348,7 @@ TEST_F(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
||||||
run_options.run_tag = "one session/one tag";
|
run_options.run_tag = "one session/one tag";
|
||||||
OrtValue ml_value_x;
|
OrtValue ml_value_x;
|
||||||
|
|
||||||
auto x_f = MLFloat16(math::floatToHalf(1.0));
|
auto x_f = MLFloat16(1.0f);
|
||||||
std::vector<int64_t> dims_x = {1, 1, 3, 3};
|
std::vector<int64_t> dims_x = {1, 1, 3, 3};
|
||||||
std::vector<MLFloat16> values_x;
|
std::vector<MLFloat16> values_x;
|
||||||
for (int i = 0; i < 9; ++i) {
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
@ -2364,7 +2364,7 @@ TEST_F(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
||||||
|
|
||||||
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
|
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
|
||||||
|
|
||||||
auto prod_f = MLFloat16(math::floatToHalf(6.0));
|
auto prod_f = MLFloat16(6.0f);
|
||||||
std::vector<int64_t> expected_dims_prod = {1, 1, 2, 2};
|
std::vector<int64_t> expected_dims_prod = {1, 1, 2, 2};
|
||||||
std::vector<MLFloat16> expected_values_prod;
|
std::vector<MLFloat16> expected_values_prod;
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
@ -5477,7 +5477,7 @@ void BuildConstantSharingDivMulGraph(ModelTestBuilder& builder) {
|
||||||
for (size_t i = 0; i < 12; ++i) {
|
for (size_t i = 0; i < 12; ++i) {
|
||||||
NodeArg* mul_initializer = nullptr;
|
NodeArg* mul_initializer = nullptr;
|
||||||
if (std::is_same<T, MLFloat16>::value) {
|
if (std::is_same<T, MLFloat16>::value) {
|
||||||
mul_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(math::floatToHalf(1.0f)));
|
mul_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(1.0f));
|
||||||
} else if (std::is_same<T, float>::value) {
|
} else if (std::is_same<T, float>::value) {
|
||||||
mul_initializer = builder.MakeScalarInitializer<float>(1.0f);
|
mul_initializer = builder.MakeScalarInitializer<float>(1.0f);
|
||||||
} else {
|
} else {
|
||||||
|
@ -5593,7 +5593,7 @@ void BuildConstantSharingDivMulGraphFor2DInitializer(ModelTestBuilder& builder)
|
||||||
values_float16.reserve(values.size());
|
values_float16.reserve(values.size());
|
||||||
if (std::is_same<T, MLFloat16>::value) {
|
if (std::is_same<T, MLFloat16>::value) {
|
||||||
for (auto v : values) {
|
for (auto v : values) {
|
||||||
values_float16.push_back(MLFloat16(math::floatToHalf(v)));
|
values_float16.push_back(MLFloat16(v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5804,7 +5804,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ
|
||||||
builder.AddNode("Cast", {div_out}, {cast_out})
|
builder.AddNode("Cast", {div_out}, {cast_out})
|
||||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
|
||||||
for (size_t i = 0; i < 3; ++i) {
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
NodeArg* add_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(math::floatToHalf(1.0f)));
|
NodeArg* add_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(1.0f));
|
||||||
auto* add_out = builder.MakeOutput();
|
auto* add_out = builder.MakeOutput();
|
||||||
builder.AddNode("Add", {cast_out, add_initializer}, {add_out});
|
builder.AddNode("Add", {cast_out, add_initializer}, {add_out});
|
||||||
}
|
}
|
||||||
|
@ -5930,7 +5930,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitial
|
||||||
std::vector<MLFloat16> values_float16;
|
std::vector<MLFloat16> values_float16;
|
||||||
values_float16.reserve(values.size());
|
values_float16.reserve(values.size());
|
||||||
for (auto v : values) {
|
for (auto v : values) {
|
||||||
values_float16.push_back(MLFloat16(math::floatToHalf(v)));
|
values_float16.push_back(MLFloat16(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) {
|
auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ void RandomFillHalfVector(const TensorShapeVector& shape, std::vector<MLFloat16>
|
||||||
std::vector<float> data_float(TensorShape(shape).Size());
|
std::vector<float> data_float(TensorShape(shape).Size());
|
||||||
RandomFillFloatVector(shape, data_float);
|
RandomFillFloatVector(shape, data_float);
|
||||||
std::transform(data_float.begin(), data_float.end(), data.begin(),
|
std::transform(data_float.begin(), data_float.end(), data.begin(),
|
||||||
[](float value) { return MLFloat16(math::floatToHalf(value)); });
|
[](float value) { return MLFloat16(value); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomMasks(int64_t batch, int64_t sequence_length, std::vector<int64_t>& data) {
|
void RandomMasks(int64_t batch, int64_t sequence_length, std::vector<int64_t>& data) {
|
||||||
|
|
|
@ -332,7 +332,7 @@ struct TensorCheck<MLFloat16> {
|
||||||
<< "i:" << i;
|
<< "i:" << i;
|
||||||
}
|
}
|
||||||
if (has_rel_err) {
|
if (has_rel_err) {
|
||||||
EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(cur_expected[i]))
|
EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(static_cast<float>(cur_expected[i])))
|
||||||
<< "i:" << i;
|
<< "i:" << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -143,7 +143,7 @@ TEST(ConstantOfShape, TypeTests) {
|
||||||
RunTypedTest(TensorProto::INT8, int8_t(8));
|
RunTypedTest(TensorProto::INT8, int8_t(8));
|
||||||
RunTypedTest(TensorProto::INT16, int16_t(16));
|
RunTypedTest(TensorProto::INT16, int16_t(16));
|
||||||
RunTypedTest(TensorProto::FLOAT, 1.f);
|
RunTypedTest(TensorProto::FLOAT, 1.f);
|
||||||
RunTypedTest(TensorProto::FLOAT16, MLFloat16(static_cast<uint16_t>(5)));
|
RunTypedTest(TensorProto::FLOAT16, MLFloat16::FromBits(static_cast<uint16_t>(5)));
|
||||||
RunTypedTest(TensorProto::DOUBLE, 1.0);
|
RunTypedTest(TensorProto::DOUBLE, 1.0);
|
||||||
RunTypedTest(TensorProto::INT32, int32_t(32));
|
RunTypedTest(TensorProto::INT32, int32_t(32));
|
||||||
RunTypedTest(TensorProto::INT64, int64_t(64));
|
RunTypedTest(TensorProto::INT64, int64_t(64));
|
||||||
|
|
|
@ -24,9 +24,9 @@ TEST(CumSumTest, _1DTest) {
|
||||||
TEST(CumSumTest, _1DTestFloat16) {
|
TEST(CumSumTest, _1DTestFloat16) {
|
||||||
if (DefaultCudaExecutionProvider().get() != nullptr) {
|
if (DefaultCudaExecutionProvider().get() != nullptr) {
|
||||||
OpTester test("CumSum", 14, onnxruntime::kOnnxDomain);
|
OpTester test("CumSum", 14, onnxruntime::kOnnxDomain);
|
||||||
test.AddInput<MLFloat16>("x", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
test.AddInput<MLFloat16>("x", {3}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
|
||||||
test.AddInput<int32_t>("axis", {}, {0});
|
test.AddInput<int32_t>("axis", {}, {0});
|
||||||
test.AddOutput<MLFloat16>("y", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(6.0f))});
|
test.AddOutput<MLFloat16>("y", {3}, {MLFloat16(1.0f), MLFloat16(3.0f), MLFloat16(6.0f)});
|
||||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCpuExecutionProvider});
|
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCpuExecutionProvider});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ namespace test {
|
||||||
std::vector<MLFloat16> MakeMLFloat16(const std::initializer_list<float>& input) {
|
std::vector<MLFloat16> MakeMLFloat16(const std::initializer_list<float>& input) {
|
||||||
std::vector<MLFloat16> output;
|
std::vector<MLFloat16> output;
|
||||||
std::transform(input.begin(), input.end(), std::back_inserter(output),
|
std::transform(input.begin(), input.end(), std::back_inserter(output),
|
||||||
[](float fl) { return MLFloat16(math::floatToHalf(fl)); });
|
[](float fl) { return MLFloat16(fl); });
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2652,8 +2652,8 @@ void TrigFloat16Test(OpTester& test, std::initializer_list<float> input) {
|
||||||
std::vector<MLFloat16> float16_input;
|
std::vector<MLFloat16> float16_input;
|
||||||
std::vector<MLFloat16> float16_output;
|
std::vector<MLFloat16> float16_output;
|
||||||
for (auto v : input) {
|
for (auto v : input) {
|
||||||
float16_input.push_back(MLFloat16(math::floatToHalf(v)));
|
float16_input.push_back(MLFloat16(v));
|
||||||
float16_output.push_back(MLFloat16(math::floatToHalf(op(v))));
|
float16_output.push_back(MLFloat16(op(v)));
|
||||||
}
|
}
|
||||||
|
|
||||||
test.AddInput<MLFloat16>("X", dims, float16_input);
|
test.AddInput<MLFloat16>("X", dims, float16_input);
|
||||||
|
|
|
@ -25,8 +25,8 @@ TEST(RoundTest, SimpleTestDouble) {
|
||||||
|
|
||||||
TEST(RoundTest, SimpleTestFloat16) {
|
TEST(RoundTest, SimpleTestFloat16) {
|
||||||
OpTester test("Round", 11, onnxruntime::kOnnxDomain);
|
OpTester test("Round", 11, onnxruntime::kOnnxDomain);
|
||||||
test.AddInput<MLFloat16>("x", {5}, {MLFloat16(math::floatToHalf(0.9f)), MLFloat16(math::floatToHalf(2.5f)), MLFloat16(math::floatToHalf(2.3f)), MLFloat16(math::floatToHalf(1.5f)), MLFloat16(math::floatToHalf(-4.5f))});
|
test.AddInput<MLFloat16>("x", {5}, {MLFloat16(0.9f), MLFloat16(2.5f), MLFloat16(2.3f), MLFloat16(1.5f), MLFloat16(-4.5f)});
|
||||||
test.AddOutput<MLFloat16>("y", {5}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(-4.0f))});
|
test.AddOutput<MLFloat16>("y", {5}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(-4.0f)});
|
||||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ struct make_type {
|
||||||
template <class A>
|
template <class A>
|
||||||
struct make_type<MLFloat16, A> {
|
struct make_type<MLFloat16, A> {
|
||||||
static MLFloat16 make(A v) {
|
static MLFloat16 make(A v) {
|
||||||
return MLFloat16(math::floatToHalf(float(v)));
|
return MLFloat16(float(v));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -34,7 +34,9 @@ struct make_type<BFloat16, A> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T, class OutputIter>
|
template <class T, class OutputIter>
|
||||||
typename std::enable_if<!std::numeric_limits<T>::is_signed>::type
|
typename std::enable_if<!std::numeric_limits<T>::is_signed &&
|
||||||
|
!std::is_same<T, MLFloat16>::value &&
|
||||||
|
!std::is_same<T, BFloat16>::value>::type
|
||||||
GenerateSequence(OutputIter out) {
|
GenerateSequence(OutputIter out) {
|
||||||
for (int i = 0; i < 7; ++i) {
|
for (int i = 0; i < 7; ++i) {
|
||||||
*out = make_type<T, int>::make(i);
|
*out = make_type<T, int>::make(i);
|
||||||
|
@ -43,7 +45,9 @@ GenerateSequence(OutputIter out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, class OutputIter>
|
template <class T, class OutputIter>
|
||||||
typename std::enable_if<std::numeric_limits<T>::is_signed>::type
|
typename std::enable_if<std::numeric_limits<T>::is_signed ||
|
||||||
|
std::is_same<T, MLFloat16>::value ||
|
||||||
|
std::is_same<T, BFloat16>::value>::type
|
||||||
GenerateSequence(OutputIter out) {
|
GenerateSequence(OutputIter out) {
|
||||||
for (int i = -5; i < 2; ++i) {
|
for (int i = -5; i < 2; ++i) {
|
||||||
*out = make_type<T, int>::make(i);
|
*out = make_type<T, int>::make(i);
|
||||||
|
@ -61,7 +65,7 @@ struct ToTestableType {
|
||||||
template <>
|
template <>
|
||||||
struct ToTestableType<MLFloat16> {
|
struct ToTestableType<MLFloat16> {
|
||||||
static float to_type(MLFloat16 v) {
|
static float to_type(MLFloat16 v) {
|
||||||
return math::halfToFloat(v.val);
|
return v.ToFloat();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -186,6 +190,23 @@ TEST(MathOpTest, Sign_MLFloat16) {
|
||||||
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Currently BFloat16 is not enabled for Sign kernel
|
||||||
|
// TEST(MathOpTest, Sign_BFloat16) {
|
||||||
|
// using namespace test_sign_internal;
|
||||||
|
// OpTester test("Sign", 9);
|
||||||
|
//
|
||||||
|
// std::vector<int64_t> input_dims{7};
|
||||||
|
// std::vector<BFloat16> input;
|
||||||
|
// GenerateSequence<BFloat16>(std::back_inserter(input));
|
||||||
|
// ASSERT_EQ(input.size(), 7U);
|
||||||
|
// test.AddInput<BFloat16>("input", input_dims, input);
|
||||||
|
//
|
||||||
|
// std::vector<BFloat16> output;
|
||||||
|
// TestImpl<BFloat16>(input.cbegin(), input.cend(), std::back_inserter(output));
|
||||||
|
// test.AddOutput<BFloat16>("output", input_dims, output);
|
||||||
|
// test.Run(OpTester::ExpectResult::kExpectSuccess);
|
||||||
|
//}
|
||||||
|
|
||||||
#if defined(USE_DNNL)
|
#if defined(USE_DNNL)
|
||||||
TEST(MathOpTest, Sign_bfloat16) {
|
TEST(MathOpTest, Sign_bfloat16) {
|
||||||
#ifdef USE_DNNL
|
#ifdef USE_DNNL
|
||||||
|
|
|
@ -90,7 +90,7 @@ void RunShrinkTest(const std::vector<ShrinkTestData<T>>& test_cases,
|
||||||
const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) {
|
const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) {
|
||||||
std::vector<MLFloat16> new_data;
|
std::vector<MLFloat16> new_data;
|
||||||
for (const auto& f : float_data) {
|
for (const auto& f : float_data) {
|
||||||
new_data.push_back(MLFloat16(math::floatToHalf(f)));
|
new_data.push_back(MLFloat16(f));
|
||||||
}
|
}
|
||||||
return new_data;
|
return new_data;
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,34 +124,34 @@ TEST(ExpandOpTest, Expand_3x1x3x1_int64) {
|
||||||
|
|
||||||
TEST(ExpandOpTest, Expand_3x3_float16) {
|
TEST(ExpandOpTest, Expand_3x3_float16) {
|
||||||
OpTester test("Expand", 8);
|
OpTester test("Expand", 8);
|
||||||
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
|
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(1.0f)});
|
||||||
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
test.AddInput<int64_t>("data_1", {2}, {3, 3});
|
||||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
{MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f))});
|
MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)});
|
||||||
test.Run();
|
test.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ExpandOpTest, Expand_3x1_float16) {
|
TEST(ExpandOpTest, Expand_3x1_float16) {
|
||||||
OpTester test("Expand", 8);
|
OpTester test("Expand", 8);
|
||||||
test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
|
||||||
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
test.AddInput<int64_t>("data_1", {2}, {3, 1});
|
||||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
|
{MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
|
MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
|
||||||
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
|
||||||
test.Run();
|
test.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ExpandOpTest, Expand_1x3_float16) {
|
TEST(ExpandOpTest, Expand_1x3_float16) {
|
||||||
OpTester test("Expand", 8);
|
OpTester test("Expand", 8);
|
||||||
test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
|
||||||
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
test.AddInput<int64_t>("data_1", {2}, {1, 3});
|
||||||
test.AddOutput<MLFloat16>("result", {3, 3},
|
test.AddOutput<MLFloat16>("result", {3, 3},
|
||||||
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
|
{MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
|
||||||
MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)),
|
MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f),
|
||||||
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))});
|
MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f)});
|
||||||
test.Run();
|
test.Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ TEST(IsNaNOpTest, IsNaNFloat) {
|
||||||
TEST(IsNaNOpTest, IsNaNFloat16) {
|
TEST(IsNaNOpTest, IsNaNFloat16) {
|
||||||
OpTester test("IsNaN", 9, kOnnxDomain);
|
OpTester test("IsNaN", 9, kOnnxDomain);
|
||||||
std::vector<int64_t> dims{2, 2};
|
std::vector<int64_t> dims{2, 2};
|
||||||
test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(NAN)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(NAN))}));
|
test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}));
|
||||||
test.AddOutput<bool>("Y", dims, {false, true, false, true});
|
test.AddOutput<bool>("Y", dims, {false, true, false, true});
|
||||||
test.Run();
|
test.Run();
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,14 +147,14 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) {
|
||||||
std::vector<int64_t> input_shape({2, 3});
|
std::vector<int64_t> input_shape({2, 3});
|
||||||
std::vector<MLFloat16> input_vals;
|
std::vector<MLFloat16> input_vals;
|
||||||
for (uint16_t i = 0; i < 6; ++i)
|
for (uint16_t i = 0; i < 6; ++i)
|
||||||
input_vals.push_back(MLFloat16(i));
|
input_vals.push_back(MLFloat16::FromBits(static_cast<uint16_t>(i)));
|
||||||
|
|
||||||
std::vector<int64_t> perm = {1, 0};
|
std::vector<int64_t> perm = {1, 0};
|
||||||
std::vector<int64_t> expected_shape({3, 2});
|
std::vector<int64_t> expected_shape({3, 2});
|
||||||
std::initializer_list<MLFloat16> expected_vals =
|
std::initializer_list<MLFloat16> expected_vals =
|
||||||
{MLFloat16{static_cast<uint16_t>(1)}, MLFloat16{static_cast<uint16_t>(4)},
|
{MLFloat16::FromBits(static_cast<uint16_t>(1)), MLFloat16::FromBits(static_cast<uint16_t>(4)),
|
||||||
MLFloat16{static_cast<uint16_t>(2)}, MLFloat16{static_cast<uint16_t>(5)},
|
MLFloat16::FromBits(static_cast<uint16_t>(2)), MLFloat16::FromBits(static_cast<uint16_t>(5)),
|
||||||
MLFloat16{static_cast<uint16_t>(3)}, MLFloat16{static_cast<uint16_t>(6)}};
|
MLFloat16::FromBits(static_cast<uint16_t>(3)), MLFloat16::FromBits(static_cast<uint16_t>(6))};
|
||||||
|
|
||||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
|
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2022,46 +2022,91 @@ TEST(CApiTest, create_tensor_with_data_float16) {
|
||||||
// Example with C++. However, what we are feeding underneath is really
|
// Example with C++. However, what we are feeding underneath is really
|
||||||
// a continuous buffer of uint16_t
|
// a continuous buffer of uint16_t
|
||||||
// Use 3rd party libraries such as Eigen to convert floats and doubles to float16 types.
|
// Use 3rd party libraries such as Eigen to convert floats and doubles to float16 types.
|
||||||
Ort::Float16_t values[] = {15360, 16384, 16896, 17408, 17664}; // 1.f, 2.f, 3.f, 4.f, 5.f
|
constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
|
||||||
constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
|
constexpr size_t values_length = std::size(values);
|
||||||
|
constexpr uint16_t expected_values[values_length] = {15360, 16384, 16896, 17408, 17664};
|
||||||
|
|
||||||
std::vector<int64_t> dims = {static_cast<int64_t>(values_length)};
|
std::vector<Ort::Float16_t> fp16_values;
|
||||||
|
fp16_values.reserve(values_length);
|
||||||
|
std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
|
||||||
|
[](float fl) { return Ort::Float16_t(fl); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < values_length; ++i) {
|
||||||
|
ASSERT_EQ(expected_values[i], fp16_values[i].val);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int64_t dims = static_cast<int64_t>(values_length);
|
||||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||||
|
|
||||||
Ort::Value tensor = Ort::Value::CreateTensor<Ort::Float16_t>(info, values, values_length, dims.data(), dims.size());
|
Ort::Value tensor = Ort::Value::CreateTensor<Ort::Float16_t>(info, fp16_values.data(), values_length, &dims, 1u);
|
||||||
const auto* new_pointer = tensor.GetTensorData<Ort::Float16_t>();
|
const auto* new_pointer = tensor.GetTensorData<Ort::Float16_t>();
|
||||||
ASSERT_EQ(new_pointer, values);
|
|
||||||
|
ASSERT_EQ(new_pointer, fp16_values.data());
|
||||||
auto type_info = tensor.GetTypeInfo();
|
auto type_info = tensor.GetTypeInfo();
|
||||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
||||||
|
const auto element_count = tensor_info.GetElementCount();
|
||||||
|
ASSERT_EQ(values_length, element_count);
|
||||||
ASSERT_NE(tensor_info, nullptr);
|
ASSERT_NE(tensor_info, nullptr);
|
||||||
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
|
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
|
||||||
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||||
|
|
||||||
Ort::Float16_t value_at_1 = tensor.At<Ort::Float16_t>({1});
|
const Ort::Float16_t& value_at_1 = tensor.At<Ort::Float16_t>({1});
|
||||||
ASSERT_EQ(values[1], value_at_1);
|
ASSERT_EQ(expected_values[1], value_at_1.val);
|
||||||
|
|
||||||
|
std::vector<float> output_values;
|
||||||
|
output_values.reserve(values_length);
|
||||||
|
const auto data_span = gsl::make_span(tensor.GetTensorData<Ort::Float16_t>(), element_count);
|
||||||
|
std::transform(data_span.begin(), data_span.end(), std::back_inserter(output_values),
|
||||||
|
[](const Ort::Float16_t& fp16) { return static_cast<float>(fp16); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < values_length; ++i) {
|
||||||
|
ASSERT_NEAR(values[i], output_values[i], 1e-3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CApiTest, create_tensor_with_data_bfloat16) {
|
TEST(CApiTest, create_tensor_with_data_bfloat16) {
|
||||||
// Example with C++. However, what we are feeding underneath is really
|
// Example with C++. However, what we are feeding underneath is really
|
||||||
// a continuous buffer of uint16_t
|
// a continuous buffer of uint16_t
|
||||||
// Conversion from float to bfloat16 is simple. Strip off half of the bytes from float.
|
// Conversion from float to bfloat16 is simple. Strip off half of the bytes from float.
|
||||||
Ort::BFloat16_t values[] = {16256, 16384, 16448, 16512, 16544}; // 1.f, 2.f, 3.f, 4.f, 5.f
|
constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
|
||||||
constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
|
constexpr size_t values_length = std::size(values);
|
||||||
std::vector<int64_t> dims = {static_cast<int64_t>(values_length)};
|
constexpr uint16_t expected_values[] = {16256, 16384, 16448, 16512, 16544}; // 1.f, 2.f, 3.f, 4.f, 5.f
|
||||||
|
|
||||||
|
constexpr int64_t dims = static_cast<int64_t>(values_length);
|
||||||
|
|
||||||
|
std::vector<Ort::BFloat16_t> b16_values;
|
||||||
|
b16_values.reserve(values_length);
|
||||||
|
std::transform(std::begin(values), std::end(values), std::back_inserter(b16_values),
|
||||||
|
[](float fl) { return Ort::BFloat16_t(fl); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < values_length; ++i) {
|
||||||
|
ASSERT_EQ(expected_values[i], b16_values[i].val);
|
||||||
|
}
|
||||||
|
|
||||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||||
|
|
||||||
Ort::Value tensor = Ort::Value::CreateTensor<Ort::BFloat16_t>(info, values, values_length, dims.data(), dims.size());
|
Ort::Value tensor = Ort::Value::CreateTensor<Ort::BFloat16_t>(info, b16_values.data(), values_length, &dims, 1u);
|
||||||
const auto* new_pointer = tensor.GetTensorData<Ort::BFloat16_t>();
|
const auto* new_pointer = tensor.GetTensorData<Ort::BFloat16_t>();
|
||||||
ASSERT_EQ(new_pointer, values);
|
ASSERT_EQ(new_pointer, b16_values.data());
|
||||||
auto type_info = tensor.GetTypeInfo();
|
auto type_info = tensor.GetTypeInfo();
|
||||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
const auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
||||||
|
const auto element_count = tensor_info.GetElementCount();
|
||||||
ASSERT_NE(tensor_info, nullptr);
|
ASSERT_NE(tensor_info, nullptr);
|
||||||
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
|
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
|
||||||
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16);
|
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16);
|
||||||
|
|
||||||
Ort::BFloat16_t value_at_1 = tensor.At<Ort::BFloat16_t>({1});
|
const Ort::BFloat16_t& value_at_1 = tensor.At<Ort::BFloat16_t>({1});
|
||||||
ASSERT_EQ(values[1], value_at_1);
|
ASSERT_EQ(expected_values[1], value_at_1.val);
|
||||||
|
|
||||||
|
std::vector<float> output_values;
|
||||||
|
output_values.reserve(values_length);
|
||||||
|
const auto data_span = gsl::make_span(tensor.GetTensorData<Ort::BFloat16_t>(), element_count);
|
||||||
|
std::transform(data_span.begin(), data_span.end(), std::back_inserter(output_values),
|
||||||
|
[](const Ort::BFloat16_t& b16) { return static_cast<float>(b16); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < values_length; ++i) {
|
||||||
|
ASSERT_NEAR(values[i], output_values[i], 1e-3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||||
|
@ -2336,7 +2381,6 @@ TEST(CApiTest, get_version_string_cpp) {
|
||||||
TEST(CApiTest, get_build_info_string) {
|
TEST(CApiTest, get_build_info_string) {
|
||||||
auto build_info_string = Ort::GetBuildInfoString();
|
auto build_info_string = Ort::GetBuildInfoString();
|
||||||
ASSERT_FALSE(build_info_string.empty());
|
ASSERT_FALSE(build_info_string.empty());
|
||||||
ASSERT_EQ(build_info_string, std::string(ORT_BUILD_INFO));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CApiTest, TestSharedAllocators) {
|
TEST(CApiTest, TestSharedAllocators) {
|
||||||
|
|
|
@ -307,6 +307,323 @@ TEST(CApiTest, TypeInfoSequence) {
|
||||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
|
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, ConvertFloatToFloat16) {
|
||||||
|
// Test data type
|
||||||
|
{
|
||||||
|
constexpr float sample = 1.0f;
|
||||||
|
Ort::Float16_t flt16(sample);
|
||||||
|
EXPECT_FALSE(flt16.IsNaN());
|
||||||
|
auto int_rep = flt16.val;
|
||||||
|
const Ort::Float16_t flt_from_int = Ort::Float16_t::FromBits(int_rep);
|
||||||
|
EXPECT_FALSE(flt_from_int.IsNaN());
|
||||||
|
EXPECT_EQ(flt16, flt_from_int);
|
||||||
|
const double diff = std::fabs(sample - flt_from_int.ToFloat());
|
||||||
|
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test bulk conversion
|
||||||
|
{
|
||||||
|
const float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||||
|
std::vector<Ort::Float16_t> converted;
|
||||||
|
converted.reserve(std::size(sample));
|
||||||
|
|
||||||
|
// Run conversion
|
||||||
|
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
|
||||||
|
[](float v) { return Ort::Float16_t(v); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < std::size(sample); ++i) {
|
||||||
|
EXPECT_FALSE(converted[i].IsNaN());
|
||||||
|
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> back_converted;
|
||||||
|
back_converted.reserve(std::size(sample));
|
||||||
|
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
|
||||||
|
[](const Ort::Float16_t& bf) { return static_cast<float>(bf); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < std::size(sample); ++i) {
|
||||||
|
EXPECT_FALSE(std::isnan(back_converted[i]));
|
||||||
|
const double diff = std::fabs(sample[i] - back_converted[i]);
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16Zeros) {
|
||||||
|
const auto positive_zero = Ort::Float16_t::FromBits(0x0000);
|
||||||
|
EXPECT_FALSE(positive_zero.IsNegative());
|
||||||
|
const float float_positive_zero = static_cast<float>(positive_zero);
|
||||||
|
EXPECT_EQ(+0.0f, float_positive_zero);
|
||||||
|
EXPECT_FALSE(std::signbit(float_positive_zero));
|
||||||
|
|
||||||
|
const auto negative_zero = Ort::Float16_t::FromBits(0x8000);
|
||||||
|
EXPECT_TRUE(negative_zero.IsNegative());
|
||||||
|
const float float_positive_negzero = static_cast<float>(negative_zero);
|
||||||
|
EXPECT_EQ(-0.0f, float_positive_negzero);
|
||||||
|
EXPECT_TRUE(std::signbit(float_positive_negzero));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const auto EpsilonFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kEpsilonBits);
|
||||||
|
const auto NaNFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kPositiveQNaNBits);
|
||||||
|
const auto MinValueFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kMinValueBits);
|
||||||
|
const auto MaxValueFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kMaxValueBits);
|
||||||
|
const auto InfinityFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kPositiveInfinityBits);
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16Comparision) {
|
||||||
|
const auto left = Ort::Float16_t(-33.33f);
|
||||||
|
const auto left_same = Ort::Float16_t(-33.33f);
|
||||||
|
const auto right = Ort::Float16_t(66.66f);
|
||||||
|
const auto right_same = Ort::Float16_t(66.66f);
|
||||||
|
|
||||||
|
EXPECT_LT(EpsilonFl16, right);
|
||||||
|
|
||||||
|
EXPECT_EQ(left, left_same);
|
||||||
|
EXPECT_NE(left, left_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_EQ(right, right_same);
|
||||||
|
EXPECT_NE(right, right_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_LT(left, right);
|
||||||
|
EXPECT_LT(right.Negate(), left);
|
||||||
|
EXPECT_LT(left.Negate(), right);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16TestNAN) {
|
||||||
|
const Ort::Float16_t fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
|
||||||
|
|
||||||
|
// NaN are not equal to each other
|
||||||
|
EXPECT_NE(NaNFl16, fp16NANFromSingle);
|
||||||
|
|
||||||
|
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
EXPECT_TRUE(std::isnan(NanFromBFloat16));
|
||||||
|
|
||||||
|
EXPECT_FALSE(MaxValueFl16.IsNaN());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16NaNComparision) {
|
||||||
|
EXPECT_FALSE(NaNFl16 < NaNFl16);
|
||||||
|
EXPECT_TRUE(NaNFl16 != NaNFl16);
|
||||||
|
EXPECT_FALSE(NaNFl16 == NaNFl16);
|
||||||
|
|
||||||
|
EXPECT_FALSE(MaxValueFl16 < NaNFl16);
|
||||||
|
EXPECT_FALSE(MaxValueFl16 == NaNFl16);
|
||||||
|
EXPECT_FALSE(NaNFl16 < MinValueFl16);
|
||||||
|
|
||||||
|
EXPECT_LT(MinValueFl16, MaxValueFl16);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16Infinity) {
|
||||||
|
EXPECT_FALSE(MinValueFl16.IsInfinity());
|
||||||
|
EXPECT_FALSE(MaxValueFl16.IsInfinity());
|
||||||
|
EXPECT_TRUE(MaxValueFl16.IsFinite());
|
||||||
|
|
||||||
|
const Ort::Float16_t pos_infinity_from_float(std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const Ort::Float16_t neg_infinity_from_float(-std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const float pos_infinity_from_bfloat16 = static_cast<float>(InfinityFl16);
|
||||||
|
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, Float16NormalSubnormal) {
|
||||||
|
EXPECT_FALSE(InfinityFl16.IsNormal());
|
||||||
|
EXPECT_TRUE(Ort::Float16_t(45.6f).IsNormal());
|
||||||
|
EXPECT_FALSE(Ort::Float16_t(45.6f).IsSubnormal());
|
||||||
|
|
||||||
|
// 0b0_0000_0000_000_0001 ~0.000000059604645
|
||||||
|
constexpr uint16_t min_subnormal_bits = 0x0001;
|
||||||
|
const Ort::Float16_t smallest_subnormal = Ort::Float16_t::FromBits(min_subnormal_bits);
|
||||||
|
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(smallest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
// float smallest positive subnormal is ~1.40129846432481707092E-45, and
|
||||||
|
// in float the same number above would be normal
|
||||||
|
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
|
||||||
|
EXPECT_TRUE(std::isnormal(float_from_smallest_subnormal));
|
||||||
|
|
||||||
|
// 0b0_0000_0000_111_1111; ~0.000060975552
|
||||||
|
constexpr uint16_t max_subnormal_bits = 0x007F;
|
||||||
|
const Ort::Float16_t largest_subnormal = Ort::Float16_t::FromBits(max_subnormal_bits);
|
||||||
|
EXPECT_TRUE(largest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(largest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
// However, in float the same number above would be normal
|
||||||
|
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
|
||||||
|
EXPECT_TRUE(std::isnormal(float_from_largest_subnormal));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16ConvertFloatToBFloat16) {
|
||||||
|
// Test data type
|
||||||
|
{
|
||||||
|
constexpr float sample = 1.0f;
|
||||||
|
Ort::BFloat16_t flt16(sample);
|
||||||
|
EXPECT_FALSE(flt16.IsNaN());
|
||||||
|
auto int_rep = flt16.val;
|
||||||
|
const Ort::BFloat16_t flt_from_int = Ort::BFloat16_t::FromBits(int_rep);
|
||||||
|
EXPECT_FALSE(flt_from_int.IsNaN());
|
||||||
|
EXPECT_EQ(flt16, flt_from_int);
|
||||||
|
const double diff = std::fabs(sample - flt_from_int.ToFloat());
|
||||||
|
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test bulk conversion
|
||||||
|
{
|
||||||
|
const float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||||
|
std::vector<Ort::BFloat16_t> converted;
|
||||||
|
converted.reserve(std::size(sample));
|
||||||
|
|
||||||
|
// Run conversion
|
||||||
|
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
|
||||||
|
[](float v) { return Ort::BFloat16_t(v); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < std::size(sample); ++i) {
|
||||||
|
EXPECT_FALSE(converted[i].IsNaN());
|
||||||
|
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> back_converted;
|
||||||
|
back_converted.reserve(std::size(sample));
|
||||||
|
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
|
||||||
|
[](const Ort::BFloat16_t& bf) { return static_cast<float>(bf); });
|
||||||
|
|
||||||
|
for (size_t i = 0; i < std::size(sample); ++i) {
|
||||||
|
EXPECT_FALSE(std::isnan(back_converted[i]));
|
||||||
|
const double diff = std::fabs(sample[i] - back_converted[i]);
|
||||||
|
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16Zeros) {
|
||||||
|
const auto positive_zero = Ort::BFloat16_t::FromBits(0x0000);
|
||||||
|
EXPECT_FALSE(positive_zero.IsNegative());
|
||||||
|
const float float_positive_zero = static_cast<float>(positive_zero);
|
||||||
|
EXPECT_EQ(+0.0f, float_positive_zero);
|
||||||
|
EXPECT_FALSE(std::signbit(float_positive_zero));
|
||||||
|
|
||||||
|
const auto negative_zero = Ort::BFloat16_t::FromBits(0x8000);
|
||||||
|
EXPECT_TRUE(negative_zero.IsNegative());
|
||||||
|
const float float_positive_negzero = static_cast<float>(negative_zero);
|
||||||
|
EXPECT_EQ(-0.0f, float_positive_negzero);
|
||||||
|
EXPECT_TRUE(std::signbit(float_positive_negzero));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const auto EpsilonBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kEpsilonBits);
|
||||||
|
const auto NaNBfl15 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kPositiveQNaNBits);
|
||||||
|
const auto MinValueBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kMinValueBits);
|
||||||
|
const auto MaxValueBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kMaxValueBits);
|
||||||
|
const auto InfinityBFl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kPositiveInfinityBits);
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16Comparision) {
|
||||||
|
const auto left = Ort::BFloat16_t(-33.33f);
|
||||||
|
const auto left_same = Ort::BFloat16_t(-33.33f);
|
||||||
|
const auto right = Ort::BFloat16_t(66.66f);
|
||||||
|
const auto right_same = Ort::BFloat16_t(66.66f);
|
||||||
|
|
||||||
|
EXPECT_LT(EpsilonBfl16, right);
|
||||||
|
|
||||||
|
EXPECT_EQ(left, left_same);
|
||||||
|
EXPECT_NE(left, left_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_EQ(right, right_same);
|
||||||
|
EXPECT_NE(right, right_same.Negate());
|
||||||
|
|
||||||
|
EXPECT_LT(left, right);
|
||||||
|
EXPECT_LT(right.Negate(), left);
|
||||||
|
EXPECT_LT(left.Negate(), right);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16TestNAN) {
|
||||||
|
const Ort::BFloat16_t fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
|
||||||
|
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
|
||||||
|
|
||||||
|
// NaN are not equal to each other
|
||||||
|
EXPECT_NE(NaNBfl15, fp16NANFromSingle);
|
||||||
|
|
||||||
|
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
|
||||||
|
EXPECT_TRUE(std::isnan(NanFromBFloat16));
|
||||||
|
|
||||||
|
EXPECT_FALSE(MaxValueBfl16.IsNaN());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16NaNComparision) {
|
||||||
|
EXPECT_FALSE(NaNBfl15 < NaNBfl15);
|
||||||
|
EXPECT_TRUE(NaNBfl15 != NaNBfl15);
|
||||||
|
EXPECT_FALSE(NaNBfl15 == NaNBfl15);
|
||||||
|
|
||||||
|
EXPECT_FALSE(MaxValueBfl16 < NaNBfl15);
|
||||||
|
EXPECT_FALSE(MaxValueBfl16 == NaNBfl15);
|
||||||
|
EXPECT_FALSE(NaNBfl15 < MinValueBfl16);
|
||||||
|
|
||||||
|
EXPECT_LT(MinValueBfl16, MaxValueBfl16);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16Infinity) {
|
||||||
|
EXPECT_FALSE(MinValueBfl16.IsInfinity());
|
||||||
|
EXPECT_FALSE(MaxValueBfl16.IsInfinity());
|
||||||
|
EXPECT_TRUE(MaxValueBfl16.IsFinite());
|
||||||
|
|
||||||
|
const Ort::BFloat16_t pos_infinity_from_float(std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const Ort::BFloat16_t neg_infinity_from_float(-std::numeric_limits<float>::infinity());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
|
||||||
|
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
|
||||||
|
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
|
||||||
|
|
||||||
|
const float pos_infinity_from_bfloat16 = static_cast<float>(InfinityBFl16);
|
||||||
|
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPApi, BFloat16NormalSubnormal) {
|
||||||
|
EXPECT_FALSE(InfinityBFl16.IsNormal());
|
||||||
|
EXPECT_TRUE(Ort::BFloat16_t(45.6f).IsNormal());
|
||||||
|
EXPECT_FALSE(Ort::BFloat16_t(45.6f).IsSubnormal());
|
||||||
|
|
||||||
|
// 0b0_0000_0000_000_0001
|
||||||
|
constexpr uint16_t min_subnormal_bits = 0x0001;
|
||||||
|
const Ort::BFloat16_t smallest_subnormal = Ort::BFloat16_t::FromBits(min_subnormal_bits);
|
||||||
|
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(smallest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
|
||||||
|
EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));
|
||||||
|
|
||||||
|
// 0b0_0000_0000_111_1111;
|
||||||
|
constexpr uint16_t max_subnormal_bits = 0x007F;
|
||||||
|
const Ort::BFloat16_t largest_subnormal = Ort::BFloat16_t::FromBits(max_subnormal_bits);
|
||||||
|
EXPECT_TRUE(largest_subnormal.IsSubnormal());
|
||||||
|
EXPECT_FALSE(largest_subnormal.IsNormal());
|
||||||
|
|
||||||
|
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
|
||||||
|
EXPECT_FALSE(std::isnormal(float_from_largest_subnormal));
|
||||||
|
}
|
||||||
|
|
||||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||||
TEST(CApiTest, SparseTensorUsingAPI) {
|
TEST(CApiTest, SparseTensorUsingAPI) {
|
||||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||||
|
|
|
@ -495,7 +495,7 @@ float ParseValueToFloat(MLFloat16 data_value) {
|
||||||
template <>
|
template <>
|
||||||
float ParseValueToFloat(float data_value) {
|
float ParseValueToFloat(float data_value) {
|
||||||
// Covert float to half and then convert back to float to simulate rounding to half
|
// Covert float to half and then convert back to float to simulate rounding to half
|
||||||
return ParseValueToFloat(MLFloat16(math::floatToHalf(data_value)));
|
return ParseValueToFloat(MLFloat16(data_value));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RealT, typename ExpectT>
|
template <typename RealT, typename ExpectT>
|
||||||
|
|
|
@ -392,7 +392,7 @@ AttributeProto GradientBuilderBase::AttributeDefinitionToAttributeProto(
|
||||||
elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
|
elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
|
||||||
float float_value = value.get<float>();
|
float float_value = value.get<float>();
|
||||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||||
tensor_proto = ScalarTensorProto(MLFloat16(math::floatToHalf(float_value)), {1});
|
tensor_proto = ScalarTensorProto(MLFloat16(float_value), {1});
|
||||||
} else if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
} else if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||||
tensor_proto = ScalarTensorProto(BFloat16(float_value), {1});
|
tensor_proto = ScalarTensorProto(BFloat16(float_value), {1});
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -262,7 +262,7 @@ class GradientBuilderBase {
|
||||||
// We only support FP32, FP16 and BF16 for these constant nodes for now.
|
// We only support FP32, FP16 and BF16 for these constant nodes for now.
|
||||||
static NodeDef ConstantScalarNode(float value, const std::string& arg_name, int elem_type) {
|
static NodeDef ConstantScalarNode(float value, const std::string& arg_name, int elem_type) {
|
||||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||||
return ConstantScalarNode(MLFloat16(math::floatToHalf(value)), {1}, arg_name);
|
return ConstantScalarNode(MLFloat16(value), {1}, arg_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||||
|
@ -294,7 +294,7 @@ class GradientBuilderBase {
|
||||||
|
|
||||||
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
|
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
|
||||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||||
return ScalarTensorProto(MLFloat16(math::floatToHalf(value)), {1});
|
return ScalarTensorProto(MLFloat16(value), {1});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
|
||||||
|
|
|
@ -97,7 +97,7 @@ Status AdamOptimizerBuilder::Build(
|
||||||
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
|
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
|
||||||
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
|
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
|
||||||
} else if (opt_configs[i].use_mixed_precision_moments) {
|
} else if (opt_configs[i].use_mixed_precision_moments) {
|
||||||
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(math::floatToHalf(0.f)), weight_dims);
|
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(0.f), weight_dims);
|
||||||
} else {
|
} else {
|
||||||
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
|
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
|
||||||
}
|
}
|
||||||
|
|
|
@ -231,7 +231,7 @@ Status LambOptimizerBuilder::Build(
|
||||||
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
|
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
|
||||||
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
|
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
|
||||||
} else if (opt_configs[i].use_mixed_precision_moments) {
|
} else if (opt_configs[i].use_mixed_precision_moments) {
|
||||||
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(math::floatToHalf(0.f)), weight_dims);
|
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(0.f), weight_dims);
|
||||||
} else {
|
} else {
|
||||||
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
|
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ static void RunTrainingSessionLoadOptimTests(std::string optim_name, bool mixed_
|
||||||
|
|
||||||
TrainingSession::OptimizerState init_optimizer_state{};
|
TrainingSession::OptimizerState init_optimizer_state{};
|
||||||
if (mixed_precision_moments) {
|
if (mixed_precision_moments) {
|
||||||
GenerateOptimizerInitialState<MLFloat16>(optim_name, MLFloat16(math::floatToHalf(2.5)), init_optimizer_state);
|
GenerateOptimizerInitialState<MLFloat16>(optim_name, MLFloat16(2.5f), init_optimizer_state);
|
||||||
} else {
|
} else {
|
||||||
GenerateOptimizerInitialState<float>(optim_name, 2.5f, init_optimizer_state);
|
GenerateOptimizerInitialState<float>(optim_name, 2.5f, init_optimizer_state);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ void RunDropoutTest(const bool use_mask, const std::vector<int64_t>& input_shape
|
||||||
ratio = 0.5f;
|
ratio = 0.5f;
|
||||||
} else {
|
} else {
|
||||||
if (use_float16_ratio) {
|
if (use_float16_ratio) {
|
||||||
t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))});
|
t.AddInput("ratio", {}, {MLFloat16(ratio)});
|
||||||
} else {
|
} else {
|
||||||
t.AddInput("ratio", {}, {ratio});
|
t.AddInput("ratio", {}, {ratio});
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче