[C#, CPP] Introduce Float16/BFloat16 support and tests for C#, C++ (#16506)

### Description
Introduce `Float16/BFloat16` support for C# and C++ APIs.
User should be able to perform conversions from `float` to/from
`Float16/BFloat16`, compare values and tests for `NaN, Inifnity, and
whether the number is denormalized.`

### Motivation and Context
User filed issues such as:
https://github.com/microsoft/onnxruntime/issues/14303
This commit is contained in:
Dmitri Smirnov 2023-07-14 10:46:52 -07:00 коммит произвёл GitHub
Родитель 77b45c6503
Коммит 853c4ff0a5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
55 изменённых файлов: 3852 добавлений и 492 удалений

Просмотреть файл

@ -21,6 +21,7 @@ endif()
set(ONNXRUNTIME_PUBLIC_HEADERS set(ONNXRUNTIME_PUBLIC_HEADERS
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h"
@ -38,6 +39,7 @@ macro(get_mobile_api_headers _HEADERS)
set(${_HEADERS} set(${_HEADERS}
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h"
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h"
) )

Просмотреть файл

@ -107,6 +107,15 @@ if(NOT onnxruntime_DISABLE_ABSEIL)
endif() endif()
endif() endif()
if (MSVC)
set(EIGEN_NATVIS_FILE ${eigen_SOURCE_DIR}/debug/msvc/eigen.natvis)
if (EXISTS ${EIGEN_NATVIS_FILE})
target_sources(
onnxruntime_common
INTERFACE $<BUILD_INTERFACE:${EIGEN_NATVIS_FILE}>)
endif()
endif()
onnxruntime_add_include_to_target(onnxruntime_common date_interface WIL::WIL) onnxruntime_add_include_to_target(onnxruntime_common date_interface WIL::WIL)
target_include_directories(onnxruntime_common target_include_directories(onnxruntime_common
PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS}

Просмотреть файл

@ -56,6 +56,13 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_framework_srcs})
onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_srcs}) onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_srcs})
if (MSVC)
set(ORT_FRAMEWORK_NATVIS_FILE "onnxruntime_framework.natvis")
target_sources(
onnxruntime_framework
INTERFACE $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/${ORT_FRAMEWORK_NATVIS_FILE}>)
endif()
if(onnxruntime_ENABLE_INSTRUMENT) if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT) target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
endif() endif()

Просмотреть файл

