Making MF to be a part of ML.NET (#1263)
Based on LIBMF, a new matrix factorization module is added into ML.NET. LIBMF is used as a submodule in ML.NET repo and it would be compiled into a Nuget for releasing. Please see LIBMF's official pages (https://www.csie.ntu.edu.tw/~cjlin/libmf/) for mathmatical details.
This commit is contained in:
Родитель
b791d66787
Коммит
eae76959e6
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "src/Native/LIBMFNative/libmf"]
|
||||||
|
path = src/Native/MatrixFactorizationNative/libmf
|
||||||
|
url = https://github.com/cjlin1/libmf.git
|
|
@ -133,6 +133,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples", "doc
|
||||||
EndProject
|
EndProject
|
||||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.SamplesUtils", "src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj", "{11A5210E-2EA7-42F1-80DB-827762E9C781}"
|
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.SamplesUtils", "src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj", "{11A5210E-2EA7-42F1-80DB-827762E9C781}"
|
||||||
EndProject
|
EndProject
|
||||||
|
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Recommender", "src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj", "{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}"
|
||||||
|
EndProject
|
||||||
Global
|
Global
|
||||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||||
Debug|Any CPU = Debug|Any CPU
|
Debug|Any CPU = Debug|Any CPU
|
||||||
|
@ -501,6 +503,14 @@ Global
|
||||||
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release|Any CPU.Build.0 = Release|Any CPU
|
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||||
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
|
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
|
||||||
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
|
{11A5210E-2EA7-42F1-80DB-827762E9C781}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
|
||||||
EndGlobalSection
|
EndGlobalSection
|
||||||
GlobalSection(SolutionProperties) = preSolution
|
GlobalSection(SolutionProperties) = preSolution
|
||||||
HideSolutionNode = FALSE
|
HideSolutionNode = FALSE
|
||||||
|
@ -556,6 +566,7 @@ Global
|
||||||
{4B101D58-E7E4-4877-A536-A9B41E2E82A3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
|
{4B101D58-E7E4-4877-A536-A9B41E2E82A3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
|
||||||
{ECB71297-9DF1-48CE-B93A-CD969221F9B6} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
|
{ECB71297-9DF1-48CE-B93A-CD969221F9B6} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
|
||||||
{11A5210E-2EA7-42F1-80DB-827762E9C781} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
|
{11A5210E-2EA7-42F1-80DB-827762E9C781} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
|
||||||
|
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
|
||||||
EndGlobalSection
|
EndGlobalSection
|
||||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||||
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
|
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
ML.NET uses third-party libraries or other resources that may be
|
||||||
|
distributed under licenses different than the ML.NET software.
|
||||||
|
|
||||||
|
In the event that we accidentally failed to list a required notice, please
|
||||||
|
bring it to our attention. Post an issue or email us:
|
||||||
|
|
||||||
|
dotnet@microsoft.com
|
||||||
|
|
||||||
|
The attached notices are provided for information only.
|
||||||
|
|
||||||
|
License notice for LIBMF
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
https://github.com/cjlin1/libmf
|
||||||
|
|
||||||
|
Copyright (c) 2014-2015 The LIBMF Project.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither name of copyright holders nor the names of its contributors
|
||||||
|
may be used to endorse or promote products derived from this software
|
||||||
|
without specific prior written permission.
|
||||||
|
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
|
||||||
|
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,13 @@
|
||||||
|
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
|
||||||
|
|
||||||
|
<PropertyGroup>
|
||||||
|
<TargetFramework>netstandard2.0</TargetFramework>
|
||||||
|
<PackageDescription>LIBMF, the core computation library for matrix factorization in ML.NET</PackageDescription>
|
||||||
|
</PropertyGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<Content Include="..\common\CommonPackage.props" Pack="true" PackagePath="build\netstandard2.0\$(MSBuildProjectName).props" />
|
||||||
|
<Content Include="$(SourceDir)Native\MatrixFactorizationNative\libmf\COPYRIGHT" Pack="true" PackagePath=".\" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
</Project>
|
|
@ -0,0 +1,5 @@
|
||||||
|
<Project DefaultTargets="Pack">
|
||||||
|
|
||||||
|
<Import Project="Microsoft.ML.MatrixFactorization.nupkgproj" />
|
||||||
|
|
||||||
|
</Project>
|
|
@ -22,6 +22,7 @@
|
||||||
<ProjectReference Include="..\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
|
||||||
<ProjectReference Include="..\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
||||||
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
|
||||||
|
<ProjectReference Include="..\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj" />
|
||||||
<ProjectReference Include="..\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
|
||||||
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
||||||
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
|
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
|
||||||
|
|
|
@ -0,0 +1,503 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.IO;
|
||||||
|
using Microsoft.ML.Runtime;
|
||||||
|
using Microsoft.ML.Runtime.Data;
|
||||||
|
using Microsoft.ML.Runtime.Data.IO;
|
||||||
|
using Microsoft.ML.Runtime.Internal.Internallearn;
|
||||||
|
using Microsoft.ML.Runtime.Internal.Utilities;
|
||||||
|
using Microsoft.ML.Runtime.Model;
|
||||||
|
using Microsoft.ML.Runtime.Recommender;
|
||||||
|
using Microsoft.ML.Runtime.Recommender.Internal;
|
||||||
|
using Microsoft.ML.Trainers;
|
||||||
|
|
||||||
|
[assembly: LoadableClass(typeof(MatrixFactorizationPredictor), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationPredictor.LoaderSignature)]
|
||||||
|
|
||||||
|
[assembly: LoadableClass(typeof(MatrixFactorizationPredictionTransformer), typeof(MatrixFactorizationPredictionTransformer),
|
||||||
|
null, typeof(SignatureLoadModel), "", MatrixFactorizationPredictionTransformer.LoaderSignature)]
|
||||||
|
|
||||||
|
namespace Microsoft.ML.Runtime.Recommender
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// <see cref="MatrixFactorizationPredictor"/> stores two factor matrices, P and Q, for approximating the training matrix, R, by P * Q,
|
||||||
|
/// where * is a matrix multiplication. This predictor expects two inputs, row index and column index, and produces the (approximated)
|
||||||
|
/// value at the location specified by the two inputs in R. More specifically, if input row and column indices are u and v, respectively.
|
||||||
|
/// The output (a scalar) would be the inner product product of the u-th row in P and the v-th column in Q.
|
||||||
|
/// </summary>
|
||||||
|
public sealed class MatrixFactorizationPredictor : IPredictor, ICanSaveModel, ICanSaveInTextFormat, ISchemaBindableMapper
|
||||||
|
{
|
||||||
|
internal const string LoaderSignature = "MFPredictor";
|
||||||
|
internal const string RegistrationName = "MatrixFactorizationPredictor";
|
||||||
|
|
||||||
|
private static VersionInfo GetVersionInfo()
|
||||||
|
{
|
||||||
|
return new VersionInfo(
|
||||||
|
modelSignature: "FAFAMAPD",
|
||||||
|
verWrittenCur: 0x00010001,
|
||||||
|
verReadableCur: 0x00010001,
|
||||||
|
verWeCanReadBack: 0x00010001,
|
||||||
|
loaderSignature: LoaderSignature,
|
||||||
|
loaderAssemblyName: typeof(MatrixFactorizationPredictor).Assembly.FullName);
|
||||||
|
}
|
||||||
|
|
||||||
|
private readonly IHost _host;
|
||||||
|
// The number of rows.
|
||||||
|
private readonly int _numberOfRows;
|
||||||
|
// The number of columns.
|
||||||
|
private readonly int _numberofColumns;
|
||||||
|
// The rank of the factor matrices.
|
||||||
|
private readonly int _approximationRank;
|
||||||
|
// Packed _numberOfRows by _approximationRank matrix.
|
||||||
|
private readonly float[] _leftFactorMatrix;
|
||||||
|
// Packed _approximationRank by _numberofColumns matrix.
|
||||||
|
private readonly float[] _rightFactorMatrix;
|
||||||
|
|
||||||
|
public PredictionKind PredictionKind
|
||||||
|
{
|
||||||
|
get { return PredictionKind.Recommendation; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public ColumnType OutputType { get { return NumberType.Float; } }
|
||||||
|
|
||||||
|
public ColumnType MatrixColumnIndexType { get; }
|
||||||
|
public ColumnType MatrixRowIndexType { get; }
|
||||||
|
|
||||||
|
internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType)
|
||||||
|
{
|
||||||
|
Contracts.CheckValue(env, nameof(env));
|
||||||
|
_host = env.Register(RegistrationName);
|
||||||
|
_host.Assert(matrixColumnIndexType.RawKind == DataKind.U4);
|
||||||
|
_host.Assert(matrixRowIndexType.RawKind == DataKind.U4);
|
||||||
|
_host.CheckValue(buffer, nameof(buffer));
|
||||||
|
_host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
|
||||||
|
_host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));
|
||||||
|
|
||||||
|
buffer.Get(out _numberOfRows, out _numberofColumns, out _approximationRank, out _leftFactorMatrix, out _rightFactorMatrix);
|
||||||
|
_host.Assert(_numberofColumns == matrixColumnIndexType.Count);
|
||||||
|
_host.Assert(_numberOfRows == matrixRowIndexType.Count);
|
||||||
|
_host.Assert(_leftFactorMatrix.Length == _numberOfRows * _approximationRank);
|
||||||
|
_host.Assert(_rightFactorMatrix.Length == _numberofColumns * _approximationRank);
|
||||||
|
|
||||||
|
MatrixColumnIndexType = matrixColumnIndexType;
|
||||||
|
MatrixRowIndexType = matrixRowIndexType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx)
|
||||||
|
{
|
||||||
|
Contracts.CheckValue(env, nameof(env));
|
||||||
|
_host = env.Register(RegistrationName);
|
||||||
|
// *** Binary format ***
|
||||||
|
// int: number of rows (m), the limit on row
|
||||||
|
// ulong: Minimum value of the row key-type
|
||||||
|
// int: number of columns (n), the limit on column
|
||||||
|
// ulong: Minimum value of the column key-type
|
||||||
|
// int: rank of factor matrices (k)
|
||||||
|
// float[m * k]: the left factor matrix
|
||||||
|
// float[k * n]: the right factor matrix
|
||||||
|
|
||||||
|
_numberOfRows = ctx.Reader.ReadInt32();
|
||||||
|
_host.CheckDecode(_numberOfRows > 0);
|
||||||
|
ulong mMin = ctx.Reader.ReadUInt64();
|
||||||
|
_host.CheckDecode((ulong)_numberOfRows <= ulong.MaxValue - mMin);
|
||||||
|
_numberofColumns = ctx.Reader.ReadInt32();
|
||||||
|
_host.CheckDecode(_numberofColumns > 0);
|
||||||
|
ulong nMin = ctx.Reader.ReadUInt64();
|
||||||
|
_host.CheckDecode((ulong)_numberofColumns <= ulong.MaxValue - nMin);
|
||||||
|
_approximationRank = ctx.Reader.ReadInt32();
|
||||||
|
_host.CheckDecode(_approximationRank > 0);
|
||||||
|
|
||||||
|
_leftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberOfRows * _approximationRank));
|
||||||
|
_rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked(_numberofColumns * _approximationRank));
|
||||||
|
|
||||||
|
MatrixColumnIndexType = new KeyType(DataKind.U4, nMin, _numberofColumns);
|
||||||
|
MatrixRowIndexType = new KeyType(DataKind.U4, mMin, _numberOfRows);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Load model from the given context
|
||||||
|
/// </summary>
|
||||||
|
public static MatrixFactorizationPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||||
|
{
|
||||||
|
Contracts.CheckValue(env, nameof(env));
|
||||||
|
env.CheckValue(ctx, nameof(ctx));
|
||||||
|
ctx.CheckAtModel(GetVersionInfo());
|
||||||
|
return new MatrixFactorizationPredictor(env, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Save model to the given context
|
||||||
|
/// </summary>
|
||||||
|
public void Save(ModelSaveContext ctx)
|
||||||
|
{
|
||||||
|
ctx.CheckAtModel();
|
||||||
|
ctx.SetVersionInfo(GetVersionInfo());
|
||||||
|
|
||||||
|
// *** Binary format ***
|
||||||
|
// int: number of rows (m), the limit on row
|
||||||
|
// ulong: Minimum value of the row key-type
|
||||||
|
// int: number of columns (n), the limit on column
|
||||||
|
// ulong: Minimum value of the column key-type
|
||||||
|
// int: rank of factor matrices (k)
|
||||||
|
// float[m * k]: the left factor matrix
|
||||||
|
// float[k * n]: the right factor matrix
|
||||||
|
|
||||||
|
_host.Check(_numberOfRows > 0, "Number of rows must be positive");
|
||||||
|
_host.Check(_numberofColumns > 0, "Number of columns must be positive");
|
||||||
|
_host.Check(_approximationRank > 0, "Number of latent factors must be positive");
|
||||||
|
ctx.Writer.Write(_numberOfRows);
|
||||||
|
ctx.Writer.Write((MatrixRowIndexType as KeyType).Min);
|
||||||
|
ctx.Writer.Write(_numberofColumns);
|
||||||
|
ctx.Writer.Write((MatrixColumnIndexType as KeyType).Min);
|
||||||
|
ctx.Writer.Write(_approximationRank);
|
||||||
|
_host.Check(Utils.Size(_leftFactorMatrix) == _numberOfRows * _approximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
|
||||||
|
_host.Check(Utils.Size(_rightFactorMatrix) == _numberofColumns * _approximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
|
||||||
|
Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix, _numberOfRows * _approximationRank);
|
||||||
|
Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix, _numberofColumns * _approximationRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Save the trained matrix factorization model (two factor matrices) in text format
|
||||||
|
/// </summary>
|
||||||
|
public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
|
||||||
|
{
|
||||||
|
writer.WriteLine("# Imputed matrix is P * Q'");
|
||||||
|
writer.WriteLine("# P in R^({0} x {1}), rows correpond to Y item", _numberOfRows, _approximationRank);
|
||||||
|
for (int i = 0; i < _leftFactorMatrix.Length; ++i)
|
||||||
|
{
|
||||||
|
writer.Write(_leftFactorMatrix[i].ToString("G"));
|
||||||
|
if (i % _approximationRank == _approximationRank - 1)
|
||||||
|
writer.WriteLine();
|
||||||
|
else
|
||||||
|
writer.Write('\t');
|
||||||
|
}
|
||||||
|
writer.WriteLine("# Q in R^({0} x {1}), rows correpond to X item", _numberofColumns, _approximationRank);
|
||||||
|
for (int i = 0; i < _rightFactorMatrix.Length; ++i)
|
||||||
|
{
|
||||||
|
writer.Write(_rightFactorMatrix[i].ToString("G"));
|
||||||
|
if (i % _approximationRank == _approximationRank - 1)
|
||||||
|
writer.WriteLine();
|
||||||
|
else
|
||||||
|
writer.Write('\t');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ValueGetter<float> GetGetter(ValueGetter<uint> matrixColumnIndexGetter, ValueGetter<uint> matrixRowIndexGetter)
|
||||||
|
{
|
||||||
|
_host.AssertValue(matrixColumnIndexGetter);
|
||||||
|
_host.AssertValue(matrixRowIndexGetter);
|
||||||
|
|
||||||
|
uint matrixColumnIndex = 0;
|
||||||
|
uint matrixRowIndex = 0;
|
||||||
|
|
||||||
|
var mapper = GetMapper<uint, uint, float>();
|
||||||
|
ValueGetter<float> del =
|
||||||
|
(ref float value) =>
|
||||||
|
{
|
||||||
|
matrixColumnIndexGetter(ref matrixColumnIndex);
|
||||||
|
matrixRowIndexGetter(ref matrixRowIndex);
|
||||||
|
mapper(ref matrixColumnIndex, ref matrixRowIndex, ref value);
|
||||||
|
};
|
||||||
|
return del;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create the mapper required by matrix factorization's predictor. That mapper maps two
|
||||||
|
/// index inputs (e.g., row index and column index) to an approximated value located by the
|
||||||
|
/// two indexes in the training matrix. In recommender system where the training matrix stores
|
||||||
|
/// ratings from users to items, the mappers maps user ID and item ID to the rating of that
|
||||||
|
/// item given by the user.
|
||||||
|
/// </summary>
|
||||||
|
public ValueMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut> GetMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut>()
|
||||||
|
{
|
||||||
|
string msg = null;
|
||||||
|
msg = "Invalid " + nameof(TMatrixColumnIndexIn) + " in GetMapper: " + typeof(TMatrixColumnIndexIn);
|
||||||
|
_host.Check(typeof(TMatrixColumnIndexIn) == typeof(uint), msg);
|
||||||
|
|
||||||
|
msg = "Invalid " + nameof(TMatrixRowIndexIn) + " in GetMapper: " + typeof(TMatrixRowIndexIn);
|
||||||
|
_host.Check(typeof(TMatrixRowIndexIn) == typeof(uint), msg);
|
||||||
|
|
||||||
|
msg = "Invalid " + nameof(TOut) + " in GetMapper: " + typeof(TOut);
|
||||||
|
_host.Check(typeof(TOut) == typeof(float), msg);
|
||||||
|
|
||||||
|
ValueMapper<uint, uint, float> mapper = MapperCore;
|
||||||
|
return mapper as ValueMapper<TMatrixColumnIndexIn, TMatrixRowIndexIn, TOut>;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void MapperCore(ref uint srcCol, ref uint srcRow, ref float dst)
|
||||||
|
{
|
||||||
|
// REVIEW: The key-type version a bit more "strict" than the predictor
|
||||||
|
// version, since the predictor version can't know the maximum bound during
|
||||||
|
// training. For higher-than-expected values, the predictor version would return
|
||||||
|
// 0, rather than NaN as we do here. It is in my mind an open question as to what
|
||||||
|
// is actually correct.
|
||||||
|
if (srcRow == 0 || srcRow > _numberOfRows || srcCol == 0 || srcCol > _numberofColumns)
|
||||||
|
{
|
||||||
|
dst = float.NaN;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst = Score((int)(srcCol - 1), (int)(srcRow - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
private float Score(int columnIndex, int rowIndex)
|
||||||
|
{
|
||||||
|
_host.Assert(0 <= rowIndex && rowIndex < _numberOfRows);
|
||||||
|
_host.Assert(0 <= columnIndex && columnIndex < _numberofColumns);
|
||||||
|
float score = 0;
|
||||||
|
// Starting position of the rowIndex-th row in the left factor factor matrix
|
||||||
|
int rowOffset = rowIndex * _approximationRank;
|
||||||
|
// Starting position of the columnIndex-th column in the right factor factor matrix
|
||||||
|
int columnOffset = columnIndex * _approximationRank;
|
||||||
|
for (int i = 0; i < _approximationRank; i++)
|
||||||
|
score += _leftFactorMatrix[rowOffset + i] * _rightFactorMatrix[columnOffset + i];
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a row mapper based on regression scorer. Because matrix factorization predictor maps a tuple of a row ID (u) and a column ID (v)
|
||||||
|
/// to the expected numerical value at the u-th row and the v-th column in the considered matrix, it is essentially a regressor.
|
||||||
|
/// </summary>
|
||||||
|
public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
|
||||||
|
{
|
||||||
|
Contracts.AssertValue(env);
|
||||||
|
env.AssertValue(schema);
|
||||||
|
return new RowMapper(env, this, schema, Schema.Create(new ScoreMapperSchema(OutputType, MetadataUtils.Const.ScoreColumnKind.Regression)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private sealed class RowMapper : ISchemaBoundRowMapper
|
||||||
|
{
|
||||||
|
private readonly MatrixFactorizationPredictor _parent;
|
||||||
|
// The tail "ColumnIndex" means the column index in IDataView
|
||||||
|
private readonly int _matrixColumnIndexColumnIndex;
|
||||||
|
private readonly int _matrixRowIndexCololumnIndex;
|
||||||
|
// The tail "ColumnName" means the column name in IDataView
|
||||||
|
private readonly string _matrixColumnIndexColumnName;
|
||||||
|
private readonly string _matrixRowIndexColumnName;
|
||||||
|
private IHostEnvironment _env;
|
||||||
|
public Schema Schema { get; }
|
||||||
|
public Schema InputSchema => InputRoleMappedSchema.Schema;
|
||||||
|
|
||||||
|
public RoleMappedSchema InputRoleMappedSchema { get; }
|
||||||
|
|
||||||
|
public RowMapper(IHostEnvironment env, MatrixFactorizationPredictor parent, RoleMappedSchema schema, Schema outputSchema)
|
||||||
|
{
|
||||||
|
Contracts.AssertValue(parent);
|
||||||
|
_env = env;
|
||||||
|
_parent = parent;
|
||||||
|
|
||||||
|
// Check role of matrix column index
|
||||||
|
var matrixColumnList = schema.GetColumns(RecommenderUtils.MatrixColumnIndexKind);
|
||||||
|
string msg = $"'{RecommenderUtils.MatrixColumnIndexKind}' column doesn't exist or not unique";
|
||||||
|
_env.Check(Utils.Size(matrixColumnList) == 1, msg);
|
||||||
|
|
||||||
|
// Check role of matrix row index
|
||||||
|
var matrixRowList = schema.GetColumns(RecommenderUtils.MatrixRowIndexKind);
|
||||||
|
msg = $"'{RecommenderUtils.MatrixRowIndexKind}' column doesn't exist or not unique";
|
||||||
|
_env.Check(Utils.Size(matrixRowList) == 1, msg);
|
||||||
|
|
||||||
|
_matrixColumnIndexColumnName = matrixColumnList[0].Name;
|
||||||
|
_matrixColumnIndexColumnIndex = matrixColumnList[0].Index;
|
||||||
|
|
||||||
|
_matrixRowIndexColumnName = matrixRowList[0].Name;
|
||||||
|
_matrixRowIndexCololumnIndex = matrixRowList[0].Index;
|
||||||
|
|
||||||
|
CheckInputSchema(schema.Schema, _matrixColumnIndexColumnIndex, _matrixRowIndexCololumnIndex);
|
||||||
|
InputRoleMappedSchema = schema;
|
||||||
|
Schema = outputSchema;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Func<int, bool> GetDependencies(Func<int, bool> predicate)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < Schema.ColumnCount; i++)
|
||||||
|
{
|
||||||
|
if (predicate(i))
|
||||||
|
return col => (col == _matrixColumnIndexColumnIndex || col == _matrixRowIndexCololumnIndex);
|
||||||
|
}
|
||||||
|
return col => false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
|
||||||
|
{
|
||||||
|
yield return RecommenderUtils.MatrixColumnIndexKind.Bind(_matrixColumnIndexColumnName);
|
||||||
|
yield return RecommenderUtils.MatrixRowIndexKind.Bind(_matrixRowIndexColumnName);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void CheckInputSchema(ISchema schema, int matrixColumnIndexCol, int matrixRowIndexCol)
|
||||||
|
{
|
||||||
|
// See if matrix-column-index role's type matches the one expected in the trained predictor
|
||||||
|
var type = schema.GetColumnType(matrixColumnIndexCol);
|
||||||
|
string msg = string.Format("Input column index type '{0}' incompatible with predictor's column index type '{1}'", type, _parent.MatrixColumnIndexType);
|
||||||
|
_env.CheckParam(type.Equals(_parent.MatrixColumnIndexType), nameof(schema), msg);
|
||||||
|
|
||||||
|
// See if matrix-column-index role's type matches the one expected in the trained predictor
|
||||||
|
type = schema.GetColumnType(matrixRowIndexCol);
|
||||||
|
msg = string.Format("Input row index type '{0}' incompatible with predictor' row index type '{1}'", type, _parent.MatrixRowIndexType);
|
||||||
|
_env.CheckParam(type.Equals(_parent.MatrixRowIndexType), nameof(schema), msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Delegate[] CreateGetter(IRow input, bool[] active)
|
||||||
|
{
|
||||||
|
_env.CheckValue(input, nameof(input));
|
||||||
|
_env.Assert(Utils.Size(active) == Schema.ColumnCount);
|
||||||
|
|
||||||
|
var getters = new Delegate[1];
|
||||||
|
if (active[0])
|
||||||
|
{
|
||||||
|
CheckInputSchema(input.Schema, _matrixColumnIndexColumnIndex, _matrixRowIndexCololumnIndex);
|
||||||
|
var matrixColumnIndexGetter = input.GetGetter<uint>(_matrixColumnIndexColumnIndex);
|
||||||
|
var matrixRowIndexGetter = input.GetGetter<uint>(_matrixRowIndexCololumnIndex);
|
||||||
|
getters[0] = _parent.GetGetter(matrixColumnIndexGetter, matrixRowIndexGetter);
|
||||||
|
}
|
||||||
|
return getters;
|
||||||
|
}
|
||||||
|
|
||||||
|
public IRow GetRow(IRow input, Func<int, bool> predicate, out Action disposer)
|
||||||
|
{
|
||||||
|
var active = Utils.BuildArray(Schema.ColumnCount, predicate);
|
||||||
|
var getters = CreateGetter(input, active);
|
||||||
|
disposer = null;
|
||||||
|
return new SimpleRow(Schema, input, getters);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ISchemaBindableMapper Bindable { get { return _parent; } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase<MatrixFactorizationPredictor, GenericScorer>, ICanSaveModel
|
||||||
|
{
|
||||||
|
public const string LoaderSignature = "MaFactPredXf";
|
||||||
|
public string MatrixColumnIndexColumnName { get; }
|
||||||
|
public string MatrixRowIndexColumnName { get; }
|
||||||
|
public ColumnType MatrixColumnIndexColumnType { get; }
|
||||||
|
public ColumnType MatrixRowIndexColumnType { get; }
|
||||||
|
protected override GenericScorer Scorer { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Build a transformer based on matrix factorization predictor (model) and the input schema (trainSchema). The created
|
||||||
|
/// transformer can only transform IDataView objects compatible to the input schema; that is, that IDataView must contain
|
||||||
|
/// columns specified by <see cref="MatrixColumnIndexColumnName"/>, <see cref="MatrixColumnIndexColumnType"/>, <see cref="MatrixRowIndexColumnName"/>, and <see cref="MatrixRowIndexColumnType"></see>.
|
||||||
|
/// The output column is "Score" by default but user can append a string to it.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="env">Eviroment object for showing information</param>
|
||||||
|
/// <param name="model">The model trained by one of the training functions in <see cref="MatrixFactorizationTrainer"/></param>
|
||||||
|
/// <param name="trainSchema">Targeted schema that containing columns named as xColumnName</param>
|
||||||
|
/// <param name="matrixColumnIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixColumnIndexKind"/> in matrix factorization world</param>
|
||||||
|
/// <param name="matrixRowIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixRowIndexKind"/> in matrix factorization world</param>
|
||||||
|
/// <param name="scoreColumnNameSuffix">A string attached to the output column name of this transformer</param>
|
||||||
|
public MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema,
|
||||||
|
string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "")
|
||||||
|
:base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema)
|
||||||
|
{
|
||||||
|
Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName));
|
||||||
|
Host.CheckNonEmpty(matrixColumnIndexColumnName, nameof(matrixRowIndexColumnName));
|
||||||
|
|
||||||
|
MatrixColumnIndexColumnName = matrixColumnIndexColumnName;
|
||||||
|
MatrixRowIndexColumnName = matrixRowIndexColumnName;
|
||||||
|
|
||||||
|
if (!trainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumnName);
|
||||||
|
MatrixColumnIndexColumnType = trainSchema.GetColumnType(xCol);
|
||||||
|
if (!trainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(yCol), RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumnName);
|
||||||
|
|
||||||
|
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
|
||||||
|
|
||||||
|
var schema = GetSchema();
|
||||||
|
var args = new GenericScorer.Arguments { Suffix = scoreColumnNameSuffix };
|
||||||
|
Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private RoleMappedSchema GetSchema()
|
||||||
|
{
|
||||||
|
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
|
||||||
|
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixColumnIndexKind, MatrixColumnIndexColumnName));
|
||||||
|
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixRowIndexKind, MatrixRowIndexColumnName));
|
||||||
|
var schema = new RoleMappedSchema(TrainSchema, roles);
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The counter constructor of re-creating <see cref="MatrixFactorizationPredictionTransformer"/> from the context where
|
||||||
|
/// the original transform is saved.
|
||||||
|
/// </summary>
|
||||||
|
public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
|
||||||
|
:base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx)
|
||||||
|
{
|
||||||
|
// *** Binary format ***
|
||||||
|
// <base info>
|
||||||
|
// string: the column name of matrix's column ids.
|
||||||
|
// string: the column name of matrix's row ids.
|
||||||
|
|
||||||
|
MatrixColumnIndexColumnName = ctx.LoadString();
|
||||||
|
MatrixRowIndexColumnName = ctx.LoadString();
|
||||||
|
|
||||||
|
if (!TrainSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(MatrixColumnIndexColumnName), RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumnName);
|
||||||
|
MatrixColumnIndexColumnType = TrainSchema.GetColumnType(xCol);
|
||||||
|
|
||||||
|
if (!TrainSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(MatrixRowIndexColumnName), RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumnName);
|
||||||
|
MatrixRowIndexColumnType = TrainSchema.GetColumnType(yCol);
|
||||||
|
|
||||||
|
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
|
||||||
|
|
||||||
|
var schema = GetSchema();
|
||||||
|
var args = new GenericScorer.Arguments { Suffix = "" };
|
||||||
|
Scorer = new GenericScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override Schema GetOutputSchema(Schema inputSchema)
|
||||||
|
{
|
||||||
|
if (!inputSchema.TryGetColumnIndex(MatrixColumnIndexColumnName, out int xCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumnName);
|
||||||
|
if (!inputSchema.TryGetColumnIndex(MatrixRowIndexColumnName, out int yCol))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumnName);
|
||||||
|
|
||||||
|
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Save(ModelSaveContext ctx)
|
||||||
|
{
|
||||||
|
Host.CheckValue(ctx, nameof(ctx));
|
||||||
|
ctx.CheckAtModel();
|
||||||
|
ctx.SetVersionInfo(GetVersionInfo());
|
||||||
|
|
||||||
|
// *** Binary format ***
|
||||||
|
// model: prediction model.
|
||||||
|
// stream: empty data view that contains train schema.
|
||||||
|
// ids of strings: feature columns.
|
||||||
|
// float: scorer threshold
|
||||||
|
// id of string: scorer threshold column
|
||||||
|
|
||||||
|
ctx.SaveModel(Model, DirModel);
|
||||||
|
ctx.SaveBinaryStream(DirTransSchema, writer =>
|
||||||
|
{
|
||||||
|
using (var ch = Host.Start("Saving train schema"))
|
||||||
|
{
|
||||||
|
var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true });
|
||||||
|
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.SaveString(MatrixColumnIndexColumnName);
|
||||||
|
ctx.SaveString(MatrixRowIndexColumnName);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static VersionInfo GetVersionInfo()
|
||||||
|
{
|
||||||
|
return new VersionInfo(
|
||||||
|
modelSignature: "MAFAPRED", // "MA"trix "FA"torization "PRED"iction
|
||||||
|
verWrittenCur: 0x00010001, // Initial
|
||||||
|
verReadableCur: 0x00010001,
|
||||||
|
verWeCanReadBack: 0x00010001,
|
||||||
|
loaderSignature: LoaderSignature,
|
||||||
|
loaderAssemblyName: typeof(MatrixFactorizationPredictionTransformer).Assembly.FullName);
|
||||||
|
}
|
||||||
|
private static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||||
|
=> new MatrixFactorizationPredictionTransformer(env, ctx);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,328 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
using Microsoft.ML.Core.Data;
|
||||||
|
using Microsoft.ML.Core.Prediction;
|
||||||
|
using Microsoft.ML.Runtime;
|
||||||
|
using Microsoft.ML.Runtime.CommandLine;
|
||||||
|
using Microsoft.ML.Runtime.Data;
|
||||||
|
using Microsoft.ML.Runtime.EntryPoints;
|
||||||
|
using Microsoft.ML.Runtime.Internal.Internallearn;
|
||||||
|
using Microsoft.ML.Runtime.Recommender;
|
||||||
|
using Microsoft.ML.Runtime.Recommender.Internal;
|
||||||
|
using Microsoft.ML.Runtime.Training;
|
||||||
|
using Microsoft.ML.Trainers;
|
||||||
|
|
||||||
|
[assembly: LoadableClass(MatrixFactorizationTrainer.Summary, typeof(MatrixFactorizationTrainer), typeof(MatrixFactorizationTrainer.Arguments),
|
||||||
|
new Type[] { typeof(SignatureTrainer), typeof(SignatureMatrixRecommendingTrainer) },
|
||||||
|
"Matrix Factorization", MatrixFactorizationTrainer.LoadNameValue, "libmf", "mf")]
|
||||||
|
|
||||||
|
namespace Microsoft.ML.Trainers
|
||||||
|
{
|
||||||
|
public sealed class MatrixFactorizationTrainer : TrainerBase<MatrixFactorizationPredictor>,
|
||||||
|
IEstimator<MatrixFactorizationPredictionTransformer>
|
||||||
|
{
|
||||||
|
public sealed class Arguments
|
||||||
|
{
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter")]
|
||||||
|
[TGUI(SuggestedSweeps = "0.01,0.05,0.1,0.5,1")]
|
||||||
|
[TlcModule.SweepableDiscreteParam("Lambda", new object[] { 0.01f, 0.05f, 0.1f, 0.5f, 1f })]
|
||||||
|
public Double Lambda = 0.1;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Latent space dimension")]
|
||||||
|
[TGUI(SuggestedSweeps = "8,16,64,128")]
|
||||||
|
[TlcModule.SweepableDiscreteParam("K", new object[] { 8, 16, 64, 128 })]
|
||||||
|
public int K = 8;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Training iterations", ShortName = "iter")]
|
||||||
|
[TGUI(SuggestedSweeps = "10,20,40")]
|
||||||
|
[TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 10, 20, 40 })]
|
||||||
|
public int NumIterations = 20;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate")]
|
||||||
|
[TGUI(SuggestedSweeps = "0.001,0.01,0.1")]
|
||||||
|
[TlcModule.SweepableDiscreteParam("Eta", new object[] { 0.001f, 0.01f, 0.1f })]
|
||||||
|
public Double Eta = 0.1;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads", ShortName = "t")]
|
||||||
|
public int? NumThreads;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Suppress writing additional information to output")]
|
||||||
|
public bool Quiet;
|
||||||
|
|
||||||
|
[Argument(ArgumentType.AtMostOnce, HelpText = "Force the matrix factorization P and Q to be non-negative", ShortName = "nn")]
|
||||||
|
public bool NonNegative;
|
||||||
|
};
|
||||||
|
|
||||||
|
internal const string Summary = "From pairs of row/column indices and a value of a matrix, this trains a predictor capable of filling in unknown entries of the matrix, "
|
||||||
|
+ "utilizing a low-rank matrix factorization. This technique is often used in recommender system, where the row and column indices indicate users and items, "
|
||||||
|
+ "and the value of the matrix is some rating. ";
|
||||||
|
|
||||||
|
private readonly Double _lambda;
|
||||||
|
private readonly int _k;
|
||||||
|
private readonly int _iter;
|
||||||
|
private readonly Double _eta;
|
||||||
|
private readonly int _threads;
|
||||||
|
private readonly bool _quiet;
|
||||||
|
private readonly bool _doNmf;
|
||||||
|
|
||||||
|
public override PredictionKind PredictionKind => PredictionKind.Recommendation;
|
||||||
|
public const string LoadNameValue = "MatrixFactorization";
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The row, column, and label columns that the trainer expects. This module uses tuples of (row index, column index, label value) to specify a matrix.
|
||||||
|
/// For example, a 2-by-2 matrix
|
||||||
|
/// [9, 4]
|
||||||
|
/// [8, 7]
|
||||||
|
/// can be encoded as tuples (0, 0, 9), (0, 1, 4), (1, 0, 8), and (1, 1, 7). It means that the row/column/label column contains [0, 0, 1, 1]/
|
||||||
|
/// [0, 1, 0, 1]/[9, 4, 8, 7].
|
||||||
|
/// </summary>
|
||||||
|
public readonly SchemaShape.Column MatrixColumnIndexColumn; // column indices of the training matrix
|
||||||
|
public readonly SchemaShape.Column MatrixRowIndexColumn; // row indices of the training matrix
|
||||||
|
public readonly SchemaShape.Column LabelColumn;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The <see cref="TrainerInfo"/> contains general parameters for this trainer.
|
||||||
|
/// </summary>
|
||||||
|
public override TrainerInfo Info { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Extra information the trainer can use. For example, its validation set (if not null) can be use to evaluate the
|
||||||
|
/// training progress made at each training iteration.
|
||||||
|
/// </summary>
|
||||||
|
public readonly TrainerEstimatorContext Context;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Legacy constructor initializing a new instance of <see cref="MatrixFactorizationTrainer"/> through the legacy
|
||||||
|
/// <see cref="Arguments"/> class.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||||
|
/// <param name="args">An instance of the legacy <see cref="Arguments"/> to apply advanced parameters to the algorithm.</param>
|
||||||
|
public MatrixFactorizationTrainer(IHostEnvironment env, Arguments args) : base(env, LoadNameValue)
|
||||||
|
{
|
||||||
|
const string posError = "Parameter must be positive";
|
||||||
|
Host.CheckValue(args, nameof(args));
|
||||||
|
Host.CheckUserArg(args.K > 0, nameof(args.K), posError);
|
||||||
|
Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads), posError);
|
||||||
|
Host.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), posError);
|
||||||
|
Host.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), posError);
|
||||||
|
Host.CheckUserArg(args.Eta > 0, nameof(args.Eta), posError);
|
||||||
|
|
||||||
|
_lambda = args.Lambda;
|
||||||
|
_k = args.K;
|
||||||
|
_iter = args.NumIterations;
|
||||||
|
_eta = args.Eta;
|
||||||
|
_threads = args.NumThreads ?? Environment.ProcessorCount;
|
||||||
|
_quiet = args.Quiet;
|
||||||
|
_doNmf = args.NonNegative;
|
||||||
|
|
||||||
|
Info = new TrainerInfo(normalization: false, caching: false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||||
|
/// <param name="labelColumn">The name of the label column.</param>
|
||||||
|
/// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param>
|
||||||
|
/// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param>
|
||||||
|
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
|
||||||
|
/// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param>
|
||||||
|
public MatrixFactorizationTrainer(IHostEnvironment env, string labelColumn, string matrixColumnIndexColumnName, string matrixRowIndexColumnName,
|
||||||
|
TrainerEstimatorContext context = null, Action<Arguments> advancedSettings = null)
|
||||||
|
: base(env, LoadNameValue)
|
||||||
|
{
|
||||||
|
var args = new Arguments();
|
||||||
|
advancedSettings?.Invoke(args);
|
||||||
|
|
||||||
|
_lambda = args.Lambda;
|
||||||
|
_k = args.K;
|
||||||
|
_iter = args.NumIterations;
|
||||||
|
_eta = args.Eta;
|
||||||
|
_threads = args.NumThreads ?? Environment.ProcessorCount;
|
||||||
|
_quiet = args.Quiet;
|
||||||
|
_doNmf = args.NonNegative;
|
||||||
|
|
||||||
|
Info = new TrainerInfo(normalization: false, caching: false);
|
||||||
|
Context = context;
|
||||||
|
|
||||||
|
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
|
||||||
|
MatrixColumnIndexColumn = new SchemaShape.Column(matrixColumnIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
|
||||||
|
MatrixRowIndexColumn = new SchemaShape.Column(matrixRowIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Train a matrix factorization model based on training data, validation data, and so on in the given context.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="context">The information collection needed for training. <see cref="TrainContext"/> for details.</param>
|
||||||
|
public override MatrixFactorizationPredictor Train(TrainContext context)
|
||||||
|
{
|
||||||
|
Host.CheckValue(context, nameof(context));
|
||||||
|
|
||||||
|
using (var ch = Host.Start("Training"))
|
||||||
|
{
|
||||||
|
return TrainCore(ch, context.TrainingSet, context.ValidationSet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData)
|
||||||
|
{
|
||||||
|
Host.AssertValue(ch);
|
||||||
|
ch.AssertValue(data);
|
||||||
|
ch.AssertValueOrNull(validData);
|
||||||
|
|
||||||
|
ColumnInfo matrixColumnIndexColInfo;
|
||||||
|
ColumnInfo matrixRowIndexColInfo;
|
||||||
|
ColumnInfo validMatrixColumnIndexColInfo = null;
|
||||||
|
ColumnInfo validMatrixRowIndexColInfo = null;
|
||||||
|
|
||||||
|
ch.CheckValue(data.Schema.Label, nameof(data), "Input data did not have a unique label");
|
||||||
|
RecommenderUtils.CheckAndGetMatrixIndexColumns(data, out matrixColumnIndexColInfo, out matrixRowIndexColInfo, isDecode: false);
|
||||||
|
if (data.Schema.Label.Type != NumberType.R4 && data.Schema.Label.Type != NumberType.R8)
|
||||||
|
throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type);
|
||||||
|
MatrixFactorizationPredictor predictor;
|
||||||
|
if (validData != null)
|
||||||
|
{
|
||||||
|
ch.CheckValue(validData, nameof(validData));
|
||||||
|
ch.CheckValue(validData.Schema.Label, nameof(validData), "Input validation data did not have a unique label");
|
||||||
|
RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out validMatrixColumnIndexColInfo, out validMatrixRowIndexColInfo, isDecode: false);
|
||||||
|
if (validData.Schema.Label.Type != NumberType.R4 && validData.Schema.Label.Type != NumberType.R8)
|
||||||
|
throw ch.Except("Column '{0}' for validation label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type);
|
||||||
|
|
||||||
|
if (!matrixColumnIndexColInfo.Type.Equals(validMatrixColumnIndexColInfo.Type))
|
||||||
|
{
|
||||||
|
throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-column types differed, {0} vs. {1}",
|
||||||
|
matrixColumnIndexColInfo.Type, validMatrixColumnIndexColInfo.Type);
|
||||||
|
}
|
||||||
|
if (!matrixRowIndexColInfo.Type.Equals(validMatrixRowIndexColInfo.Type))
|
||||||
|
{
|
||||||
|
throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-row types differed, {0} vs. {1}",
|
||||||
|
matrixRowIndexColInfo.Type, validMatrixRowIndexColInfo.Type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int colCount = matrixColumnIndexColInfo.Type.KeyCount;
|
||||||
|
int rowCount = matrixRowIndexColInfo.Type.KeyCount;
|
||||||
|
ch.Assert(rowCount > 0);
|
||||||
|
ch.Assert(colCount > 0);
|
||||||
|
// Checks for equality on the validation set ensure it is correct here.
|
||||||
|
|
||||||
|
using (var cursor = data.Data.GetRowCursor(c => c == matrixColumnIndexColInfo.Index || c == matrixRowIndexColInfo.Index || c == data.Schema.Label.Index))
|
||||||
|
{
|
||||||
|
// LibMF works only over single precision floats, but we want to be able to consume either.
|
||||||
|
ValueGetter<Single> labGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, data.Schema.Label.Index);
|
||||||
|
var matrixColumnIndexGetter = cursor.GetGetter<uint>(matrixColumnIndexColInfo.Index);
|
||||||
|
var matrixRowIndexGetter = cursor.GetGetter<uint>(matrixRowIndexColInfo.Index);
|
||||||
|
|
||||||
|
if (validData == null)
|
||||||
|
{
|
||||||
|
// Have the trainer do its work.
|
||||||
|
using (var buffer = PrepareBuffer())
|
||||||
|
{
|
||||||
|
buffer.Train(ch, rowCount, colCount,
|
||||||
|
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
|
||||||
|
predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
using (var validCursor = validData.Data.GetRowCursor(
|
||||||
|
c => c == validMatrixColumnIndexColInfo.Index || c == validMatrixRowIndexColInfo.Index || c == validData.Schema.Label.Index))
|
||||||
|
{
|
||||||
|
ValueGetter<Single> validLabGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, validCursor, validData.Schema.Label.Index);
|
||||||
|
var validXGetter = validCursor.GetGetter<uint>(validMatrixColumnIndexColInfo.Index);
|
||||||
|
var validYGetter = validCursor.GetGetter<uint>(validMatrixRowIndexColInfo.Index);
|
||||||
|
|
||||||
|
// Have the trainer do its work.
|
||||||
|
using (var buffer = PrepareBuffer())
|
||||||
|
{
|
||||||
|
buffer.TrainWithValidation(ch, rowCount, colCount,
|
||||||
|
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
|
||||||
|
validCursor, validLabGetter, validYGetter, validXGetter);
|
||||||
|
predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return predictor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SafeTrainingAndModelBuffer PrepareBuffer()
|
||||||
|
{
|
||||||
|
return new SafeTrainingAndModelBuffer(Host, _k, Math.Max(20, 2 * _threads),
|
||||||
|
_threads, _iter, _lambda, _eta, _doNmf, _quiet, copyData: false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Train a matrix factorization model based on the input <see cref="IDataView"/> using the roles specified by XColumn and YColumn in <see cref="MatrixFactorizationTrainer"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="input">The training data set.</param>
|
||||||
|
public MatrixFactorizationPredictionTransformer Fit(IDataView input)
|
||||||
|
{
|
||||||
|
MatrixFactorizationPredictor model = null;
|
||||||
|
|
||||||
|
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
|
||||||
|
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name));
|
||||||
|
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumn.Name));
|
||||||
|
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumn.Name));
|
||||||
|
|
||||||
|
var trainingData = new RoleMappedData(input, roles);
|
||||||
|
var validData = Context == null ? null : new RoleMappedData(Context.ValidationSet, roles);
|
||||||
|
|
||||||
|
using (var ch = Host.Start("Training"))
|
||||||
|
using (var pch = Host.StartProgressChannel("Training"))
|
||||||
|
{
|
||||||
|
model = TrainCore(ch, trainingData, validData);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new MatrixFactorizationPredictionTransformer(Host, model, input.Schema, MatrixColumnIndexColumn.Name, MatrixRowIndexColumn.Name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
|
||||||
|
{
|
||||||
|
Host.CheckValue(inputSchema, nameof(inputSchema));
|
||||||
|
|
||||||
|
void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColumnName)
|
||||||
|
{
|
||||||
|
if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
|
||||||
|
throw Host.ExceptSchemaMismatch(nameof(col), expectedColumnName, expectedColumnName);
|
||||||
|
|
||||||
|
if (!cachedColumn.IsCompatibleWith(col))
|
||||||
|
throw Host.Except($"{expectedColumnName} column '{cachedColumn.Name}' is not compatible");
|
||||||
|
}
|
||||||
|
|
||||||
|
// In prediction phase, no label column is expected.
|
||||||
|
if (LabelColumn != null)
|
||||||
|
CheckColumnsCompatible(LabelColumn, LabelColumn.Name);
|
||||||
|
|
||||||
|
// In both of training and prediction phases, we need columns of user ID and column ID.
|
||||||
|
CheckColumnsCompatible(MatrixColumnIndexColumn, MatrixColumnIndexColumn.Name);
|
||||||
|
CheckColumnsCompatible(MatrixRowIndexColumn, MatrixRowIndexColumn.Name);
|
||||||
|
|
||||||
|
// Input columns just pass through so that output column dictionary contains all input columns.
|
||||||
|
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
|
||||||
|
|
||||||
|
// Add columns produced by this estimator.
|
||||||
|
foreach (var col in GetOutputColumnsCore(inputSchema))
|
||||||
|
outColumns[col.Name] = col;
|
||||||
|
|
||||||
|
return new SchemaShape(outColumns.Values);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
||||||
|
{
|
||||||
|
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
|
||||||
|
Contracts.Assert(success);
|
||||||
|
|
||||||
|
return new[]
|
||||||
|
{
|
||||||
|
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
|
||||||
|
<PropertyGroup>
|
||||||
|
<TargetFramework>netstandard2.0</TargetFramework>
|
||||||
|
<IncludeInPackage>Microsoft.ML.MatrixFactorization</IncludeInPackage>
|
||||||
|
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
|
||||||
|
</PropertyGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
|
||||||
|
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
</Project>
|
|
@ -0,0 +1,87 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System.Threading;
|
||||||
|
using Microsoft.ML.Runtime.Data;
|
||||||
|
using Microsoft.ML.Runtime.Internal.Utilities;
|
||||||
|
|
||||||
|
namespace Microsoft.ML.Runtime.Recommender
|
||||||
|
{
|
||||||
|
internal static class RecommenderUtils
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Check if the considered data, <see cref="RoleMappedData"/>, contains column roles specified by <see cref="MatrixColumnIndexKind"/> and <see cref="MatrixRowIndexKind"/>.
|
||||||
|
/// If the column roles, <see cref="MatrixColumnIndexKind"/> and <see cref="MatrixRowIndexKind"/>, uniquely exist in data, their <see cref="ColumnInfo"/> would be assigned
|
||||||
|
/// to the two out parameters below.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="data">The considered data being checked</param>
|
||||||
|
/// <param name="matrixColumnIndexColumn">The column as role row index in the input data</param>
|
||||||
|
/// <param name="matrixRowIndexColumn">The column as role column index in the input data</param>
|
||||||
|
/// <param name="isDecode">Whether a non-user error should be thrown as a decode</param>
|
||||||
|
public static void CheckAndGetMatrixIndexColumns(RoleMappedData data, out ColumnInfo matrixColumnIndexColumn, out ColumnInfo matrixRowIndexColumn, bool isDecode)
|
||||||
|
{
|
||||||
|
Contracts.AssertValue(data);
|
||||||
|
CheckRowColumnType(data, MatrixColumnIndexKind, out matrixColumnIndexColumn, isDecode);
|
||||||
|
CheckRowColumnType(data, MatrixRowIndexKind, out matrixRowIndexColumn, isDecode);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns whether a type is a U4 key of known cardinality, and if so, sets
|
||||||
|
/// <paramref name="keyType"/> to a non-null value.
|
||||||
|
/// </summary>
|
||||||
|
private static bool TryMarshalGoodRowColumnType(ColumnType type, out KeyType keyType)
|
||||||
|
{
|
||||||
|
keyType = type as KeyType;
|
||||||
|
return type.KeyCount > 0 && type.RawKind == DataKind.U4 &&
|
||||||
|
keyType != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks whether a column kind in a RoleMappedData is unique, and its type
|
||||||
|
/// is a U4 key of known cardinality.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="data">The training examples</param>
|
||||||
|
/// <param name="role">The column role to try to extract</param>
|
||||||
|
/// <param name="info">The extracted column info</param>
|
||||||
|
/// <param name="isDecode">Whether a non-user error should be thrown as a decode</param>
|
||||||
|
/// <returns>The type cast to a key-type</returns>
|
||||||
|
private static KeyType CheckRowColumnType(RoleMappedData data, RoleMappedSchema.ColumnRole role, out ColumnInfo info, bool isDecode)
|
||||||
|
{
|
||||||
|
Contracts.AssertValue(data);
|
||||||
|
Contracts.AssertValue(role.Value);
|
||||||
|
|
||||||
|
const string format2 = "There should be exactly one column with role {0}, but {1} were found instead";
|
||||||
|
if (!data.Schema.HasUnique(role))
|
||||||
|
{
|
||||||
|
int kindCount = Utils.Size(data.Schema.GetColumns(role));
|
||||||
|
if (isDecode)
|
||||||
|
throw Contracts.ExceptDecode(format2, role.Value, kindCount);
|
||||||
|
throw Contracts.Except(format2, role.Value, kindCount);
|
||||||
|
}
|
||||||
|
info = data.Schema.GetColumns(role)[0];
|
||||||
|
|
||||||
|
// REVIEW tfinley: Should we be a bit less restrictive? This doesn't seem like
|
||||||
|
// too terrible of a restriction.
|
||||||
|
const string format = "Column '{0}' with role {1} should be a known cardinality U4 key, but is instead '{2}'";
|
||||||
|
KeyType keyType;
|
||||||
|
if (!TryMarshalGoodRowColumnType(info.Type, out keyType))
|
||||||
|
{
|
||||||
|
if (isDecode)
|
||||||
|
throw Contracts.ExceptDecode(format, info.Name, role.Value, info.Type);
|
||||||
|
throw Contracts.Except(format, info.Name, role.Value, info.Type);
|
||||||
|
}
|
||||||
|
return keyType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The column role that is treated as column index in matrix factorization problem
|
||||||
|
/// </summary>
|
||||||
|
public static RoleMappedSchema.ColumnRole MatrixColumnIndexKind => "MatrixColumnIndex";
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The column role that is treated as row index in matrix factorization problem
|
||||||
|
/// </summary>
|
||||||
|
public static RoleMappedSchema.ColumnRole MatrixRowIndexKind => "MatrixRowIndex";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,281 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Runtime.InteropServices;
|
||||||
|
using System.Security;
|
||||||
|
using Microsoft.ML.Runtime.Data;
|
||||||
|
using Microsoft.ML.Runtime.Internal.Utilities;
|
||||||
|
|
||||||
|
namespace Microsoft.ML.Runtime.Recommender.Internal
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Contains mirrors of unmanaged struct import extern functions from mf.h / mf.cpp, which implements Matrix Factorization in native C++.
|
||||||
|
/// It also wraps/bridges the train, traintest and cv interfaces ready for ML.NET infra.
|
||||||
|
/// </summary>
|
||||||
|
internal sealed class SafeTrainingAndModelBuffer : IDisposable
|
||||||
|
{
|
||||||
|
[StructLayout(LayoutKind.Explicit)]
|
||||||
|
private struct MFNode
|
||||||
|
{
|
||||||
|
[FieldOffset(0)]
|
||||||
|
public int U;
|
||||||
|
[FieldOffset(4)]
|
||||||
|
public int V;
|
||||||
|
[FieldOffset(8)]
|
||||||
|
public float R;
|
||||||
|
}
|
||||||
|
|
||||||
|
[StructLayout(LayoutKind.Explicit)]
|
||||||
|
private unsafe struct MFProblem
|
||||||
|
{
|
||||||
|
[FieldOffset(0)]
|
||||||
|
public int M;
|
||||||
|
[FieldOffset(4)]
|
||||||
|
public int N;
|
||||||
|
[FieldOffset(8)]
|
||||||
|
public long Nnz;
|
||||||
|
[FieldOffset(16)]
|
||||||
|
public MFNode* R;
|
||||||
|
}
|
||||||
|
|
||||||
|
[StructLayout(LayoutKind.Explicit)]
|
||||||
|
private struct MFParameter
|
||||||
|
{
|
||||||
|
[FieldOffset(0)]
|
||||||
|
public int K;
|
||||||
|
[FieldOffset(4)]
|
||||||
|
public int NrThreads;
|
||||||
|
[FieldOffset(8)]
|
||||||
|
public int NrBins;
|
||||||
|
[FieldOffset(12)]
|
||||||
|
public int NrIters;
|
||||||
|
[FieldOffset(16)]
|
||||||
|
public float Lambda;
|
||||||
|
[FieldOffset(20)]
|
||||||
|
public float Eta;
|
||||||
|
[FieldOffset(24)]
|
||||||
|
public int DoNmf;
|
||||||
|
[FieldOffset(28)]
|
||||||
|
public int Quiet;
|
||||||
|
[FieldOffset(32)]
|
||||||
|
public int CopyData;
|
||||||
|
}
|
||||||
|
|
||||||
|
[StructLayout(LayoutKind.Explicit)]
|
||||||
|
private unsafe struct MFModel
|
||||||
|
{
|
||||||
|
[FieldOffset(0)]
|
||||||
|
public int M;
|
||||||
|
[FieldOffset(4)]
|
||||||
|
public int N;
|
||||||
|
[FieldOffset(8)]
|
||||||
|
public int K;
|
||||||
|
[FieldOffset(16)]
|
||||||
|
public float* P;
|
||||||
|
[FieldOffset(24)]
|
||||||
|
public float* Q;
|
||||||
|
}
|
||||||
|
|
||||||
|
private const string DllPath = "MatrixFactorizationNative";
|
||||||
|
|
||||||
|
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
|
||||||
|
private static unsafe extern void MFDestroyModel(ref MFModel* model);
|
||||||
|
|
||||||
|
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
|
||||||
|
private static unsafe extern MFModel* MFTrain(MFProblem* prob, MFParameter* param);
|
||||||
|
|
||||||
|
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
|
||||||
|
private static unsafe extern MFModel* MFTrainWithValidation(MFProblem* tr, MFProblem* va, MFParameter* param);
|
||||||
|
|
||||||
|
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
|
||||||
|
private static unsafe extern float MFCrossValidation(MFProblem* prob, int nrFolds, MFParameter* param);
|
||||||
|
|
||||||
|
[DllImport(DllPath), SuppressUnmanagedCodeSecurity]
|
||||||
|
private static unsafe extern float MFPredict(MFModel* model, int pIdx, int qIdx);
|
||||||
|
|
||||||
|
private MFParameter _mfParam;
|
||||||
|
private unsafe MFModel* _pMFModel;
|
||||||
|
private readonly IHost _host;
|
||||||
|
|
||||||
|
public SafeTrainingAndModelBuffer(IHostEnvironment env, int k, int nrBins, int nrThreads, int nrIters, double lambda, double eta,
|
||||||
|
bool doNmf, bool quiet, bool copyData)
|
||||||
|
{
|
||||||
|
_host = env.Register("SafeTrainingAndModelBuffer");
|
||||||
|
_mfParam.K = k;
|
||||||
|
_mfParam.NrBins = nrBins;
|
||||||
|
_mfParam.NrThreads = nrThreads;
|
||||||
|
_mfParam.NrIters = nrIters;
|
||||||
|
_mfParam.Lambda = (float)lambda;
|
||||||
|
_mfParam.Eta = (float)eta;
|
||||||
|
_mfParam.DoNmf = doNmf ? 1 : 0;
|
||||||
|
_mfParam.Quiet = quiet ? 1 : 0;
|
||||||
|
_mfParam.CopyData = copyData ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
~SafeTrainingAndModelBuffer()
|
||||||
|
{
|
||||||
|
Dispose(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
Dispose(true);
|
||||||
|
GC.SuppressFinalize(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
private unsafe void Dispose(bool disposing)
|
||||||
|
{
|
||||||
|
// Free unmanaged resources.
|
||||||
|
if (_pMFModel != null)
|
||||||
|
{
|
||||||
|
MFDestroyModel(ref _pMFModel);
|
||||||
|
_host.Assert(_pMFModel == null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MFNode[] ConstructLabeledNodesFrom(IChannel ch, ICursor cursor, ValueGetter<float> labGetter,
|
||||||
|
ValueGetter<uint> rowGetter, ValueGetter<uint> colGetter,
|
||||||
|
int rowCount, int colCount)
|
||||||
|
{
|
||||||
|
long numSkipped = 0;
|
||||||
|
uint row = 0;
|
||||||
|
uint col = 0;
|
||||||
|
float label = 0;
|
||||||
|
|
||||||
|
List<MFNode> nodes = new List<MFNode>();
|
||||||
|
long i = 0;
|
||||||
|
using (var pch = _host.StartProgressChannel("Create matrix"))
|
||||||
|
{
|
||||||
|
pch.SetHeader(new ProgressHeader(new[] { "processed rows", "created nodes" }),
|
||||||
|
e => { e.SetProgress(0, i); e.SetProgress(1, nodes.Count); });
|
||||||
|
while (cursor.MoveNext())
|
||||||
|
{
|
||||||
|
i++;
|
||||||
|
labGetter(ref label);
|
||||||
|
if (!FloatUtils.IsFinite(label))
|
||||||
|
{
|
||||||
|
numSkipped++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
rowGetter(ref row);
|
||||||
|
// REVIEW: Instead of ignoring, should I throw in the row > rowCount case?
|
||||||
|
if (row == 0 || row > (uint)rowCount)
|
||||||
|
{
|
||||||
|
numSkipped++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
colGetter(ref col);
|
||||||
|
if (col == 0 || col > (uint)colCount)
|
||||||
|
{
|
||||||
|
numSkipped++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MFNode node;
|
||||||
|
node.U = (int)(row - 1);
|
||||||
|
node.V = (int)(col - 1);
|
||||||
|
node.R = label;
|
||||||
|
nodes.Add(node);
|
||||||
|
}
|
||||||
|
pch.Checkpoint(i, nodes.Count);
|
||||||
|
}
|
||||||
|
if (numSkipped > 0)
|
||||||
|
ch.Warning("Skipped {0} instances with missing/negative features during data loading", numSkipped);
|
||||||
|
ch.Check(nodes.Count > 0, "No valid instances encountered during data loading");
|
||||||
|
|
||||||
|
return nodes.ToArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
public unsafe void Train(IChannel ch, int rowCount, int colCount,
|
||||||
|
ICursor cursor, ValueGetter<float> labGetter,
|
||||||
|
ValueGetter<uint> rowGetter, ValueGetter<uint> colGetter)
|
||||||
|
{
|
||||||
|
if (_pMFModel != null)
|
||||||
|
{
|
||||||
|
MFDestroyModel(ref _pMFModel);
|
||||||
|
_host.Assert(_pMFModel == null);
|
||||||
|
}
|
||||||
|
|
||||||
|
MFProblem prob = new MFProblem();
|
||||||
|
MFNode[] nodes = ConstructLabeledNodesFrom(ch, cursor, labGetter, rowGetter, colGetter, rowCount, colCount);
|
||||||
|
|
||||||
|
fixed (MFNode* nodesPtr = &nodes[0])
|
||||||
|
{
|
||||||
|
prob.R = nodesPtr;
|
||||||
|
prob.M = rowCount;
|
||||||
|
prob.N = colCount;
|
||||||
|
prob.Nnz = nodes.Length;
|
||||||
|
|
||||||
|
ch.Info("Training {0} by {1} problem on {2} examples",
|
||||||
|
prob.M, prob.N, prob.Nnz);
|
||||||
|
|
||||||
|
fixed (MFParameter* pParam = &_mfParam)
|
||||||
|
{
|
||||||
|
_pMFModel = MFTrain(&prob, pParam);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public unsafe void TrainWithValidation(IChannel ch, int rowCount, int colCount,
|
||||||
|
ICursor cursor, ValueGetter<float> labGetter,
|
||||||
|
ValueGetter<uint> rowGetter, ValueGetter<uint> colGetter,
|
||||||
|
ICursor validCursor, ValueGetter<float> validLabGetter,
|
||||||
|
ValueGetter<uint> validRowGetter, ValueGetter<uint> validColGetter)
|
||||||
|
{
|
||||||
|
if (_pMFModel != null)
|
||||||
|
{
|
||||||
|
MFDestroyModel(ref _pMFModel);
|
||||||
|
_host.Assert(_pMFModel == null);
|
||||||
|
}
|
||||||
|
|
||||||
|
MFNode[] nodes = ConstructLabeledNodesFrom(ch, cursor, labGetter, rowGetter, colGetter, rowCount, colCount);
|
||||||
|
MFNode[] validNodes = ConstructLabeledNodesFrom(ch, validCursor, validLabGetter, validRowGetter, validColGetter, rowCount, colCount);
|
||||||
|
MFProblem prob = new MFProblem();
|
||||||
|
MFProblem validProb = new MFProblem();
|
||||||
|
fixed (MFNode* nodesPtr = &nodes[0])
|
||||||
|
fixed (MFNode* validNodesPtrs = &validNodes[0])
|
||||||
|
{
|
||||||
|
prob.R = nodesPtr;
|
||||||
|
prob.M = rowCount;
|
||||||
|
prob.N = colCount;
|
||||||
|
prob.Nnz = nodes.Length;
|
||||||
|
|
||||||
|
validProb.R = validNodesPtrs;
|
||||||
|
validProb.M = rowCount;
|
||||||
|
validProb.N = colCount;
|
||||||
|
validProb.Nnz = nodes.Length;
|
||||||
|
|
||||||
|
ch.Info("Training {0} by {1} problem on {2} examples with a {3} by {4} validation set including {5} examples",
|
||||||
|
prob.M, prob.N, prob.Nnz, validProb.M, validProb.N, validProb.Nnz);
|
||||||
|
|
||||||
|
fixed (MFParameter* pParam = &_mfParam)
|
||||||
|
{
|
||||||
|
_pMFModel = MFTrainWithValidation(&prob, &validProb, pParam);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public unsafe void Get(out int m, out int n, out int k, out float[] p, out float[] q)
|
||||||
|
{
|
||||||
|
_host.Check(_pMFModel != null, "Attempted to get predictor before training");
|
||||||
|
m = _pMFModel->M;
|
||||||
|
_host.Check(m > 0, "Number of rows should have been positive but was not");
|
||||||
|
n = _pMFModel->N;
|
||||||
|
_host.Check(n > 0, "Number of columns should have been positive but was not");
|
||||||
|
k = _pMFModel->K;
|
||||||
|
_host.Check(k > 0, "Internal dimension should have been positive but was not");
|
||||||
|
|
||||||
|
p = new float[m * k];
|
||||||
|
q = new float[n * k];
|
||||||
|
|
||||||
|
unsafe
|
||||||
|
{
|
||||||
|
Marshal.Copy((IntPtr)_pMFModel->P, p, 0, p.Length);
|
||||||
|
Marshal.Copy((IntPtr)_pMFModel->Q, q, 0, q.Length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -181,6 +181,7 @@ endfunction()
|
||||||
add_subdirectory(CpuMathNative)
|
add_subdirectory(CpuMathNative)
|
||||||
add_subdirectory(FastTreeNative)
|
add_subdirectory(FastTreeNative)
|
||||||
add_subdirectory(LdaNative)
|
add_subdirectory(LdaNative)
|
||||||
|
add_subdirectory(MatrixFactorizationNative)
|
||||||
add_subdirectory(FactorizationMachineNative)
|
add_subdirectory(FactorizationMachineNative)
|
||||||
add_subdirectory(SymSgdNative)
|
add_subdirectory(SymSgdNative)
|
||||||
add_subdirectory(MklProxyNative)
|
add_subdirectory(MklProxyNative)
|
|
@ -0,0 +1,16 @@
|
||||||
|
project (MatrixFactorizationNative)
|
||||||
|
|
||||||
|
include_directories(libmf)
|
||||||
|
|
||||||
|
set(SOURCES
|
||||||
|
UnmanagedMemory.cpp
|
||||||
|
libmf/mf.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
if(NOT WIN32)
|
||||||
|
list(APPEND SOURCES ${VERSION_FILE_PATH})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(MatrixFactorizationNative SHARED ${SOURCES} ${RESOURCES})
|
||||||
|
|
||||||
|
install_library_and_symbols (MatrixFactorizationNative)
|
|
@ -0,0 +1,36 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "UnmanagedMemory.h"
|
||||||
|
#include "mf.h"
|
||||||
|
|
||||||
|
using namespace mf;
|
||||||
|
|
||||||
|
EXPORT_API(void) MFDestroyModel(mf_model *&model)
|
||||||
|
{
|
||||||
|
return mf_destroy_model(&model);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param)
|
||||||
|
{
|
||||||
|
return mf_train(prob, *param);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param)
|
||||||
|
{
|
||||||
|
return mf_train_with_validation(tr, va, *param);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter *param)
|
||||||
|
{
|
||||||
|
return mf_cross_validation(prob, nr_folds, *param);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx)
|
||||||
|
{
|
||||||
|
return mf_predict(model, p_idx, q_idx);
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "mf.h"
|
||||||
|
#include "../Stdafx.h"
|
||||||
|
|
||||||
|
using namespace mf;
|
||||||
|
|
||||||
|
EXPORT_API(void) MFDestroyModel(mf_model *&model);
|
||||||
|
|
||||||
|
EXPORT_API(mf_model*) MFTrain(const mf_problem *prob, const mf_parameter *param);
|
||||||
|
|
||||||
|
EXPORT_API(mf_model*) MFTrainWithValidation(const mf_problem *tr, const mf_problem *va, const mf_parameter *param);
|
||||||
|
|
||||||
|
EXPORT_API(float) MFCrossValidation(const mf_problem *prob, int nr_folds, const mf_parameter* param);
|
||||||
|
|
||||||
|
EXPORT_API(float) MFPredict(const mf_model *model, int p_idx, int q_idx);
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 1ecc365249e5cac5e72c66317a141298dc52f6e3
|
|
@ -95,6 +95,8 @@
|
||||||
RelativePath="Microsoft.ML.HalLearners\runtimes\$(PackageRid)\native" />
|
RelativePath="Microsoft.ML.HalLearners\runtimes\$(PackageRid)\native" />
|
||||||
<NativePackageAsset Include="$(NativeAssetsBuiltPath)\$(NativeLibPrefix)MklProxyNative$(NativeLibExtension)"
|
<NativePackageAsset Include="$(NativeAssetsBuiltPath)\$(NativeLibPrefix)MklProxyNative$(NativeLibExtension)"
|
||||||
RelativePath="Microsoft.ML.Mkl.Redist\runtimes\$(PackageRid)\native" />
|
RelativePath="Microsoft.ML.Mkl.Redist\runtimes\$(PackageRid)\native" />
|
||||||
|
<NativePackageAsset Include="$(NativeAssetsBuiltPath)\$(NativeLibPrefix)MatrixFactorizationNative$(NativeLibExtension)"
|
||||||
|
RelativePath="Microsoft.ML.MatrixFactorization\runtimes\$(PackageRid)\native" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|
|
@ -676,5 +676,13 @@ namespace Microsoft.ML.Runtime.RunTests
|
||||||
testFilename = @"..\V3\Data\OCR\train.tsv",
|
testFilename = @"..\V3\Data\OCR\train.tsv",
|
||||||
loaderSettings = "loader=Text{col=Label:U1[0-25]:1 col=GroupId:U4[1-*]:3 col=Features:Num:4-*}"
|
loaderSettings = "loader=Text{col=Label:U1[0-25]:1 col=GroupId:U4[1-*]:3 col=Features:Num:4-*}"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
public static TestDataset trivialMatrixFactorization = new TestDataset()
|
||||||
|
{
|
||||||
|
name = "trivialMatrixFactorization",
|
||||||
|
trainFilename = @"trivial-train.tsv",
|
||||||
|
testFilename = @"trivial-test.tsv",
|
||||||
|
loaderSettings = "loader=Text{col=Label:R4:0 col=User:U4[0-19]:1 col=Item:U4[0-39]:2 header+}"
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
using Microsoft.ML.Runtime;
|
using Microsoft.ML.Runtime;
|
||||||
using Microsoft.ML.Runtime.Data;
|
using Microsoft.ML.Runtime.Data;
|
||||||
using Microsoft.ML.Runtime.Ensemble;
|
using Microsoft.ML.Runtime.Ensemble;
|
||||||
|
using Microsoft.ML.Runtime.Learners;
|
||||||
using Microsoft.ML.Trainers.FastTree;
|
using Microsoft.ML.Trainers.FastTree;
|
||||||
using Microsoft.ML.Trainers.KMeans;
|
using Microsoft.ML.Trainers.KMeans;
|
||||||
using Microsoft.ML.Runtime.Learners;
|
|
||||||
using Microsoft.ML.Trainers.PCA;
|
using Microsoft.ML.Trainers.PCA;
|
||||||
|
|
||||||
namespace Microsoft.ML.TestFramework
|
namespace Microsoft.ML.TestFramework
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.Parquet\Microsoft.ML.Parquet.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.Parquet\Microsoft.ML.Parquet.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
||||||
|
<ProjectReference Include="..\..\src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.Legacy\Microsoft.ML.Legacy.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.Legacy\Microsoft.ML.Legacy.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
||||||
|
@ -24,5 +25,6 @@
|
||||||
<NativeAssemblyReference Include="CpuMathNative" />
|
<NativeAssemblyReference Include="CpuMathNative" />
|
||||||
<NativeAssemblyReference Include="MklProxyNative" />
|
<NativeAssemblyReference Include="MklProxyNative" />
|
||||||
<NativeAssemblyReference Include="FactorizationMachineNative" />
|
<NativeAssemblyReference Include="FactorizationMachineNative" />
|
||||||
|
<NativeAssemblyReference Include="MatrixFactorizationNative" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
</Project>
|
</Project>
|
|
@ -19,6 +19,7 @@
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
|
||||||
|
<ProjectReference Include="..\..\src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
|
||||||
<ProjectReference Include="..\..\src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
|
<ProjectReference Include="..\..\src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
|
||||||
|
@ -32,6 +33,7 @@
|
||||||
<NativeAssemblyReference Include="CpuMathNative" />
|
<NativeAssemblyReference Include="CpuMathNative" />
|
||||||
<NativeAssemblyReference Include="FastTreeNative" />
|
<NativeAssemblyReference Include="FastTreeNative" />
|
||||||
<NativeAssemblyReference Include="FactorizationMachineNative" />
|
<NativeAssemblyReference Include="FactorizationMachineNative" />
|
||||||
|
<NativeAssemblyReference Include="MatrixFactorizationNative" />
|
||||||
<NativeAssemblyReference Include="LdaNative" />
|
<NativeAssemblyReference Include="LdaNative" />
|
||||||
<NativeAssemblyReference Include="SymSgdNative" />
|
<NativeAssemblyReference Include="SymSgdNative" />
|
||||||
<NativeAssemblyReference Include="MklProxyNative" />
|
<NativeAssemblyReference Include="MklProxyNative" />
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
// Licensed to the .NET Foundation under one or more agreements.
|
||||||
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using Microsoft.ML.Runtime.Data;
|
||||||
|
using Microsoft.ML.Runtime.RunTests;
|
||||||
|
using Microsoft.ML.Trainers;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.ML.Tests.TrainerEstimators
|
||||||
|
{
|
||||||
|
public partial class TrainerEstimators : TestDataPipeBase
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public void MatrixFactorization_Estimator()
|
||||||
|
{
|
||||||
|
string labelColumnName = "Label";
|
||||||
|
string matrixColumnIndexColumnName = "Col";
|
||||||
|
string matrixRowIndexColumnName = "Row";
|
||||||
|
|
||||||
|
// This data contains three columns, Label, Col, and Row where Col and Row will be treated as the expected input names
|
||||||
|
// of the trained matrix factorization model.
|
||||||
|
var data = new TextLoader(Env, GetLoaderArgs(labelColumnName, matrixColumnIndexColumnName, matrixRowIndexColumnName))
|
||||||
|
.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.trainFilename)));
|
||||||
|
|
||||||
|
// "invalidData" is not compatible to "data" because it contains columns Label, ColRenamed, and RowRenamed (no column is Col or Row).
|
||||||
|
var invalidData = new TextLoader(Env, GetLoaderArgs(labelColumnName, matrixColumnIndexColumnName + "Renamed", matrixRowIndexColumnName+"Renamed"))
|
||||||
|
.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename)));
|
||||||
|
|
||||||
|
var est = new MatrixFactorizationTrainer(Env, labelColumnName, matrixColumnIndexColumnName, matrixRowIndexColumnName,
|
||||||
|
advancedSettings:s=>
|
||||||
|
{
|
||||||
|
s.NumIterations = 3;
|
||||||
|
s.NumThreads = 1;
|
||||||
|
s.K = 4;
|
||||||
|
});
|
||||||
|
|
||||||
|
TestEstimatorCore(est, data, invalidInput: invalidData);
|
||||||
|
|
||||||
|
Done();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void MatrixFactorizationSimpleTrainAndPredict()
|
||||||
|
{
|
||||||
|
using (var env = new LocalEnvironment(seed: 1, conc: 1))
|
||||||
|
{
|
||||||
|
// Specific column names of the considered data set
|
||||||
|
string labelColumnName = "Label";
|
||||||
|
string userColumnName = "User";
|
||||||
|
string itemColumnName = "Item";
|
||||||
|
string scoreColumnName = "Score";
|
||||||
|
|
||||||
|
// Create reader for both of training and test data sets
|
||||||
|
var reader = new TextLoader(env, GetLoaderArgs(labelColumnName, userColumnName, itemColumnName));
|
||||||
|
|
||||||
|
// Read training data as an IDataView object
|
||||||
|
var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.trainFilename)));
|
||||||
|
|
||||||
|
// Create a pipeline with a single operator.
|
||||||
|
var pipeline = new MatrixFactorizationTrainer(env, labelColumnName, userColumnName, itemColumnName,
|
||||||
|
advancedSettings:s=>
|
||||||
|
{
|
||||||
|
s.NumIterations = 3;
|
||||||
|
s.NumThreads = 1; // To eliminate randomness, # of threads must be 1.
|
||||||
|
s.K = 7;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Train a matrix factorization model.
|
||||||
|
var model = pipeline.Fit(data);
|
||||||
|
|
||||||
|
// Read the test data set as an IDataView
|
||||||
|
var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename)));
|
||||||
|
|
||||||
|
// Apply the trained model to the test set
|
||||||
|
var prediction = model.Transform(testData);
|
||||||
|
|
||||||
|
// Get output schema and check its column names
|
||||||
|
var outputSchema = model.GetOutputSchema(data.Schema);
|
||||||
|
var expectedOutputNames = new string[] { labelColumnName, userColumnName, itemColumnName, scoreColumnName };
|
||||||
|
foreach (var (i, col) in outputSchema.GetColumns())
|
||||||
|
Assert.True(col.Name == expectedOutputNames[i]);
|
||||||
|
|
||||||
|
// Retrieve label column's index from the test IDataView
|
||||||
|
testData.Schema.TryGetColumnIndex(labelColumnName, out int labelColumnId);
|
||||||
|
|
||||||
|
// Retrieve score column's index from the IDataView produced by the trained model
|
||||||
|
prediction.Schema.TryGetColumnIndex(scoreColumnName, out int scoreColumnId);
|
||||||
|
|
||||||
|
// Compute prediction errors
|
||||||
|
var mlContext = new MLContext();
|
||||||
|
var metrices = mlContext.Regression.Evaluate(prediction, label: labelColumnName, score: scoreColumnName);
|
||||||
|
|
||||||
|
// Determine if the selected metric is reasonable for differen
|
||||||
|
var expectedWindowsL2Error = 0.61528733643754685; // Windows baseline
|
||||||
|
var expectedMacL2Error = 0.61192207960271; // Mac baseline
|
||||||
|
var expectedLinuxL2Error = 0.616821448679879; // Linux baseline
|
||||||
|
double tolerance = System.Math.Pow(10, -DigitsOfPrecision);
|
||||||
|
bool inWindowsRange = expectedWindowsL2Error - tolerance < metrices.L2 && metrices.L2 < expectedWindowsL2Error + tolerance;
|
||||||
|
bool inMacRange = expectedMacL2Error - tolerance < metrices.L2 && metrices.L2 < expectedMacL2Error + tolerance;
|
||||||
|
bool inLinuxRange = expectedLinuxL2Error - tolerance < metrices.L2 && metrices.L2 < expectedLinuxL2Error + tolerance;
|
||||||
|
Assert.True(inWindowsRange || inMacRange || inLinuxRange);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private TextLoader.Arguments GetLoaderArgs(string labelColumnName, string matrixColumnIndexColumnName, string matrixRowIndexColumnName)
|
||||||
|
{
|
||||||
|
return new TextLoader.Arguments()
|
||||||
|
{
|
||||||
|
Separator = "\t",
|
||||||
|
HasHeader = true,
|
||||||
|
Column = new[]
|
||||||
|
{
|
||||||
|
new TextLoader.Column(labelColumnName, DataKind.R4, new [] { new TextLoader.Range(0) }),
|
||||||
|
new TextLoader.Column(matrixColumnIndexColumnName, DataKind.U4, new [] { new TextLoader.Range(1) }, new KeyRange(0, 19)),
|
||||||
|
new TextLoader.Column(matrixRowIndexColumnName, DataKind.U4, new [] { new TextLoader.Range(2) }, new KeyRange(0, 39)),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,623 @@
|
||||||
|
# This is the same matrix A as in the trivial training set, except it is
|
||||||
|
# all of the entries that were dropped from the training set.
|
||||||
|
Label Row Column
|
||||||
|
1 0 0
|
||||||
|
1 0 1
|
||||||
|
1 0 3
|
||||||
|
1 0 4
|
||||||
|
1 0 5
|
||||||
|
1 0 6
|
||||||
|
1 0 7
|
||||||
|
1 0 8
|
||||||
|
2 0 10
|
||||||
|
2 0 11
|
||||||
|
2 0 12
|
||||||
|
2 0 13
|
||||||
|
2 0 14
|
||||||
|
2 0 15
|
||||||
|
2 0 16
|
||||||
|
2 0 17
|
||||||
|
2 0 18
|
||||||
|
2 0 19
|
||||||
|
2 0 20
|
||||||
|
2 0 21
|
||||||
|
2 0 22
|
||||||
|
2 0 23
|
||||||
|
2 0 24
|
||||||
|
2 0 25
|
||||||
|
2 0 26
|
||||||
|
2 0 27
|
||||||
|
2 0 28
|
||||||
|
2 0 29
|
||||||
|
2 0 31
|
||||||
|
2 0 32
|
||||||
|
2 0 33
|
||||||
|
2 0 34
|
||||||
|
2 0 35
|
||||||
|
2 0 36
|
||||||
|
2 0 37
|
||||||
|
2 0 39
|
||||||
|
1 1 0
|
||||||
|
1 1 1
|
||||||
|
1 1 3
|
||||||
|
1 1 5
|
||||||
|
1 1 6
|
||||||
|
1 1 7
|
||||||
|
1 1 9
|
||||||
|
2 1 10
|
||||||
|
2 1 12
|
||||||
|
2 1 13
|
||||||
|
2 1 15
|
||||||
|
2 1 18
|
||||||
|
2 1 20
|
||||||
|
2 1 22
|
||||||
|
2 1 24
|
||||||
|
2 1 25
|
||||||
|
2 1 26
|
||||||
|
2 1 28
|
||||||
|
2 1 29
|
||||||
|
2 1 30
|
||||||
|
2 1 33
|
||||||
|
2 1 34
|
||||||
|
2 1 35
|
||||||
|
2 1 37
|
||||||
|
2 1 39
|
||||||
|
1 2 0
|
||||||
|
1 2 2
|
||||||
|
1 2 3
|
||||||
|
1 2 4
|
||||||
|
1 2 5
|
||||||
|
1 2 6
|
||||||
|
1 2 7
|
||||||
|
1 2 8
|
||||||
|
1 2 9
|
||||||
|
2 2 10
|
||||||
|
2 2 12
|
||||||
|
2 2 13
|
||||||
|
2 2 14
|
||||||
|
2 2 15
|
||||||
|
2 2 16
|
||||||
|
2 2 17
|
||||||
|
2 2 18
|
||||||
|
2 2 19
|
||||||
|
2 2 21
|
||||||
|
2 2 23
|
||||||
|
2 2 24
|
||||||
|
2 2 25
|
||||||
|
2 2 26
|
||||||
|
2 2 28
|
||||||
|
2 2 29
|
||||||
|
2 2 30
|
||||||
|
2 2 31
|
||||||
|
2 2 32
|
||||||
|
2 2 33
|
||||||
|
2 2 34
|
||||||
|
2 2 35
|
||||||
|
2 2 36
|
||||||
|
2 2 37
|
||||||
|
2 2 38
|
||||||
|
2 2 39
|
||||||
|
1 3 0
|
||||||
|
1 3 1
|
||||||
|
1 3 3
|
||||||
|
1 3 4
|
||||||
|
1 3 6
|
||||||
|
1 3 7
|
||||||
|
1 3 8
|
||||||
|
1 3 9
|
||||||
|
2 3 10
|
||||||
|
2 3 11
|
||||||
|
2 3 12
|
||||||
|
2 3 13
|
||||||
|
2 3 15
|
||||||
|
2 3 16
|
||||||
|
2 3 17
|
||||||
|
2 3 19
|
||||||
|
2 3 20
|
||||||
|
2 3 21
|
||||||
|
2 3 22
|
||||||
|
2 3 23
|
||||||
|
2 3 24
|
||||||
|
2 3 26
|
||||||
|
2 3 27
|
||||||
|
2 3 29
|
||||||
|
2 3 30
|
||||||
|
2 3 31
|
||||||
|
2 3 32
|
||||||
|
2 3 35
|
||||||
|
2 3 36
|
||||||
|
2 3 37
|
||||||
|
2 3 38
|
||||||
|
1 4 0
|
||||||
|
1 4 1
|
||||||
|
1 4 2
|
||||||
|
1 4 3
|
||||||
|
1 4 4
|
||||||
|
1 4 5
|
||||||
|
1 4 6
|
||||||
|
1 4 8
|
||||||
|
1 4 9
|
||||||
|
2 4 10
|
||||||
|
2 4 11
|
||||||
|
2 4 12
|
||||||
|
2 4 13
|
||||||
|
2 4 14
|
||||||
|
2 4 15
|
||||||
|
2 4 16
|
||||||
|
2 4 17
|
||||||
|
2 4 18
|
||||||
|
2 4 19
|
||||||
|
2 4 20
|
||||||
|
2 4 21
|
||||||
|
2 4 22
|
||||||
|
2 4 24
|
||||||
|
2 4 25
|
||||||
|
2 4 26
|
||||||
|
2 4 27
|
||||||
|
2 4 28
|
||||||
|
2 4 29
|
||||||
|
2 4 30
|
||||||
|
2 4 31
|
||||||
|
2 4 32
|
||||||
|
2 4 33
|
||||||
|
2 4 34
|
||||||
|
2 4 36
|
||||||
|
2 4 37
|
||||||
|
2 4 38
|
||||||
|
2 4 39
|
||||||
|
1 5 0
|
||||||
|
1 5 2
|
||||||
|
1 5 3
|
||||||
|
1 5 5
|
||||||
|
1 5 6
|
||||||
|
1 5 7
|
||||||
|
1 5 8
|
||||||
|
1 5 9
|
||||||
|
2 5 10
|
||||||
|
2 5 12
|
||||||
|
2 5 13
|
||||||
|
2 5 14
|
||||||
|
2 5 15
|
||||||
|
2 5 16
|
||||||
|
2 5 17
|
||||||
|
2 5 18
|
||||||
|
2 5 22
|
||||||
|
2 5 23
|
||||||
|
2 5 24
|
||||||
|
2 5 25
|
||||||
|
2 5 26
|
||||||
|
2 5 28
|
||||||
|
2 5 29
|
||||||
|
2 5 30
|
||||||
|
2 5 31
|
||||||
|
2 5 32
|
||||||
|
2 5 33
|
||||||
|
2 5 34
|
||||||
|
2 5 35
|
||||||
|
2 5 36
|
||||||
|
2 5 38
|
||||||
|
1 6 0
|
||||||
|
1 6 1
|
||||||
|
1 6 2
|
||||||
|
1 6 4
|
||||||
|
1 6 5
|
||||||
|
1 6 6
|
||||||
|
1 6 7
|
||||||
|
1 6 8
|
||||||
|
1 6 9
|
||||||
|
2 6 12
|
||||||
|
2 6 13
|
||||||
|
2 6 14
|
||||||
|
2 6 16
|
||||||
|
2 6 17
|
||||||
|
2 6 18
|
||||||
|
2 6 22
|
||||||
|
2 6 23
|
||||||
|
2 6 24
|
||||||
|
2 6 25
|
||||||
|
2 6 27
|
||||||
|
2 6 28
|
||||||
|
2 6 29
|
||||||
|
2 6 31
|
||||||
|
2 6 32
|
||||||
|
2 6 33
|
||||||
|
2 6 34
|
||||||
|
2 6 35
|
||||||
|
2 6 36
|
||||||
|
2 6 37
|
||||||
|
2 6 38
|
||||||
|
2 6 39
|
||||||
|
1 7 0
|
||||||
|
1 7 1
|
||||||
|
1 7 2
|
||||||
|
1 7 4
|
||||||
|
1 7 9
|
||||||
|
2 7 10
|
||||||
|
2 7 11
|
||||||
|
2 7 13
|
||||||
|
2 7 14
|
||||||
|
2 7 16
|
||||||
|
2 7 17
|
||||||
|
2 7 18
|
||||||
|
2 7 19
|
||||||
|
2 7 20
|
||||||
|
2 7 21
|
||||||
|
2 7 22
|
||||||
|
2 7 23
|
||||||
|
2 7 24
|
||||||
|
2 7 25
|
||||||
|
2 7 26
|
||||||
|
2 7 27
|
||||||
|
2 7 28
|
||||||
|
2 7 29
|
||||||
|
2 7 30
|
||||||
|
2 7 31
|
||||||
|
2 7 32
|
||||||
|
2 7 34
|
||||||
|
2 7 36
|
||||||
|
2 7 37
|
||||||
|
1 8 2
|
||||||
|
1 8 3
|
||||||
|
1 8 4
|
||||||
|
1 8 5
|
||||||
|
1 8 6
|
||||||
|
1 8 7
|
||||||
|
1 8 8
|
||||||
|
1 8 9
|
||||||
|
2 8 12
|
||||||
|
2 8 13
|
||||||
|
2 8 14
|
||||||
|
2 8 15
|
||||||
|
2 8 16
|
||||||
|
2 8 17
|
||||||
|
2 8 18
|
||||||
|
2 8 20
|
||||||
|
2 8 21
|
||||||
|
2 8 22
|
||||||
|
2 8 23
|
||||||
|
2 8 24
|
||||||
|
2 8 25
|
||||||
|
2 8 26
|
||||||
|
2 8 27
|
||||||
|
2 8 29
|
||||||
|
2 8 31
|
||||||
|
2 8 33
|
||||||
|
2 8 34
|
||||||
|
2 8 35
|
||||||
|
2 8 37
|
||||||
|
2 8 38
|
||||||
|
1 9 0
|
||||||
|
1 9 2
|
||||||
|
1 9 4
|
||||||
|
1 9 6
|
||||||
|
1 9 8
|
||||||
|
1 9 9
|
||||||
|
2 9 10
|
||||||
|
2 9 11
|
||||||
|
2 9 13
|
||||||
|
2 9 14
|
||||||
|
2 9 15
|
||||||
|
2 9 16
|
||||||
|
2 9 17
|
||||||
|
2 9 18
|
||||||
|
2 9 19
|
||||||
|
2 9 20
|
||||||
|
2 9 21
|
||||||
|
2 9 22
|
||||||
|
2 9 23
|
||||||
|
2 9 25
|
||||||
|
2 9 26
|
||||||
|
2 9 27
|
||||||
|
2 9 28
|
||||||
|
2 9 29
|
||||||
|
2 9 31
|
||||||
|
2 9 32
|
||||||
|
2 9 33
|
||||||
|
2 9 34
|
||||||
|
2 9 36
|
||||||
|
2 9 37
|
||||||
|
2 9 38
|
||||||
|
2 9 39
|
||||||
|
3 10 0
|
||||||
|
3 10 3
|
||||||
|
3 10 5
|
||||||
|
3 10 8
|
||||||
|
3 10 9
|
||||||
|
1 10 11
|
||||||
|
1 10 12
|
||||||
|
1 10 13
|
||||||
|
1 10 14
|
||||||
|
1 10 15
|
||||||
|
1 10 16
|
||||||
|
1 10 17
|
||||||
|
1 10 18
|
||||||
|
1 10 19
|
||||||
|
1 10 21
|
||||||
|
1 10 23
|
||||||
|
1 10 24
|
||||||
|
1 10 25
|
||||||
|
1 10 27
|
||||||
|
1 10 28
|
||||||
|
1 10 29
|
||||||
|
1 10 31
|
||||||
|
1 10 32
|
||||||
|
1 10 33
|
||||||
|
1 10 35
|
||||||
|
1 10 38
|
||||||
|
1 10 39
|
||||||
|
3 11 1
|
||||||
|
3 11 2
|
||||||
|
3 11 3
|
||||||
|
3 11 5
|
||||||
|
3 11 6
|
||||||
|
3 11 7
|
||||||
|
3 11 8
|
||||||
|
3 11 9
|
||||||
|
1 11 10
|
||||||
|
1 11 13
|
||||||
|
1 11 14
|
||||||
|
1 11 15
|
||||||
|
1 11 17
|
||||||
|
1 11 18
|
||||||
|
1 11 19
|
||||||
|
1 11 20
|
||||||
|
1 11 21
|
||||||
|
1 11 22
|
||||||
|
1 11 23
|
||||||
|
1 11 25
|
||||||
|
1 11 26
|
||||||
|
1 11 27
|
||||||
|
1 11 28
|
||||||
|
1 11 29
|
||||||
|
1 11 30
|
||||||
|
1 11 31
|
||||||
|
1 11 32
|
||||||
|
1 11 34
|
||||||
|
1 11 35
|
||||||
|
1 11 36
|
||||||
|
1 11 37
|
||||||
|
1 11 38
|
||||||
|
3 12 0
|
||||||
|
3 12 1
|
||||||
|
3 12 2
|
||||||
|
3 12 3
|
||||||
|
3 12 5
|
||||||
|
3 12 9
|
||||||
|
1 12 11
|
||||||
|
1 12 12
|
||||||
|
1 12 14
|
||||||
|
1 12 16
|
||||||
|
1 12 17
|
||||||
|
1 12 18
|
||||||
|
1 12 19
|
||||||
|
1 12 20
|
||||||
|
1 12 21
|
||||||
|
1 12 23
|
||||||
|
1 12 24
|
||||||
|
1 12 25
|
||||||
|
1 12 27
|
||||||
|
1 12 29
|
||||||
|
1 12 31
|
||||||
|
1 12 32
|
||||||
|
1 12 34
|
||||||
|
1 12 35
|
||||||
|
1 12 36
|
||||||
|
1 12 37
|
||||||
|
1 12 38
|
||||||
|
1 12 39
|
||||||
|
3 13 0
|
||||||
|
3 13 1
|
||||||
|
3 13 2
|
||||||
|
3 13 3
|
||||||
|
3 13 4
|
||||||
|
3 13 5
|
||||||
|
3 13 6
|
||||||
|
3 13 7
|
||||||
|
1 13 14
|
||||||
|
1 13 15
|
||||||
|
1 13 16
|
||||||
|
1 13 17
|
||||||
|
1 13 18
|
||||||
|
1 13 19
|
||||||
|
1 13 21
|
||||||
|
1 13 22
|
||||||
|
1 13 23
|
||||||
|
1 13 24
|
||||||
|
1 13 25
|
||||||
|
1 13 27
|
||||||
|
1 13 28
|
||||||
|
1 13 29
|
||||||
|
1 13 30
|
||||||
|
1 13 31
|
||||||
|
1 13 32
|
||||||
|
1 13 33
|
||||||
|
1 13 34
|
||||||
|
1 13 35
|
||||||
|
1 13 36
|
||||||
|
1 13 37
|
||||||
|
1 13 38
|
||||||
|
1 13 39
|
||||||
|
3 14 0
|
||||||
|
3 14 1
|
||||||
|
3 14 2
|
||||||
|
3 14 3
|
||||||
|
3 14 5
|
||||||
|
3 14 6
|
||||||
|
3 14 7
|
||||||
|
3 14 8
|
||||||
|
3 14 9
|
||||||
|
1 14 10
|
||||||
|
1 14 11
|
||||||
|
1 14 12
|
||||||
|
1 14 13
|
||||||
|
1 14 14
|
||||||
|
1 14 15
|
||||||
|
1 14 16
|
||||||
|
1 14 17
|
||||||
|
1 14 18
|
||||||
|
1 14 19
|
||||||
|
1 14 21
|
||||||
|
1 14 22
|
||||||
|
1 14 23
|
||||||
|
1 14 24
|
||||||
|
1 14 25
|
||||||
|
1 14 27
|
||||||
|
1 14 28
|
||||||
|
1 14 29
|
||||||
|
1 14 30
|
||||||
|
1 14 31
|
||||||
|
1 14 32
|
||||||
|
1 14 33
|
||||||
|
1 14 34
|
||||||
|
1 14 35
|
||||||
|
1 14 36
|
||||||
|
1 14 37
|
||||||
|
1 14 38
|
||||||
|
1 14 39
|
||||||
|
3 15 0
|
||||||
|
3 15 1
|
||||||
|
3 15 3
|
||||||
|
3 15 4
|
||||||
|
3 15 5
|
||||||
|
3 15 6
|
||||||
|
3 15 7
|
||||||
|
3 15 8
|
||||||
|
1 15 11
|
||||||
|
1 15 12
|
||||||
|
1 15 13
|
||||||
|
1 15 15
|
||||||
|
1 15 18
|
||||||
|
1 15 19
|
||||||
|
1 15 20
|
||||||
|
1 15 22
|
||||||
|
1 15 24
|
||||||
|
1 15 25
|
||||||
|
1 15 26
|
||||||
|
1 15 27
|
||||||
|
1 15 28
|
||||||
|
1 15 29
|
||||||
|
1 15 30
|
||||||
|
1 15 32
|
||||||
|
1 15 33
|
||||||
|
1 15 34
|
||||||
|
1 15 35
|
||||||
|
1 15 37
|
||||||
|
1 15 38
|
||||||
|
1 15 39
|
||||||
|
3 16 2
|
||||||
|
3 16 4
|
||||||
|
3 16 6
|
||||||
|
3 16 7
|
||||||
|
3 16 8
|
||||||
|
3 16 9
|
||||||
|
1 16 10
|
||||||
|
1 16 11
|
||||||
|
1 16 12
|
||||||
|
1 16 13
|
||||||
|
1 16 14
|
||||||
|
1 16 15
|
||||||
|
1 16 16
|
||||||
|
1 16 17
|
||||||
|
1 16 18
|
||||||
|
1 16 19
|
||||||
|
1 16 22
|
||||||
|
1 16 23
|
||||||
|
1 16 24
|
||||||
|
1 16 25
|
||||||
|
1 16 27
|
||||||
|
1 16 28
|
||||||
|
1 16 29
|
||||||
|
1 16 31
|
||||||
|
1 16 32
|
||||||
|
1 16 33
|
||||||
|
1 16 34
|
||||||
|
1 16 35
|
||||||
|
1 16 37
|
||||||
|
1 16 39
|
||||||
|
3 17 0
|
||||||
|
3 17 2
|
||||||
|
3 17 3
|
||||||
|
3 17 5
|
||||||
|
3 17 6
|
||||||
|
3 17 7
|
||||||
|
3 17 8
|
||||||
|
3 17 9
|
||||||
|
1 17 11
|
||||||
|
1 17 12
|
||||||
|
1 17 13
|
||||||
|
1 17 14
|
||||||
|
1 17 15
|
||||||
|
1 17 18
|
||||||
|
1 17 19
|
||||||
|
1 17 20
|
||||||
|
1 17 22
|
||||||
|
1 17 23
|
||||||
|
1 17 24
|
||||||
|
1 17 26
|
||||||
|
1 17 27
|
||||||
|
1 17 28
|
||||||
|
1 17 31
|
||||||
|
1 17 32
|
||||||
|
1 17 34
|
||||||
|
1 17 35
|
||||||
|
1 17 36
|
||||||
|
1 17 38
|
||||||
|
1 17 39
|
||||||
|
3 18 0
|
||||||
|
3 18 1
|
||||||
|
3 18 2
|
||||||
|
3 18 3
|
||||||
|
3 18 4
|
||||||
|
3 18 7
|
||||||
|
3 18 8
|
||||||
|
3 18 9
|
||||||
|
1 18 11
|
||||||
|
1 18 12
|
||||||
|
1 18 13
|
||||||
|
1 18 14
|
||||||
|
1 18 15
|
||||||
|
1 18 16
|
||||||
|
1 18 17
|
||||||
|
1 18 18
|
||||||
|
1 18 21
|
||||||
|
1 18 22
|
||||||
|
1 18 24
|
||||||
|
1 18 27
|
||||||
|
1 18 29
|
||||||
|
1 18 31
|
||||||
|
1 18 32
|
||||||
|
1 18 33
|
||||||
|
1 18 34
|
||||||
|
1 18 35
|
||||||
|
1 18 38
|
||||||
|
1 18 39
|
||||||
|
3 19 0
|
||||||
|
3 19 1
|
||||||
|
3 19 5
|
||||||
|
3 19 6
|
||||||
|
3 19 7
|
||||||
|
3 19 8
|
||||||
|
3 19 9
|
||||||
|
1 19 10
|
||||||
|
1 19 11
|
||||||
|
1 19 12
|
||||||
|
1 19 13
|
||||||
|
1 19 14
|
||||||
|
1 19 15
|
||||||
|
1 19 16
|
||||||
|
1 19 17
|
||||||
|
1 19 18
|
||||||
|
1 19 19
|
||||||
|
1 19 20
|
||||||
|
1 19 21
|
||||||
|
1 19 22
|
||||||
|
1 19 23
|
||||||
|
1 19 25
|
||||||
|
1 19 27
|
||||||
|
1 19 29
|
||||||
|
1 19 33
|
||||||
|
1 19 34
|
||||||
|
1 19 35
|
||||||
|
1 19 36
|
||||||
|
1 19 37
|
||||||
|
1 19 39
|
Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 3.
|
|
@ -0,0 +1,187 @@
|
||||||
|
# The idea here is this is a block 20 x 40 matrix A, where:
|
||||||
|
# A[ 0:10, 0:10] is 1
|
||||||
|
# A[ 0:10, 10:40] is 2
|
||||||
|
# A[10:40, 0:10] is 3
|
||||||
|
# A[10:40, 10:40] is 1
|
||||||
|
# In this training file each entry has a one fourth chance of getting dropped.
|
||||||
|
Label Row Column
|
||||||
|
1 14 20
|
||||||
|
1 19 26
|
||||||
|
3 17 4
|
||||||
|
1 10 20
|
||||||
|
1 3 5
|
||||||
|
1 7 5
|
||||||
|
1 18 36
|
||||||
|
2 1 36
|
||||||
|
2 1 38
|
||||||
|
3 17 1
|
||||||
|
2 6 26
|
||||||
|
2 9 30
|
||||||
|
3 13 8
|
||||||
|
2 7 33
|
||||||
|
2 8 30
|
||||||
|
3 10 1
|
||||||
|
1 18 25
|
||||||
|
1 13 12
|
||||||
|
1 3 2
|
||||||
|
2 8 28
|
||||||
|
1 11 24
|
||||||
|
2 3 28
|
||||||
|
2 1 16
|
||||||
|
1 9 7
|
||||||
|
1 15 16
|
||||||
|
3 19 4
|
||||||
|
1 1 8
|
||||||
|
1 8 0
|
||||||
|
1 10 34
|
||||||
|
1 18 37
|
||||||
|
2 1 17
|
||||||
|
2 8 39
|
||||||
|
1 17 30
|
||||||
|
2 1 27
|
||||||
|
2 0 38
|
||||||
|
1 11 16
|
||||||
|
3 19 3
|
||||||
|
1 7 8
|
||||||
|
1 13 13
|
||||||
|
1 19 31
|
||||||
|
3 16 1
|
||||||
|
1 5 1
|
||||||
|
2 6 11
|
||||||
|
1 9 5
|
||||||
|
3 10 6
|
||||||
|
1 1 2
|
||||||
|
2 6 30
|
||||||
|
2 7 15
|
||||||
|
1 17 21
|
||||||
|
1 18 23
|
||||||
|
3 10 7
|
||||||
|
2 5 39
|
||||||
|
2 2 27
|
||||||
|
3 12 6
|
||||||
|
3 11 4
|
||||||
|
1 9 3
|
||||||
|
1 12 22
|
||||||
|
2 8 19
|
||||||
|
2 1 14
|
||||||
|
1 11 11
|
||||||
|
1 10 36
|
||||||
|
3 12 4
|
||||||
|
1 15 21
|
||||||
|
1 17 37
|
||||||
|
1 6 3
|
||||||
|
2 3 18
|
||||||
|
1 10 10
|
||||||
|
1 11 33
|
||||||
|
1 18 19
|
||||||
|
2 7 35
|
||||||
|
3 10 2
|
||||||
|
1 12 30
|
||||||
|
1 12 26
|
||||||
|
2 1 31
|
||||||
|
2 5 21
|
||||||
|
2 1 11
|
||||||
|
1 7 3
|
||||||
|
2 8 36
|
||||||
|
3 10 4
|
||||||
|
1 18 26
|
||||||
|
2 8 10
|
||||||
|
1 10 22
|
||||||
|
1 15 14
|
||||||
|
3 16 0
|
||||||
|
2 0 30
|
||||||
|
2 3 34
|
||||||
|
3 13 9
|
||||||
|
1 0 2
|
||||||
|
1 15 36
|
||||||
|
1 15 23
|
||||||
|
1 10 30
|
||||||
|
2 6 20
|
||||||
|
2 9 24
|
||||||
|
2 9 35
|
||||||
|
1 7 6
|
||||||
|
2 7 39
|
||||||
|
2 5 20
|
||||||
|
3 12 8
|
||||||
|
2 9 12
|
||||||
|
1 17 25
|
||||||
|
1 12 33
|
||||||
|
2 6 19
|
||||||
|
1 17 10
|
||||||
|
2 4 35
|
||||||
|
1 15 31
|
||||||
|
3 12 7
|
||||||
|
1 17 16
|
||||||
|
2 1 19
|
||||||
|
2 3 25
|
||||||
|
1 16 30
|
||||||
|
1 19 30
|
||||||
|
1 5 4
|
||||||
|
2 6 10
|
||||||
|
1 18 20
|
||||||
|
1 13 26
|
||||||
|
2 3 39
|
||||||
|
2 2 20
|
||||||
|
1 4 7
|
||||||
|
2 3 33
|
||||||
|
1 16 20
|
||||||
|
2 1 21
|
||||||
|
3 15 2
|
||||||
|
3 19 2
|
||||||
|
1 12 10
|
||||||
|
2 5 37
|
||||||
|
2 1 32
|
||||||
|
3 18 6
|
||||||
|
1 2 1
|
||||||
|
1 16 21
|
||||||
|
2 1 23
|
||||||
|
1 17 33
|
||||||
|
2 5 11
|
||||||
|
2 3 14
|
||||||
|
1 11 12
|
||||||
|
1 13 20
|
||||||
|
1 19 38
|
||||||
|
1 15 10
|
||||||
|
2 8 11
|
||||||
|
3 11 0
|
||||||
|
1 18 10
|
||||||
|
1 19 24
|
||||||
|
1 13 11
|
||||||
|
2 4 23
|
||||||
|
1 16 26
|
||||||
|
1 7 7
|
||||||
|
1 17 29
|
||||||
|
1 18 30
|
||||||
|
1 13 10
|
||||||
|
2 6 21
|
||||||
|
1 19 32
|
||||||
|
2 7 12
|
||||||
|
1 12 28
|
||||||
|
2 2 11
|
||||||
|
1 12 15
|
||||||
|
2 8 32
|
||||||
|
3 15 9
|
||||||
|
3 16 5
|
||||||
|
1 9 1
|
||||||
|
1 19 28
|
||||||
|
3 16 3
|
||||||
|
1 15 17
|
||||||
|
2 7 38
|
||||||
|
1 16 38
|
||||||
|
1 14 26
|
||||||
|
1 10 26
|
||||||
|
1 10 37
|
||||||
|
3 18 5
|
||||||
|
2 5 27
|
||||||
|
2 2 22
|
||||||
|
1 11 39
|
||||||
|
1 16 36
|
||||||
|
1 0 9
|
||||||
|
2 5 19
|
||||||
|
1 18 28
|
||||||
|
1 12 13
|
||||||
|
1 17 17
|
||||||
|
1 8 1
|
||||||
|
2 6 15
|
||||||
|
3 14 4
|
||||||
|
1 1 4
|
Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 7.
|
Загрузка…
Ссылка в новой задаче