@ -0,0 +1,47 @@
<?xml version="1.0" encoding="utf-8"?>
<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
<Type Name="onnxruntime::MLFloat16">
<Intrinsic Name="_negative" Expression="(val &amp; 0x8000) == 1"/>
<Intrinsic Name="_strip_sign" Expression="(val &amp; ~0x8000)"/>
<Intrinsic Name="_is_nan" Expression="(_strip_sign() &gt; 0x7C00)"/>
<Intrinsic Name="_is_finite" Expression="(_strip_sign() &lt; 0x7C00)"/>
<Intrinsic Name="_is_normal" Expression="(_is_finite() &amp;&amp; (val != 0)) &amp;&amp; ((val &amp; 0x7C00) != 0)"/>
<Intrinsic Name="_biased_exponent" Expression="(val &gt;&gt; 10) &amp; (0x7C00 &gt;&gt; 10)"/>
<Intrinsic Name="_exponent" Expression="(int16_t)(_biased_exponent() - 15)"/>
<Intrinsic Name="_significand" Expression="(val &amp; 0x03FF)"/>
<DisplayString>{{val={ val }}}</DisplayString>
<Expand>
<Item Name="[Negative]" ExcludeView="simple">_negative()</Item>
<Item Name="[IsNan]" ExcludeView="simple" Condition="_is_nan()">true</Item>
<Item Name="[IsFinite]" ExcludeView="simple">_is_finite()</Item>
<Item Name="[IsNormal]" ExcludeView="simple">_is_normal()</Item>
<Item Name="[uint16_t]" ExcludeView="simple">val</Item>
<Item Name="[Exponent]" ExcludeView="simple">_exponent()</Item>
<Item Name="[Biased Exponent]" ExcludeView="simple">_biased_exponent()</Item>
<Item Name="[Significand]" ExcludeView="simple">_significand()</Item>
</Expand>
</Type>
<Type Name="onnxruntime::BFloat16">
<Intrinsic Name="_negative" Expression="(val &amp; 0x8000) == 1"/>
<Intrinsic Name="_strip_sign" Expression="(val &amp; ~0x8000)"/>
<Intrinsic Name="_is_nan" Expression="(_strip_sign() &gt; 0x7F80)"/>
<Intrinsic Name="_is_finite" Expression="(_strip_sign() &lt; 0x7F80)"/>
<Intrinsic Name="_is_normal" Expression="(_is_finite() &amp;&amp; (val != 0)) &amp;&amp; ((val &amp; 0x7F80) != 0)"/>
<Intrinsic Name="_biased_exponent" Expression="(val &gt;&gt; 7) &amp; (0x7F80 &gt;&gt; 7)"/>
<Intrinsic Name="_exponent" Expression="(int16_t)(_biased_exponent() - 127)"/>
<Intrinsic Name="_significand" Expression="(val &amp; 0x007F)"/>
<DisplayString>{{val={ val }}}</DisplayString>
<Expand>
<Item Name="[Negative]" ExcludeView="simple">_negative()</Item>
<Item Name="[IsNormal]" ExcludeView="simple">_is_normal()</Item>
<Item Name="[IsNan]" ExcludeView="simple" Condition="_is_nan()">true</Item>
<Item Name="[IsFinite]" ExcludeView="simple">_is_finite()</Item>
<Item Name="[uint16_t]" ExcludeView="simple">val</Item>
<Item Name="[Exponent]" ExcludeView="simple">_exponent()</Item>
<Item Name="[Biased Exponent]" ExcludeView="simple">_biased_exponent()</Item>
<Item Name="[Significand]" ExcludeView="simple">_significand()</Item>
</Expand>
</Type>
</AutoVisualizer>

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -44,174 +44,6 @@ namespace Microsoft.ML.OnnxRuntime.Tensors
DataTypeMax = 17 DataTypeMax = 17
} }
/// <summary>
/// This value type represents A Float16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct Float16
{
/// <summary>
/// float16 representation bits
/// </summary>
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public Float16(ushort v)
{
value = v;
}
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="f">instance of Float16</param>
/// <returns>value member</returns>
public static implicit operator ushort(Float16 f) { return f.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a Float16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A Float16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator Float16(ushort value) { return new Float16(value); }
/// <summary>
/// Compares values of two Float16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two Float16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other Float16 represent the same value.
/// </summary>
/// <param name="other">A Float16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(Float16 other)
{
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is Float16)
{
Float16 fl16 = (Float16)obj;
result = (fl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary>
/// This value type represents A BFloat16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct BFloat16
{
/// <summary>
/// bfloat16 representation bits
/// </summary>
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public BFloat16(ushort v)
{
value = v;
}
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="bf">instance of BFloat16</param>
/// <returns>value member</returns>
public static implicit operator ushort(BFloat16 bf) { return bf.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a BFloat16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A BFloat16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
/// <summary>
/// Compares values of two BFloat16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two BFloat16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
/// </summary>
/// <param name="other">A BFloat16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(BFloat16 other)
{
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is BFloat16)
{
BFloat16 bfl16 = (BFloat16)obj;
result = (bfl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary> /// <summary>
/// Helps typecasting. Holds Tensor element type traits. /// Helps typecasting. Holds Tensor element type traits.
/// </summary> /// </summary>

Просмотреть файл

@ -1319,12 +1319,13 @@ namespace Microsoft.ML.OnnxRuntime.Tests
private void TestModelInputFLOAT16() private void TestModelInputFLOAT16()
{ {
// model takes 1x5 input of fixed type, echoes back // model takes 1x5 input of fixed type, echoes back
Float16[] modelInput = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
int[] inputShape = { 1, 5 };
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx"); var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
using (var session = new InferenceSession(model)) using (var session = new InferenceSession(model))
{ {
var container = new List<NamedOnnxValue>(); var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<Float16>( var tensorIn = new DenseTensor<Float16>(modelInput, inputShape);
new Float16[] { 15360, 16384, 16896, 17408, 17664 }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov); container.Add(nov);
using (var res = session.Run(container)) using (var res = session.Run(container))
@ -1341,13 +1342,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
[Fact(DisplayName = "TestModelInputBFLOAT16")] [Fact(DisplayName = "TestModelInputBFLOAT16")]
private void TestModelInputBFLOAT16() private void TestModelInputBFLOAT16()
{ {
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) };
int[] inputShape = { 1, 5 };
// model takes 1x5 input of fixed type, echoes back // model takes 1x5 input of fixed type, echoes back
var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_BFLOAT16.onnx"); var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_BFLOAT16.onnx");
using (var session = new InferenceSession(model)) using (var session = new InferenceSession(model))
{ {
var container = new List<NamedOnnxValue>(); var container = new List<NamedOnnxValue>();
var tensorIn = new DenseTensor<BFloat16>( var tensorIn = new DenseTensor<BFloat16>(modelInput, inputShape);
new BFloat16[] { 16256, 16384, 16448, 16512, 16544 }, new int[] { 1, 5 });
var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn); var nov = NamedOnnxValue.CreateFromTensor("input", tensorIn);
container.Add(nov); container.Add(nov);
using (var res = session.Run(container)) using (var res = session.Run(container))
@ -1999,80 +2002,6 @@ namespace Microsoft.ML.OnnxRuntime.Tests
} }
} }
internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;
private float rtol = 1.7e-2f;
public bool Equals(float x, float y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(float x)
{
return x.GetHashCode();
}
}
internal class DoubleComparer : IEqualityComparer<double>
{
private double atol = 1e-3;
private double rtol = 1.7e-2;
public bool Equals(double x, double y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(double x)
{
return x.GetHashCode();
}
}
class ExactComparer<T> : IEqualityComparer<T>
{
public bool Equals(T x, T y)
{
return x.Equals(y);
}
public int GetHashCode(T x)
{
return x.GetHashCode();
}
}
/// <summary>
/// Use it to compare Float16
/// </summary>
internal class Float16Comparer : IEqualityComparer<Float16>
{
public ushort tolerance = 0;
public bool Equals(Float16 x, Float16 y)
{
return Math.Abs(x.value - y.value) <= (tolerance + y);
}
public int GetHashCode(Float16 x)
{
return x.GetHashCode();
}
}
/// <summary>
/// Use it to compare Bloat16
/// </summary>
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
{
public ushort tolerance = 0;
public bool Equals(BFloat16 x, BFloat16 y)
{
return Math.Abs(x.value - y.value) <= (tolerance + y);
}
public int GetHashCode(BFloat16 x)
{
return x.GetHashCode();
}
}
private class GpuFact : FactAttribute private class GpuFact : FactAttribute
{ {
public GpuFact() public GpuFact()

Просмотреть файл

@ -50,9 +50,10 @@
<ItemGroup> <ItemGroup>
<Compile Remove="InferenceTest.cs" /> <Compile Remove="InferenceTest.cs" />
<Compile Remove="OrtEnvTests.cs" />
<Compile Remove="OrtIoBindingAllocationTest.cs" /> <Compile Remove="OrtIoBindingAllocationTest.cs" />
<Compile Remove="OrtEnvTests.cs" />
<Compile Remove="OrtValueTests.cs" /> <Compile Remove="OrtValueTests.cs" />
<Compile Remove="OrtFloat16Tests.cs" />
<Compile Remove="Tensors\TensorTests.cs" /> <Compile Remove="Tensors\TensorTests.cs" />
<Compile Remove="TrainingTest.cs" /> <Compile Remove="TrainingTest.cs" />
</ItemGroup> </ItemGroup>
@ -80,10 +81,10 @@
<ItemGroup> <ItemGroup>
<!-- include common files for visibility, however they're compiled directly by the target specific test projects --> <!-- include common files for visibility, however they're compiled directly by the target specific test projects -->
<None Include="InferenceTest.cs" /> <None Include="InferenceTest.cs" />
<None Include="OrtEnvTests.cs" />
<None Include="OnnxData.cs" /> <None Include="OnnxData.cs" />
<None Include="OrtIoBindingAllocationTest.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " /> <None Include="OrtIoBindingAllocationTest.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
<None Include="OrtValueTests.cs" /> <None Include="OrtValueTests.cs" />
<None Include="OrtFloat16Tests.cs" />
<None Include="Tensors\TensorTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " /> <None Include="Tensors\TensorTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
<None Include="Tensors\ArrayTensorExtensionTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " /> <None Include="Tensors\ArrayTensorExtensionTests.cs" Condition=" '$(EnableDefaultCompileItems)' == 'true' " />
<None Include="TrainingTest.cs" /> <None Include="TrainingTest.cs" />

Просмотреть файл

@ -0,0 +1,533 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using Xunit;
namespace Microsoft.ML.OnnxRuntime.Tests
{
[Collection("Ort Float16 tests")]
public class OrtFloat16Tests
{
const float oneThird = 1 / 3.0f;
const float oneSeventh = 1 / 7.0f;
const float oneTenth = 1 / 10.0f;
[Fact(DisplayName = "ConvertFloatToFloat16")]
public void ConvertFloatToFloat16()
{
// Generate integer floats and insert between them
// fractions. This will test the rounding logic.
float start = -10;
var floatValues = new float[21 * 4];
for (int i = 0; i < floatValues.Length; i += 4)
{
floatValues[i] = start;
floatValues[i + 1] = start + oneThird;
floatValues[i + 2] = start + oneSeventh;
floatValues[i + 3] = start + oneTenth;
start += 1;
}
var f16Converted = Array.ConvertAll(floatValues, f => (Float16)f);
var backConverted = Array.ConvertAll(f16Converted, f16 => (float)f16);
Assert.Equal(floatValues, backConverted, new FloatComparer());
}
[Fact(DisplayName = "TestZeros")]
public void TestZeros()
{
var positiveZero = new Float16(0);
Assert.False(Float16.IsNegative(positiveZero));
Assert.True(Float16.IsNaNOrZero(positiveZero));
float singlePositiveZero = (float)positiveZero;
Assert.Equal(+0.0f, singlePositiveZero);
#if NET6_0_OR_GREATER
Assert.False(float.IsNegative(singlePositiveZero));
#endif
var negativeZero = Float16.Negate(positiveZero);
Assert.True(Float16.IsNegative(negativeZero));
Assert.True(Float16.IsNaNOrZero(negativeZero));
float singleNegativeZero = (float)negativeZero;
Assert.Equal(-0.0f, singleNegativeZero);
#if NET6_0_OR_GREATER
Assert.True(float.IsNegative(singleNegativeZero));
#endif
}
[Fact(DisplayName = "TestComparisonOperators")]
public void TestComparisonOperators()
{
Float16 left = (Float16)(float)-33.33f;
Float16 leftSame = (Float16)(float)-33.33f;
Float16 right = (Float16)(float)66.66f;
Float16 rightSame = (Float16)(float)66.66f;
Assert.False(Float16.IsNaNOrZero(left));
Assert.False(Float16.IsNaNOrZero(right));
Assert.True(right > Float16.Epsilon);
Assert.True(left == leftSame);
Assert.False(left == Float16.Negate(leftSame));
Assert.True(right == rightSame);
Assert.False(right == Float16.Negate(rightSame));
Assert.True(left < right);
Assert.True(left > Float16.Negate(right));
Assert.True(Float16.Negate(left) < right);
Assert.True(left <= right);
Assert.True(left >= Float16.Negate(right));
Assert.False(left > right);
Assert.False(left >= right);
Assert.True(Float16.Negate(left) <= right);
Assert.False(left == right);
Assert.False(right == left);
Assert.True(left != right);
Assert.True(right != left);
}
[Fact(DisplayName = "TestNAN")]
public void TestNAN()
{
Float16 fp16NANFromSingle = (Float16)float.NaN;
Assert.True(Float16.IsNaN(fp16NANFromSingle));
Assert.Equal(Float16.NaN, fp16NANFromSingle);
Assert.True(Float16.IsNaNOrZero(fp16NANFromSingle));
float NanFromFloat16 = fp16NANFromSingle.ToFloat();
Assert.True(float.IsNaN(NanFromFloat16));
// IEqualityComparable returns true, because it tests
// objects, not numbers.
Assert.Equal(fp16NANFromSingle, Float16.NaN);
Assert.Equal(Float16.NaN, Float16.Negate(Float16.NaN));
}
[Fact(DisplayName = "TestNANComparision")]
public void TestNANComparisionOperators()
{
// NaN is not ordered with respect to anything
// including itself
// IEqualityComparable returns true, because it tests
// objects, not numbers.
Assert.Equal(Float16.NaN, Float16.NaN);
Assert.False(Float16.NaN < Float16.NaN);
Assert.False(Float16.NaN > Float16.NaN);
Assert.False(Float16.NaN <= Float16.NaN);
Assert.False(Float16.NaN >= Float16.NaN);
Assert.False(Float16.NaN == Float16.NaN);
// IEqualityComparable returns false, because it tests
// objects, not numbers.
Assert.NotEqual(Float16.NaN, Float16.MaxValue);
Assert.False(Float16.NaN < Float16.MaxValue);
Assert.False(Float16.MaxValue < Float16.NaN);
Assert.False(Float16.NaN == Float16.MaxValue);
Assert.False(Float16.MaxValue == Float16.NaN);
Assert.False(Float16.NaN > Float16.MinValue);
Assert.False(Float16.MaxValue > Float16.NaN);
Assert.False(Float16.NaN == Float16.MinValue);
Assert.False(Float16.MaxValue == Float16.NaN);
Assert.True(Float16.MinValue < Float16.MaxValue);
}
[Fact(DisplayName = "TestInfinity")]
public void TestInfinity()
{
Assert.False(Float16.IsInfinity(Float16.MinValue));
Assert.False(Float16.IsInfinity(Float16.MaxValue));
Float16 posInfinityFromSingle = (Float16)float.PositiveInfinity;
Assert.True(Float16.IsPositiveInfinity(posInfinityFromSingle));
Assert.Equal(Float16.PositiveInfinity, posInfinityFromSingle);
Assert.False(Float16.IsFinite(posInfinityFromSingle));
Assert.True(Float16.IsInfinity(posInfinityFromSingle));
Assert.True(Float16.IsPositiveInfinity(posInfinityFromSingle));
Assert.False(Float16.IsNegativeInfinity(posInfinityFromSingle));
Assert.False(Float16.IsPositiveInfinity(Float16.MinValue));
Assert.False(Float16.IsPositiveInfinity(Float16.MaxValue));
Assert.Equal(float.PositiveInfinity < 0, Float16.IsNegative(posInfinityFromSingle));
Float16 negInfinityFromSingle = (Float16)float.NegativeInfinity;
Assert.True(Float16.IsNegativeInfinity(negInfinityFromSingle));
Assert.Equal(Float16.NegativeInfinity, negInfinityFromSingle);
Assert.False(Float16.IsFinite(negInfinityFromSingle));
Assert.True(Float16.IsInfinity(negInfinityFromSingle));
Assert.True(Float16.IsNegativeInfinity(negInfinityFromSingle));
Assert.False(Float16.IsPositiveInfinity(negInfinityFromSingle));
Assert.False(Float16.IsNegativeInfinity(Float16.MinValue));
Assert.False(Float16.IsNegativeInfinity(Float16.MaxValue));
Assert.Equal(float.NegativeInfinity < 0, Float16.IsNegative(negInfinityFromSingle));
// Convert infinity to float and test the fact
float infFromFloat16 = (float)Float16.PositiveInfinity;
Assert.True(float.IsInfinity(infFromFloat16));
Assert.True(float.IsPositiveInfinity(infFromFloat16));
}
[Fact(DisplayName = "TestNormalSubnormal")]
public void TestNormalSubnormal()
{
Float16 fp16FromSingleMaxValue = (Float16)float.MaxValue;
// Float MaxValue is outside Float16 range. This is different
// from BFloat16 that retains sufficient range.
Assert.True(Float16.IsInfinity(fp16FromSingleMaxValue));
Assert.False(Float16.IsNormal(fp16FromSingleMaxValue));
Assert.False(Float16.IsNormal(Float16.PositiveInfinity));
Assert.True(Float16.IsNormal((Float16)45.6f));
Assert.False(Float16.IsSubnormal((Float16)45.6f));
Assert.False(Float16.IsSubnormal(fp16FromSingleMaxValue));
Assert.False(Float16.IsSubnormal(Float16.PositiveInfinity));
// 0b0_00000_0000000001 => 5.9604645E-08
const ushort minSubnormalBits = 0x0001;
const float smallestF16Subnormal = 5.9604645E-08f;
Float16 smallestSubnormal = new Float16(minSubnormalBits);
Assert.True(Float16.IsSubnormal(smallestSubnormal));
Assert.False(Float16.IsNormal(smallestSubnormal));
// 0b0_00000_1111111111 => 6.09755516E-05
const float largestF16Subnormal = 6.09755516E-05f;
const ushort maxSubnormalBits = 0x03FF;
Float16 largestSubnormal = new Float16(maxSubnormalBits);
Assert.True(Float16.IsSubnormal(largestSubnormal));
Assert.False(Float16.IsNormal(largestSubnormal));
// Convert subnormal to float and see if we match
float convertedFromSmallestSubnormal = (float)smallestSubnormal;
Assert.Equal(smallestF16Subnormal, convertedFromSmallestSubnormal, 6);
float convertedFromLargestSubnormal = (float)largestSubnormal;
Assert.Equal(largestF16Subnormal, convertedFromLargestSubnormal, 6);
}
[Fact(DisplayName = "TestEqual")]
public void TestEqual()
{
// Box it
object obj_1 = Float16.MaxValue;
object obj_2 = new Float16(Float16.MaxValue.value);
Assert.True(obj_1.Equals(obj_2));
Assert.NotEqual(0, obj_1.GetHashCode());
Assert.Equal(obj_1.GetHashCode(), obj_2.GetHashCode());
Assert.True(Float16.NaN.Equals(Float16.NaN));
Float16 fp16Zero = (Float16)0.0f;
const ushort ushortZero = 0;
Float16 fp16FromUshortZero = (Float16)ushortZero;
Assert.True(fp16Zero.Equals(fp16FromUshortZero));
// Should have the same hash code constant
Assert.Equal(fp16Zero.GetHashCode(), fp16FromUshortZero.GetHashCode());
Assert.Equal(Float16.NaN.GetHashCode(), Float16.NaN.GetHashCode());
}
[Fact(DisplayName = "TestCompare")]
public void TestCompare()
{
object objMaxValue = new Float16(Float16.MaxValue.value);
Assert.Equal(0, Float16.MaxValue.CompareTo(objMaxValue));
Float16 one = (Float16)1.0f;
Assert.Equal(-1, Float16.MinValue.CompareTo(one));
Assert.Equal(1, Float16.MaxValue.CompareTo(one));
// one is bigger than NaN
Assert.Equal(-1, Float16.NaN.CompareTo(one));
// Two NaNs are equal according to CompareTo()
Assert.Equal(0, Float16.NaN.CompareTo((Float16)float.NaN));
Assert.Equal(1, one.CompareTo(Float16.NaN));
// Compare to null
Assert.Equal(1, one.CompareTo(null));
// Make sure it throws
var obj = new object();
Assert.Throws<ArgumentException>(() => one.CompareTo(obj));
}
}
[Collection("Ort BFloat16 tests")]
public class OrtBFloat16Tests
{
const float oneThird = 1 / 3.0f;
const float oneSeventh = 1 / 7.0f;
const float oneTenth = 1 / 10.0f;
[Fact(DisplayName = "ConvertFloatToBFloat16")]
public void ConvertFloatToBFloat16()
{
// Generate integer floats and insert between them
// fractions. This will test the rounding logic.
float start = -10;
var floatValues = new float[21 * 4];
for (int i = 0; i < floatValues.Length; i += 4)
{
floatValues[i] = start;
floatValues[i + 1] = start + oneThird;
floatValues[i + 2] = start + oneSeventh;
floatValues[i + 3] = start + oneTenth;
start += 1;
}
var f16Converted = Array.ConvertAll(floatValues, f => (BFloat16)f);
var backConverted = Array.ConvertAll(f16Converted, f16 => (float)f16);
Assert.Equal(floatValues, backConverted, new FloatComparer());
}
[Fact(DisplayName = "TestZeros")]
public void TestZeros()
{
var positiveZero = new BFloat16(0);
Assert.False(BFloat16.IsNegative(positiveZero));
Assert.True(BFloat16.IsNaNOrZero(positiveZero));
float singlePositiveZero = (float)positiveZero;
Assert.Equal(+0.0f, singlePositiveZero);
#if NET6_0_OR_GREATER
Assert.False(float.IsNegative(singlePositiveZero));
#endif
var negativeZero = BFloat16.Negate(positiveZero);
Assert.True(BFloat16.IsNegative(negativeZero));
Assert.True(BFloat16.IsNaNOrZero(negativeZero));
float singleNegativeZero = (float)negativeZero;
Assert.Equal(-0.0f, singleNegativeZero);
#if NET6_0_OR_GREATER
Assert.True(float.IsNegative(singleNegativeZero));
#endif
}
[Fact(DisplayName = "TestComparisonOperators")]
public void TestComparisionOperators()
{
BFloat16 left = (BFloat16)(float)-33.33f;
BFloat16 leftSame = (BFloat16)(float)-33.33f;
BFloat16 right = (BFloat16)(float)66.66f;
BFloat16 rightSame = (BFloat16)(float)66.66f;
Assert.False(BFloat16.IsNaNOrZero(left));
Assert.False(BFloat16.IsNaNOrZero(right));
Assert.True(right > BFloat16.Epsilon);
Assert.True(left == leftSame);
Assert.False(left == BFloat16.Negate(leftSame));
Assert.True(right == rightSame);
Assert.False(right == BFloat16.Negate(rightSame));
Assert.True(left < right);
Assert.True(left > BFloat16.Negate(right));
Assert.True(BFloat16.Negate(left) < right);
Assert.True(left <= right);
Assert.True(left >= BFloat16.Negate(right));
Assert.False(left > right);
Assert.False(left >= right);
Assert.True(BFloat16.Negate(left) <= right);
Assert.False(left == right);
Assert.False(right == left);
Assert.True(left != right);
Assert.True(right != left);
}
[Fact(DisplayName = "TestNAN")]
public void TestNAN()
{
BFloat16 fp16NANFromSingle = (BFloat16)float.NaN;
Assert.True(BFloat16.IsNaN(fp16NANFromSingle));
Assert.Equal(BFloat16.NaN, fp16NANFromSingle);
Assert.True(BFloat16.IsNaNOrZero(fp16NANFromSingle));
float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
Assert.True(float.IsNaN(NanFromBFloat16));
// IEqualityComparable returns true, because it tests
// objects, not numbers.
Assert.Equal(fp16NANFromSingle, BFloat16.NaN);
Assert.Equal(BFloat16.NaN, BFloat16.Negate(BFloat16.NaN));
Assert.False(BFloat16.IsNaN(BFloat16.MaxValue));
}
[Fact(DisplayName = "TestNANComparision")]
public void TestNANComparisionOperators()
{
// NaN is not ordered with respect to anything
// including itself
// IEqualityComparable returns true, because it tests
// objects, not numbers.
Assert.Equal(BFloat16.NaN, BFloat16.NaN);
Assert.False(BFloat16.NaN < BFloat16.NaN);
Assert.False(BFloat16.NaN > BFloat16.NaN);
Assert.False(BFloat16.NaN <= BFloat16.NaN);
Assert.False(BFloat16.NaN >= BFloat16.NaN);
Assert.False(BFloat16.NaN == BFloat16.NaN);
// IEqualityComparable returns false, because it tests
// objects, not numbers.
Assert.NotEqual(BFloat16.NaN, BFloat16.MaxValue);
Assert.False(BFloat16.NaN < BFloat16.MaxValue);
Assert.False(BFloat16.MaxValue < BFloat16.NaN);
Assert.False(BFloat16.NaN == BFloat16.MaxValue);
Assert.False(BFloat16.MaxValue == BFloat16.NaN);
Assert.False(BFloat16.NaN > BFloat16.MinValue);
Assert.False(BFloat16.MaxValue > BFloat16.NaN);
Assert.False(BFloat16.NaN == BFloat16.MinValue);
Assert.False(BFloat16.MaxValue == BFloat16.NaN);
Assert.True(BFloat16.MinValue < BFloat16.MaxValue);
}
[Fact(DisplayName = "TestInfinity")]
public void TestInfinity()
{
Assert.False(BFloat16.IsInfinity(BFloat16.MinValue));
Assert.False(BFloat16.IsInfinity(BFloat16.MaxValue));
BFloat16 posInfinityFromSingle = (BFloat16)float.PositiveInfinity;
Assert.True(BFloat16.IsPositiveInfinity(posInfinityFromSingle));
Assert.Equal(BFloat16.PositiveInfinity, posInfinityFromSingle);
Assert.False(BFloat16.IsFinite(posInfinityFromSingle));
Assert.True(BFloat16.IsInfinity(posInfinityFromSingle));
Assert.True(BFloat16.IsPositiveInfinity(posInfinityFromSingle));
Assert.False(BFloat16.IsNegativeInfinity(posInfinityFromSingle));
Assert.False(BFloat16.IsPositiveInfinity(BFloat16.MinValue));
Assert.False(BFloat16.IsPositiveInfinity(BFloat16.MaxValue));
Assert.Equal(float.PositiveInfinity < 0, BFloat16.IsNegative(posInfinityFromSingle));
BFloat16 negInfinityFromSingle = (BFloat16)float.NegativeInfinity;
Assert.True(BFloat16.IsNegativeInfinity(negInfinityFromSingle));
Assert.Equal(BFloat16.NegativeInfinity, negInfinityFromSingle);
Assert.False(BFloat16.IsFinite(negInfinityFromSingle));
Assert.True(BFloat16.IsInfinity(negInfinityFromSingle));
Assert.True(BFloat16.IsNegativeInfinity(negInfinityFromSingle));
Assert.False(BFloat16.IsPositiveInfinity(negInfinityFromSingle));
Assert.False(BFloat16.IsNegativeInfinity(BFloat16.MinValue));
Assert.False(BFloat16.IsNegativeInfinity(BFloat16.MaxValue));
Assert.True(BFloat16.IsNegative(negInfinityFromSingle));
// Convert infinity to float and test the fact
float infFromBFloat16 = (float)BFloat16.PositiveInfinity;
Assert.True(float.IsInfinity(infFromBFloat16));
Assert.True(float.IsPositiveInfinity(infFromBFloat16));
}
[Fact(DisplayName = "TestNormalSubnormal")]
public void TestNormalSubnormal()
{
BFloat16 fp16FromSingleMaxValue = (BFloat16)float.MaxValue;
Assert.True(BFloat16.IsInfinity(fp16FromSingleMaxValue));
Assert.False(BFloat16.IsNormal(fp16FromSingleMaxValue));
Assert.False(BFloat16.IsNormal(BFloat16.PositiveInfinity));
Assert.True(BFloat16.IsNormal((BFloat16)45.6f));
Assert.False(BFloat16.IsSubnormal((BFloat16)45.6f));
Assert.False(BFloat16.IsSubnormal(fp16FromSingleMaxValue));
Assert.False(BFloat16.IsSubnormal(BFloat16.PositiveInfinity));
// 0b0_0000_0000_000_0001
const ushort minSubnormalBits = 0x0001;
BFloat16 smallestSubnormal = new BFloat16(minSubnormalBits);
Assert.True(BFloat16.IsSubnormal(smallestSubnormal));
Assert.False(BFloat16.IsNormal(smallestSubnormal));
#if NET6_0_OR_GREATER
float singleSmallestSubnormal = (float)smallestSubnormal;
Assert.True(float.IsSubnormal(singleSmallestSubnormal));
#endif
const ushort maxSubnormalBits = 0x007F; // 0b0_0000_0000_111_1111;
BFloat16 largestSubnormal = new BFloat16(maxSubnormalBits);
Assert.True(BFloat16.IsSubnormal(largestSubnormal));
Assert.False(BFloat16.IsNormal(largestSubnormal));
#if NET6_0_OR_GREATER
float singleLargestSubnornal = (float)largestSubnormal;
Assert.True(float.IsSubnormal(singleLargestSubnornal));
#endif
}
[Fact(DisplayName = "TestEqual")]
public void TestEqual()
{
// Box it
object obj_1 = BFloat16.MaxValue;
object obj_2 = new BFloat16(BFloat16.MaxValue.value);
Assert.True(obj_1.Equals(obj_2));
Assert.NotEqual(0, obj_1.GetHashCode());
Assert.Equal(obj_1.GetHashCode(), obj_2.GetHashCode());
Assert.True(BFloat16.NaN.Equals(BFloat16.NaN));
BFloat16 fp16Zero = (BFloat16)0.0f;
const ushort ushortZero = 0;
BFloat16 fp16FromUshortZero = (BFloat16)ushortZero;
Assert.True(fp16Zero.Equals(fp16FromUshortZero));
// Should have the same hash code constant
Assert.Equal(fp16Zero.GetHashCode(), fp16FromUshortZero.GetHashCode());
Assert.Equal(BFloat16.NaN.GetHashCode(), BFloat16.NaN.GetHashCode());
}
[Fact(DisplayName = "TestCompare")]
public void TestCompare()
{
object objMaxValue = new BFloat16(BFloat16.MaxValue.value);
Assert.Equal(0, BFloat16.MaxValue.CompareTo(objMaxValue));
BFloat16 one = (BFloat16)1.0f;
Assert.Equal(-1, BFloat16.MinValue.CompareTo(one));
Assert.Equal(1, BFloat16.MaxValue.CompareTo(one));
// one is bigger than NaN
Assert.Equal(-1, BFloat16.NaN.CompareTo(one));
// Two NaNs are equal according to CompareTo()
Assert.Equal(0, BFloat16.NaN.CompareTo((BFloat16)float.NaN));
Assert.Equal(1, one.CompareTo(BFloat16.NaN));
// Compare to null
Assert.Equal(1, one.CompareTo(null));
// Make sure it throws
var obj = new object();
Assert.Throws<ArgumentException>(() => one.CompareTo(obj));
}
}
}

Просмотреть файл

@ -242,13 +242,13 @@ namespace Microsoft.ML.OnnxRuntime.Tests
int[] int_data = { 1, 2, 3, 4, 5, 6, 7, 8 }; int[] int_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
ushort[] ushort_data = { 1, 2, 3, 4, 5, 6, 7, 8 }; ushort[] ushort_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
double[] dbl_data = { 1, 2, 3, 4, 5, 6, 7, 8 }; double[] dbl_data = { 1, 2, 3, 4, 5, 6, 7, 8 };
Float16[] fl16_data = { 1, 2, 3, 4, 5, 6, 7, 8 }; var fp16_data = Array.ConvertAll(ushort_data, sh => new Float16(sh));
PopulateAndCheck(float_data); PopulateAndCheck(float_data);
PopulateAndCheck(int_data); PopulateAndCheck(int_data);
PopulateAndCheck(ushort_data); PopulateAndCheck(ushort_data);
PopulateAndCheck(dbl_data); PopulateAndCheck(dbl_data);
PopulateAndCheck(fl16_data); PopulateAndCheck(fp16_data);
} }
private static readonly long[] ml_data_1 = { 1, 2 }; private static readonly long[] ml_data_1 = { 1, 2 };

Просмотреть файл

@ -619,4 +619,78 @@ namespace Microsoft.ML.OnnxRuntime.Tests
return tensorData.ToArray(); return tensorData.ToArray();
} }
} }
internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;
private float rtol = 1.7e-2f;
public bool Equals(float x, float y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(float x)
{
return x.GetHashCode();
}
}
internal class DoubleComparer : IEqualityComparer<double>
{
private double atol = 1e-3;
private double rtol = 1.7e-2;
public bool Equals(double x, double y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(double x)
{
return x.GetHashCode();
}
}
class ExactComparer<T> : IEqualityComparer<T>
{
public bool Equals(T x, T y)
{
return x.Equals(y);
}
public int GetHashCode(T x)
{
return x.GetHashCode();
}
}
/// <summary>
/// Use it to compare Float16
/// </summary>
internal class Float16Comparer : IEqualityComparer<Float16>
{
public ushort tolerance = 0;
public bool Equals(Float16 x, Float16 y)
{
return Math.Abs(x.value - y.value) <= (tolerance + y.value);
}
public int GetHashCode(Float16 x)
{
return x.GetHashCode();
}
}
/// <summary>
/// Use it to compare Bloat16
/// </summary>
internal class BFloat16Comparer : IEqualityComparer<BFloat16>
{
public ushort tolerance = 0;
public bool Equals(BFloat16 x, BFloat16 y)
{
return Math.Abs(x.value - y.value) <= (tolerance + y.value);
}
public int GetHashCode(BFloat16 x)
{
return x.GetHashCode();
}
}
} }

Просмотреть файл

@ -119,6 +119,9 @@
<!-- NOTE: The xUnit framework doesn't pickup the tests defined within the referenced Microsoft.ML.OnnxRuntime.Tests.Common project --> <!-- NOTE: The xUnit framework doesn't pickup the tests defined within the referenced Microsoft.ML.OnnxRuntime.Tests.Common project -->
<ItemGroup> <ItemGroup>
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtFloat16Tests.cs">
<Link>OrtFloat16Tests.cs</Link>
</Compile>
<Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtEnvTests.cs"> <Compile Include="..\Microsoft.ML.OnnxRuntime.Tests.Common\OrtEnvTests.cs">
<Link>OrtEnvTests.cs</Link> <Link>OrtEnvTests.cs</Link>
</Compile> </Compile>

Просмотреть файл

@ -2,6 +2,8 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <math.h>
#include "endian.h" #include "endian.h"
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include "cuda_bf16.h" #include "cuda_bf16.h"
@ -13,6 +15,8 @@
#include "core/common/common.h" #include "core/common/common.h"
#include "core/session/onnxruntime_float16.h"
namespace onnxruntime { namespace onnxruntime {
#if defined(__CUDACC__) || defined(__HIPCC__) #if defined(__CUDACC__) || defined(__HIPCC__)
@ -22,25 +26,69 @@ namespace onnxruntime {
#endif #endif
// MLFloat16 // MLFloat16
struct MLFloat16 { struct MLFloat16 : onnxruntime_float16::Float16Impl<MLFloat16> {
uint16_t val{0}; private:
explicit constexpr MLFloat16(uint16_t x) noexcept { val = x; }
public:
using Base = onnxruntime_float16::Float16Impl<MLFloat16>;
MLFloat16() = default; MLFloat16() = default;
explicit constexpr MLFloat16(uint16_t x) : val(x) {}
explicit MLFloat16(float f);
float ToFloat() const; constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); }
operator float() const { return ToFloat(); } // Using inherited implementation instead of math floatToHalf allows us to use this
// in other shared providers without having to implement the bridge
explicit MLFloat16(float v) noexcept { val = Base::ToUint16Impl(v); }
static const MLFloat16 NaN;
static const MLFloat16 NegativeNaN;
static const MLFloat16 Infinity;
static const MLFloat16 NegativeInfinity;
static const MLFloat16 Epsilon;
static const MLFloat16 MinValue;
static const MLFloat16 MaxValue;
static const MLFloat16 Zero;
static const MLFloat16 One;
static const MLFloat16 MinusOne;
// Using inherited implementation instead of math halfToFloat allows us to use this
// in other shared providers without having to implement the bridge
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
using Base::IsNegative;
using Base::IsNaN;
using Base::IsFinite;
using Base::IsPositiveInfinity;
using Base::IsNegativeInfinity;
using Base::IsInfinity;
using Base::IsNaNOrZero;
using Base::IsNormal;
using Base::IsSubnormal;
using Base::Abs;
using Base::Negate;
operator float() const noexcept { return ToFloat(); }
using Base::operator==;
using Base::operator!=;
using Base::operator<;
}; };
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }
// BFloat16 // BFloat16
struct BFloat16 { struct BFloat16 : onnxruntime_float16::BFloat16Impl<BFloat16> {
uint16_t val{0}; using Base = onnxruntime_float16::BFloat16Impl<BFloat16>;
#if defined(__HIP__) #if defined(__HIP__)
ORT_HOST_DEVICE BFloat16() = default; ORT_HOST_DEVICE BFloat16() = default;
#else #else
@ -48,10 +96,14 @@ struct BFloat16 {
#endif #endif
struct FromBitsT {}; struct FromBitsT {};
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } static constexpr ORT_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); }
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits) {} constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) noexcept { val = bits; }
inline ORT_HOST_DEVICE BFloat16(float v) { static constexpr ORT_HOST_DEVICE BFloat16 FromBits(uint16_t bits) noexcept {
return BFloat16(bits, FromBits());
}
inline ORT_HOST_DEVICE BFloat16(float v) noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
val = __bfloat16_as_ushort(__float2bfloat16(v)); val = __bfloat16_as_ushort(__float2bfloat16(v));
#elif defined(__HIP__) #elif defined(__HIP__)
@ -69,15 +121,34 @@ struct BFloat16 {
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16); val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
} }
#else #else
if constexpr (endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t)); // Use C isnan to work both in host and device
if (::isnan(v)) {
val = kPositiveQNaNBits;
} else { } else {
std::memcpy(&val, &v, sizeof(uint16_t)); auto get_msb_half = [](float fl) {
uint16_t result;
if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) {
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
} else {
std::memcpy(&result, &fl, sizeof(uint16_t));
}
return result;
};
uint16_t upper_bits = get_msb_half(v);
union {
uint32_t U32;
float F32;
};
F32 = v;
U32 += (upper_bits & 1) + kRoundToNearest;
val = get_msb_half(F32);
} }
#endif #endif
} }
inline ORT_HOST_DEVICE float ToFloat() const { inline ORT_HOST_DEVICE float ToFloat() const noexcept {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val)); return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
#elif defined(__HIP__) #elif defined(__HIP__)
@ -89,55 +160,129 @@ struct BFloat16 {
result = *tempRes; result = *tempRes;
return result; return result;
#else #else
float result;
if (IsNaNHostDevice()) {
return std::numeric_limits<float>::quiet_NaN();
}
float result = 0;
char* const first = reinterpret_cast<char*>(&result); char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
if constexpr (endian::native == endian::little) { if constexpr (endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t)); char* const second = first + sizeof(uint16_t);
std::memcpy(second, &val, sizeof(uint16_t)); std::memcpy(second, &val, sizeof(uint16_t));
} else { } else {
std::memcpy(first, &val, sizeof(uint16_t)); std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
} }
return result; return result;
#endif #endif
} }
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } static const BFloat16 NaN;
static const BFloat16 NegativeNaN;
static const BFloat16 Infinity;
static const BFloat16 NegativeInfinity;
static const BFloat16 Epsilon;
static const BFloat16 MinValue;
static const BFloat16 MaxValue;
static const BFloat16 Zero;
static const BFloat16 One;
static const BFloat16 MinusOne;
using Base::IsNegative;
using Base::IsNaN;
using Base::IsFinite;
using Base::IsPositiveInfinity;
using Base::IsNegativeInfinity;
using Base::IsInfinity;
using Base::IsNaNOrZero;
using Base::IsNormal;
using Base::IsSubnormal;
using Base::Abs;
using Base::Negate;
ORT_HOST_DEVICE operator float() const noexcept { return ToFloat(); }
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); } ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); } explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
#endif #endif
};
inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& right) { return left.val == right.val; } ORT_HOST_DEVICE bool operator==(const BFloat16& rhs) const noexcept {
inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; } if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; } // IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return val == rhs.val;
}
ORT_HOST_DEVICE bool operator!=(const BFloat16& rhs) const noexcept {
return !(*this == rhs);
}
ORT_HOST_DEVICE bool operator<(const BFloat16& rhs) const noexcept {
if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) {
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
const bool left_is_negative = IsNegativeHostDevice();
if (left_is_negative != rhs.IsNegativeHostDevice()) {
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return left_is_negative && !AreZeroHostDevice(*this, rhs);
}
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
}
ORT_HOST_DEVICE bool IsNegativeHostDevice() const noexcept {
return (val & kSignMask) != 0;
}
ORT_HOST_DEVICE bool IsNaNHostDevice() const noexcept {
return static_cast<uint16_t>(val & ~kSignMask) > kPositiveInfinityBits;
}
ORT_HOST_DEVICE static bool AreZeroHostDevice(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
// and therefore equivalent, if the resulting value is still zero.
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
}
};
// User defined suffixes to make it easier to declare // User defined suffixes to make it easier to declare
// initializers with MLFloat16 and BFloat16 from unsigned short // initializers with MLFloat16 and BFloat16 from unsigned short
// E.g 10_f16 or 10_b16 // E.g 10_f16 or 10_b16
#if !defined(__CUDACC__) && !defined(__HIPCC__) #if !defined(__CUDACC__) && !defined(__HIPCC__)
inline MLFloat16 operator"" _f16(unsigned long long int v) { inline MLFloat16 operator"" _f16(unsigned long long int v) noexcept {
return MLFloat16(narrow<uint16_t>(v)); return MLFloat16::FromBits(narrow<uint16_t>(v));
} }
inline MLFloat16 operator"" _fp16(long double v) { inline MLFloat16 operator"" _fp16(long double v) noexcept {
return MLFloat16(static_cast<float>(v)); return MLFloat16(static_cast<float>(v));
} }
inline BFloat16 operator"" _b16(unsigned long long int v) { inline BFloat16 operator"" _b16(unsigned long long int v) noexcept {
return BFloat16(narrow<uint16_t>(v), BFloat16::FromBits()); return BFloat16::FromBits((narrow<uint16_t>(v)));
} }
inline BFloat16 operator"" _bfp16(long double v) { inline BFloat16 operator"" _bfp16(long double v) noexcept {
return BFloat16(static_cast<float>(v)); return BFloat16(static_cast<float>(v));
} }
#endif #endif
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) noexcept {
auto src = blf; auto src = blf;
auto d = flt; auto d = flt;
for (; size != 0; ++src, ++d, --size) { for (; size != 0; ++src, ++d, --size) {
@ -149,7 +294,7 @@ inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
auto src = flt; auto src = flt;
auto d = blf; auto d = blf;
for (; size != 0; ++src, ++d, --size) { for (; size != 0; ++src, ++d, --size) {
new (d) BFloat16(*src); *d = BFloat16(*src);
} }
} }

Просмотреть файл

@ -24,6 +24,8 @@
#pragma once #pragma once
#include "onnxruntime_c_api.h" #include "onnxruntime_c_api.h"
#include "onnxruntime_float16.h"
#include <cstddef> #include <cstddef>
#include <cstdio> #include <cstdio>
#include <array> #include <array>
@ -142,70 +144,286 @@ std::string GetBuildInfoString();
std::vector<std::string> GetAvailableProviders(); std::vector<std::string> GetAvailableProviders();
/** \brief IEEE 754 half-precision floating point data type /** \brief IEEE 754 half-precision floating point data type
* \details It is necessary for type dispatching to make use of C++ API *
* The type is implicitly convertible to/from uint16_t. * \details This struct is used for converting float to float16 and back
* so the user could feed inputs and fetch outputs using these type.
*
* The size of the structure should align with uint16_t and one can freely cast * The size of the structure should align with uint16_t and one can freely cast
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
* *
* Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
* on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
* And you can also feed a array of uint16_t elements directly. For example,
*
* \code{.unparsed} * \code{.unparsed}
* uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; * // This example demonstrates converion from float to float16
* constexpr size_t values_length = sizeof(values) / sizeof(values[0]); * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
* std::vector<int64_t> dims = {values_length}; // one dimensional example * std::vector<Ort::Float16_t> fp16_values;
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); * fp16_values.reserve(std::size(values));
* // Note we are passing bytes count in this api, not number of elements -> sizeof(values) * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
* auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), * [](float value) { return Ort::Float16_t(value); });
* dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); *
* \endcode * \endcode
*
* Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
* a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
* template specialization.
*
* \code{.unparsed}
* namespace yours { struct half {}; } // assume this is your type, define this:
* namespace Ort {
* template<>
* struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
* } //namespace Ort
*
* std::vector<yours::half> values;
* std::vector<int64_t> dims = {values.size()}; // one dimensional example
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
* // Here we are passing element count -> values.size()
* auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
*
* \endcode
*/ */
struct Float16_t { struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
uint16_t value; private:
constexpr Float16_t() noexcept : value(0) {} /// <summary>
constexpr Float16_t(uint16_t v) noexcept : value(v) {} /// Constructor from a 16-bit representation of a float16 value
constexpr operator uint16_t() const noexcept { return value; } /// No conversion is done here.
constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; /// </summary>
constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; /// <param name="v">16-bit representation</param>
constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
public:
using Base = onnxruntime_float16::Float16Impl<Float16_t>;
/// <summary>
/// Default constructor
/// </summary>
Float16_t() = default;
/// <summary>
/// Explicit conversion to uint16_t representation of bfloat16.
/// </summary>
/// <param name="v">uint16_t bit representation of bfloat16</param>
/// <returns>new instance of Float16_t</returns>
constexpr static Float16_t FromBits(uint16_t x) noexcept { return Float16_t(x); }
/// <summary>
/// __ctor from float. Float is converted into float16 16-bit representation.
/// </summary>
/// <param name="v">float value</param>
explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
/// <summary>
/// Converts bfloat16 to float
/// </summary>
/// <returns>float representation of bfloat16 value</returns>
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
using Base::IsNegative;
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
using Base::IsNaN;
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
using Base::IsFinite;
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
using Base::IsPositiveInfinity;
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
using Base::IsNegativeInfinity;
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
using Base::IsInfinity;
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
using Base::IsNaNOrZero;
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
using Base::IsNormal;
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
using Base::IsSubnormal;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
using Base::Abs;
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
using Base::Negate;
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
using Base::AreZero;
/// <summary>
/// User defined conversion operator. Converts Float16_t to float.
/// </summary>
explicit operator float() const noexcept { return ToFloat(); }
using Base::operator==;
using Base::operator!=;
using Base::operator<;
}; };
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
/** \brief bfloat16 (Brain Floating Point) data type /** \brief bfloat16 (Brain Floating Point) data type
* \details It is necessary for type dispatching to make use of C++ API *
* The type is implicitly convertible to/from uint16_t. * \details This struct is used for converting float to bfloat16 and back
* so the user could feed inputs and fetch outputs using these type.
*
* The size of the structure should align with uint16_t and one can freely cast * The size of the structure should align with uint16_t and one can freely cast
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
* *
* See also code examples for Float16_t above. * \code{.unparsed}
* // This example demonstrates converion from float to float16
* constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
* std::vector<Ort::BFloat16_t> bfp16_values;
* bfp16_values.reserve(std::size(values));
* std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
* [](float value) { return Ort::BFloat16_t(value); });
*
* \endcode
*/ */
struct BFloat16_t { struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
uint16_t value; private:
constexpr BFloat16_t() noexcept : value(0) {} /// <summary>
constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} /// Constructor from a uint16_t representation of bfloat16
constexpr operator uint16_t() const noexcept { return value; } /// used in FromBits() to escape overload resolution issue with
constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; /// constructor from float.
constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; /// No conversion is done.
/// </summary>
/// <param name="v">16-bit bfloat16 value</param>
constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
public:
using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
BFloat16_t() = default;
/// <summary>
/// Explicit conversion to uint16_t representation of bfloat16.
/// </summary>
/// <param name="v">uint16_t bit representation of bfloat16</param>
/// <returns>new instance of BFloat16_t</returns>
static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
/// <summary>
/// __ctor from float. Float is converted into bfloat16 16-bit representation.
/// </summary>
/// <param name="v">float value</param>
explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
/// <summary>
/// Converts bfloat16 to float
/// </summary>
/// <returns>float representation of bfloat16 value</returns>
float ToFloat() const noexcept { return Base::ToFloatImpl(); }
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
using Base::IsNegative;
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
using Base::IsNaN;
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
using Base::IsFinite;
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
using Base::IsPositiveInfinity;
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
using Base::IsNegativeInfinity;
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
using Base::IsInfinity;
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
using Base::IsNaNOrZero;
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
using Base::IsNormal;
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
using Base::IsSubnormal;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
using Base::Abs;
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
using Base::Negate;
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
using Base::AreZero;
/// <summary>
/// User defined conversion operator. Converts BFloat16_t to float.
/// </summary>
explicit operator float() const noexcept { return ToFloat(); }
// We do not have an inherited impl for the below operators
// as the internal class implements them a little differently
bool operator==(const BFloat16_t& rhs) const noexcept;
bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
bool operator<(const BFloat16_t& rhs) const noexcept;
}; };
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");

Просмотреть файл

@ -7,6 +7,8 @@
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
// the main C++ file with implementation details. // the main C++ file with implementation details.
#include <cstring>
namespace Ort { namespace Ort {
namespace detail { namespace detail {
@ -131,6 +133,30 @@ struct TypeToTensorType<Float8E5M2FNUZ_t> {
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
}; };
inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return val == rhs.val;
}
inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
const bool left_is_negative = IsNegative();
if (left_is_negative != rhs.IsNegative()) {
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return left_is_negative && !AreZero(*this, rhs);
}
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
}
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
: allocator_(allocator), p_(p), size_(size) { : allocator_(allocator), p_(p), size_(size) {
} }

Просмотреть файл

@ -0,0 +1,531 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <cmath>
#include <cstring>
#include <limits>
namespace onnxruntime_float16 {
namespace detail {
enum class endian {
#if defined(_WIN32)
little = 0,
big = 1,
native = little,
#elif defined(__GNUC__) || defined(__clang__)
little = __ORDER_LITTLE_ENDIAN__,
big = __ORDER_BIG_ENDIAN__,
native = __BYTE_ORDER__,
#else
#error onnxruntime_float16::detail::endian is not implemented in this environment.
#endif
};
static_assert(
endian::native == endian::little || endian::native == endian::big,
"Only little-endian or big-endian native byte orders are supported.");
} // namespace detail
/// <summary>
/// Shared implementation between public and internal classes. CRTP pattern.
/// </summary>
template <class Derived>
struct Float16Impl {
protected:
/// <summary>
/// Converts from float to uint16_t float16 representation
/// </summary>
/// <param name="v"></param>
/// <returns></returns>
constexpr static uint16_t ToUint16Impl(float v) noexcept;
/// <summary>
/// Converts float16 to float
/// </summary>
/// <returns>float representation of float16 value</returns>
float ToFloatImpl() const noexcept;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
uint16_t AbsImpl() const noexcept {
return static_cast<uint16_t>(val & ~kSignMask);
}
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
uint16_t NegateImpl() const noexcept {
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
}
public:
// uint16_t special values
static constexpr uint16_t kSignMask = 0x8000U;
static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
static constexpr uint16_t kEpsilonBits = 0x4170U;
static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
static constexpr uint16_t kOneBits = 0x3C00U;
static constexpr uint16_t kMinusOneBits = 0xBC00U;
uint16_t val{0};
Float16Impl() = default;
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
bool IsNegative() const noexcept {
return static_cast<int16_t>(val) < 0;
}
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
bool IsNaN() const noexcept {
return AbsImpl() > kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
bool IsFinite() const noexcept {
return AbsImpl() < kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
bool IsPositiveInfinity() const noexcept {
return val == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
bool IsNegativeInfinity() const noexcept {
return val == kNegativeInfinityBits;
}
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
bool IsInfinity() const noexcept {
return AbsImpl() == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
bool IsNaNOrZero() const noexcept {
auto abs = AbsImpl();
return (abs == 0 || abs > kPositiveInfinityBits);
}
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
bool IsNormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
bool IsSubnormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
}
bool operator==(const Float16Impl& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return val == rhs.val;
}
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
bool operator<(const Float16Impl& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
const bool left_is_negative = IsNegative();
if (left_is_negative != rhs.IsNegative()) {
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return left_is_negative && !AreZero(*this, rhs);
}
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
}
};
// The following Float16_t conversions are based on the code from
// Eigen library.
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
// The original license follows:
//
// Copyright (c) Fabian Giesen, 2016
// All rights reserved.
// Redistribution and use in source and binary forms, with or without
// modification, are permitted.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
namespace detail {
union float32_bits {
unsigned int u;
float f;
};
} // namespace detail
template <class Derived>
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
detail::float32_bits f;
f.f = v;
constexpr detail::float32_bits f32infty = {255 << 23};
constexpr detail::float32_bits f16max = {(127 + 16) << 23};
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
constexpr unsigned int sign_mask = 0x80000000u;
uint16_t val = static_cast<uint16_t>(0x0u);
unsigned int sign = f.u & sign_mask;
f.u ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
} else { // (De)normalized number or zero
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this
// just works.
f.f += denorm_magic.f;
// and one integer subtract of the bias later, we have our final float!
val = static_cast<uint16_t>(f.u - denorm_magic.u);
} else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
// update exponent, rounding bias part 1
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
// without arithmetic overflow.
f.u += 0xc8000fffU;
// rounding bias part 2
f.u += mant_odd;
// take the bits!
val = static_cast<uint16_t>(f.u >> 13);
}
}
val |= static_cast<uint16_t>(sign >> 16);
return val;
}
template <class Derived>
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
constexpr detail::float32_bits magic = {113 << 23};
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
detail::float32_bits o;
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
unsigned int exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // re-normalize
}
o.u |= (val & 0x8000) << 16; // sign bit
return o.f;
}
/// Shared implementation between public and internal classes. CRTP pattern.
template <class Derived>
struct BFloat16Impl {
protected:
/// <summary>
/// Converts from float to uint16_t float16 representation
/// </summary>
/// <param name="v"></param>
/// <returns></returns>
static uint16_t ToUint16Impl(float v) noexcept;
/// <summary>
/// Converts bfloat16 to float
/// </summary>
/// <returns>float representation of bfloat16 value</returns>
float ToFloatImpl() const noexcept;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
uint16_t AbsImpl() const noexcept {
return static_cast<uint16_t>(val & ~kSignMask);
}
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
uint16_t NegateImpl() const noexcept {
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
}
public:
// uint16_t special values
static constexpr uint16_t kSignMask = 0x8000U;
static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
static constexpr uint16_t kEpsilonBits = 0x0080U;
static constexpr uint16_t kMinValueBits = 0xFF7FU;
static constexpr uint16_t kMaxValueBits = 0x7F7FU;
static constexpr uint16_t kRoundToNearest = 0x7FFFU;
static constexpr uint16_t kOneBits = 0x3F80U;
static constexpr uint16_t kMinusOneBits = 0xBF80U;
uint16_t val{0};
BFloat16Impl() = default;
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
bool IsNegative() const noexcept {
return static_cast<int16_t>(val) < 0;
}
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
bool IsNaN() const noexcept {
return AbsImpl() > kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
bool IsFinite() const noexcept {
return AbsImpl() < kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
bool IsPositiveInfinity() const noexcept {
return val == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
bool IsNegativeInfinity() const noexcept {
return val == kNegativeInfinityBits;
}
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
bool IsInfinity() const noexcept {
return AbsImpl() == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
bool IsNaNOrZero() const noexcept {
auto abs = AbsImpl();
return (abs == 0 || abs > kPositiveInfinityBits);
}
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
bool IsNormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
bool IsSubnormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
// and therefore equivalent, if the resulting value is still zero.
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
}
};
template <class Derived>
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
uint16_t result;
if (std::isnan(v)) {
result = kPositiveQNaNBits;
} else {
auto get_msb_half = [](float fl) {
uint16_t result;
#ifdef __cpp_if_constexpr
if constexpr (detail::endian::native == detail::endian::little) {
#else
if (detail::endian::native == detail::endian::little) {
#endif
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
} else {
std::memcpy(&result, &fl, sizeof(uint16_t));
}
return result;
};
uint16_t upper_bits = get_msb_half(v);
union {
uint32_t U32;
float F32;
};
F32 = v;
U32 += (upper_bits & 1) + kRoundToNearest;
result = get_msb_half(F32);
}
return result;
}
template <class Derived>
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
if (IsNaN()) {
return std::numeric_limits<float>::quiet_NaN();
}
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
#ifdef __cpp_if_constexpr
if constexpr (detail::endian::native == detail::endian::little) {
#else
if (detail::endian::native == detail::endian::little) {
#endif
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
}
} // namespace onnxruntime_float16

Просмотреть файл

@ -81,10 +81,10 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
auto* length_penalty_tensor = context->Input<Tensor>(5); auto* length_penalty_tensor = context->Input<Tensor>(5);
if (length_penalty_tensor) { if (length_penalty_tensor) {
if (length_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) { if (length_penalty_tensor->IsDataType<float>()) {
length_penalty = static_cast<float>(*length_penalty_tensor->Data<float>()); length_penalty = *length_penalty_tensor->Data<float>();
} else { } else {
length_penalty = static_cast<MLFloat16>(*length_penalty_tensor->Data<MLFloat16>()); length_penalty = static_cast<float>(*length_penalty_tensor->Data<MLFloat16>());
} }
} else { } else {
length_penalty = 1.0f; length_penalty = 1.0f;
@ -92,10 +92,10 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
auto* repetition_penalty_tensor = context->Input<Tensor>(6); auto* repetition_penalty_tensor = context->Input<Tensor>(6);
if (repetition_penalty_tensor) { if (repetition_penalty_tensor) {
if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) { if (repetition_penalty_tensor->IsDataType<float>()) {
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<float>()); repetition_penalty = *repetition_penalty_tensor->Data<float>();
} else { } else {
repetition_penalty = static_cast<MLFloat16>(*repetition_penalty_tensor->Data<MLFloat16>()); repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<MLFloat16>());
} }
} else { } else {
repetition_penalty = 1.0f; repetition_penalty = 1.0f;

Просмотреть файл

@ -11,7 +11,6 @@
#include "core/framework/tensor.h" #include "core/framework/tensor.h"
#include "core/framework/TensorSeq.h" #include "core/framework/TensorSeq.h"
#include "core/graph/onnx_protobuf.h" #include "core/graph/onnx_protobuf.h"
#include "core/util/math.h"
#ifdef __GNUC__ #ifdef __GNUC__
#pragma GCC diagnostic push #pragma GCC diagnostic push
@ -27,18 +26,34 @@ using namespace ONNX_NAMESPACE;
namespace onnxruntime { namespace onnxruntime {
MLFloat16::MLFloat16(float f) : val{math::floatToHalf(f)} {}
float MLFloat16::ToFloat() const {
return math::halfToFloat(val);
}
// Return the MLDataType used for a generic Tensor // Return the MLDataType used for a generic Tensor
template <> template <>
MLDataType DataTypeImpl::GetType<Tensor>() { MLDataType DataTypeImpl::GetType<Tensor>() {
return TensorTypeBase::Type(); return TensorTypeBase::Type();
} }
const MLFloat16 MLFloat16::NaN(MLFloat16::FromBits(MLFloat16::kPositiveQNaNBits));
const MLFloat16 MLFloat16::NegativeNaN(MLFloat16::FromBits(MLFloat16::kNegativeQNaNBits));
const MLFloat16 MLFloat16::Infinity(MLFloat16::FromBits(MLFloat16::kPositiveInfinityBits));
const MLFloat16 MLFloat16::NegativeInfinity(MLFloat16::FromBits(MLFloat16::kNegativeInfinityBits));
const MLFloat16 MLFloat16::Epsilon(MLFloat16::FromBits(MLFloat16::kEpsilonBits));
const MLFloat16 MLFloat16::MinValue(MLFloat16::FromBits(MLFloat16::kMinValueBits));
const MLFloat16 MLFloat16::MaxValue(MLFloat16::FromBits(MLFloat16::kMaxValueBits));
const MLFloat16 MLFloat16::Zero(MLFloat16::FromBits(0));
const MLFloat16 MLFloat16::One(MLFloat16::FromBits(MLFloat16::kOneBits));
const MLFloat16 MLFloat16::MinusOne(MLFloat16::FromBits(MLFloat16::kMinusOneBits));
const BFloat16 BFloat16::NaN(BFloat16::FromBits(BFloat16::kPositiveQNaNBits));
const BFloat16 BFloat16::NegativeNaN(BFloat16::FromBits(BFloat16::kNegativeQNaNBits));
const BFloat16 BFloat16::Infinity(BFloat16::FromBits(BFloat16::kPositiveInfinityBits));
const BFloat16 BFloat16::NegativeInfinity(BFloat16::FromBits(BFloat16::kNegativeInfinityBits));
const BFloat16 BFloat16::Epsilon(BFloat16::FromBits(BFloat16::kEpsilonBits));
const BFloat16 BFloat16::MinValue(BFloat16::FromBits(BFloat16::kMinValueBits));
const BFloat16 BFloat16::MaxValue(BFloat16::FromBits(BFloat16::kMaxValueBits));
const BFloat16 BFloat16::Zero(BFloat16::FromBits(0));
const BFloat16 BFloat16::One(BFloat16::FromBits(BFloat16::kOneBits));
const BFloat16 BFloat16::MinusOne(BFloat16::FromBits(BFloat16::kMinusOneBits));
} // namespace onnxruntime } // namespace onnxruntime
// This conflicts with the above GetType<>() specialization // This conflicts with the above GetType<>() specialization

Просмотреть файл

@ -370,7 +370,7 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
if (v < 0 || v > max_value) { if (v < 0 || v > max_value) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow"); return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow");
} }
p_data[i] = MLFloat16(static_cast<uint16_t>(v)); p_data[i] = MLFloat16::FromBits(static_cast<uint16_t>(v));
} }
return Status::OK(); return Status::OK();

Просмотреть файл

@ -1849,7 +1849,7 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
std::transform(Y.begin(), Y.end(), output.begin(), std::transform(Y.begin(), Y.end(), output.begin(),
[X_fl = math::halfToFloat(X.val)](const MLFloat16& y) { [X_fl = math::halfToFloat(X.val)](const MLFloat16& y) {
return MLFloat16(math::floatToHalf(std::fmod(X_fl, math::halfToFloat(y.val)))); return MLFloat16(std::fmod(X_fl, y.ToFloat()));
}); });
}, },
[](BroadcastHelper& per_iter_bh) { [](BroadcastHelper& per_iter_bh) {
@ -1859,7 +1859,7 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
std::transform(X.begin(), X.end(), output.begin(), std::transform(X.begin(), X.end(), output.begin(),
[Y_fl = math::halfToFloat(Y.val)](const MLFloat16& x) { [Y_fl = math::halfToFloat(Y.val)](const MLFloat16& x) {
return MLFloat16(math::floatToHalf(std::fmod(math::halfToFloat(x.val), Y_fl))); return MLFloat16(std::fmod(x.ToFloat(), Y_fl));
}); });
}, },
[](BroadcastHelper& per_iter_bh) { [](BroadcastHelper& per_iter_bh) {
@ -1869,9 +1869,9 @@ void BroadCastMLFloat16FMod(OpKernelContext* context) {
std::transform(X.begin(), X.end(), Y.begin(), output.begin(), std::transform(X.begin(), X.end(), Y.begin(), output.begin(),
[](const MLFloat16& x, const MLFloat16& y) { [](const MLFloat16& x, const MLFloat16& y) {
auto x_fl = math::halfToFloat(x.val); auto x_fl = x.ToFloat();
auto y_fl = math::halfToFloat(y.val); auto y_fl = y.ToFloat();
return MLFloat16(math::floatToHalf(std::fmod(x_fl, y_fl))); return MLFloat16(std::fmod(x_fl, y_fl));
}); });
}}; }};

Просмотреть файл

@ -36,7 +36,7 @@ Status Round<MLFloat16>::Compute(OpKernelContext* ctx) const {
auto* output = Y.MutableData<MLFloat16>(); auto* output = Y.MutableData<MLFloat16>();
const auto size = X.Shape().Size(); const auto size = X.Shape().Size();
for (int64_t i = 0; i < size; ++i, ++output, ++input) { for (int64_t i = 0; i < size; ++i, ++output, ++input) {
*output = MLFloat16(math::floatToHalf(::rint(math::halfToFloat(input->val)))); *output = MLFloat16(static_cast<float>(::rint(input->ToFloat())));
} }
return Status::OK(); return Status::OK();
} }

Просмотреть файл

@ -55,41 +55,27 @@ struct CallSignImpl {
} }
}; };
// The spec does not specify how NaN is
// treated but we have to treat it somehow. We choose
// to return 0 for NaN as TF does.
template <class T>
inline T FloatingImpl(T val) {
if (std::isnan(val) || val == T(0)) {
return T(0);
}
if (val > T(0)) {
return T(1);
} else {
return T(-1);
}
}
template <> template <>
struct CallSignImpl<MLFloat16> { struct CallSignImpl<MLFloat16> {
void operator()(const Tensor* input, Tensor* output) const { void operator()(const Tensor* input, Tensor* output) const {
auto span = gsl::make_span(input->Data<MLFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size())); ConstEigenVectorMap<Eigen::half> input_data(
auto output_data = output->MutableData<MLFloat16>(); reinterpret_cast<const Eigen::half*>(input->Data<MLFloat16>()),
std::transform(span.begin(), span.end(), output_data, [](const MLFloat16& val) { narrow<ptrdiff_t>(input->Shape().Size()));
float fl = math::halfToFloat(val.val);
return MLFloat16(math::floatToHalf(FloatingImpl(fl))); EigenVectorMap<Eigen::half>(reinterpret_cast<Eigen::half*>(output->MutableData<MLFloat16>()),
}); narrow<ptrdiff_t>(output->Shape().Size())) = input_data.array().cwiseSign();
} }
}; };
template <> template <>
struct CallSignImpl<BFloat16> { struct CallSignImpl<BFloat16> {
void operator()(const Tensor* input, Tensor* output) const { void operator()(const Tensor* input, Tensor* output) const {
auto span = gsl::make_span(input->Data<BFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size())); auto span = input->DataAsSpan<BFloat16>();
auto output_data = output->MutableData<BFloat16>(); auto output_data = output->MutableData<BFloat16>();
std::transform(span.begin(), span.end(), output_data, [](const BFloat16& val) { std::transform(span.begin(), span.end(), output_data, [](const BFloat16& val) {
float fl = val.ToFloat(); // Return 0 as TF does for NaN.
return BFloat16(FloatingImpl(fl)); if (val.IsNaNOrZero()) return BFloat16::Zero;
return (val.IsNegative()) ? BFloat16::MinusOne : BFloat16::One;
}); });
} }
}; };

Просмотреть файл

@ -55,18 +55,18 @@ Status ShrinkImpl(const Tensor* input, Tensor* output, float bias, float lambd)
template <> template <>
Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) { Status ShrinkImpl<MLFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
const auto& span = gsl::make_span(input->Data<MLFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size())); const auto span = input->DataAsSpan<MLFloat16>();
auto* output_data = output->MutableData<MLFloat16>(); auto* output_data = output->MutableData<MLFloat16>();
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const MLFloat16& val) { std::transform(span.begin(), span.end(), output_data, [bias, lambd](const MLFloat16& val) {
float fl = math::halfToFloat(val.val); float fl = val.ToFloat();
return MLFloat16(math::floatToHalf(ShrinkCore<float>(fl, bias, lambd))); return MLFloat16(ShrinkCore<float>(fl, bias, lambd));
}); });
return Status::OK(); return Status::OK();
} }
template <> template <>
Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) { Status ShrinkImpl<BFloat16>(const Tensor* input, Tensor* output, float bias, float lambd) {
const auto& span = gsl::make_span(input->Data<BFloat16>(), onnxruntime::narrow<size_t>(input->Shape().Size())); const auto span = input->DataAsSpan<BFloat16>();
auto* output_data = output->MutableData<BFloat16>(); auto* output_data = output->MutableData<BFloat16>();
std::transform(span.begin(), span.end(), output_data, [bias, lambd](const BFloat16& val) { std::transform(span.begin(), span.end(), output_data, [bias, lambd](const BFloat16& val) {
float fl = val.ToFloat(); float fl = val.ToFloat();

Просмотреть файл

@ -146,20 +146,20 @@ __device__ __forceinline__ BFloat16& operator/=(BFloat16& a, const BFloat16& b)
/// Arithmetic with floats /// Arithmetic with floats
__device__ __forceinline__ float operator+(BFloat16 a, float b) { return static_cast<float>(a) + b; } __device__ __forceinline__ float operator+(BFloat16 a, float b) { return a + b; }
__device__ __forceinline__ float operator-(BFloat16 a, float b) { return static_cast<float>(a) - b; } __device__ __forceinline__ float operator-(BFloat16 a, float b) { return a - b; }
__device__ __forceinline__ float operator*(BFloat16 a, float b) { return static_cast<float>(a) * b; } __device__ __forceinline__ float operator*(BFloat16 a, float b) { return a * b; }
__device__ __forceinline__ float operator/(BFloat16 a, float b) { return static_cast<float>(a) / b; } __device__ __forceinline__ float operator/(BFloat16 a, float b) { return a / b; }
__device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + static_cast<float>(b); } __device__ __forceinline__ float operator+(float a, BFloat16 b) { return a + b; }
__device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - static_cast<float>(b); } __device__ __forceinline__ float operator-(float a, BFloat16 b) { return a - b; }
__device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * static_cast<float>(b); } __device__ __forceinline__ float operator*(float a, BFloat16 b) { return a * b; }
__device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / static_cast<float>(b); } __device__ __forceinline__ float operator/(float a, BFloat16 b) { return a / b; }
__device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += static_cast<float>(b); } __device__ __forceinline__ float& operator+=(float& a, const BFloat16& b) { return a += b; }
__device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= static_cast<float>(b); } __device__ __forceinline__ float& operator-=(float& a, const BFloat16& b) { return a -= b; }
__device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= static_cast<float>(b); } __device__ __forceinline__ float& operator*=(float& a, const BFloat16& b) { return a *= b; }
__device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= static_cast<float>(b); } __device__ __forceinline__ float& operator/=(float& a, const BFloat16& b) { return a /= b; }
/// Arithmetic with doubles /// Arithmetic with doubles

Просмотреть файл

@ -73,10 +73,10 @@ struct LowMax {
template <> template <>
struct LowMax<MLFloat16> { struct LowMax<MLFloat16> {
static MLFloat16 low() { static MLFloat16 low() {
return MLFloat16(math::floatToHalf(std::numeric_limits<float>::lowest())); return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::lowest()));
} }
static MLFloat16 max() { static MLFloat16 max() {
return MLFloat16(math::floatToHalf(std::numeric_limits<float>::max())); return MLFloat16::FromBits(math::floatToHalf(std::numeric_limits<float>::max()));
} }
}; };
} // namespace clip_internal } // namespace clip_internal

Просмотреть файл

@ -460,10 +460,6 @@ Status DenseTensorToSparseCoo(const DataTransferManager& data_manager, const Ten
} // namespace sparse_utils } // namespace sparse_utils
float MLFloat16::ToFloat() const {
return math::halfToFloat(val);
}
std::vector<std::string> GetStackTrace() { return g_host->GetStackTrace(); } std::vector<std::string> GetStackTrace() { return g_host->GetStackTrace(); }
void LogRuntimeError(uint32_t session_id, const common::Status& status, void LogRuntimeError(uint32_t session_id, const common::Status& status,

Просмотреть файл

@ -76,7 +76,7 @@ class RandomValueGenerator {
std::vector<TFloat16> val(detail::SizeFromDims(dims)); std::vector<TFloat16> val(detail::SizeFromDims(dims));
std::uniform_real_distribution<float> distribution(min, max); std::uniform_real_distribution<float> distribution(min, max);
for (size_t i = 0; i < val.size(); ++i) { for (size_t i = 0; i < val.size(); ++i) {
val[i] = TFloat16(math::floatToHalf(distribution(generator_))); val[i] = TFloat16(static_cast<float>(distribution(generator_)));
} }
return val; return val;
} }

Просмотреть файл

@ -151,7 +151,7 @@ inline std::vector<MLFloat16> ToFloat16(const std::vector<float>& data) {
std::vector<MLFloat16> result; std::vector<MLFloat16> result;
result.reserve(data.size()); result.reserve(data.size());
for (size_t i = 0; i < data.size(); i++) { for (size_t i = 0; i < data.size(); i++) {
result.push_back(MLFloat16(math::floatToHalf(data[i]))); result.push_back(MLFloat16(data[i]));
} }
return result; return result;
} }
@ -245,7 +245,7 @@ class ParallelRandomValueGenerator {
RandomEngine generator{seed}; RandomEngine generator{seed};
std::uniform_real_distribution<float> distribution(min, max); std::uniform_real_distribution<float> distribution(min, max);
for (std::ptrdiff_t di = begin; di != end; ++di) { for (std::ptrdiff_t di = begin; di != end; ++di) {
val[di] = TFloat16(math::floatToHalf(distribution(generator))); val[di] = TFloat16(static_cast<float>(distribution(generator)));
; ;
} }
}); });

Просмотреть файл

@ -109,8 +109,8 @@ void RunBiasDropoutTest(const bool use_mask, const std::vector<int64_t>& input_s
ratio = 0.5f; ratio = 0.5f;
} else { } else {
if (use_float16_ratio) { if (use_float16_ratio) {
t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))}); t.AddInput("ratio", {}, {MLFloat16(ratio)});
t_bitmask.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))}); t_bitmask.AddInput("ratio", {}, {MLFloat16(ratio)});
} else { } else {
t.AddInput("ratio", {}, {ratio}); t.AddInput("ratio", {}, {ratio});
t_bitmask.AddInput("ratio", {}, {ratio}); t_bitmask.AddInput("ratio", {}, {ratio});

Просмотреть файл

@ -253,17 +253,17 @@ TEST(MathOpTest, ComplexMul_fp16) {
if (DefaultCudaExecutionProvider() == nullptr) return; if (DefaultCudaExecutionProvider() == nullptr) return;
std::vector<MLFloat16> input_a_data = { std::vector<MLFloat16> input_a_data = {
MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.6f))}; MLFloat16(-0.5f), MLFloat16(0.6f)};
std::vector<MLFloat16> input_b_data = { std::vector<MLFloat16> input_b_data = {
MLFloat16(math::floatToHalf(0.8f)), MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.0f)), MLFloat16(math::floatToHalf(1.f)), MLFloat16(0.8f), MLFloat16(-0.5f), MLFloat16(0.0f), MLFloat16(1.f),
MLFloat16(math::floatToHalf(0.5f)), MLFloat16(math::floatToHalf(0.2f)), MLFloat16(math::floatToHalf(0.3f)), MLFloat16(math::floatToHalf(-0.6f))}; MLFloat16(0.5f), MLFloat16(0.2f), MLFloat16(0.3f), MLFloat16(-0.6f)};
std::vector<MLFloat16> output_data = { std::vector<MLFloat16> output_data = {
MLFloat16(math::floatToHalf(-0.10f)), MLFloat16(math::floatToHalf(0.73f)), MLFloat16(-0.10f), MLFloat16(0.73f),
MLFloat16(math::floatToHalf(-0.60f)), MLFloat16(math::floatToHalf(-0.50f)), MLFloat16(-0.60f), MLFloat16(-0.50f),
MLFloat16(math::floatToHalf(-0.37f)), MLFloat16(math::floatToHalf(0.20f)), MLFloat16(-0.37f), MLFloat16(0.20f),
MLFloat16(math::floatToHalf(0.21f)), MLFloat16(math::floatToHalf(0.48f))}; MLFloat16(0.21f), MLFloat16(0.48f)};
OpTester tester("ComplexMul", 1, onnxruntime::kMSDomain); OpTester tester("ComplexMul", 1, onnxruntime::kMSDomain);
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data); tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);
@ -279,17 +279,17 @@ TEST(MathOpTest, ComplexMulConj_fp16) {
if (DefaultCudaExecutionProvider() == nullptr) return; if (DefaultCudaExecutionProvider() == nullptr) return;
std::vector<MLFloat16> input_a_data = { std::vector<MLFloat16> input_a_data = {
MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.6f))}; MLFloat16(-0.5f), MLFloat16(0.6f)};
std::vector<MLFloat16> input_b_data = { std::vector<MLFloat16> input_b_data = {
MLFloat16(math::floatToHalf(0.8f)), MLFloat16(math::floatToHalf(-0.5f)), MLFloat16(math::floatToHalf(0.0f)), MLFloat16(math::floatToHalf(1.f)), MLFloat16(0.8f), MLFloat16(-0.5f), MLFloat16(0.0f), MLFloat16(1.f),
MLFloat16(math::floatToHalf(0.5f)), MLFloat16(math::floatToHalf(0.2f)), MLFloat16(math::floatToHalf(0.3f)), MLFloat16(math::floatToHalf(-0.6f))}; MLFloat16(0.5f), MLFloat16(0.2f), MLFloat16(0.3f), MLFloat16(-0.6f)};
std::vector<MLFloat16> output_data = { std::vector<MLFloat16> output_data = {
MLFloat16(math::floatToHalf(-0.70f)), MLFloat16(math::floatToHalf(0.23f)), MLFloat16(-0.70f), MLFloat16(0.23f),
MLFloat16(math::floatToHalf(0.60f)), MLFloat16(math::floatToHalf(0.50f)), MLFloat16(0.60f), MLFloat16(0.50f),
MLFloat16(math::floatToHalf(-0.13f)), MLFloat16(math::floatToHalf(0.40f)), MLFloat16(-0.13f), MLFloat16(0.40f),
MLFloat16(math::floatToHalf(-0.51f)), MLFloat16(math::floatToHalf(-0.12f))}; MLFloat16(-0.51f), MLFloat16(-0.12f)};
OpTester tester("ComplexMulConj", 1, onnxruntime::kMSDomain); OpTester tester("ComplexMulConj", 1, onnxruntime::kMSDomain);
tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data); tester.AddInput<MLFloat16>("A", {1, 2}, input_a_data);

Просмотреть файл

@ -30,14 +30,14 @@ TEST(InverseContribOpTest, two_by_two_float16) {
std::transform( std::transform(
input_float.begin(), input_float.end(), std::back_inserter(input), input_float.begin(), input_float.end(), std::back_inserter(input),
[](float v) { [](float v) {
return MLFloat16(math::floatToHalf(v)); return MLFloat16(v);
}); });
auto output_float = {0.6f, -0.7f, -0.2f, 0.4f}; auto output_float = {0.6f, -0.7f, -0.2f, 0.4f};
std::vector<MLFloat16> output; std::vector<MLFloat16> output;
std::transform( std::transform(
output_float.begin(), output_float.end(), std::back_inserter(output), [](float v) { output_float.begin(), output_float.end(), std::back_inserter(output), [](float v) {
return MLFloat16(math::floatToHalf(v)); return MLFloat16(v);
}); });
test.AddInput<MLFloat16>("X", {2, 2}, input); test.AddInput<MLFloat16>("X", {2, 2}, input);

Просмотреть файл

@ -420,13 +420,159 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) {
} }
#endif // !defined(DISABLE_ML_OPS) #endif // !defined(DISABLE_ML_OPS)
TEST_F(DataTypeTest, BFloat16Test) { TEST_F(DataTypeTest, MlFloat16ConvertFloatToMLFloat16) {
// Test data type // Test data type
{ {
constexpr float sample = 1.0f; constexpr float sample = 1.0f;
BFloat16 flt16(sample); const MLFloat16 flt16(sample);
auto int_rep = flt16.val; auto int_rep = flt16.val;
BFloat16 flt_from_int(int_rep, BFloat16::FromBits()); const auto flt_from_int = MLFloat16::FromBits(int_rep);
const double diff = std::fabs(sample - flt_from_int.ToFloat());
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
EXPECT_TRUE(false);
}
}
// Test bulk conversion
{
float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
std::vector<MLFloat16> converted;
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
[](float fl) { return MLFloat16(fl); });
for (size_t i = 0; i < sizeof(sample) / sizeof(float); ++i) {
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
FAIL();
}
}
std::vector<float> back_converted;
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
[](const MLFloat16 ml) { return (float)ml; });
for (size_t i = 0; i < sizeof(sample) / sizeof(float); ++i) {
const double diff = std::fabs(sample[i] - back_converted[i]);
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
FAIL();
}
}
}
}
TEST_F(DataTypeTest, MLFloat16Zeros) {
const auto positive_zero = MLFloat16::FromBits(0U);
EXPECT_FALSE(positive_zero.IsNegative());
const float float_positive_zero = static_cast<float>(positive_zero);
EXPECT_EQ(+0.0f, float_positive_zero);
EXPECT_FALSE(std::signbit(float_positive_zero));
const auto negative_zero = positive_zero.Negate();
EXPECT_TRUE(negative_zero.IsNegative());
const float float_positive_negzero = static_cast<float>(negative_zero);
EXPECT_EQ(-0.0f, float_positive_negzero);
EXPECT_TRUE(std::signbit(float_positive_negzero));
EXPECT_TRUE(positive_zero.IsNaNOrZero());
EXPECT_TRUE(negative_zero.IsNaNOrZero());
}
TEST_F(DataTypeTest, MLFloat16Comparision) {
const MLFloat16 left = MLFloat16(-33.33f);
const MLFloat16 left_same = MLFloat16(-33.33f);
const MLFloat16 right = MLFloat16(66.66f);
const MLFloat16 right_same = MLFloat16(66.66f);
EXPECT_TRUE(MLFloat16::Epsilon < right);
EXPECT_EQ(left, left_same);
EXPECT_NE(left, left_same.Negate());
EXPECT_EQ(right, right_same);
EXPECT_NE(right, right_same.Negate());
EXPECT_LT(left, right);
EXPECT_LT(right.Negate(), left);
EXPECT_LT(left.Negate(), right);
}
TEST_F(DataTypeTest, MLFloat16TestNAN) {
const MLFloat16 fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
// NaN are not equal to each other
EXPECT_NE(MLFloat16::NaN, fp16NANFromSingle);
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
EXPECT_TRUE(std::isnan(NanFromBFloat16));
EXPECT_FALSE(MLFloat16::FromBits(MLFloat16::kMaxValueBits).IsNaN());
}
TEST_F(DataTypeTest, MLFloat16NaNComparision) {
EXPECT_FALSE(MLFloat16::NaN < MLFloat16::NaN);
EXPECT_FALSE(MLFloat16::NaN == MLFloat16::NaN);
EXPECT_FALSE(MLFloat16::MaxValue < MLFloat16::NaN);
EXPECT_FALSE(MLFloat16::MaxValue == MLFloat16::NaN);
EXPECT_FALSE(MLFloat16::MinValue < MLFloat16::NaN);
EXPECT_FALSE(MLFloat16::NaN < MLFloat16::MaxValue);
EXPECT_TRUE(MLFloat16::MinValue < MLFloat16::MaxValue);
}
TEST_F(DataTypeTest, MLFloat16Infinity) {
EXPECT_FALSE(MLFloat16::MinValue.IsInfinity());
EXPECT_FALSE(MLFloat16::MaxValue.IsInfinity());
EXPECT_TRUE(MLFloat16::MaxValue.IsFinite());
const MLFloat16 pos_infinity_from_float(std::numeric_limits<float>::infinity());
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
const MLFloat16 neg_infinity_from_float(-std::numeric_limits<float>::infinity());
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
const float pos_infinity_from_bfloat16 = static_cast<float>(MLFloat16::Infinity);
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
EXPECT_TRUE(!std::signbit(pos_infinity_from_bfloat16));
}
TEST_F(DataTypeTest, MLFloat16NormalSubnormal) {
EXPECT_FALSE(MLFloat16::Infinity.IsNormal());
EXPECT_TRUE(MLFloat16(45.6f).IsNormal());
EXPECT_FALSE(MLFloat16(45.6f).IsSubnormal());
// 0b0_0000_0000_000_0001 ~0.000000059604645
constexpr uint16_t min_subnormal_bits = 0x0001;
const MLFloat16 smallest_subnormal = MLFloat16::FromBits(min_subnormal_bits);
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
// float smallest positive subnormal is ~1.40129846432481707092E-45, and
// in float the same number above would be normal
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
EXPECT_TRUE(std::isnormal(float_from_smallest_subnormal));
// 0b0_0000_0000_111_1111; ~0.000060975552
constexpr uint16_t max_subnormal_bits = 0x007F;
const MLFloat16 largest_subnormal = MLFloat16::FromBits(max_subnormal_bits);
EXPECT_TRUE(largest_subnormal.IsSubnormal());
EXPECT_FALSE(largest_subnormal.IsNormal());
// However, in float the same number above would be normal
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
EXPECT_TRUE(std::isnormal(float_from_largest_subnormal));
}
TEST_F(DataTypeTest, BFloat16ConvertFloatToBFloat16) {
// Test data type
{
constexpr float sample = 1.0f;
const BFloat16 flt16(sample);
auto int_rep = flt16.val;
const auto flt_from_int = BFloat16::FromBits(int_rep);
const double diff = std::fabs(sample - flt_from_int.ToFloat()); const double diff = std::fabs(sample - flt_from_int.ToFloat());
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) { if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
EXPECT_TRUE(false); EXPECT_TRUE(false);
@ -456,6 +602,112 @@ TEST_F(DataTypeTest, BFloat16Test) {
} }
} }
TEST_F(DataTypeTest, BFloat16Zeros) {
const auto positive_zero = BFloat16::FromBits(0U);
EXPECT_FALSE(positive_zero.IsNegative());
const float float_positive_zero = static_cast<float>(positive_zero);
EXPECT_EQ(+0.0f, float_positive_zero);
EXPECT_FALSE(std::signbit(float_positive_zero));
const auto negative_zero = positive_zero.Negate();
EXPECT_TRUE(negative_zero.IsNegative());
const float float_positive_negzero = static_cast<float>(negative_zero);
EXPECT_EQ(-0.0f, float_positive_negzero);
EXPECT_TRUE(std::signbit(float_positive_negzero));
EXPECT_TRUE(positive_zero.IsNaNOrZero());
EXPECT_TRUE(negative_zero.IsNaNOrZero());
}
TEST_F(DataTypeTest, BFloat16Comparision) {
const BFloat16 left = BFloat16(-33.33f);
const BFloat16 left_same = BFloat16(-33.33f);
const BFloat16 right = BFloat16(66.66f);
const BFloat16 right_same = BFloat16(66.66f);
EXPECT_TRUE(BFloat16::Epsilon < right);
EXPECT_EQ(left, left_same);
EXPECT_NE(left, left_same.Negate());
EXPECT_EQ(right, right_same);
EXPECT_NE(right, right_same.Negate());
EXPECT_LT(left, right);
EXPECT_LT(right.Negate(), left);
EXPECT_LT(left.Negate(), right);
}
TEST_F(DataTypeTest, BFloat16TestNAN) {
const BFloat16 fp16NANFromSingle = std::numeric_limits<float>::quiet_NaN();
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaNOrZero());
// NaN are not equal to each other
EXPECT_NE(BFloat16::NaN, fp16NANFromSingle);
float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
EXPECT_TRUE(std::isnan(NanFromBFloat16));
EXPECT_FALSE(BFloat16::FromBits(BFloat16::kMaxValueBits).IsNaN());
}
TEST_F(DataTypeTest, BFloat16NaNComparision) {
EXPECT_FALSE(BFloat16::NaN < BFloat16::NaN);
EXPECT_FALSE(BFloat16::NaN == BFloat16::NaN);
EXPECT_FALSE(BFloat16::MaxValue < BFloat16::NaN);
EXPECT_FALSE(BFloat16::MaxValue == BFloat16::NaN);
EXPECT_FALSE(BFloat16::MinValue < BFloat16::NaN);
EXPECT_FALSE(BFloat16::NaN < BFloat16::MaxValue);
EXPECT_TRUE(BFloat16::MinValue < BFloat16::MaxValue);
}
TEST_F(DataTypeTest, BFloat16Infinity) {
EXPECT_FALSE(BFloat16::MinValue.IsInfinity());
EXPECT_FALSE(BFloat16::MaxValue.IsInfinity());
EXPECT_TRUE(BFloat16::MaxValue.IsFinite());
const BFloat16 pos_infinity_from_float = std::numeric_limits<float>::infinity();
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
const BFloat16 neg_infinity_from_float = -std::numeric_limits<float>::infinity();
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
EXPECT_TRUE(std::signbit(neg_infinity_from_float.ToFloat()));
const float pos_infinity_from_bfloat16 = static_cast<float>(BFloat16::Infinity);
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
EXPECT_TRUE(!std::signbit(pos_infinity_from_bfloat16));
}
TEST_F(DataTypeTest, BFloat16NormalSubnormal) {
EXPECT_FALSE(BFloat16::Infinity.IsNormal());
EXPECT_TRUE(BFloat16(45.6f).IsNormal());
EXPECT_FALSE(BFloat16(45.6f).IsSubnormal());
// 0b0_0000_0000_000_0001
constexpr uint16_t min_subnormal_bits = 0x0001;
const BFloat16 smallest_subnormal = BFloat16::FromBits(min_subnormal_bits);
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
const float float_from_smallest_subnormal = (float)smallest_subnormal;
EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));
// 0b0_0000_0000_111_1111;
constexpr uint16_t max_subnormal_bits = 0x007F;
const BFloat16 largest_subnormal = BFloat16::FromBits(max_subnormal_bits);
EXPECT_TRUE(largest_subnormal.IsSubnormal());
EXPECT_FALSE(largest_subnormal.IsNormal());
const float float_from_largest_subnormal = (float)largest_subnormal;
EXPECT_FALSE(std::isnormal(float_from_largest_subnormal));
}
TEST_F(DataTypeTest, DataUtilsTest) { TEST_F(DataTypeTest, DataUtilsTest) {
using namespace ONNX_NAMESPACE::Utils; using namespace ONNX_NAMESPACE::Utils;
// Test simple seq // Test simple seq
@ -697,7 +949,7 @@ TEST(InlinedVectorTests, TestDefaultInlinedCapacity) {
TEST(TypeLiterals, Tests) { TEST(TypeLiterals, Tests) {
{ {
// uint16_t test // uint16_t test
MLFloat16 mlfloat{static_cast<uint16_t>(16)}; MLFloat16 mlfloat = MLFloat16::FromBits(static_cast<uint16_t>(16));
auto mlfloat_literal = 16_f16; auto mlfloat_literal = 16_f16;
ASSERT_EQ(mlfloat, mlfloat_literal); ASSERT_EQ(mlfloat, mlfloat_literal);

Просмотреть файл

@ -159,7 +159,7 @@ TEST(Float16_Tests, Mul_16_Test) {
std::vector<float> values_x_32 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector<float> values_x_32 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
std::vector<MLFloat16> values_x; std::vector<MLFloat16> values_x;
for (float i : values_x_32) { for (float i : values_x_32) {
values_x.push_back(MLFloat16(math::floatToHalf(i))); values_x.push_back(MLFloat16(i));
} }
// prepare expected inputs and outputs // prepare expected inputs and outputs
@ -168,7 +168,7 @@ TEST(Float16_Tests, Mul_16_Test) {
std::vector<float> expected_values_y_32 = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; std::vector<float> expected_values_y_32 = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
std::vector<MLFloat16> expected_values_y; std::vector<MLFloat16> expected_values_y;
for (float i : expected_values_y_32) { for (float i : expected_values_y_32) {
expected_values_y.push_back(MLFloat16(math::floatToHalf(i))); expected_values_y.push_back(MLFloat16(i));
} }
// Now run // Now run

Просмотреть файл

@ -189,7 +189,7 @@ void UnpackTensor(const onnx::TensorProto& tensor, const void* raw_data, size_t
ORT_CXX_API_THROW( ORT_CXX_API_THROW(
"data overflow", OrtErrorCode::ORT_FAIL); "data overflow", OrtErrorCode::ORT_FAIL);
} }
p_data[i] = MLFloat16(static_cast<uint16_t>(v)); p_data[i] = MLFloat16::FromBits(static_cast<uint16_t>(v));
} }
return; return;

Просмотреть файл

@ -2348,7 +2348,7 @@ TEST_F(GraphTransformationTests, FuseConvBnAddMulFloat16) {
run_options.run_tag = "one session/one tag"; run_options.run_tag = "one session/one tag";
OrtValue ml_value_x; OrtValue ml_value_x;
auto x_f = MLFloat16(math::floatToHalf(1.0)); auto x_f = MLFloat16(1.0f);
std::vector<int64_t> dims_x = {1, 1, 3, 3}; std::vector<int64_t> dims_x = {1, 1, 3, 3};
std::vector<MLFloat16> values_x; std::vector<MLFloat16> values_x;
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
@ -2364,7 +2364,7 @@ TEST_F(GraphTransformationTests, FuseConvBnAddMulFloat16) {
ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches));
auto prod_f = MLFloat16(math::floatToHalf(6.0)); auto prod_f = MLFloat16(6.0f);
std::vector<int64_t> expected_dims_prod = {1, 1, 2, 2}; std::vector<int64_t> expected_dims_prod = {1, 1, 2, 2};
std::vector<MLFloat16> expected_values_prod; std::vector<MLFloat16> expected_values_prod;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
@ -5477,7 +5477,7 @@ void BuildConstantSharingDivMulGraph(ModelTestBuilder& builder) {
for (size_t i = 0; i < 12; ++i) { for (size_t i = 0; i < 12; ++i) {
NodeArg* mul_initializer = nullptr; NodeArg* mul_initializer = nullptr;
if (std::is_same<T, MLFloat16>::value) { if (std::is_same<T, MLFloat16>::value) {
mul_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(math::floatToHalf(1.0f))); mul_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(1.0f));
} else if (std::is_same<T, float>::value) { } else if (std::is_same<T, float>::value) {
mul_initializer = builder.MakeScalarInitializer<float>(1.0f); mul_initializer = builder.MakeScalarInitializer<float>(1.0f);
} else { } else {
@ -5593,7 +5593,7 @@ void BuildConstantSharingDivMulGraphFor2DInitializer(ModelTestBuilder& builder)
values_float16.reserve(values.size()); values_float16.reserve(values.size());
if (std::is_same<T, MLFloat16>::value) { if (std::is_same<T, MLFloat16>::value) {
for (auto v : values) { for (auto v : values) {
values_float16.push_back(MLFloat16(math::floatToHalf(v))); values_float16.push_back(MLFloat16(v));
} }
} }
@ -5804,7 +5804,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ
builder.AddNode("Cast", {div_out}, {cast_out}) builder.AddNode("Cast", {div_out}, {cast_out})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); .AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16));
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
NodeArg* add_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(math::floatToHalf(1.0f))); NodeArg* add_initializer = builder.MakeScalarInitializer<MLFloat16>(MLFloat16(1.0f));
auto* add_out = builder.MakeOutput(); auto* add_out = builder.MakeOutput();
builder.AddNode("Add", {cast_out, add_initializer}, {add_out}); builder.AddNode("Add", {cast_out, add_initializer}, {add_out});
} }
@ -5930,7 +5930,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitial
std::vector<MLFloat16> values_float16; std::vector<MLFloat16> values_float16;
values_float16.reserve(values.size()); values_float16.reserve(values.size());
for (auto v : values) { for (auto v : values) {
values_float16.push_back(MLFloat16(math::floatToHalf(v))); values_float16.push_back(MLFloat16(v));
} }
auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) { auto build_test_case_float = [&values, &values_float16](ModelTestBuilder& builder) {

Просмотреть файл

@ -31,7 +31,7 @@ void RandomFillHalfVector(const TensorShapeVector& shape, std::vector<MLFloat16>
std::vector<float> data_float(TensorShape(shape).Size()); std::vector<float> data_float(TensorShape(shape).Size());
RandomFillFloatVector(shape, data_float); RandomFillFloatVector(shape, data_float);
std::transform(data_float.begin(), data_float.end(), data.begin(), std::transform(data_float.begin(), data_float.end(), data.begin(),
[](float value) { return MLFloat16(math::floatToHalf(value)); }); [](float value) { return MLFloat16(value); });
} }
void RandomMasks(int64_t batch, int64_t sequence_length, std::vector<int64_t>& data) { void RandomMasks(int64_t batch, int64_t sequence_length, std::vector<int64_t>& data) {

Просмотреть файл

@ -332,7 +332,7 @@ struct TensorCheck<MLFloat16> {
<< "i:" << i; << "i:" << i;
} }
if (has_rel_err) { if (has_rel_err) {
EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(static_cast<float>(cur_expected[i])))
<< "i:" << i; << "i:" << i;
} }
} }

Просмотреть файл

@ -143,7 +143,7 @@ TEST(ConstantOfShape, TypeTests) {
RunTypedTest(TensorProto::INT8, int8_t(8)); RunTypedTest(TensorProto::INT8, int8_t(8));
RunTypedTest(TensorProto::INT16, int16_t(16)); RunTypedTest(TensorProto::INT16, int16_t(16));
RunTypedTest(TensorProto::FLOAT, 1.f); RunTypedTest(TensorProto::FLOAT, 1.f);
RunTypedTest(TensorProto::FLOAT16, MLFloat16(static_cast<uint16_t>(5))); RunTypedTest(TensorProto::FLOAT16, MLFloat16::FromBits(static_cast<uint16_t>(5)));
RunTypedTest(TensorProto::DOUBLE, 1.0); RunTypedTest(TensorProto::DOUBLE, 1.0);
RunTypedTest(TensorProto::INT32, int32_t(32)); RunTypedTest(TensorProto::INT32, int32_t(32));
RunTypedTest(TensorProto::INT64, int64_t(64)); RunTypedTest(TensorProto::INT64, int64_t(64));

Просмотреть файл

@ -24,9 +24,9 @@ TEST(CumSumTest, _1DTest) {
TEST(CumSumTest, _1DTestFloat16) { TEST(CumSumTest, _1DTestFloat16) {
if (DefaultCudaExecutionProvider().get() != nullptr) { if (DefaultCudaExecutionProvider().get() != nullptr) {
OpTester test("CumSum", 14, onnxruntime::kOnnxDomain); OpTester test("CumSum", 14, onnxruntime::kOnnxDomain);
test.AddInput<MLFloat16>("x", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); test.AddInput<MLFloat16>("x", {3}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
test.AddInput<int32_t>("axis", {}, {0}); test.AddInput<int32_t>("axis", {}, {0});
test.AddOutput<MLFloat16>("y", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(6.0f))}); test.AddOutput<MLFloat16>("y", {3}, {MLFloat16(1.0f), MLFloat16(3.0f), MLFloat16(6.0f)});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCpuExecutionProvider}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCpuExecutionProvider});
} }
} }

Просмотреть файл

@ -15,7 +15,7 @@ namespace test {
std::vector<MLFloat16> MakeMLFloat16(const std::initializer_list<float>& input) { std::vector<MLFloat16> MakeMLFloat16(const std::initializer_list<float>& input) {
std::vector<MLFloat16> output; std::vector<MLFloat16> output;
std::transform(input.begin(), input.end(), std::back_inserter(output), std::transform(input.begin(), input.end(), std::back_inserter(output),
[](float fl) { return MLFloat16(math::floatToHalf(fl)); }); [](float fl) { return MLFloat16(fl); });
return output; return output;
} }
@ -2652,8 +2652,8 @@ void TrigFloat16Test(OpTester& test, std::initializer_list<float> input) {
std::vector<MLFloat16> float16_input; std::vector<MLFloat16> float16_input;
std::vector<MLFloat16> float16_output; std::vector<MLFloat16> float16_output;
for (auto v : input) { for (auto v : input) {
float16_input.push_back(MLFloat16(math::floatToHalf(v))); float16_input.push_back(MLFloat16(v));
float16_output.push_back(MLFloat16(math::floatToHalf(op(v)))); float16_output.push_back(MLFloat16(op(v)));
} }
test.AddInput<MLFloat16>("X", dims, float16_input); test.AddInput<MLFloat16>("X", dims, float16_input);

Просмотреть файл

@ -25,8 +25,8 @@ TEST(RoundTest, SimpleTestDouble) {
TEST(RoundTest, SimpleTestFloat16) { TEST(RoundTest, SimpleTestFloat16) {
OpTester test("Round", 11, onnxruntime::kOnnxDomain); OpTester test("Round", 11, onnxruntime::kOnnxDomain);
test.AddInput<MLFloat16>("x", {5}, {MLFloat16(math::floatToHalf(0.9f)), MLFloat16(math::floatToHalf(2.5f)), MLFloat16(math::floatToHalf(2.3f)), MLFloat16(math::floatToHalf(1.5f)), MLFloat16(math::floatToHalf(-4.5f))}); test.AddInput<MLFloat16>("x", {5}, {MLFloat16(0.9f), MLFloat16(2.5f), MLFloat16(2.3f), MLFloat16(1.5f), MLFloat16(-4.5f)});
test.AddOutput<MLFloat16>("y", {5}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(-4.0f))}); test.AddOutput<MLFloat16>("y", {5}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(-4.0f)});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
} }

Просмотреть файл

@ -22,7 +22,7 @@ struct make_type {
template <class A> template <class A>
struct make_type<MLFloat16, A> { struct make_type<MLFloat16, A> {
static MLFloat16 make(A v) { static MLFloat16 make(A v) {
return MLFloat16(math::floatToHalf(float(v))); return MLFloat16(float(v));
} }
}; };
@ -34,7 +34,9 @@ struct make_type<BFloat16, A> {
}; };
template <class T, class OutputIter> template <class T, class OutputIter>
typename std::enable_if<!std::numeric_limits<T>::is_signed>::type typename std::enable_if<!std::numeric_limits<T>::is_signed &&
!std::is_same<T, MLFloat16>::value &&
!std::is_same<T, BFloat16>::value>::type
GenerateSequence(OutputIter out) { GenerateSequence(OutputIter out) {
for (int i = 0; i < 7; ++i) { for (int i = 0; i < 7; ++i) {
*out = make_type<T, int>::make(i); *out = make_type<T, int>::make(i);
@ -43,7 +45,9 @@ GenerateSequence(OutputIter out) {
} }
template <class T, class OutputIter> template <class T, class OutputIter>
typename std::enable_if<std::numeric_limits<T>::is_signed>::type typename std::enable_if<std::numeric_limits<T>::is_signed ||
std::is_same<T, MLFloat16>::value ||
std::is_same<T, BFloat16>::value>::type
GenerateSequence(OutputIter out) { GenerateSequence(OutputIter out) {
for (int i = -5; i < 2; ++i) { for (int i = -5; i < 2; ++i) {
*out = make_type<T, int>::make(i); *out = make_type<T, int>::make(i);
@ -61,7 +65,7 @@ struct ToTestableType {
template <> template <>
struct ToTestableType<MLFloat16> { struct ToTestableType<MLFloat16> {
static float to_type(MLFloat16 v) { static float to_type(MLFloat16 v) {
return math::halfToFloat(v.val); return v.ToFloat();
} }
}; };
@ -186,6 +190,23 @@ TEST(MathOpTest, Sign_MLFloat16) {
test.Run(OpTester::ExpectResult::kExpectSuccess); test.Run(OpTester::ExpectResult::kExpectSuccess);
} }
// Currently BFloat16 is not enabled for Sign kernel
// TEST(MathOpTest, Sign_BFloat16) {
// using namespace test_sign_internal;
// OpTester test("Sign", 9);
//
// std::vector<int64_t> input_dims{7};
// std::vector<BFloat16> input;
// GenerateSequence<BFloat16>(std::back_inserter(input));
// ASSERT_EQ(input.size(), 7U);
// test.AddInput<BFloat16>("input", input_dims, input);
//
// std::vector<BFloat16> output;
// TestImpl<BFloat16>(input.cbegin(), input.cend(), std::back_inserter(output));
// test.AddOutput<BFloat16>("output", input_dims, output);
// test.Run(OpTester::ExpectResult::kExpectSuccess);
//}
#if defined(USE_DNNL) #if defined(USE_DNNL)
TEST(MathOpTest, Sign_bfloat16) { TEST(MathOpTest, Sign_bfloat16) {
#ifdef USE_DNNL #ifdef USE_DNNL

Просмотреть файл

@ -90,7 +90,7 @@ void RunShrinkTest(const std::vector<ShrinkTestData<T>>& test_cases,
const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) { const std::vector<MLFloat16> ConvertFloatToMLFloat16(const std::vector<float>& float_data) {
std::vector<MLFloat16> new_data; std::vector<MLFloat16> new_data;
for (const auto& f : float_data) { for (const auto& f : float_data) {
new_data.push_back(MLFloat16(math::floatToHalf(f))); new_data.push_back(MLFloat16(f));
} }
return new_data; return new_data;
} }

Просмотреть файл

@ -124,34 +124,34 @@ TEST(ExpandOpTest, Expand_3x1x3x1_int64) {
TEST(ExpandOpTest, Expand_3x3_float16) { TEST(ExpandOpTest, Expand_3x3_float16) {
OpTester test("Expand", 8); OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))}); test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(1.0f)});
test.AddInput<int64_t>("data_1", {2}, {3, 3}); test.AddInput<int64_t>("data_1", {2}, {3, 3});
test.AddOutput<MLFloat16>("result", {3, 3}, test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f))}); MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)});
test.Run(); test.Run();
} }
TEST(ExpandOpTest, Expand_3x1_float16) { TEST(ExpandOpTest, Expand_3x1_float16) {
OpTester test("Expand", 8); OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
test.AddInput<int64_t>("data_1", {2}, {3, 1}); test.AddInput<int64_t>("data_1", {2}, {3, 1});
test.AddOutput<MLFloat16>("result", {3, 3}, test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)), {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
test.Run(); test.Run();
} }
TEST(ExpandOpTest, Expand_1x3_float16) { TEST(ExpandOpTest, Expand_1x3_float16) {
OpTester test("Expand", 8); OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))}); test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)});
test.AddInput<int64_t>("data_1", {2}, {1, 3}); test.AddInput<int64_t>("data_1", {2}, {1, 3});
test.AddOutput<MLFloat16>("result", {3, 3}, test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f),
MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f),
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))}); MLFloat16(3.0f), MLFloat16(3.0f), MLFloat16(3.0f)});
test.Run(); test.Run();
} }

Просмотреть файл

@ -20,7 +20,7 @@ TEST(IsNaNOpTest, IsNaNFloat) {
TEST(IsNaNOpTest, IsNaNFloat16) { TEST(IsNaNOpTest, IsNaNFloat16) {
OpTester test("IsNaN", 9, kOnnxDomain); OpTester test("IsNaN", 9, kOnnxDomain);
std::vector<int64_t> dims{2, 2}; std::vector<int64_t> dims{2, 2};
test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(NAN)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(NAN))})); test.AddInput<MLFloat16>("X", dims, std::initializer_list<MLFloat16>({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}));
test.AddOutput<bool>("Y", dims, {false, true, false, true}); test.AddOutput<bool>("Y", dims, {false, true, false, true});
test.Run(); test.Run();
} }

Просмотреть файл

@ -147,14 +147,14 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) {
std::vector<int64_t> input_shape({2, 3}); std::vector<int64_t> input_shape({2, 3});
std::vector<MLFloat16> input_vals; std::vector<MLFloat16> input_vals;
for (uint16_t i = 0; i < 6; ++i) for (uint16_t i = 0; i < 6; ++i)
input_vals.push_back(MLFloat16(i)); input_vals.push_back(MLFloat16::FromBits(static_cast<uint16_t>(i)));
std::vector<int64_t> perm = {1, 0}; std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2}); std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<MLFloat16> expected_vals = std::initializer_list<MLFloat16> expected_vals =
{MLFloat16{static_cast<uint16_t>(1)}, MLFloat16{static_cast<uint16_t>(4)}, {MLFloat16::FromBits(static_cast<uint16_t>(1)), MLFloat16::FromBits(static_cast<uint16_t>(4)),
MLFloat16{static_cast<uint16_t>(2)}, MLFloat16{static_cast<uint16_t>(5)}, MLFloat16::FromBits(static_cast<uint16_t>(2)), MLFloat16::FromBits(static_cast<uint16_t>(5)),
MLFloat16{static_cast<uint16_t>(3)}, MLFloat16{static_cast<uint16_t>(6)}}; MLFloat16::FromBits(static_cast<uint16_t>(3)), MLFloat16::FromBits(static_cast<uint16_t>(6))};
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
} }

Просмотреть файл

@ -2022,46 +2022,91 @@ TEST(CApiTest, create_tensor_with_data_float16) {
// Example with C++. However, what we are feeding underneath is really // Example with C++. However, what we are feeding underneath is really
// a continuous buffer of uint16_t // a continuous buffer of uint16_t
// Use 3rd party libraries such as Eigen to convert floats and doubles to float16 types. // Use 3rd party libraries such as Eigen to convert floats and doubles to float16 types.
Ort::Float16_t values[] = {15360, 16384, 16896, 17408, 17664}; // 1.f, 2.f, 3.f, 4.f, 5.f constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
constexpr size_t values_length = sizeof(values) / sizeof(values[0]); constexpr size_t values_length = std::size(values);
constexpr uint16_t expected_values[values_length] = {15360, 16384, 16896, 17408, 17664};
std::vector<int64_t> dims = {static_cast<int64_t>(values_length)}; std::vector<Ort::Float16_t> fp16_values;
fp16_values.reserve(values_length);
std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
[](float fl) { return Ort::Float16_t(fl); });
for (size_t i = 0; i < values_length; ++i) {
ASSERT_EQ(expected_values[i], fp16_values[i].val);
}
constexpr int64_t dims = static_cast<int64_t>(values_length);
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Ort::Value tensor = Ort::Value::CreateTensor<Ort::Float16_t>(info, values, values_length, dims.data(), dims.size()); Ort::Value tensor = Ort::Value::CreateTensor<Ort::Float16_t>(info, fp16_values.data(), values_length, &dims, 1u);
const auto* new_pointer = tensor.GetTensorData<Ort::Float16_t>(); const auto* new_pointer = tensor.GetTensorData<Ort::Float16_t>();
ASSERT_EQ(new_pointer, values);
ASSERT_EQ(new_pointer, fp16_values.data());
auto type_info = tensor.GetTypeInfo(); auto type_info = tensor.GetTypeInfo();
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
const auto element_count = tensor_info.GetElementCount();
ASSERT_EQ(values_length, element_count);
ASSERT_NE(tensor_info, nullptr); ASSERT_NE(tensor_info, nullptr);
ASSERT_EQ(1u, tensor_info.GetDimensionsCount()); ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
Ort::Float16_t value_at_1 = tensor.At<Ort::Float16_t>({1}); const Ort::Float16_t& value_at_1 = tensor.At<Ort::Float16_t>({1});
ASSERT_EQ(values[1], value_at_1); ASSERT_EQ(expected_values[1], value_at_1.val);
std::vector<float> output_values;
output_values.reserve(values_length);
const auto data_span = gsl::make_span(tensor.GetTensorData<Ort::Float16_t>(), element_count);
std::transform(data_span.begin(), data_span.end(), std::back_inserter(output_values),
[](const Ort::Float16_t& fp16) { return static_cast<float>(fp16); });
for (size_t i = 0; i < values_length; ++i) {
ASSERT_NEAR(values[i], output_values[i], 1e-3);
}
} }
TEST(CApiTest, create_tensor_with_data_bfloat16) { TEST(CApiTest, create_tensor_with_data_bfloat16) {
// Example with C++. However, what we are feeding underneath is really // Example with C++. However, what we are feeding underneath is really
// a continuous buffer of uint16_t // a continuous buffer of uint16_t
// Conversion from float to bfloat16 is simple. Strip off half of the bytes from float. // Conversion from float to bfloat16 is simple. Strip off half of the bytes from float.
Ort::BFloat16_t values[] = {16256, 16384, 16448, 16512, 16544}; // 1.f, 2.f, 3.f, 4.f, 5.f constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
constexpr size_t values_length = sizeof(values) / sizeof(values[0]); constexpr size_t values_length = std::size(values);
std::vector<int64_t> dims = {static_cast<int64_t>(values_length)}; constexpr uint16_t expected_values[] = {16256, 16384, 16448, 16512, 16544}; // 1.f, 2.f, 3.f, 4.f, 5.f
constexpr int64_t dims = static_cast<int64_t>(values_length);
std::vector<Ort::BFloat16_t> b16_values;
b16_values.reserve(values_length);
std::transform(std::begin(values), std::end(values), std::back_inserter(b16_values),
[](float fl) { return Ort::BFloat16_t(fl); });
for (size_t i = 0; i < values_length; ++i) {
ASSERT_EQ(expected_values[i], b16_values[i].val);
}
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Ort::Value tensor = Ort::Value::CreateTensor<Ort::BFloat16_t>(info, values, values_length, dims.data(), dims.size()); Ort::Value tensor = Ort::Value::CreateTensor<Ort::BFloat16_t>(info, b16_values.data(), values_length, &dims, 1u);
const auto* new_pointer = tensor.GetTensorData<Ort::BFloat16_t>(); const auto* new_pointer = tensor.GetTensorData<Ort::BFloat16_t>();
ASSERT_EQ(new_pointer, values); ASSERT_EQ(new_pointer, b16_values.data());
auto type_info = tensor.GetTypeInfo(); auto type_info = tensor.GetTypeInfo();
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); const auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
const auto element_count = tensor_info.GetElementCount();
ASSERT_NE(tensor_info, nullptr); ASSERT_NE(tensor_info, nullptr);
ASSERT_EQ(1u, tensor_info.GetDimensionsCount()); ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16); ASSERT_EQ(tensor_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16);
Ort::BFloat16_t value_at_1 = tensor.At<Ort::BFloat16_t>({1}); const Ort::BFloat16_t& value_at_1 = tensor.At<Ort::BFloat16_t>({1});
ASSERT_EQ(values[1], value_at_1); ASSERT_EQ(expected_values[1], value_at_1.val);
std::vector<float> output_values;
output_values.reserve(values_length);
const auto data_span = gsl::make_span(tensor.GetTensorData<Ort::BFloat16_t>(), element_count);
std::transform(data_span.begin(), data_span.end(), std::back_inserter(output_values),
[](const Ort::BFloat16_t& b16) { return static_cast<float>(b16); });
for (size_t i = 0; i < values_length; ++i) {
ASSERT_NEAR(values[i], output_values[i], 1e-3);
}
} }
#if !defined(DISABLE_FLOAT8_TYPES) #if !defined(DISABLE_FLOAT8_TYPES)
@ -2336,7 +2381,6 @@ TEST(CApiTest, get_version_string_cpp) {
TEST(CApiTest, get_build_info_string) { TEST(CApiTest, get_build_info_string) {
auto build_info_string = Ort::GetBuildInfoString(); auto build_info_string = Ort::GetBuildInfoString();
ASSERT_FALSE(build_info_string.empty()); ASSERT_FALSE(build_info_string.empty());
ASSERT_EQ(build_info_string, std::string(ORT_BUILD_INFO));
} }
TEST(CApiTest, TestSharedAllocators) { TEST(CApiTest, TestSharedAllocators) {

Просмотреть файл

@ -307,6 +307,323 @@ TEST(CApiTest, TypeInfoSequence) {
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
} }
TEST(CPPApi, ConvertFloatToFloat16) {
// Test data type
{
constexpr float sample = 1.0f;
Ort::Float16_t flt16(sample);
EXPECT_FALSE(flt16.IsNaN());
auto int_rep = flt16.val;
const Ort::Float16_t flt_from_int = Ort::Float16_t::FromBits(int_rep);
EXPECT_FALSE(flt_from_int.IsNaN());
EXPECT_EQ(flt16, flt_from_int);
const double diff = std::fabs(sample - flt_from_int.ToFloat());
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
EXPECT_TRUE(false);
}
}
// Test bulk conversion
{
const float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
std::vector<Ort::Float16_t> converted;
converted.reserve(std::size(sample));
// Run conversion
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
[](float v) { return Ort::Float16_t(v); });
for (size_t i = 0; i < std::size(sample); ++i) {
EXPECT_FALSE(converted[i].IsNaN());
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
EXPECT_TRUE(false);
}
}
std::vector<float> back_converted;
back_converted.reserve(std::size(sample));
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
[](const Ort::Float16_t& bf) { return static_cast<float>(bf); });
for (size_t i = 0; i < std::size(sample); ++i) {
EXPECT_FALSE(std::isnan(back_converted[i]));
const double diff = std::fabs(sample[i] - back_converted[i]);
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
EXPECT_TRUE(false);
}
}
}
}
TEST(CPPApi, Float16Zeros) {
const auto positive_zero = Ort::Float16_t::FromBits(0x0000);
EXPECT_FALSE(positive_zero.IsNegative());
const float float_positive_zero = static_cast<float>(positive_zero);
EXPECT_EQ(+0.0f, float_positive_zero);
EXPECT_FALSE(std::signbit(float_positive_zero));
const auto negative_zero = Ort::Float16_t::FromBits(0x8000);
EXPECT_TRUE(negative_zero.IsNegative());
const float float_positive_negzero = static_cast<float>(negative_zero);
EXPECT_EQ(-0.0f, float_positive_negzero);
EXPECT_TRUE(std::signbit(float_positive_negzero));
}
namespace {
const auto EpsilonFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kEpsilonBits);
const auto NaNFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kPositiveQNaNBits);
const auto MinValueFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kMinValueBits);
const auto MaxValueFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kMaxValueBits);
const auto InfinityFl16 = Ort::Float16_t::FromBits(Ort::Float16_t::kPositiveInfinityBits);
} // namespace
TEST(CPPApi, Float16Comparision) {
const auto left = Ort::Float16_t(-33.33f);
const auto left_same = Ort::Float16_t(-33.33f);
const auto right = Ort::Float16_t(66.66f);
const auto right_same = Ort::Float16_t(66.66f);
EXPECT_LT(EpsilonFl16, right);
EXPECT_EQ(left, left_same);
EXPECT_NE(left, left_same.Negate());
EXPECT_EQ(right, right_same);
EXPECT_NE(right, right_same.Negate());
EXPECT_LT(left, right);
EXPECT_LT(right.Negate(), left);
EXPECT_LT(left.Negate(), right);
}
TEST(CPPApi, Float16TestNAN) {
const Ort::Float16_t fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
// NaN are not equal to each other
EXPECT_NE(NaNFl16, fp16NANFromSingle);
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
EXPECT_TRUE(std::isnan(NanFromBFloat16));
EXPECT_FALSE(MaxValueFl16.IsNaN());
}
TEST(CPPApi, Float16NaNComparision) {
EXPECT_FALSE(NaNFl16 < NaNFl16);
EXPECT_TRUE(NaNFl16 != NaNFl16);
EXPECT_FALSE(NaNFl16 == NaNFl16);
EXPECT_FALSE(MaxValueFl16 < NaNFl16);
EXPECT_FALSE(MaxValueFl16 == NaNFl16);
EXPECT_FALSE(NaNFl16 < MinValueFl16);
EXPECT_LT(MinValueFl16, MaxValueFl16);
}
TEST(CPPApi, Float16Infinity) {
EXPECT_FALSE(MinValueFl16.IsInfinity());
EXPECT_FALSE(MaxValueFl16.IsInfinity());
EXPECT_TRUE(MaxValueFl16.IsFinite());
const Ort::Float16_t pos_infinity_from_float(std::numeric_limits<float>::infinity());
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
const Ort::Float16_t neg_infinity_from_float(-std::numeric_limits<float>::infinity());
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
const float pos_infinity_from_bfloat16 = static_cast<float>(InfinityFl16);
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
}
TEST(CPPApi, Float16NormalSubnormal) {
EXPECT_FALSE(InfinityFl16.IsNormal());
EXPECT_TRUE(Ort::Float16_t(45.6f).IsNormal());
EXPECT_FALSE(Ort::Float16_t(45.6f).IsSubnormal());
// 0b0_0000_0000_000_0001 ~0.000000059604645
constexpr uint16_t min_subnormal_bits = 0x0001;
const Ort::Float16_t smallest_subnormal = Ort::Float16_t::FromBits(min_subnormal_bits);
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
// float smallest positive subnormal is ~1.40129846432481707092E-45, and
// in float the same number above would be normal
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
EXPECT_TRUE(std::isnormal(float_from_smallest_subnormal));
// 0b0_0000_0000_111_1111; ~0.000060975552
constexpr uint16_t max_subnormal_bits = 0x007F;
const Ort::Float16_t largest_subnormal = Ort::Float16_t::FromBits(max_subnormal_bits);
EXPECT_TRUE(largest_subnormal.IsSubnormal());
EXPECT_FALSE(largest_subnormal.IsNormal());
// However, in float the same number above would be normal
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
EXPECT_TRUE(std::isnormal(float_from_largest_subnormal));
}
TEST(CPPApi, BFloat16ConvertFloatToBFloat16) {
// Test data type
{
constexpr float sample = 1.0f;
Ort::BFloat16_t flt16(sample);
EXPECT_FALSE(flt16.IsNaN());
auto int_rep = flt16.val;
const Ort::BFloat16_t flt_from_int = Ort::BFloat16_t::FromBits(int_rep);
EXPECT_FALSE(flt_from_int.IsNaN());
EXPECT_EQ(flt16, flt_from_int);
const double diff = std::fabs(sample - flt_from_int.ToFloat());
if (diff > FLT_EPSILON || (std::isnan(diff) && !std::isnan(sample))) {
EXPECT_TRUE(false);
}
}
// Test bulk conversion
{
const float sample[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
std::vector<Ort::BFloat16_t> converted;
converted.reserve(std::size(sample));
// Run conversion
std::transform(std::begin(sample), std::end(sample), std::back_inserter(converted),
[](float v) { return Ort::BFloat16_t(v); });
for (size_t i = 0; i < std::size(sample); ++i) {
EXPECT_FALSE(converted[i].IsNaN());
const double diff = std::fabs(sample[i] - converted[i].ToFloat());
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
EXPECT_TRUE(false);
}
}
std::vector<float> back_converted;
back_converted.reserve(std::size(sample));
std::transform(converted.cbegin(), converted.cend(), std::back_inserter(back_converted),
[](const Ort::BFloat16_t& bf) { return static_cast<float>(bf); });
for (size_t i = 0; i < std::size(sample); ++i) {
EXPECT_FALSE(std::isnan(back_converted[i]));
const double diff = std::fabs(sample[i] - back_converted[i]);
if ((std::isnan(diff) && !std::isnan(sample[i])) || diff > FLT_EPSILON) {
EXPECT_TRUE(false);
}
}
}
}
TEST(CPPApi, BFloat16Zeros) {
const auto positive_zero = Ort::BFloat16_t::FromBits(0x0000);
EXPECT_FALSE(positive_zero.IsNegative());
const float float_positive_zero = static_cast<float>(positive_zero);
EXPECT_EQ(+0.0f, float_positive_zero);
EXPECT_FALSE(std::signbit(float_positive_zero));
const auto negative_zero = Ort::BFloat16_t::FromBits(0x8000);
EXPECT_TRUE(negative_zero.IsNegative());
const float float_positive_negzero = static_cast<float>(negative_zero);
EXPECT_EQ(-0.0f, float_positive_negzero);
EXPECT_TRUE(std::signbit(float_positive_negzero));
}
namespace {
const auto EpsilonBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kEpsilonBits);
const auto NaNBfl15 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kPositiveQNaNBits);
const auto MinValueBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kMinValueBits);
const auto MaxValueBfl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kMaxValueBits);
const auto InfinityBFl16 = Ort::BFloat16_t::FromBits(Ort::BFloat16_t::kPositiveInfinityBits);
} // namespace
TEST(CPPApi, BFloat16Comparision) {
const auto left = Ort::BFloat16_t(-33.33f);
const auto left_same = Ort::BFloat16_t(-33.33f);
const auto right = Ort::BFloat16_t(66.66f);
const auto right_same = Ort::BFloat16_t(66.66f);
EXPECT_LT(EpsilonBfl16, right);
EXPECT_EQ(left, left_same);
EXPECT_NE(left, left_same.Negate());
EXPECT_EQ(right, right_same);
EXPECT_NE(right, right_same.Negate());
EXPECT_LT(left, right);
EXPECT_LT(right.Negate(), left);
EXPECT_LT(left.Negate(), right);
}
TEST(CPPApi, BFloat16TestNAN) {
const Ort::BFloat16_t fp16NANFromSingle(std::numeric_limits<float>::quiet_NaN());
EXPECT_TRUE(fp16NANFromSingle.IsNaN());
// NaN are not equal to each other
EXPECT_NE(NaNBfl15, fp16NANFromSingle);
const float NanFromBFloat16 = fp16NANFromSingle.ToFloat();
EXPECT_TRUE(std::isnan(NanFromBFloat16));
EXPECT_FALSE(MaxValueBfl16.IsNaN());
}
TEST(CPPApi, BFloat16NaNComparision) {
EXPECT_FALSE(NaNBfl15 < NaNBfl15);
EXPECT_TRUE(NaNBfl15 != NaNBfl15);
EXPECT_FALSE(NaNBfl15 == NaNBfl15);
EXPECT_FALSE(MaxValueBfl16 < NaNBfl15);
EXPECT_FALSE(MaxValueBfl16 == NaNBfl15);
EXPECT_FALSE(NaNBfl15 < MinValueBfl16);
EXPECT_LT(MinValueBfl16, MaxValueBfl16);
}
TEST(CPPApi, BFloat16Infinity) {
EXPECT_FALSE(MinValueBfl16.IsInfinity());
EXPECT_FALSE(MaxValueBfl16.IsInfinity());
EXPECT_TRUE(MaxValueBfl16.IsFinite());
const Ort::BFloat16_t pos_infinity_from_float(std::numeric_limits<float>::infinity());
EXPECT_TRUE(pos_infinity_from_float.IsInfinity());
EXPECT_FALSE(pos_infinity_from_float.IsFinite());
EXPECT_FALSE(pos_infinity_from_float.IsNegative());
const Ort::BFloat16_t neg_infinity_from_float(-std::numeric_limits<float>::infinity());
EXPECT_TRUE(neg_infinity_from_float.IsInfinity());
EXPECT_FALSE(neg_infinity_from_float.IsFinite());
EXPECT_TRUE(neg_infinity_from_float.IsNegative());
const float pos_infinity_from_bfloat16 = static_cast<float>(InfinityBFl16);
EXPECT_TRUE(std::isinf(pos_infinity_from_bfloat16));
}
TEST(CPPApi, BFloat16NormalSubnormal) {
EXPECT_FALSE(InfinityBFl16.IsNormal());
EXPECT_TRUE(Ort::BFloat16_t(45.6f).IsNormal());
EXPECT_FALSE(Ort::BFloat16_t(45.6f).IsSubnormal());
// 0b0_0000_0000_000_0001
constexpr uint16_t min_subnormal_bits = 0x0001;
const Ort::BFloat16_t smallest_subnormal = Ort::BFloat16_t::FromBits(min_subnormal_bits);
EXPECT_TRUE(smallest_subnormal.IsSubnormal());
EXPECT_FALSE(smallest_subnormal.IsNormal());
const float float_from_smallest_subnormal = static_cast<float>(smallest_subnormal);
EXPECT_FALSE(std::isnormal(float_from_smallest_subnormal));
// 0b0_0000_0000_111_1111;
constexpr uint16_t max_subnormal_bits = 0x007F;
const Ort::BFloat16_t largest_subnormal = Ort::BFloat16_t::FromBits(max_subnormal_bits);
EXPECT_TRUE(largest_subnormal.IsSubnormal());
EXPECT_FALSE(largest_subnormal.IsNormal());
const float float_from_largest_subnormal = static_cast<float>(largest_subnormal);
EXPECT_FALSE(std::isnormal(float_from_largest_subnormal));
}
#if !defined(DISABLE_SPARSE_TENSORS) #if !defined(DISABLE_SPARSE_TENSORS)
TEST(CApiTest, SparseTensorUsingAPI) { TEST(CApiTest, SparseTensorUsingAPI) {
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);

Просмотреть файл

@ -495,7 +495,7 @@ float ParseValueToFloat(MLFloat16 data_value) {
template <> template <>
float ParseValueToFloat(float data_value) { float ParseValueToFloat(float data_value) {
// Covert float to half and then convert back to float to simulate rounding to half // Covert float to half and then convert back to float to simulate rounding to half
return ParseValueToFloat(MLFloat16(math::floatToHalf(data_value))); return ParseValueToFloat(MLFloat16(data_value));
} }
template <typename RealT, typename ExpectT> template <typename RealT, typename ExpectT>

Просмотреть файл

@ -392,7 +392,7 @@ AttributeProto GradientBuilderBase::AttributeDefinitionToAttributeProto(
elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16); elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
float float_value = value.get<float>(); float float_value = value.get<float>();
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
tensor_proto = ScalarTensorProto(MLFloat16(math::floatToHalf(float_value)), {1}); tensor_proto = ScalarTensorProto(MLFloat16(float_value), {1});
} else if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { } else if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
tensor_proto = ScalarTensorProto(BFloat16(float_value), {1}); tensor_proto = ScalarTensorProto(BFloat16(float_value), {1});
} else { } else {

Просмотреть файл

@ -262,7 +262,7 @@ class GradientBuilderBase {
// We only support FP32, FP16 and BF16 for these constant nodes for now. // We only support FP32, FP16 and BF16 for these constant nodes for now.
static NodeDef ConstantScalarNode(float value, const std::string& arg_name, int elem_type) { static NodeDef ConstantScalarNode(float value, const std::string& arg_name, int elem_type) {
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return ConstantScalarNode(MLFloat16(math::floatToHalf(value)), {1}, arg_name); return ConstantScalarNode(MLFloat16(value), {1}, arg_name);
} }
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
@ -294,7 +294,7 @@ class GradientBuilderBase {
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) { static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return ScalarTensorProto(MLFloat16(math::floatToHalf(value)), {1}); return ScalarTensorProto(MLFloat16(value), {1});
} }
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {

Просмотреть файл

@ -97,7 +97,7 @@ Status AdamOptimizerBuilder::Build(
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims)); ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name); moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
} else if (opt_configs[i].use_mixed_precision_moments) { } else if (opt_configs[i].use_mixed_precision_moments) {
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(math::floatToHalf(0.f)), weight_dims); moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(0.f), weight_dims);
} else { } else {
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims); moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
} }

Просмотреть файл

@ -231,7 +231,7 @@ Status LambOptimizerBuilder::Build(
ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims)); ORT_THROW_IF_ERROR(IsMatchingTypeAndShape(init_tensor, element_type, weight_dims));
moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name); moment_tensor_proto = utils::TensorToTensorProto(init_tensor, gradient_moment_name);
} else if (opt_configs[i].use_mixed_precision_moments) { } else if (opt_configs[i].use_mixed_precision_moments) {
moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(math::floatToHalf(0.f)), weight_dims); moment_tensor_proto = CreateTensorProto<MLFloat16>(gradient_moment_name, MLFloat16(0.f), weight_dims);
} else { } else {
moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims); moment_tensor_proto = CreateTensorProto<float>(gradient_moment_name, 0.f, weight_dims);
} }

Просмотреть файл

@ -35,7 +35,7 @@ static void RunTrainingSessionLoadOptimTests(std::string optim_name, bool mixed_
TrainingSession::OptimizerState init_optimizer_state{}; TrainingSession::OptimizerState init_optimizer_state{};
if (mixed_precision_moments) { if (mixed_precision_moments) {
GenerateOptimizerInitialState<MLFloat16>(optim_name, MLFloat16(math::floatToHalf(2.5)), init_optimizer_state); GenerateOptimizerInitialState<MLFloat16>(optim_name, MLFloat16(2.5f), init_optimizer_state);
} else { } else {
GenerateOptimizerInitialState<float>(optim_name, 2.5f, init_optimizer_state); GenerateOptimizerInitialState<float>(optim_name, 2.5f, init_optimizer_state);
} }

Просмотреть файл

@ -54,7 +54,7 @@ void RunDropoutTest(const bool use_mask, const std::vector<int64_t>& input_shape
ratio = 0.5f; ratio = 0.5f;
} else { } else {
if (use_float16_ratio) { if (use_float16_ratio) {
t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))}); t.AddInput("ratio", {}, {MLFloat16(ratio)});
} else { } else {
t.AddInput("ratio", {}, {ratio}); t.AddInput("ratio", {}, {ratio});
} }