Merge pull request #483 from hebinhuang/ManagedRioSock

Add RIOSocketWrapper and SaeaSocketWrapper
This commit is contained in:
Hebin Huang 2016-07-10 07:32:20 -07:00 коммит произвёл GitHub
Родитель e75eb189e0 b6ec24119e
Коммит ffe223c2f9
52 изменённых файлов: 7175 добавлений и 754 удалений

5
.gitignore поставляемый
Просмотреть файл

@ -4,6 +4,7 @@
*.user
*.csproj.user
*.iml
*.VC.*
# Compiled source #
###################
@ -27,6 +28,10 @@
**/packages/
**/target/
**/.idea/
**/x64/
**/debug/
**/release/
**/TestResults/
scala/dependency-reduced-pom.xml
build/runtime/
build/tools/

Просмотреть файл

@ -106,6 +106,29 @@ if EXIST "%CMDHOME%\lib" (
)
)
:buildCpp
Set CppSkipped=0
@echo Assemble Mobius C++ components (RIOSOCK)
pushd %CMDHOME%\..\cpp
call Clean.cmd
call Build.cmd
if %ERRORLEVEL% EQU 2 (
set CppSkipped=1
popd
goto :buildCSharp
)
if %ERRORLEVEL% NEQ 0 (
@echo Build Mobius C++ RIOSOCK failed, stop building.
popd
goto :eof
)
@echo Mobius C++ RIOSOCK binaries
copy /y x64\Release\*.dll "%SPARKCLR_HOME%\bin\"
copy /y x64\Release\*.pdb "%SPARKCLR_HOME%\bin\"
popd
:buildCSharp
@echo Assemble Mobius C# components
pushd "%CMDHOME%\..\csharp"
@ -215,3 +238,16 @@ copy /Y "%CMDHOME%\..\notes\mobius-release-info.md"
:distdone
popd
if %CppSkipped% EQU 1 (
@echo.
@echo ============================================================================================
@echo.
@echo Note!!! Skipped to build Mobius C++ components due to missing VC++ Build Toolset.
@echo If you want to compile C++ components, please enalble VC++ language from
@echo Visual Studio. You can either download "Visual C++ Build Tools" availabe at
@echo "http://landinghub.visualstudio.com/visual-cpp-build-tools"
@echo.
@echo ============================================================================================
@echo.
)

83
cpp/Build.cmd Normal file
Просмотреть файл

@ -0,0 +1,83 @@
@setlocal
@ECHO off
SET CMDHOME=%~dp0
@REM Remove trailing backslash \
set CMDHOME=%CMDHOME:~0,-1%
set PROJ_NAME=Riosock
set PROJ=%CMDHOME%\%PROJ_NAME%.sln
@REM Set msbuild location.
SET VisualStudioVersion=12.0
if EXIST "%VS140COMNTOOLS%" SET VisualStudioVersion=14.0
SET VCBuildTool="%VS120COMNTOOLS:~0,-14%VC\bin\cl.exe"
if EXIST "%VS140COMNTOOLS%" SET VCBuildTool="%VS140COMNTOOLS:~0,-14%VC\bin\cl.exe"
if NOT EXIST %VCBuildTool% GOTO :ErrorNoCLEXE
SET MSBUILDEXEDIR=%programfiles(x86)%\MSBuild\%VisualStudioVersion%\Bin
if NOT EXIST "%MSBUILDEXEDIR%\." SET MSBUILDEXEDIR=%programfiles%\MSBuild\%VisualStudioVersion%\Bin
if NOT EXIST "%MSBUILDEXEDIR%\." GOTO :ErrorMSBUILD
SET MSBUILDEXE=%MSBUILDEXEDIR%\MSBuild.exe
SET MSBUILDOPT=/verbosity:minimal
if "%builduri%" == "" set builduri=Build.cmd
cd "%CMDHOME%"
@cd
@echo ===== Building %PROJ% =====
@echo Build Debug ==============================
SET STEP=Debug
SET CONFIGURATION=%STEP%
SET STEP=%CONFIGURATION%
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION% %MSBUILDOPT% "%PROJ%"
@if ERRORLEVEL 1 GOTO :ErrorStop
@echo BUILD ok for %CONFIGURATION% %PROJ%
@echo Build Release ============================
SET STEP=Release
SET CONFIGURATION=%STEP%
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION% %MSBUILDOPT% "%PROJ%"
@if ERRORLEVEL 1 GOTO :ErrorStop
@echo BUILD ok for %CONFIGURATION% %PROJ%
if EXIST %PROJ_NAME%.nuspec (
@echo ===== Build NuGet package for %PROJ% =====
SET STEP=NuGet-Pack
powershell -f %CMDHOME%\..\build\localmode\nugetpack.ps1
@if ERRORLEVEL 1 GOTO :ErrorStop
@echo NuGet package ok for %PROJ%
)
@echo ===== Build succeeded for %PROJ% =====
@GOTO :EOF
:ErrorNoCLEXE
set RC=2
@echo ===== WARNING: Build skipped due to missing VC++ Build Toolset. =====
@echo ===== Build SKIPPED for %PROJ% =====
exit /B %RC%
:ErrorMSBUILD
set RC=1
@echo ===== Build FAILED due to missing MSBUILD.EXE. =====
@echo ===== Mobius requires "Developer Command Prompt for VS2013" and above =====
exit /B %RC%
:ErrorStop
set RC=%ERRORLEVEL%
if "%STEP%" == "" set STEP=%CONFIGURATION%
@echo ===== Build FAILED for %PROJ% -- %STEP% with error %RC% - CANNOT CONTINUE =====
exit /B %RC%
:EOF

4
cpp/Clean.cmd Normal file
Просмотреть файл

@ -0,0 +1,4 @@
@ECHO OFF
FOR /D /R . %%G IN (bin) DO @IF EXIST "%%G" (@echo RDMR /S /Q "%%G" & rd /s /q "%%G")
FOR /D /R . %%G IN (obj) DO @IF EXIST "%%G" (@echo RDMR /S /Q "%%G" & rd /s /q "%%G")
FOR /D /R . %%G IN (x64) DO @IF EXIST "%%G" (@echo RDMR /S /Q "%%G" & rd /s /q "%%G")

22
cpp/Riosock.sln Normal file
Просмотреть файл

@ -0,0 +1,22 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 2013
VisualStudioVersion = 12.0.30501.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Riosock", "Riosock\Riosock.vcxproj", "{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
Release|x64 = Release|x64
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}.Debug|x64.ActiveCfg = Debug|x64
{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}.Debug|x64.Build.0 = Debug|x64
{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}.Release|x64.ActiveCfg = Release|x64
{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
EndGlobal

109
cpp/Riosock/Locks.h Normal file
Просмотреть файл

@ -0,0 +1,109 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#pragma once
#include <windows.h>
class PrioritizedLock {
public:
PrioritizedLock() throw()
{
InitializeSRWLock(&srwlock);
InitializeCriticalSectionEx(&cs, 4000, 0);
}
~PrioritizedLock() throw()
{
DeleteCriticalSection(&cs);
}
// taking an *exclusive* lock to interrupt the *shared* lock taken by the deque IO path
// - we want to interrupt the IO path so we can initiate more IO if we need to grow the CQ
_Acquires_exclusive_lock_(this->srwlock)
_Acquires_lock_(this->cs)
void PriorityLock() throw()
{
AcquireSRWLockExclusive(&srwlock);
EnterCriticalSection(&cs);
}
_Releases_lock_(this->cs)
_Releases_exclusive_lock_(this->srwlock)
void PriorityRelease() throw()
{
LeaveCriticalSection(&cs);
ReleaseSRWLockExclusive(&srwlock);
}
_Acquires_shared_lock_(this->srwlock)
_Acquires_lock_(this->cs)
void DefaultLock() throw()
{
AcquireSRWLockShared(&srwlock);
EnterCriticalSection(&cs);
}
_Releases_lock_(this->cs)
_Releases_shared_lock_(this->srwlock)
void DefaultRelease() throw()
{
LeaveCriticalSection(&cs);
ReleaseSRWLockShared(&srwlock);
}
/// not copyable
PrioritizedLock(const PrioritizedLock&) = delete;
PrioritizedLock& operator=(const PrioritizedLock&) = delete;
private:
SRWLOCK srwlock;
CRITICAL_SECTION cs;
};
class AutoReleasePriorityLock {
public:
explicit AutoReleasePriorityLock(PrioritizedLock &priorityLock) throw()
: prioritizedLock(priorityLock)
{
prioritizedLock.PriorityLock();
}
~AutoReleasePriorityLock() throw()
{
prioritizedLock.PriorityRelease();
}
/// no default ctor
AutoReleasePriorityLock() = delete;
/// non-copyable
AutoReleasePriorityLock(const AutoReleasePriorityLock&) = delete;
AutoReleasePriorityLock operator=(const AutoReleasePriorityLock&) = delete;
private:
PrioritizedLock &prioritizedLock;
};
class AutoReleaseDefaultLock {
public:
explicit AutoReleaseDefaultLock(PrioritizedLock &priorityLock) throw()
: prioritizedLock(priorityLock)
{
prioritizedLock.DefaultLock();
}
~AutoReleaseDefaultLock() throw()
{
prioritizedLock.DefaultRelease();
}
/// no default ctor
AutoReleaseDefaultLock() = delete;
/// non-copyable
AutoReleaseDefaultLock(const AutoReleaseDefaultLock&) = delete;
AutoReleaseDefaultLock operator=(const AutoReleaseDefaultLock&) = delete;
private:
PrioritizedLock &prioritizedLock;
};

871
cpp/Riosock/Riosock.cpp Normal file
Просмотреть файл

@ -0,0 +1,871 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#define WIN32_LEAN_AND_MEAN
//#include "stdafx.h"
#include <windows.h>
#include <winsock2.h>
#include <mswsock.h>
#include <mstcpip.h>
#include "RIOSock.h"
#include <cstdlib>
#include "Locks.h"
#include <new>
#include <cassert>
// Need to link with Ws2_32.lib
#pragma comment (lib, "Ws2_32.lib")
const DWORD DefaultRIOCQSize = 256;
DWORD CQSize = 0;
DWORD CQUsed = 0;
LONG RIOSockRef;
PrioritizedLock *CQAccessLock = nullptr;
RIO_CQ CompletionQueue = nullptr;
RIO_RQ RequestQueue = nullptr;
RIO_NOTIFICATION_COMPLETION CompletionType;
RIO_EXTENSION_FUNCTION_TABLE RIOFuncs = { 0 };
//
// Local Functions
//
HRESULT EnsureWinSockMethods(_In_ SOCKET socket);
RIO_CQ CreateRIOCompletionQueue(_In_ DWORD queueSize, _In_opt_ PRIO_NOTIFICATION_COMPLETION pNotificationCompletion);
void CloseRIOCompletionQueue(_In_ RIO_CQ cq);
ULONG DequeueRIOCompletion(_In_ RIO_CQ cq, _Out_writes_to_(arraySize, return) PRIORESULT array, _In_ ULONG arraySize);
BOOL ResizeRIOCompletionQueue(_In_ RIO_CQ cq, _In_ ULONG queueSize);
//+
// Function:
// EnsureWinSockMethods()
//
// Description:
// Static function only to be called locally to ensure WSAStartup is held
// for the function pointers to remain accurate
//
// Result:
// Returns a registered buffer descriptor, if no errors occurs.
// Otherwise, a value of RIO_INVALID_BUFFERID is returned.
//-
LONG WinSockMethodsLock = 0;
HRESULT EnsureWinSockMethods(
_In_ SOCKET socket
)
{
static const LONG LockUninitialized = 0;
static const LONG LockInitialized = 1;
static const LONG LockInitializing = 2;
LONG lastState;
while((lastState = ::InterlockedCompareExchange(
&WinSockMethodsLock,
LockInitializing,
LockUninitialized)) == LockInitializing)
{
Sleep(0);
}
if (lastState == LockInitialized)
{
return S_OK;
}
WSADATA wsaData;
auto err = WSAStartup(WINSOCK_VERSION, &wsaData);
if (err != 0) {
// Reset lock to uninitialized
::InterlockedExchange(&WinSockMethodsLock, LockUninitialized);
// WSAStartup does not set LastWin32Error
SetLastError(err);
return HRESULT_FROM_WIN32(err);
}
// Check to see if we need to create a temp socket
auto localSocket = socket;
if (INVALID_SOCKET == localSocket)
{
DWORD dwFlags = WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED | WSA_FLAG_REGISTERED_IO;
localSocket = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, dwFlags);
if (INVALID_SOCKET == localSocket)
{
DWORD errorCode = WSAGetLastError();
// Reset lock to uninitialized
WSACleanup();
::InterlockedExchange(&WinSockMethodsLock, LockUninitialized);
SetLastError(errorCode);
return HRESULT_FROM_WIN32(errorCode);
}
}
GUID funcGuid = WSAID_MULTIPLE_RIO;
DWORD dwBytes = 0;
if (WSAIoctl(
localSocket,
SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
&funcGuid, sizeof(GUID),
&RIOFuncs, sizeof(RIOFuncs),
&dwBytes,nullptr, nullptr) != 0)
{
DWORD errorCode = WSAGetLastError();
if (localSocket != socket)
{
closesocket(localSocket);
}
WSACleanup();
// Reset lock to uninitialized
::InterlockedExchange(&WinSockMethodsLock, LockUninitialized);
SetLastError(errorCode);
return HRESULT_FROM_WIN32(errorCode);
}
// Update lock to fully Initialized
::InterlockedExchange(&WinSockMethodsLock, LockInitialized);
if (localSocket != socket) {
closesocket(localSocket);
}
return S_OK;
}
//+
// Function:
// CreateRIOCompletionQueue()
//
// Description:
// Internally, this function calls RIOCreateCompletionQueue to
// create a completion queue.
//
// Result:
// Returns a RIO_CQ
//-
FORCEINLINE
RIO_CQ CreateRIOCompletionQueue(
_In_ DWORD queueSize,
_In_opt_ PRIO_NOTIFICATION_COMPLETION pNotificationCompletion
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return RIO_INVALID_CQ;
}
return RIOFuncs.RIOCreateCompletionQueue(queueSize, pNotificationCompletion);
}
//+
// Function:
// CloseRIOCompletionQueue()
//
// Description:
// Internally, this function calls RIOCloseCompletionQueue to
// close a completion queue.
//
// Result:
// None.
//-
FORCEINLINE
void CloseRIOCompletionQueue(
_In_ RIO_CQ cq
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return;
}
RIOFuncs.RIOCloseCompletionQueue(cq);
}
//+
// Function:
// DequeueRIOCompletion()
//
// Description:
// Internally, this function calls RIODequeueCompletion to
// remove entries from an I/O completion queue.
//
// Result:
// Returns the number of completion entries removed from the specified completion queue.
//-
FORCEINLINE
ULONG DequeueRIOCompletion(
_In_ RIO_CQ cq,
_Out_writes_to_(arraySize, return) PRIORESULT array,
_In_ ULONG arraySize
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return RIO_CORRUPT_CQ;
}
return RIOFuncs.RIODequeueCompletion(cq, array, arraySize);
}
//+
// Function:
// ResizeRIOCompletionQueue()
//
// Description:
// Internally, this function calls RIOResizeCompletionQueue to
// resizes the I/O completion queue.
//
// Result:
// None.
//-
FORCEINLINE
BOOL ResizeRIOCompletionQueue(
_In_ RIO_CQ cq,
_In_ ULONG queueSize
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return FALSE;
}
return RIOFuncs.RIOResizeCompletionQueue(cq, queueSize);
}
//
// Global APIs
//
//+
// Function:
// RIOSockInitialize()
//
// Description:
// This function is global initializer for RIOSock.dll and must be called
// before any RIOSock APIs are invoked.
//
// Result:
// HRESULT codes.
//-
HRESULT RIOSOCKAPI RIOSockInitialize()
{
// Return if already initialized
if (RIOSockRef > 0) {
InterlockedIncrement(&RIOSockRef);
return S_OK;
}
// Create lock for Completion Queue access
CQAccessLock = new (std::nothrow) PrioritizedLock;
if (nullptr == CQAccessLock) {
return E_OUTOFMEMORY;
}
// Create IOCP handle
auto iocpHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 0);
if (iocpHandle == nullptr)
{
DWORD errorCode = GetLastError();
delete CQAccessLock;
SetLastError(errorCode);
return HRESULT_FROM_WIN32(errorCode);
}
// Create OVERLAPPED
auto overLapped = calloc(1, sizeof(OVERLAPPED));
if (nullptr == overLapped) {
// Close IOCP handle
CloseHandle(iocpHandle);
delete CQAccessLock;
SetLastError(WSAENOBUFS);
return HRESULT_FROM_WIN32(WSAENOBUFS);
}
// With RIO, we don't associate the IOCP handle with the socket like 'typical' sockets
// - Instead we directly pass the IOCP handle through RIOCreateCompletionQueue
::ZeroMemory(&CompletionType, sizeof(CompletionType));
CompletionType.Type = RIO_IOCP_COMPLETION;
CompletionType.Iocp.CompletionKey = reinterpret_cast<void*>(1);
CompletionType.Iocp.Overlapped = overLapped;
CompletionType.Iocp.IocpHandle = iocpHandle;
// Create a completion queue
CompletionQueue = CreateRIOCompletionQueue(DefaultRIOCQSize, &CompletionType);
if (RIO_INVALID_CQ == CompletionQueue) {
DWORD errorCode = WSAGetLastError();
CloseHandle(iocpHandle);
free(overLapped);
delete CQAccessLock;
SetLastError(errorCode);
return HRESULT_FROM_WIN32(errorCode);
}
// now that the CQ is created, update info
CQSize = DefaultRIOCQSize;
CQUsed = 0;
return S_OK;
}
//+
// Function:
// RIOSockUninitialize()
//
// Description:
// This function cleans up resources allocated by RIOSockInitialize.
//
// Result:
// None.
//-
void RIOSOCKAPI RIOSockUninitialize()
{
InterlockedDecrement(&RIOSockRef);
if (RIOSockRef > 0) return;
if (CompletionQueue != RIO_INVALID_CQ) {
CloseRIOCompletionQueue(CompletionQueue);
CompletionQueue = RIO_INVALID_CQ;
}
if (CompletionType.Iocp.IocpHandle != nullptr) {
CloseHandle(CompletionType.Iocp.IocpHandle);
CompletionType.Iocp.IocpHandle = nullptr;
}
free(CompletionType.Iocp.Overlapped);
CompletionType.Iocp.Overlapped = nullptr;
delete CQAccessLock;
CQAccessLock = nullptr;
}
//+
// Function:
// CreateRIOSocket()
//
// Description:
// This function creates a socket that bound to a local loop-back for use with RIO.
//
// Parameters:
// localAddr - A pointer to the beginning of the memory buffer to register.
// localAddrLen - The length, in bytes, in the buffer to register.
//
// Result:
// Returns a new socket, if no errors occurs. Otherwise, a value of INVALID_SOCKET is returned.
//-
SOCKET RIOSOCKAPI CreateRIOSocket(
_Out_ SOCKADDR *localAddr,
_Inout_ int *localAddrLen
)
{
// DWORD dwFlags = WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED | WSA_FLAG_REGISTERED_IO;
auto socket = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_REGISTERED_IO);
if (INVALID_SOCKET == socket)
{
DWORD errorCode = WSAGetLastError();
SetLastError(errorCode);
return INVALID_SOCKET;
}
// Enables SIO_LOOPBACK_FAST_PATH
auto OptionValue = 1;
DWORD NumberOfBytesReturned = 0;
if (WSAIoctl(socket,
SIO_LOOPBACK_FAST_PATH,
&OptionValue,
sizeof(OptionValue),
nullptr, 0,
&NumberOfBytesReturned,
nullptr, nullptr) == SOCKET_ERROR)
{
DWORD errorCode = WSAGetLastError();
closesocket(socket);
SetLastError(errorCode);
return INVALID_SOCKET;
}
// Bind socket for exclusive access
const BOOL bindExclUse = 1;
if (setsockopt(socket,
SOL_SOCKET,
SO_EXCLUSIVEADDRUSE,
reinterpret_cast<const char *>(&bindExclUse),
sizeof(bindExclUse)) == SOCKET_ERROR)
{
// Unexpected failure: report it then close our socket
DWORD errorCode = WSAGetLastError();
closesocket(socket);
SetLastError(errorCode);
return INVALID_SOCKET;
}
// Bind the socket to the loop-back address
SOCKADDR_IN sockAddr;
sockAddr.sin_family = AF_INET;
sockAddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
sockAddr.sin_port = htons(0);
if (bind(socket, reinterpret_cast<SOCKADDR *>(&sockAddr), sizeof(sockAddr)) == SOCKET_ERROR)
{
DWORD errorCode = WSAGetLastError();
closesocket(socket);
SetLastError(errorCode);
return INVALID_SOCKET;
}
// Retrieve the local name (addr) of the socket
if (getsockname(socket, localAddr, localAddrLen) == SOCKET_ERROR)
{
DWORD errorCode = WSAGetLastError();
closesocket(socket);
SetLastError(errorCode);
return INVALID_SOCKET;
}
return socket;
}
//+
// Function:
// PostRIOReceive()
//
// Description:
// This function posts a receive operation to receives data on a connected RIO socket.
//
// Parameters:
// socketQueue - The request queue that identifies a connected RIO socket.
// pData - The portion of the registered buffer in which to receive data.
// dataBufferCount - The data buffer count of the buffer pointed to by the pData parameter.
// flags - A set of flags that modify the behavior of the RIOReceive function.
// requestContext - The request context to associate with this receive operation.
//
// Result:
// Returns true if no error occurs. Otherwise, a value of false is returned.
//-
FORCEINLINE
BOOL RIOSOCKAPI PostRIOReceive(
_In_ RIO_RQ socketQueue,
_In_ PRIO_BUF pData,
_In_ ULONG dataBufferCount,
_In_ DWORD flags,
_In_ PVOID requestContext
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return FALSE;
}
return RIOFuncs.RIOReceive(socketQueue, pData, dataBufferCount, flags, requestContext);
}
//+
// Function:
// PostRIOSend()
//
// Description:
// This function posts a send operation to send data on a connected RIO socket.
//
// Parameters:
// socketQueue - The request queue that identifies a connected RIO socket.
// pData - The portion of the registered buffer in which to receive data.
// dataBufferCount - The data buffer count of the buffer pointed to by the pData parameter.
// flags - A set of flags that modify the behavior of the RIOReceive function.
// requestContext - The request context to associate with this receive operation.
//
// Result:
// Returns TRUE if no error occurs. Otherwise, a value of FALSE is returned.
//-
FORCEINLINE
BOOL RIOSOCKAPI PostRIOSend(
_In_ RIO_RQ socketQueue,
_In_ PRIO_BUF pData,
_In_ ULONG dataBufferCount,
_In_ DWORD flags,
_In_ PVOID requestContext
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return FALSE;
}
return RIOFuncs.RIOSend(socketQueue, pData, dataBufferCount, flags, requestContext);
}
//+
// Function:
// RegisterRIONotify()
//
// Description:
// This function registers the method to use for notification behavior.
//
// Result:
// Returns TRUE if no error occurs. Otherwise, a value of FALSE is returned.
//-
FORCEINLINE
BOOL RIOSOCKAPI RegisterRIONotify()
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return FALSE;
}
auto notify = RIOFuncs.RIONotify(CompletionQueue);
if (notify != ERROR_SUCCESS) {
SetLastError(notify);
return FALSE;
}
return TRUE;
}
//+
// Function:
// RegisterRIOBuffer()
//
// Description:
// This function registers a specified buffer for use with RIO Socket.
//
// Parameters:
// dataBuffer - A pointer to the beginning of the memory buffer to register.
// dataLength - The length, in bytes, in the buffer to register.
//
// Result:
// Returns a registered buffer descriptor, if no errors occurs.
// Otherwise, a value of RIO_INVALID_BUFFERID is returned.
//-
FORCEINLINE
RIO_BUFFERID RIOSOCKAPI RegisterRIOBuffer(
_In_ PCHAR dataBuffer,
_In_ DWORD dataLength
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return RIO_INVALID_BUFFERID;
}
return RIOFuncs.RIORegisterBuffer(dataBuffer, dataLength);
}
//+
// Function:
// DeregisterRIOBuffer()
//
// Description:
// This function deregisters a registered buffer used with RIO socket.
//
// Parameters:
// bufferId - A descriptor identifying a registered buffer.
//
// Result:
// None.
//-
FORCEINLINE
void RIOSOCKAPI DeregisterRIOBuffer(
_In_ RIO_BUFFERID bufferId
)
{
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return;
}
return RIOFuncs.RIODeregisterBuffer(bufferId);
}
//+
// Function:
// AllocateRIOCompletion()
//
// Description:
// This function makes rooms in the CQ for a new IO.
// - check if there is room in the CQ for the new IO
// - if not, take a writer lock around the CS to halt readers and to write to cq_used
// then take the CS over the CQ, and resize the CQ by 1.5 times current size
//
// Parameters:
// numCompltetion - The number of completions that need to be allocated from the completion queue.
//
// Result:
// return TRUE, if no error occurs; otherwise, FALSE.
//-
FORCEINLINE
BOOL RIOSOCKAPI AllocateRIOCompletion(
_In_ DWORD numCompltetion
)
{
// Taking an priority lock to interrupt the general lock taken by the deque IO path
// We want to interrupt the IO path so we can initiate more IO if we need to grow the CQ
AutoReleasePriorityLock priorityLock(*CQAccessLock);
auto newCQUsed = CQUsed + numCompltetion;
auto newCQSize = CQSize; // not yet resized
if (CQSize < newCQUsed) {
if (RIO_MAX_CQ_SIZE == CQSize || newCQUsed > RIO_MAX_CQ_SIZE)
{
// fail hard if we are already at the max CQ size and can't grow it for more IO
return FALSE;
}
// multiply newCQUsed by 1.25 for better growth patterns
newCQSize = static_cast<DWORD>(newCQUsed * 1.25);
if (newCQSize > RIO_MAX_CQ_SIZE) {
static_assert(MAXLONG / 1.25 > RIO_MAX_CQ_SIZE, "CQSize can overflow");
newCQSize = RIO_MAX_CQ_SIZE;
}
if (!ResizeRIOCompletionQueue(CompletionQueue, newCQSize))
{
return FALSE;
}
}
// update CQUsed and CQSize on the success path
CQUsed = newCQUsed;
CQSize = newCQSize;
return TRUE;
}
//+
// Function:
// ReleaseRIOCompletion()
//
// Description:
// This function release rooms back to the CQ
//
// Parameters:
// numCompltetion - The number of completions that need to be released.
//
// Result:
// return FALSE, if numCompltetion > CQUsed; otherwise, TRUE.
//-
FORCEINLINE
BOOL RIOSOCKAPI ReleaseRIOCompletion(
_In_ DWORD numCompletion
)
{
AutoReleasePriorityLock priorityLock(*CQAccessLock);
if (CQUsed < numCompletion)
{
return FALSE;
}
CQUsed -= numCompletion;
return TRUE;
}
//+
// Function:
// DequeueRIOResults()
//
// Description:
// This function dequeue RIO results from the I/O completion queue used with RIO socket.
// It will always post a Notify with proper synchronization.
//
// Parameters:
// rioResults - An array of RIORESULT structures to receive the description of the completions dequeued.
// rioResultSize - The maximum number of entries in the rioResults to write.
//
// Result:
// If no error occurs, it returns the number of RIO results retrieved from the completion queue.
// Otherwise, a value of RIO_CORRUPT_CQ is returned to indicate that the state of the completion
// queue has become corrupt due to memory corruption or misuse of the RIO functions.
//-
FORCEINLINE
DWORD RIOSOCKAPI DequeueRIOResults(
_Out_ PRIORESULT rioResults,
_In_ DWORD rioResultSize
)
{
// Taking a lower-priority lock, to allow the priority lock to interrupt
// dequeuing. So it can add space to the CQ
AutoReleaseDefaultLock defaultLock(*CQAccessLock);
auto resultCount = DequeueRIOCompletion(CompletionQueue, rioResults, rioResultSize);
if (0 == resultCount || RIO_CORRUPT_CQ == resultCount)
{
// We were notified there were completions, but we can't dequeue any IO
// Something has gone horribly wrong - likely our CQ is corrupt.
return resultCount;
}
// Immediately after invoking Dequeue, post another Notify
auto notifyResult = RegisterRIONotify();
if (notifyResult == FALSE)
{
// if notify fails, we can't reliably know when the next IO completes
// this will cause everything to come to a grinding halt
return RIO_CORRUPT_CQ;
}
return resultCount;
}
//+
// Function:
// CreateRIORequestQueue()
//
// Description:
// This function creates a request queue by calling RIOCreateRequestQueue.
//
// Parameters:
// socket - A socket to for the new request queue.
// maxOutstandingReceive - The maximum number of outstanding receives allowed on the socket.
// maxOutstandingSend - The maximum number of outstanding sends allowed on the socket.
// socketContext - The socket context to associate with this request queue.
//
// Result:
// If no error occurs, it returns a new request queue. Otherwise, a value of RIO_INVALID_RQ is returned.
//-
FORCEINLINE
RIO_RQ RIOSOCKAPI CreateRIORequestQueue(
_In_ SOCKET socket,
_In_ ULONG maxOutstandingReceive,
_In_ ULONG maxOutstandingSend,
_In_ PVOID socketContext
)
{
// A request queue is associated with a socket, ensure that the client passed us a valid socket
assert(socket != INVALID_SOCKET);
auto hr = EnsureWinSockMethods(socket);
if (FAILED(hr))
{
return RIO_INVALID_RQ;
}
return RIOFuncs.RIOCreateRequestQueue(
socket,
maxOutstandingReceive,
1,
maxOutstandingSend,
1,
CompletionQueue,
CompletionQueue,
socketContext
);
}
//+
// Function:
// ResizeRIORequestQueue()
//
// Description:
// This function resizes a request queue by calling RIOResizeRequestQueue.
//
// Parameters:
// rq - A request queue to be resize.
// maxOutstandingReceive - The maximum number of outstanding receives allowed on the socket.
// maxOutstandingSend - The maximum number of outstanding sends allowed on the socket.
//
// Result:
// If no error occurs, it returns TRUE. Otherwise, a value of FALSE is returned.
//-
BOOL RIOSOCKAPI ResizeRIORequestQueue(
_In_ RIO_RQ rq,
_In_ DWORD maxOutstandingReceive,
_In_ DWORD maxOutstandingSend
)
{
// ensure that the client passed us a valid RQ
assert(rq != RIO_INVALID_RQ);
auto hr = EnsureWinSockMethods(INVALID_SOCKET);
if (FAILED(hr))
{
return FALSE;
}
return RIOFuncs.RIOResizeRequestQueue(rq, maxOutstandingReceive, maxOutstandingSend);
}
//+
// Function:
// GetRIOCompletionStatus()
//
// Description:
// This function calls GetQueuedCompletionStatus() internally to dequeue an IO completion packet.
// If there is no completion packet queued, the function blocks the thread.
//
// Result:
// If no error occurs, it returns a new request queue. Otherwise, a value of RIO_INVALID_RQ is returned.
//-
FORCEINLINE
BOOL RIOSOCKAPI GetRIOCompletionStatus()
{
DWORD bytesTransferred;
ULONG_PTR completionKey;
OVERLAPPED *pov = nullptr;
if (!GetQueuedCompletionStatus(
CompletionType.Iocp.IocpHandle,
&bytesTransferred,
&completionKey,
&pov,
INFINITE))
{
auto lastError = GetLastError();
SetLastError(lastError);
return FALSE;
}
return TRUE;
}
//////////////////////////////////////////////////////////////////////////
//+
// DLL Entry
//-
BOOL APIENTRY DllMain(HMODULE hModule, DWORD dwReason, LPVOID lpReserved)
{
UNREFERENCED_PARAMETER(lpReserved);
if (dwReason == DLL_PROCESS_ATTACH)
{
// Initializes use of Winsock 2 DLL
WSADATA wsaData;
if (WSAStartup(WINSOCK_VERSION, &wsaData) != 0)
{
return FALSE;
}
// Disables the DLL_THREAD_ATTACH and DLL_THREAD_DETACH notifications
DisableThreadLibraryCalls(hModule);
return TRUE;
}
if (dwReason == DLL_PROCESS_DETACH)
{
// Terminates use of the Winsock 2 DLL
WSACleanup();
}
return TRUE;
}

104
cpp/Riosock/Riosock.h Normal file
Просмотреть файл

@ -0,0 +1,104 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#ifdef RIOSOCK_DLL
#define RIOSOCKAPI __declspec(dllexport) __stdcall
#else
#define RIOSOCKAPI __declspec(dllimport) __stdcall
#endif
//
// Global Routings
//
// Global initializer for RIOSock.dll
HRESULT RIOSOCKAPI RIOSockInitialize();
// Cleans up resources allocated by RIOSockInitialize.
void RIOSOCKAPI RIOSockUninitialize();
//
// Operations
//
// Creates a socket that bound to a local loop-back for use with RIO.
SOCKET RIOSOCKAPI CreateRIOSocket(
_Out_ SOCKADDR *localAddr,
_Inout_ int *addrLen
);
// Posts a receive operation to receives data on a connected RIO socket.
BOOL RIOSOCKAPI PostRIOReceive(
_In_ RIO_RQ socketQueue,
_In_ PRIO_BUF pData,
_In_ ULONG dataBufferCount,
_In_ DWORD flags,
_In_ PVOID requestContext
);
// Posts a send operation to send data on a connected RIO socket.
BOOL RIOSOCKAPI PostRIOSend(
_In_ RIO_RQ socketQueue,
_In_ PRIO_BUF pData,
_In_ ULONG dataBufferCount,
_In_ DWORD flags,
_In_ PVOID requestContext
);
// Registers the method to use for notification behavior.
BOOL RIOSOCKAPI RegisterRIONotify();
// Registers a specified buffer for use with RIO Socket
RIO_BUFFERID RIOSOCKAPI RegisterRIOBuffer(
_In_ PCHAR dataBuffer,
_In_ DWORD dataLength
);
// Deregisters a registered buffer used with RIO socket.
void RIOSOCKAPI DeregisterRIOBuffer(
_In_ RIO_BUFFERID bufferId
);
// Makes rooms in the completion queue for a new IO.
BOOL RIOSOCKAPI AllocateRIOCompletion(
_In_ DWORD numCompltetion
);
// Release rooms in the completion queue
BOOL RIOSOCKAPI ReleaseRIOCompletion(
_In_ DWORD numCompletion
);
// Dequeues RIO results from the I/O completion queue used with RIO socket.
DWORD RIOSOCKAPI DequeueRIOResults(
_Out_ PRIORESULT rioResults,
_In_ DWORD rioResultSize
);
// Creates a request queue
RIO_RQ RIOSOCKAPI CreateRIORequestQueue(
_In_ SOCKET socket,
_In_ ULONG maxOutstandingReceive,
_In_ ULONG maxOutstandingSend,
_In_ PVOID socketContext
);
// Resizes a request queue
BOOL RIOSOCKAPI ResizeRIORequestQueue(
_In_ RIO_RQ rq,
_In_ DWORD maxOutstandingReceive,
_In_ DWORD maxOutstandingSend
);
// Dequeues an IO completion packet
BOOL RIOSOCKAPI GetRIOCompletionStatus();
#ifdef __cplusplus
}
#endif

Просмотреть файл

@ -0,0 +1,91 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{9572BB3A-F6C5-4E02-BEB7-282E53F5E75C}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>Riosock</RootNamespace>
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="Shared">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>RIOSOCK_DLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>NDEBUG;_WINDOWS;_USRDLL;RIOSOCK_DLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="Locks.h" />
<ClInclude Include="Riosock.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="Riosock.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -11,6 +11,7 @@
<AssemblyName>Microsoft.Spark.CSharp.Adapter</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<CppDll Condition="Exists('..\..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -23,6 +24,7 @@
<WarningLevel>4</WarningLevel>
<Prefer32Bit>false</Prefer32Bit>
<DocumentationFile>..\documentation\Microsoft.Spark.CSharp.Adapter.Doc.XML</DocumentationFile>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -84,8 +86,17 @@
<Compile Include="Interop\Ipc\JvmObjectReference.cs" />
<Compile Include="Interop\Ipc\PayloadHelper.cs" />
<Compile Include="Interop\Ipc\SerDe.cs" />
<Compile Include="Network\ByteBuf.cs" />
<Compile Include="Network\ByteBufChunk.cs" />
<Compile Include="Network\ByteBufChunkList.cs" />
<Compile Include="Network\ByteBufPool.cs" />
<Compile Include="Network\DefaultSocketWrapper.cs" />
<Compile Include="Network\ISocketWrapper.cs" />
<Compile Include="Network\RioNative.cs" />
<Compile Include="Network\RioSocketWrapper.cs" />
<Compile Include="Network\SaeaSocketWrapper.cs" />
<Compile Include="Network\SocketStream.cs" />
<Compile Include="Network\SockDataToken.cs" />
<Compile Include="Network\SocketFactory.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Proxy\IDataFrameNaFunctionsProxy.cs" />
@ -141,7 +152,18 @@
<Compile Include="Streaming\StreamingContext.cs" />
<Compile Include="Streaming\TransformedDStream.cs" />
</ItemGroup>
<ItemGroup />
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
<TargetPath>Riosock.dll</TargetPath>
</ContentWithTargetPath>
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
<TargetPath>Riosock.pdb</TargetPath>
</ContentWithTargetPath>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>

Просмотреть файл

@ -20,6 +20,7 @@ namespace Microsoft.Spark.CSharp.Configuration
public const string ProcFileName = "CSharpWorker.exe";
public const string CSharpWorkerPathSettingKey = "CSharpWorkerPath";
public const string CSharpBackendPortNumberSettingKey = "CSharpBackendPortNumber";
public const string CSharpSocketTypeEnvName = "spark.mobius.CSharp.socketType";
public const string SPARKCLR_HOME = "SPARKCLR_HOME";
public const string SPARK_MASTER = "spark.master";
public const string CSHARPBACKEND_PORT = "CSHARPBACKEND_PORT";
@ -169,9 +170,11 @@ namespace Microsoft.Spark.CSharp.Configuration
private class SparkCLRDebugConfiguration : SparkCLRLocalConfiguration
{
private readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(SparkCLRDebugConfiguration));
internal SparkCLRDebugConfiguration(System.Configuration.Configuration configuration)
: base(configuration)
{}
{
}
internal override int GetPortNumber()
{

Просмотреть файл

@ -1,12 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.Spark.CSharp.Configuration
{
/// <summary>

Просмотреть файл

@ -44,7 +44,7 @@ namespace Microsoft.Spark.CSharp.Interop.Ipc
/// <summary>
/// adaptively control the number of weak objects that should be checked for each interval
/// <summary>
/// </summary>
internal class WeakReferenceCheckCountController
{
private static readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(WeakReferenceCheckCountController));

Просмотреть файл

@ -0,0 +1,401 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Runtime.InteropServices;
using System.Security;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// ByteBuf delimits a section of a ByteBufChunk.
/// It is the smallest unit to be allocated.
/// </summary>
internal class ByteBuf
{
private int readerIndex;
private int writerIndex;
/// <summary>
/// Indicates the state of that the ByteBuf as socket data transport.
/// </summary>
public int Status;
/// <summary>
/// We borrow some ideas from Netty's ByteBuf.
/// ByteBuf provides two pointer variables to support sequential read and write operations
/// - readerIndex for a read operation and writerIndex for a write operation respectively.
/// The following diagram shows how a buffer is segmented into three areas by the two
/// pointers:
///
/// +-------------------+------------------+------------------+
/// | discardable bytes | readable bytes | writable bytes |
/// | | (CONTENT) | |
/// +-------------------+------------------+------------------+
/// | | | |
/// 0 == readerIndex == writerIndex == capacity
/// </summary>
internal ByteBuf(ByteBufChunk chunk, int offset, int capacity)
{
if (offset < 0)
throw new ArgumentOutOfRangeException("offset", "Offset is less than zero.");
if (capacity < 0)
throw new ArgumentOutOfRangeException("capacity", "Count is less than zero.");
if (chunk == null)
throw new ArgumentNullException("chunk");
if (chunk.Size - offset < capacity)
throw new ArgumentException(
"Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection.");
Capacity = capacity;
Offset = offset;
ByteBufChunk = chunk;
readerIndex = writerIndex = 0;
Status = 0;
}
private ByteBuf(int errorStatus)
{
Status = errorStatus;
Capacity = 0;
Offset = 0;
ByteBufChunk = null;
readerIndex = writerIndex = 0;
}
/// <summary>
/// Gets the underlying array.
/// </summary>
public byte[] Array { get { return ByteBufChunk.Array; } }
/// <summary>
/// Gets the total number of elements in the range delimited by the ByteBuf.
/// </summary>
public int Capacity { get; private set; }
/// <summary>
/// Gets the position of the first element in the range delimited
/// by the ByteBuf, relative to the start of the original array.
/// </summary>
public int Offset { get; private set; }
/// <summary>
/// Returns the ByteBuf chunk that contains this ByteBuf.
/// </summary>
internal ByteBufChunk ByteBufChunk { get; private set; }
/// <summary>
/// Returns the number of readable bytes which is equal to (writerIndex - readerIndex).
/// </summary>
public int ReadableBytes { get { return WriterIndex - ReaderIndex; } }
/// <summary>
/// Returns the number of writable bytes which is equal to (capacity - writerIndex).
/// </summary>
public int WritableBytes { get { return Capacity - WriterIndex; } }
/// <summary>
/// Gets the underlying unsafe array.
/// </summary>
public IntPtr UnsafeArray { get { return ByteBufChunk.UnsafeArray; } }
/// <summary>
/// Gets or sets the readerIndex of this ByteBuf
/// </summary>
/// <exception cref="IndexOutOfRangeException"></exception>
public int ReaderIndex
{
get { return readerIndex; }
set
{
if (value < 0 || value > WriterIndex)
{
throw new IndexOutOfRangeException(string.Format(
"ReaderIndex: {0} (expected: 0 <= readerIndex <= writerIndex({1})", value, writerIndex));
}
readerIndex = value;
}
}
/// <summary>
/// Gets or sets the writerIndex of this ByteBuf
/// </summary>
/// <exception cref="IndexOutOfRangeException"></exception>
public int WriterIndex
{
get { return writerIndex; }
set
{
if (value < ReaderIndex || value > Capacity)
{
throw new IndexOutOfRangeException(string.Format(
"WriterIndex: {0} (expected: 0 <= readerIndex({1}) <= writerIndex <= capacity ({2})", value, ReaderIndex, Capacity));
}
writerIndex = value;
}
}
/// <summary>
/// Returns the position of the readerIndex element in the range delimited
/// by the ByteBuf, relative to the start of the original array.
/// </summary>
public int ReaderIndexOffset { get { return Idx(readerIndex); } }
/// <summary>
/// Returns the position of the readerIndex element in the range delimited
/// by the ByteBuf, relative to the start of the original array.
/// </summary>
public int WriterIndexOffset { get { return Idx(writerIndex); } }
/// <summary>
/// Sets the readerIndex and writerIndex of this buffer to 0.
/// </summary>
public void Clear()
{
readerIndex = writerIndex = 0;
}
/// <summary>
/// Is this ByteSegment readable if and only if the buffer contains equal or more than
/// the specified number of elements
/// </summary>
/// <param name="size">
/// The number of elements we would like to read,
/// The default value is 1 that is to check this ByteBuf has at least 1 byte can be read.
/// </param>
/// <returns>true, if it is readable; otherwise, false</returns>
public bool IsReadable(int size = 1)
{
if (ByteBufChunk == null || ByteBufChunk.IsDisposed)
{
return false;
}
return ReadableBytes >= size;
}
/// <summary>
/// Returns true if and only if the buffer has enough Capacity to accommodate size
/// additional bytes.
/// </summary>
/// <param name="size">
/// The number of additional elements we would like to write
/// The default value is 1 that is to check this ByteBuf has at least 1 byte can be wroten.
/// </param>
/// <returns>true, if this ByteSegment is writable; otherwise, false</returns>
public bool IsWritable(int size = 1)
{
if (ByteBufChunk == null || ByteBufChunk.IsDisposed)
{
return false;
}
return WritableBytes >= size;
}
/// <summary>
/// Gets a byte at the current readerIndex and increases the readerIndex by 1 in this buffer.
/// </summary>
public byte ReadByte()
{
CheckReadableBytes(1);
var b = ByteBufChunk.IsUnsafe
? Marshal.ReadByte(ByteBufChunk.UnsafeArray, ReaderIndexOffset)
: ByteBufChunk.Array[ReaderIndexOffset];
ReaderIndex += 1;
return b;
}
/// <summary>
/// Reads a block of bytes from the ByteBuf and writes the data to a buffer.
/// </summary>
/// <param name="buffer">
/// When this method returns, contains the specified byte array with the values
/// between offset and (offset + count - 1) replaced by the characters read
/// from the ByteBuf.
/// </param>
/// <param name="offset">The zero-based byte offset in buffer at which to begin storing data from the ByteBuf.</param>
/// <param name="count">The maximum number of bytes to read.</param>
/// <returns></returns>
public unsafe int ReadBytes(byte[] buffer, int offset, int count)
{
if (buffer == null)
throw new ArgumentNullException("buffer", "Buffer cannot be null.");
if (offset < 0)
throw new ArgumentOutOfRangeException("offset", "Offset is less than zero.");
if (count < 0)
throw new ArgumentOutOfRangeException("count", "Count is less than zero.");
if (buffer.Length - offset < count)
throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection.");
EnsureAccessible();
CheckReadableBytes(count);
int n = WriterIndex - ReaderIndex;
if (n > count) n = count;
if (n <= 0) return 0;
if (ByteBufChunk.IsUnsafe)
{
fixed (byte* pBuf = &buffer[offset])
{
memcpy((IntPtr)pBuf, ByteBufChunk.UnsafeArray + ReaderIndexOffset, (ulong)n);
}
}
else
{
Buffer.BlockCopy(ByteBufChunk.Array, ReaderIndexOffset, buffer, offset, n);
}
ReaderIndex += n;
return n;
}
/// <summary>
/// Release the ByteBuf back to the ByteBufPool
/// </summary>
public void Release()
{
if (ByteBufChunk == null || ByteBufChunk.IsDisposed)
{
return;
}
var byteBufPool = ByteBufChunk.Pool;
byteBufPool.Free(this);
ByteBufChunk = null;
}
/// <summary>
/// Writes a block of bytes to the ByteBuf using data read from a buffer.
/// </summary>
/// <param name="buffer">The buffer to write data from. </param>
/// <param name="offset">The zero-based byte offset in buffer at which to begin copying bytes to the ByteBuf.</param>
/// <param name="count">The maximum number of bytes to write.</param>
public unsafe void WriteBytes(byte[] buffer, int offset, int count)
{
// Check parameters
if (buffer == null)
throw new ArgumentNullException("buffer", "Buffer cannot be null.");
if (offset < 0)
throw new ArgumentOutOfRangeException("offset", "Offset is less than zero.");
if (count < 0)
throw new ArgumentOutOfRangeException("count", "Count is less than zero.");
if (buffer.Length - offset < count)
throw new ArgumentException(
"Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection.");
EnsureAccessible();
EnsureWritable(count);
if (ByteBufChunk.IsUnsafe)
{
fixed (byte* pBuf = &buffer[offset])
{
memcpy(ByteBufChunk.UnsafeArray + WriterIndexOffset, (IntPtr)pBuf, (ulong) count);
}
}
else
{
Buffer.BlockCopy(buffer, offset, ByteBufChunk.Array, WriterIndexOffset, count);
}
WriterIndex += count;
}
/// <summary>
/// Returns a RioBuf object for input (receive)
/// </summary>
/// <returns>A RioBuf object</returns>
internal RioBuf GetInputRioBuf()
{
EnsureAccessible();
if (!ByteBufChunk.IsUnsafe)
{
throw new InvalidOperationException("Managed ByteSegment does not support RioBuf.");
}
return new RioBuf(ByteBufChunk.BufId, (uint)WriterIndexOffset, (uint)WritableBytes);
}
/// <summary>
/// Returns a RioBuf object for output (send).
/// </summary>
/// <returns>A RioBuf object</returns>
internal RioBuf GetOutputRioBuf()
{
EnsureAccessible();
if (!ByteBufChunk.IsUnsafe)
{
throw new InvalidOperationException("Managed ByteSegment does not support RioBuf.");
}
return new RioBuf(ByteBufChunk.BufId, (uint)ReaderIndexOffset, (uint)ReadableBytes);
}
/// <summary>
/// Creates an empty ByteBuf with error status.
/// </summary>
internal static ByteBuf NewErrorStatusByteBuf(int errorCode)
{
return new ByteBuf(errorCode);
}
private void CheckReadableBytes(int minimumReadableBytes)
{
EnsureAccessible();
if (ReaderIndex > WriterIndex - minimumReadableBytes)
{
throw new IndexOutOfRangeException(string.Format(
"readerIndex({0}) + length({1}) exceeds writerIndex({2})", ReaderIndex, minimumReadableBytes, WriterIndex));
}
}
private void EnsureAccessible()
{
if (ByteBufChunk == null || ByteBufChunk.IsDisposed)
{
throw new ObjectDisposedException("ByteBufChunk");
}
}
private void EnsureWritable(int minWritableBytes)
{
EnsureAccessible();
if (minWritableBytes <= WritableBytes)
{
return;
}
if (minWritableBytes > Capacity - WriterIndex)
{
throw new IndexOutOfRangeException(string.Format(
"writerIndex({0}) + minWritableBytes({1}) exceeds Capacity({2})", WriterIndex, minWritableBytes, Capacity));
}
}
private int Idx(int index)
{
return Offset + index;
}
[DllImport("msvcrt.dll", EntryPoint = "memcpy", CallingConvention = CallingConvention.Cdecl, SetLastError = false)]
[SuppressUnmanagedCodeSecurity]
private static extern IntPtr memcpy(IntPtr dest, IntPtr src, ulong count);
}
[StructLayout(LayoutKind.Sequential)]
internal struct RioBuf
{
public RioBuf(IntPtr bufferId, uint offset, uint length)
{
BufferId = bufferId;
Offset = offset;
Length = length;
}
public readonly IntPtr BufferId;
public readonly uint Offset;
public uint Length;
}
}

Просмотреть файл

@ -0,0 +1,347 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// ByteBufChunk represents a memory blocks that can be allocated from
/// .Net heap (managed code) or process heap(unsafe code)
/// </summary>
internal sealed class ByteBufChunk
{
private readonly Queue<Segment> segmentQueue;
private readonly int segmentSize;
private bool disposed;
private byte[] memory;
private IntPtr unsafeMemory;
/// <summary>
/// The ByteBufChunkList that contains this ByteBufChunk
/// </summary>
public ByteBufChunkList Parent;
/// <summary>
/// The previous ByteBufChunk in linked like list
/// </summary>
public ByteBufChunk Prev;
/// <summary>
/// The next ByteBufChunk in linked like list
/// </summary>
public ByteBufChunk Next;
private ByteBufChunk(ByteBufPool pool, int segmentSize, int chunkSize)
{
Pool = pool;
FreeBytes = chunkSize;
Size = chunkSize;
this.segmentSize = segmentSize;
segmentQueue = new Queue<Segment>();
var numSegment = chunkSize / segmentSize;
for (var i = 0; i < numSegment; i++)
{
segmentQueue.Enqueue(new Segment(i * segmentSize, segmentSize));
}
}
private ByteBufChunk(ByteBufPool pool, byte[] memory, int segmentSize, int chunkSize)
: this(pool, segmentSize, chunkSize)
{
if (memory == null || chunkSize == 0)
{
throw new ArgumentNullException("memory", "Must be initialized with a valid byte array");
}
IsUnsafe = false;
this.memory = memory;
unsafeMemory = IntPtr.Zero;
BufId = IntPtr.Zero;
}
private ByteBufChunk(ByteBufPool pool, IntPtr memory, IntPtr bufId, int segmentSize, int chunkSize)
: this(pool, segmentSize, chunkSize)
{
if (memory == IntPtr.Zero || chunkSize == 0)
{
throw new ArgumentNullException("memory", "Must be initialized with a valid heap block.");
}
IsUnsafe = true;
unsafeMemory = memory;
this.memory = null;
BufId = bufId;
}
/// <summary>
/// Finalizer.
/// </summary>
~ByteBufChunk()
{
Dispose(false);
}
/// <summary>
/// Returns the underlying array that is used for managed code.
/// </summary>
public byte[] Array { get { return memory; } }
/// <summary>
/// Returns the buffer Id that registered as RIO buffer.
/// Only apply to unsafe ByteBufChunk
/// </summary>
public IntPtr BufId { get; private set; }
/// <summary>
/// Returns the unused bytes in this chunk.
/// </summary>
public int FreeBytes { get; private set; }
/// <summary>
/// Indicates whether this ByteBufChunk is disposed.
/// </summary>
public bool IsDisposed { get { return disposed; } }
/// <summary>
/// Indicates whether the underlying buffer array is a unsafe type array.
/// The unsafe array is used for PInvoke with native code.
/// </summary>
public bool IsUnsafe { get; private set; }
/// <summary>
/// Returns the ByteBufPool that this ByteBufChunk belongs to.
/// </summary>
public ByteBufPool Pool { get; private set; }
/// <summary>
/// Returns the size of the ByteBufChunk
/// </summary>
public int Size { get; private set; }
/// <summary>
/// Returns the percentage of the current usage of the chunk
/// </summary>
public int Usage
{
get
{
var bytes = FreeBytes;
if (bytes == 0)
{
return 100;
}
var freePercentage = (int)(bytes * 100L / Size);
if (freePercentage == 0)
{
return 99;
}
return 100 - freePercentage;
}
}
/// <summary>
/// Returns the IntPtr that points to beginning of the cached heap block.
/// This is used for PInvoke with native code.
/// </summary>
public IntPtr UnsafeArray { get { return unsafeMemory;} }
/// <summary>
/// Allocates a ByteBuf from this ByteChunk.
/// </summary>
/// <param name="byteBuf">The ByteBuf be allocated</param>
/// <returns>true, if succeed to allocate a ByteBuf; otherwise, false</returns>
public bool Allocate(out ByteBuf byteBuf)
{
if (segmentQueue.Count > 0)
{
var segment = segmentQueue.Dequeue();
FreeBytes -= segmentSize;
byteBuf = new ByteBuf(this, segment.Offset, segment.Count);
return true;
}
byteBuf = default(ByteBuf);
return false;
}
/// <summary>
/// Release all resources
/// </summary>
public void Dispose()
{
Dispose(true);
}
/// <summary>
/// Releases the ByteBuf back to this ByteChunk
/// </summary>
/// <param name="byteBuf">The ByteBuf to be released.</param>
public void Free(ByteBuf byteBuf)
{
segmentQueue.Enqueue(new Segment(byteBuf.Offset, byteBuf.Capacity));
FreeBytes += segmentSize;
}
/// <summary>
/// Returns a readable string for the ByteBufChunk
/// </summary>
public override string ToString()
{
return new StringBuilder()
.Append("Chunk(")
.Append(RuntimeHelpers.GetHashCode(this).ToString("X"))
.Append(": ")
.Append(Usage)
.Append("%, ")
.Append(Size - FreeBytes)
.Append('/')
.Append(Size)
.Append(')')
.ToString();
}
/// <summary>
/// Static method to create a new ByteBufChunk with given segment and chunk size.
/// If isUnsafe is true, it allocates memory from the process's heap.
/// </summary>
/// <param name="pool">The ByteBufPool that contains the new ByteChunk</param>
/// <param name="segmentSize">The segment size</param>
/// <param name="chunkSize">The chunk size to create</param>
/// <param name="isUnsafe">Indicates if it is a safe or unsafe</param>
/// <returns>The new ByteBufChunk object</returns>
[SuppressMessage("Microsoft.Security", "CA2118:ReviewSuppressUnmanagedCodeSecurityUsage")]
[SuppressUnmanagedCodeSecurity]
public static ByteBufChunk NewChunk(ByteBufPool pool, int segmentSize, int chunkSize, bool isUnsafe)
{
ByteBufChunk chunk = null;
if (!isUnsafe)
{
chunk = new ByteBufChunk(pool, new byte[chunkSize], segmentSize, chunkSize);
return chunk;
}
// allocate buffers from process heap
var token = HeapAlloc(GetProcessHeap(), 0, chunkSize);
if (token == IntPtr.Zero)
{
throw new OutOfMemoryException();
}
// register this heap buffer to RIO buffer
var bufferId = RioNative.RegisterRIOBuffer(token, (uint)chunkSize);
if (bufferId == IntPtr.Zero)
{
FreeToProcessHeap(token);
throw new Exception("Failed to register RIO buffer");
}
try
{
chunk = new ByteBufChunk(pool, token, bufferId, segmentSize, chunkSize);
token = IntPtr.Zero;
bufferId = IntPtr.Zero;
return chunk;
}
finally
{
if (chunk == null && token != IntPtr.Zero)
{
if (bufferId != IntPtr.Zero)
{
RioNative.DeregisterRIOBuffer(bufferId);
}
FreeToProcessHeap(token);
}
}
}
/// <summary>
/// Wraps HeapFree to process heap.
/// </summary>
[SuppressMessage("Microsoft.Security", "CA2118:ReviewSuppressUnmanagedCodeSecurityUsage")]
[SuppressUnmanagedCodeSecurity]
internal static void FreeToProcessHeap(IntPtr heapBlock)
{
Debug.Assert(heapBlock != IntPtr.Zero);
HeapFree(GetProcessHeap(), 0, heapBlock);
}
/// <summary>
/// Implementation of the Dispose pattern.
/// </summary>
private void Dispose(bool disposing)
{
if (disposed)
{
return;
}
if (!IsUnsafe && memory != null)
{
memory = null;
segmentQueue.Clear();
}
if (BufId != IntPtr.Zero)
{
RioNative.DeregisterRIOBuffer(BufId);
}
// If the unsafedMemory is still valid, free it.
if (unsafeMemory != IntPtr.Zero)
{
var heapBlock = unsafeMemory;
unsafeMemory = IntPtr.Zero;
FreeToProcessHeap(heapBlock);
}
if (disposing)
{
GC.SuppressFinalize(this);
}
disposed = true;
}
/// <summary>
/// Segment struct delimits a section of a byte chunk.
/// </summary>
private struct Segment
{
public Segment(int offset, int count)
{
Offset = offset;
Count = count;
}
public readonly int Count;
public readonly int Offset;
}
#region PInvoke
[SuppressUnmanagedCodeSecurity]
[DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, CharSet = CharSet.Unicode)]
private static extern IntPtr HeapAlloc(IntPtr heapHandle, int flags, int size);
[SuppressUnmanagedCodeSecurity]
[DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, CharSet = CharSet.Unicode)]
private static extern void HeapFree(IntPtr heapHandle, int flags, IntPtr freePtr);
[SuppressUnmanagedCodeSecurity]
[DllImport("Kernel32.dll", CallingConvention = CallingConvention.Winapi, CharSet = CharSet.Unicode)]
private static extern IntPtr GetProcessHeap();
#endregion
}
}

Просмотреть файл

@ -0,0 +1,213 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Diagnostics;
using System.Text;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// ByteBufChunkList class represents a simple linked like list used to store ByteBufChunk objects
/// based on its usage.
/// </summary>
internal class ByteBufChunkList
{
private readonly int maxUsage;
private readonly int minUsage;
internal readonly ByteBufChunkList NextList;
internal ByteBufChunk head;
/// <summary>
/// The previous ByteBufChunkList. This is only update once when create the linked like list
/// of ByteBufChunkList in ByteBufPool constructor.
/// </summary>
public ByteBufChunkList PrevList;
/// <summary>
/// Initializes a ByteBufChunkList instance with the next ByteBufChunkList and minUsage and maxUsage
/// </summary>
/// <param name="nextList">The next item of this ByteBufChunkList</param>
/// <param name="minUsage">The definition of minimum usage to contain ByteBufChunk</param>
/// <param name="maxUsage">The definition of maximum usage to contain ByteBufChunk</param>
public ByteBufChunkList(ByteBufChunkList nextList, int minUsage, int maxUsage)
{
NextList = nextList;
this.minUsage = minUsage;
this.maxUsage = maxUsage;
}
/// <summary>
/// Add the ByteBufChunk to this ByteBufChunkList linked-list based on ByteBufChunk's usage.
/// So it will be moved to the right ByteBufChunkList that has the correct minUsage/maxUsage.
/// </summary>
/// <param name="chunk">The ByteBufChunk to be added</param>
public void Add(ByteBufChunk chunk)
{
if (chunk.Usage >= maxUsage)
{
NextList.Add(chunk);
return;
}
AddInternal(chunk);
}
/// <summary>
/// Allocates a ByteBuf from this ByteBufChunkList if it is not empty.
/// </summary>
/// <param name="byteBuf">The allocated ByteBuf</param>
/// <returns>true, if the ByteBuf be allocated; otherwise, false.</returns>
public bool Allocate(out ByteBuf byteBuf)
{
if (head == null)
{
// This ByteBufChunkList is empty
byteBuf = default(ByteBuf);
return false;
}
for (var cur = head; ;)
{
if (!cur.Allocate(out byteBuf))
{
cur = cur.Next;
if (cur == null)
{
return false;
}
}
else
{
if (cur.Usage < maxUsage) return true;
Remove(cur);
NextList.Add(cur);
return true;
}
}
}
/// <summary>
/// Releases the segment back to its ByteBufChunk.
/// </summary>
/// <param name="chunk">The ByteBufChunk that contains the ByteBuf</param>
/// <param name="byteBuf">The ByteBuf to be released.</param>
/// <returns>
/// true, if the ByteBuf be released and NOT need to destroy the
/// ByteBufChunk (its usage is 0); otherwise, false.
/// </returns>
public bool Free(ByteBufChunk chunk, ByteBuf byteBuf)
{
chunk.Free(byteBuf);
if (chunk.Usage >= minUsage) return true;
Remove(chunk);
// Move the ByteBufChunk down the ByteBufChunkList linked-list.
return MoveInternal(chunk);
}
private bool Move(ByteBufChunk chunk)
{
if (chunk.Usage < minUsage)
{
// Move the ByteBufChunk down the ByteBufChunkList linked-list
return MoveInternal(chunk);
}
// ByteBufChunk fits into this ByteBufChunkList, adding it here.
AddInternal(chunk);
return true;
}
/// <summary>
/// Adds the ByteBufChunk to this ByteBufChunkList
/// </summary>
private void AddInternal(ByteBufChunk chunk)
{
chunk.Parent = this;
if (head == null)
{
head = chunk;
chunk.Prev = null;
chunk.Next = null;
}
else
{
chunk.Prev = null;
chunk.Next = head;
head.Prev = chunk;
head = chunk;
}
}
/// <summary>
/// Moves the ByteBufChunk down the ByteBufChunkList linked-list so it will end up in the right
/// ByteBufChunkList that has the correct minUsage/maxUsage in respect to ByteBufChunk.Usage.
/// </summary>
private bool MoveInternal(ByteBufChunk chunk)
{
if (PrevList == null)
{
// If there is no previous ByteBufChunkList so return false which result in
// having the ByteBufChunk destroyed and memory associated with the ByteBufChunk
// will be released.
Debug.Assert(chunk.Usage == 0);
return false;
}
return PrevList.Move(chunk);
}
/// <summary>
/// Remove the ByteBufChunk from this ByteBufChunkList
/// </summary>
private void Remove(ByteBufChunk chunk)
{
Debug.Assert(chunk != null);
if (chunk == head)
{
head = chunk.Next;
if (head != null)
{
head.Prev = null;
}
}
else
{
var next = chunk.Next;
chunk.Prev.Next = next;
if (next != null)
{
next.Prev = chunk.Prev;
}
}
}
/// <summary>
/// Returns a readable string for this ByteBufChunkList
/// </summary>
public override string ToString()
{
if (head == null)
{
return "none";
}
var buf = new StringBuilder();
for (var cur = head; ;)
{
buf.Append(cur);
cur = cur.Next;
if (cur == null)
{
break;
}
buf.Append(Environment.NewLine);
}
return buf.ToString();
}
}
}

Просмотреть файл

@ -0,0 +1,237 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.Spark.CSharp.Services;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// ByteBufPool class is used to manage the ByteBuf pool that allocate and free pooled memory buffer.
/// We borrows some ideas from Netty buffer memory management.
/// </summary>
internal sealed class ByteBufPool
{
/// <summary>
/// The chunk size is calculated with given segment size and chunk order.
/// The MaxChunkSize ensure ByteBuf chunk Size (Int32.MaxValue) does not overflow.
/// </summary>
private const int MaxChunkSize = (int)((int.MaxValue + 1L) / 2);
private static readonly Lazy<ByteBufPool> DefaultPool =
new Lazy<ByteBufPool>(() => new ByteBufPool(DefaultSegmentSize, DefaultChunkOrder, false));
private static readonly Lazy<ByteBufPool> DefaultUnsafePool =
new Lazy<ByteBufPool>(() => new ByteBufPool(DefaultSegmentSize, DefaultChunkOrder, true));
private readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(RioSocketWrapper));
private readonly bool isUnsafe;
private readonly ByteBufChunkList qInit, q000, q025, q050, q075, q100;
/// <summary>
/// The default segment size to delimit a byte array chunk.
/// </summary>
public const int DefaultSegmentSize = 131072; // 65536; // 64K
/// <summary>
/// The default chunk order used to calculate chunk size and ensure it is a multiple of a segment size.
/// </summary>
public const int DefaultChunkOrder = 7; // 65536 << 8 = 16 MBytes per chunk
/// <summary>
/// Initializes a new ByteBufPool instance with a given segment size and chunk order.
/// If isUnsafe is true, all memory will be allocated from process's heap by using
/// native HeapAlloc API.
/// </summary>
/// <param name="segmentSize">The size of a segment</param>
/// <param name="chunkOrder">Used to caculate chunk size and ensures it is a multiple of a segment size</param>
/// <param name="isUnsafe">Indicates whether allocates memory from process's heap</param>
public ByteBufPool(int segmentSize, int chunkOrder, bool isUnsafe)
{
SegmentSize = segmentSize;
ChunkSize = ValidateAndCalculateChunkSize(segmentSize, chunkOrder);
this.isUnsafe = isUnsafe;
q100 = new ByteBufChunkList(null, 100, int.MaxValue); // No maxUsage for this list. All 100% usage of chunks will be in here.
q075 = new ByteBufChunkList(q100, 75, 100);
q050 = new ByteBufChunkList(q075, 50, 100);
q025 = new ByteBufChunkList(q050, 25, 75);
q000 = new ByteBufChunkList(q025, 1, 50);
qInit = new ByteBufChunkList(q000, int.MinValue, 25); // No minUsage for this list. All new chunks will be inserted to here.
q100.PrevList = q075;
q075.PrevList = q050;
q050.PrevList = q025;
q025.PrevList = q000;
q000.PrevList = null;
qInit.PrevList = qInit;
}
/// <summary>
/// Gets the default byte buffer pool instance for managed memory.
/// </summary>
/// <remarks>
/// You should only be using this instance to allocate/deallocate the managed memory arena
/// if you don't want to manage memory arena on your own.
/// </remarks>
public static ByteBufPool Default
{
get { return DefaultPool.Value; }
}
/// <summary>
/// Gets the default byte arena instance for unmanaged memory.
/// </summary>
/// <remarks>
/// You should only be using this instance to allocate/deallocate the unmanaged memory arena
/// if you don't want to manage memory arena on your own.
/// </remarks>
public static ByteBufPool UnsafeDefault
{
get { return DefaultUnsafePool.Value; }
}
/// <summary>
/// Returns the size of a ByteBuf in this ByteBufPool
/// </summary>
public int SegmentSize { get; private set; }
/// <summary>
/// Returns the size of a ByteChunk in this ByteBufPool
/// </summary>
public int ChunkSize { get; private set; }
/// <summary>
/// Allocates a ByteBuf from this ByteBufPool to use.
/// </summary>
/// <returns>A ByteBuf contained in this ByteBufPool</returns>
[MethodImpl(MethodImplOptions.Synchronized)]
public ByteBuf Allocate()
{
ByteBuf byteBuf;
if (q050.Allocate(out byteBuf) || q025.Allocate(out byteBuf) ||
q000.Allocate(out byteBuf) || qInit.Allocate(out byteBuf) ||
q075.Allocate(out byteBuf))
{
return byteBuf;
}
// Add a new chunk and allocate a segment from it.
var chunk = ByteBufChunk.NewChunk(this, SegmentSize, ChunkSize, isUnsafe);
if (!chunk.Allocate(out byteBuf))
{
logger.LogError("Failed to allocate a ByteBuf from a new ByteBufChunk. {0}", chunk);
return null;
}
qInit.Add(chunk);
return byteBuf;
}
/// <summary>
/// Deallocates a ByteBuf back to this ByteBufPool.
/// </summary>
/// <param name="byteBuf">The ByteBuf to be release.</param>
public void Free(ByteBuf byteBuf)
{
if (byteBuf.ByteBufChunk == null || byteBuf.Capacity == 0 || byteBuf.ByteBufChunk.Size < byteBuf.Offset + byteBuf.Capacity)
{
throw new Exception("Attempt to free invalid byteBuf");
}
if (byteBuf.Capacity != SegmentSize)
{
throw new ArgumentException("Segment was not from the same byte arena", "byteBuf");
}
bool mustDestroyChunk;
var chunk = byteBuf.ByteBufChunk;
lock (this)
{
mustDestroyChunk = !chunk.Parent.Free(chunk, byteBuf);
}
if (!mustDestroyChunk) return;
// Destroy chunk not need to be called while holding the synchronized lock.
chunk.Parent = null;
chunk.Dispose();
}
/// <summary>
/// Gets a readable string for this ByteBufPool
/// </summary>
[MethodImpl(MethodImplOptions.Synchronized)]
public override string ToString()
{
StringBuilder buf = new StringBuilder()
.Append("Chunk(s) at 0~25%:")
.Append(Environment.NewLine)
.Append(qInit)
.Append(Environment.NewLine)
.Append("Chunk(s) at 0~50%:")
.Append(Environment.NewLine)
.Append(q000)
.Append(Environment.NewLine)
.Append("Chunk(s) at 25~75%:")
.Append(Environment.NewLine)
.Append(q025)
.Append(Environment.NewLine)
.Append("Chunk(s) at 50~100%:")
.Append(Environment.NewLine)
.Append(q050)
.Append(Environment.NewLine)
.Append("Chunk(s) at 75~100%:")
.Append(Environment.NewLine)
.Append(q075)
.Append(Environment.NewLine)
.Append("Chunk(s) at 100%:")
.Append(Environment.NewLine)
.Append(q100)
.Append(Environment.NewLine);
return buf.ToString();
}
/// <summary>
/// Returns the chunk numbers in each queue.
/// </summary>
[MethodImpl(MethodImplOptions.Synchronized)]
public int[] GetUsages()
{
var qUsage = new int[6];
var qIndex = 0;
for (var q = qInit; q != null; q = q.NextList, qIndex++)
{
int count = 0;
for (var cur = q.head; cur != null; cur = cur.Next, count++) {}
qUsage[qIndex] = count;
}
return qUsage;
}
private static int ValidateAndCalculateChunkSize(int segmentSize, int chunkOrder)
{
if (chunkOrder > 14)
{
throw new ArgumentException(string.Format(
"chunkOrder: {0} (expected: 0-14)", chunkOrder));
}
// Ensure the resulting chunkSize does not overflow.
var chunkSize = segmentSize;
for (var i = chunkOrder; i > 0; i--)
{
if (chunkSize > MaxChunkSize >> 1)
{
throw new ArgumentException(string.Format(
"segmentSize ({0}) << chunkOrder ({1}) must not exceed {2}", segmentSize, chunkOrder, MaxChunkSize));
}
chunkSize <<= 1;
}
return chunkSize;
}
}
}

Просмотреть файл

@ -11,7 +11,7 @@ namespace Microsoft.Spark.CSharp.Network
/// <summary>
/// A simple wrapper of System.Net.Sockets.Socket class.
/// </summary>
public class DefaultSocketWrapper : ISocketWrapper
internal class DefaultSocketWrapper : ISocketWrapper
{
private readonly Socket innerSocket;
@ -83,11 +83,33 @@ namespace Microsoft.Spark.CSharp.Network
/// Starts listening for incoming connections requests
/// </summary>
/// <param name="backlog">The maximum length of the pending connections queue. </param>
public void Listen(int backlog = (int)SocketOptionName.MaxConnections)
public void Listen(int backlog = 16)
{
innerSocket.Listen(backlog);
}
/// <summary>
/// Receives network data from this socket, and returns a ByteBuf that contains the received data.
///
/// The DefaultSocketWrapper does not support this function.
/// </summary>
/// <returns>A ByteBuf object that contains received data.</returns>
public ByteBuf Receive()
{
throw new NotImplementedException();
}
/// <summary>
/// Sends data to this socket with a ByteBuf object that contains data to be sent.
///
/// The DefaultSocketWrapper does not support this function.
/// </summary>
/// <param name="data">A ByteBuf object that contains data to be sent</param>
public void Send(ByteBuf data)
{
throw new NotImplementedException();
}
/// <summary>
/// Disposes the resources used by this instance of the DefaultSocket class.
/// </summary>
@ -116,15 +138,19 @@ namespace Microsoft.Spark.CSharp.Network
Dispose(false);
}
/// <summary>
/// Indicates whether there are data that has been received from the network and is available to be read.
/// </summary>
public bool HasData { get { return innerSocket.Available > 0; } }
/// <summary>
/// Returns the local endpoint.
/// </summary>
public EndPoint LocalEndPoint
{
get
{
return innerSocket.LocalEndPoint;
}
}
public EndPoint LocalEndPoint { get { return innerSocket.LocalEndPoint; } }
/// <summary>
/// Returns the remote endpoint if it has one.
/// </summary>
public EndPoint RemoteEndPoint { get { return innerSocket.RemoteEndPoint; } }
}
}

Просмотреть файл

@ -4,7 +4,6 @@
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
namespace Microsoft.Spark.CSharp.Network
{
@ -12,7 +11,7 @@ namespace Microsoft.Spark.CSharp.Network
/// ISocketWrapper interface defines the common methods to operate a socket (traditional socket or
/// Windows Registered IO socket)
/// </summary>
public interface ISocketWrapper : IDisposable
internal interface ISocketWrapper : IDisposable
{
/// <summary>
/// Accepts a incoming connection request.
@ -42,11 +41,33 @@ namespace Microsoft.Spark.CSharp.Network
/// Starts listening for incoming connections requests
/// </summary>
/// <param name="backlog">The maximum length of the pending connections queue. </param>
void Listen(int backlog = (int)SocketOptionName.MaxConnections);
void Listen(int backlog = 16);
/// <summary>
/// Receives network data from this socket, and returns a ByteBuf that contains the received data.
/// </summary>
/// <returns>A ByteBuf object that contains received data.</returns>
ByteBuf Receive();
/// <summary>
/// Sends data to this socket with a ByteBuf object that contains data to be sent.
/// </summary>
/// <param name="data">A ByteBuf object that contains data to be sent</param>
void Send(ByteBuf data);
/// <summary>
/// Indicates whether there are data that has been received from the network and is available to be read.
/// </summary>
bool HasData { get ; }
/// <summary>
/// Returns the local endpoint.
/// </summary>
EndPoint LocalEndPoint { get; }
/// <summary>
/// Returns the remote endpoint
/// </summary>
EndPoint RemoteEndPoint { get; }
}
}

Просмотреть файл

@ -0,0 +1,340 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Concurrent;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Security;
using System.Threading;
using Microsoft.Spark.CSharp.Services;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// RioNative class imports and initializes RIOSock.dll for use with RIO socket APIs.
/// It also provided a simple thread pool that retrieves the results from IO completion port.
/// </summary>
internal class RioNative : IDisposable
{
private const int DefaultResultSize = 32; // Default RIO result size that be used to dequeue RIO results from IOCP
private static readonly Lazy<RioNative> Default = new Lazy<RioNative>(() => new RioNative());
private static readonly ILoggerService Logger = LoggerServiceFactory.GetLogger(typeof(RioNative));
private static bool useThreadPool;
private readonly ConcurrentDictionary<long, RioSocketWrapper> connectedSocks =
new ConcurrentDictionary<long, RioSocketWrapper>();
private volatile bool keepRunning = true;
private bool disposed;
private Thread[] workThreadPool;
private RioNative()
{
Init();
}
/// <summary>
/// Finalizer
/// </summary>
~RioNative()
{
Dispose(false);
}
/// <summary>
/// Release all resources.
/// </summary>
public void Dispose()
{
Dispose(true);
}
internal static int GetWorkThreadNumber()
{
return Default.Value.workThreadPool.Length;
}
/// <summary>
/// Sets whether use thread pool to query RIO socket results,
/// it must be called before calling EnsureRioLoaded()
/// </summary>
internal static void SetUseThreadPool(bool toUseThreadPool)
{
useThreadPool = toUseThreadPool;
}
/// <summary>
/// Gets the connection table that contains all connections.
/// </summary>
internal static ConcurrentDictionary<long, RioSocketWrapper> ConnectionTable
{
get { return Default.Value.connectedSocks; }
}
/// <summary>
/// Ensures that the native dll of RIO socket is loaded and initialized.
/// </summary>
internal static void EnsureRioLoaded()
{
if (Default.Value == null)
{
throw new Exception("Failed to load RIOSOCK.dll and initialize it.");
}
if (Default.Value.disposed)
{
Default.Value.Init();
}
}
/// <summary>
/// Explicitly unload the native dll of RIO socket, and release resources.
/// </summary>
internal static void UnloadRio()
{
if (!Default.IsValueCreated) return;
Default.Value.Dispose(false);
}
private void Dispose(bool disposing)
{
if (disposed) return;
keepRunning = false;
RIOSockUninitialize();
disposed = true;
if (disposing)
{
GC.SuppressFinalize(this);
}
Logger.LogDebug("Disposed RioNative instance.");
}
/// <summary>
/// Initializes RIOSock native library.
/// </summary>
private void Init()
{
// Initializes the RIOSock
var lastError = RIOSockInitialize();
if (lastError < 0)
{
Logger.LogError("RIOSockInitialize() failed with error {0}.", lastError);
Marshal.ThrowExceptionForHR(lastError);
}
// Create a thread pool for RIO socket
var maxThreads = 1;
if (useThreadPool)
{
maxThreads = Environment.ProcessorCount;
}
workThreadPool = new Thread[maxThreads];
for (var i = 0; i < workThreadPool.Length; i++)
{
var worker = new Thread(WorkThreadFunc)
{
Name = "RIOThread " + i,
IsBackground = true
};
workThreadPool[i] = worker;
worker.Start();
}
// if everything succeeds, post a Notify to catch the first set of IO
var registered = RegisterRIONotify();
if (!registered)
{
// Failed to post a NOTIFY.
var socketException = new SocketException();
Logger.LogError("RegisterRIONotify() failed with error {0}.", socketException.ErrorCode);
// Stop threads and clean up resources.
keepRunning = false;
RIOSockUninitialize();
throw socketException;
}
disposed = false;
}
private unsafe void WorkThreadFunc()
{
RioResult* results = stackalloc RioResult[DefaultResultSize];
while (keepRunning)
{
if (!GetRIOCompletionStatus())
{
var socketException = new SocketException();
Logger.LogError("GetRIOCompletionStatus() with error {0}. Error Message: {1}",
socketException.ErrorCode, socketException.Message);
//this one is not normal error. might need to debug this issue.
continue;
}
var resultCount = DequeueRIOResults((IntPtr)results, DefaultResultSize);
if (resultCount == 0 || resultCount == 0xFFFFFFFF /*RIO_CORRUPT_CQ*/)
{
// We were notified there were completions, but we can't dequeue any IO
// Something has gone horribly wrong - likely our CQ is corrupt.
Logger.LogError(
"DequeueRIOResults() returned [{0}] : expected to have dequeued IO after being signaled",
resultCount);
continue;
}
for (uint i = 0; i < resultCount; ++i)
{
var result = results[i];
RioSocketWrapper socket;
if (connectedSocks.TryGetValue(result.ConnectionId, out socket))
{
socket.IoCompleted(result.RequestId, result.Status, result.BytesTransferred);
}
else
{
if (result.Status == 0 && result.BytesTransferred == 0)
{
// Already normally removed from SocketTable.
break;
}
if (result.Status == (int)SocketError.ConnectionAborted)
{
Logger.LogDebug(
"The correlated socket [{0}] already disposed and removed from SocketTable.",
result.ConnectionId);
break;
}
var socketException = new SocketException(result.Status);
Logger.LogError("Failed to lookup socket [{0}] from SocketTable with status [{1}] and BytesTransferred [{2}] - Error Message: {3}.",
result.ConnectionId, result.Status, result.BytesTransferred, socketException.Message);
}
}
}
}
#region PInvoke
private const string RioSockDll = "RIOSock.dll";
private const string Ws2Dll = "WS2_32.dll";
//
// Private functions
//
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
private static extern int RIOSockInitialize();
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
private static extern void RIOSockUninitialize();
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
private static extern bool RegisterRIONotify();
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
private static extern bool GetRIOCompletionStatus();
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
private static extern uint DequeueRIOResults([Out] IntPtr rioResults, [In] uint rioResultSize);
//
// Internal functions
//
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern IntPtr CreateRIOSocket([In, Out] IntPtr localAddr, [In, Out] ref int addrLen);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern IntPtr CreateRIORequestQueue(
[In] IntPtr socket,
[In] uint maxOutstandingReceive,
[In] uint maxOutstandingSend,
[In] long socketCorrelation);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern unsafe bool PostRIOReceive(
[In] IntPtr socketQueue,
[In] RioBuf* pData,
[In] uint dataBufferCount,
[In] uint flags,
[In] long requestCorrelation);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern unsafe bool PostRIOSend(
[In] IntPtr socketQueue,
[In] RioBuf* pData,
[In] uint dataBufferCount,
[In] uint flags,
[In] long requestCorrelation);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern bool AllocateRIOCompletion([In] uint numCompletions);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern bool ReleaseRIOCompletion([In] uint numCompletion);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern bool ResizeRIORequestQueue(
[In] IntPtr rq,
[In] uint maxOutstandingReceive,
[In] uint maxOutstandingSend);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern IntPtr RegisterRIOBuffer([In] IntPtr dataBuffer, [In] uint dataLength);
[DllImport(RioSockDll, CharSet = CharSet.Unicode, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern void DeregisterRIOBuffer([In] IntPtr bufferId);
[DllImport(Ws2Dll, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern IntPtr accept([In] IntPtr s, [In, Out] IntPtr addr, [In, Out] ref int addrlen);
[DllImport(Ws2Dll, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern int connect([In] IntPtr s, [In] byte[] addr, [In] int addrlen);
[DllImport(Ws2Dll, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern int closesocket([In] IntPtr s);
[DllImport(Ws2Dll, SetLastError = true)]
[SuppressUnmanagedCodeSecurity]
internal static extern int listen([In] IntPtr s, [In] int backlog);
#endregion
}
/// <summary>
/// The RioResult structure contains data used to indicate request completion results used with RIO socket
/// </summary>
[StructLayout(LayoutKind.Sequential)]
internal struct RioResult
{
public int Status;
public uint BytesTransferred;
public long ConnectionId;
public long RequestId;
}
}

Просмотреть файл

@ -0,0 +1,651 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Microsoft.Spark.CSharp.Services;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// RioSocketWrapper class is a wrapper of a socket that use Windows RIO socket with IO
/// completion ports to implement socket operations.
/// </summary>
internal class RioSocketWrapper : ISocketWrapper
{
private const int Ipv4AddressSize = 16; // Default buffer size for IP v4 address
private const int MaxDataCacheSize = 4096; // The max size of data caching in the queue.
private static readonly int InitialCqRoom = 2 * rioRqGrowthFactor; // initial room allocated from completion queue.
internal static int rioRqGrowthFactor = 2; // Growth factor used to grow the RIO request queue.
private readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(RioSocketWrapper));
private readonly BlockingCollection<ByteBuf> receivedDataQueue =
new BlockingCollection<ByteBuf>(new ConcurrentQueue<ByteBuf>(), MaxDataCacheSize);
private readonly ConcurrentDictionary<long, RequestContext> requestContexts =
new ConcurrentDictionary<long, RequestContext>();
private readonly BlockingCollection<int> sendStatusQueue =
new BlockingCollection<int>(new ConcurrentQueue<int>(), MaxDataCacheSize);
private long connectionId;
private bool isCleanedUp;
private bool isConnected;
private bool isListening;
private IntPtr rioRqHandle;
private uint rqReservedSize = 2 * (uint)rioRqGrowthFactor;
private uint rqUsed;
/// <summary>
/// Default ctor that creates a new instance of RioSocketWrapper class.
/// The instance binds to loop-back address with port 0.
/// </summary>
public unsafe RioSocketWrapper()
{
RioNative.EnsureRioLoaded();
// Creates a socket handle by calling native method.
var sockaddlen = Ipv4AddressSize;
var addrbuf = stackalloc byte[(sockaddlen / IntPtr.Size + 2) * IntPtr.Size];
var sockaddr = (IntPtr)addrbuf;
SockHandle = RioNative.CreateRIOSocket(sockaddr, ref sockaddlen);
if (SockHandle == new IntPtr(-1))
{
// if the native call fails we'll throw a SocketException
var socketException = new SocketException();
logger.LogError("Native CreateRIOSocket() failed with error {0}", socketException.ErrorCode);
throw socketException;
}
// Generate the local IP endpoint from the returned raw socket address data.
LocalEndPoint = CreateIpEndPoint(sockaddr);
}
/// <summary>
/// Initializes a RioSocketWrapper instance for an accepted socket.
/// </summary>
private RioSocketWrapper(IntPtr socketHandle, EndPoint localEp, EndPoint remoteEp)
{
SockHandle = socketHandle;
LocalEndPoint = localEp;
RemoteEndPoint = remoteEp;
CreateRequestQueue();
isConnected = true;
// Post a receive operation from the connected RIO socket.
DoReceive();
}
/// <summary>
/// Finalizer
/// </summary>
~RioSocketWrapper()
{
Dispose(false);
}
/// <summary>
/// Indicates whether there are data that has been received from the network and is available to be read.
/// </summary>
public bool HasData { get { return receivedDataQueue.Count > 0; } }
/// <summary>
/// Returns the local endpoint.
/// </summary>
public EndPoint LocalEndPoint { get; private set; }
/// <summary>
/// Returns the remote endpoint
/// </summary>
public EndPoint RemoteEndPoint { get; private set; }
/// <summary>
/// Returns the handle of native socket.
/// </summary>
internal IntPtr SockHandle { get; private set; }
/// <summary>
/// Accepts a incoming connection request.
/// </summary>
/// <returns>A ISocket instance used to send and receive data</returns>
public unsafe ISocketWrapper Accept()
{
EnsureAccessible();
if (!isListening)
{
throw new InvalidOperationException("You must call the Listen method before performing this operation.");
}
var sockaddrlen = Ipv4AddressSize;
var addrbuf = stackalloc byte[(sockaddrlen / IntPtr.Size + 2) * IntPtr.Size]; //sizeof DWORD
var sockaddr = (IntPtr)addrbuf;
var acceptedSockHandle = RioNative.accept(SockHandle, sockaddr, ref sockaddrlen);
if (acceptedSockHandle == new IntPtr(-1))
{
// if the native call fails we'll throw a SocketException
var socketException = new SocketException();
logger.LogError("Native accept() failed with error {0}", socketException.NativeErrorCode);
throw socketException;
}
var remoteEp = CreateIpEndPoint(sockaddr);
var socket = new RioSocketWrapper(acceptedSockHandle, LocalEndPoint, remoteEp);
logger.LogDebug("Accepted connection from {0} to {1}", socket.RemoteEndPoint, socket.LocalEndPoint);
return socket;
}
/// <summary>
/// Close the ISocket connections and releases all associated resources.
/// </summary>
public void Close()
{
Dispose(true);
}
/// <summary>
/// Establishes a connection to a remote host that is specified by an IP address and a port number
/// </summary>
/// <param name="remoteaddr">The IP address of the remote host</param>
/// <param name="port">The port number of the remote host</param>
public void Connect(IPAddress remoteaddr, int port)
{
EnsureAccessible();
var remoteEp = new IPEndPoint(remoteaddr, port);
int sockaddrlen;
var sockaddr = GetNativeSocketAddress(remoteEp, out sockaddrlen);
var errorCode = RioNative.connect(SockHandle, sockaddr, sockaddrlen);
if (errorCode != 0)
{
// if the native call fails we'll throw a SocketException
var socketException = new SocketException();
logger.LogError("Native connect() failed with error {0}", socketException.ErrorCode);
throw socketException;
}
CreateRequestQueue();
isConnected = true;
RemoteEndPoint = remoteEp;
// Post a receive operation from the connected RIO socket.
DoReceive();
}
/// <summary>
/// Releases all resources used by the current instance of the RioSocketWrapper class.
/// </summary>
public void Dispose()
{
Dispose(true);
}
/// <summary>
/// Returns a stream used to send and receive data.
/// </summary>
/// <returns>The underlying Stream instance that be used to send and receive data</returns>
public Stream GetStream()
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
return new SocketStream(this);
}
/// <summary>
/// Starts listening for incoming connections requests
/// </summary>
/// <param name="backlog">The maximum length of the pending connections queue. </param>
public void Listen(int backlog = 16)
{
EnsureAccessible();
if (isListening) return;
var errorCode = RioNative.listen(SockHandle, backlog);
if (errorCode != 0)
{
var socketException = new SocketException();
logger.LogError("Native listen() failed with error {0}", socketException.ErrorCode);
throw socketException;
}
isListening = true;
}
/// <summary>
/// Receives network data from this socket, and returns a ByteBuf that contains the received data.
/// </summary>
/// <returns>A ByteBuf object that contains received data.</returns>
public ByteBuf Receive()
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
var data = receivedDataQueue.Take();
if (data.Status == (int) SocketError.Success) return data;
// Throw exception if there is an error.
data.Release();
Dispose(true);
SocketException sockException = new SocketException(data.Status);
throw sockException;
}
/// <summary>
/// Sends data to this socket with a ByteBuf object that contains data to be sent.
/// </summary>
/// <param name="data">A ByteBuf object that contains data to be sent</param>
public void Send(ByteBuf data)
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
if (!data.IsReadable())
{
throw new ArgumentException("The parameter {0} must contain one or more elements.", "data");
}
var context = new RequestContext(SocketOperation.Send, data);
DoSend(GenerateUniqueKey(), context);
var status = sendStatusQueue.Take();
if (status == (int)SocketError.Success) return;
// throw a SocketException if theres is an error.
Dispose(true);
var socketException = new SocketException(status);
throw socketException;
}
/// <summary>
/// IO completion callback
/// </summary>
internal void IoCompleted(long requestId, int status, uint byteTransferred)
{
if (isCleanedUp) return;
ReleaseRequest();
RequestContext context;
if (!requestContexts.TryRemove(requestId, out context)) return;
switch (context.Operation)
{
case SocketOperation.Receive:
ProcessReceive(context, status, byteTransferred);
return;
case SocketOperation.Send:
ProcessSend(requestId, context, status, byteTransferred);
return;
default:
throw new InvalidOperationException("Invalid socket operation - not a receive / send operation.");
}
}
/// <summary>
/// Allocates a room form request queue for this next IO.
/// </summary>
/// <returns></returns>
[MethodImpl(MethodImplOptions.Synchronized)]
private int AllocateRequest()
{
var newRqUsed = rqUsed + 1;
if (newRqUsed > rqReservedSize)
{
var newRqReservedSize = rqReservedSize + rioRqGrowthFactor;
if (!RioNative.AllocateRIOCompletion((uint)rioRqGrowthFactor))
{
var errorCode = Marshal.GetLastWin32Error();
logger.LogError(
"Failed to allocate completions to this socket. AllocateRIOCompletion() returns error {0}",
errorCode);
return errorCode;
}
// Resize the RQ.
if (!RioNative.ResizeRIORequestQueue(rioRqHandle, (uint)newRqReservedSize >> 1, (uint)newRqReservedSize >> 1))
{
var errorCode = Marshal.GetLastWin32Error();
logger.LogError("Failed to resize the request queue. ResizeRIORequestQueue() returns error {0}",
errorCode);
RioNative.ReleaseRIOCompletion((uint)rioRqGrowthFactor);
return errorCode;
}
// since it succeeded, update reserved with the new size
rqReservedSize = (uint)newRqReservedSize;
}
// everything succeeded - update rqUsed with the new slots being used for this next IO
rqUsed = newRqUsed;
return 0;
}
/// <summary>
/// Creates a request queue for this socket operation
/// </summary>
private void CreateRequestQueue()
{
// Allocate completion from completion queue
if (!RioNative.AllocateRIOCompletion((uint)InitialCqRoom))
{
var socketException = new SocketException();
logger.LogError(
"AllocateRIOCompletion() failed to allocate completions to this socket. Returns error {0}",
socketException.ErrorCode);
throw socketException;
}
// Create the RQ for this socket.
connectionId = GenerateUniqueKey();
while (!RioNative.ConnectionTable.TryAdd(connectionId, this))
{
connectionId = GenerateUniqueKey();
}
rioRqHandle = RioNative.CreateRIORequestQueue(SockHandle, rqReservedSize >> 1, rqReservedSize >> 1, connectionId);
if (rioRqHandle != IntPtr.Zero) return;
// Error Handling
var sockException = new SocketException();
RioNative.ReleaseRIOCompletion((uint)InitialCqRoom);
RioSocketWrapper sock;
RioNative.ConnectionTable.TryRemove(connectionId, out sock);
logger.LogError("CreateRIORequestQueue() returns error {0}", sockException.ErrorCode);
throw sockException;
}
private void Dispose(bool disposing)
{
// Mark this as disposed before changing anything else.
var cleanedUp = isCleanedUp;
isCleanedUp = true;
if (!cleanedUp && disposing)
{
try
{
// Release room back to Completion Queue
RioNative.ReleaseRIOCompletion((uint)InitialCqRoom);
// Remove this socket from the connected socket table.
RioSocketWrapper socket;
RioNative.ConnectionTable.TryRemove(connectionId, out socket);
}
catch (Exception)
{
logger.LogDebug("RioNative default instance already disposed.");
}
// Remove all pending socket operations
if (!requestContexts.IsEmpty)
{
foreach (var keyValuePair in requestContexts)
{
// Release the data buffer of the pending operation
keyValuePair.Value.Data.Release();
}
requestContexts.Clear();
}
// Remove received Data from the queue, and release buffer back to byte pool.
while (receivedDataQueue.Count > 0)
{
ByteBuf data;
receivedDataQueue.TryTake(out data);
data.Release();
}
// Close the socket handle. No need to release Request Queue handle that
// will be gone once the socket handle be closed.
if (SockHandle != IntPtr.Zero)
{
RioNative.closesocket(SockHandle);
SockHandle = IntPtr.Zero;
}
GC.SuppressFinalize(this);
}
isConnected = false;
isListening = false;
}
private void EnsureAccessible()
{
if (isCleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
}
/// <summary>
/// Posts a receive operation to this socket
/// </summary>
private unsafe void DoReceive()
{
// Make a room from Request Queue of this socket for operation.
var errorStatus = AllocateRequest();
if (errorStatus != 0)
{
logger.LogError("Cannot post receive operation due to no room in Request Queue.");
receivedDataQueue.Add(ByteBuf.NewErrorStatusByteBuf(errorStatus));
return;
}
// Allocate buffer to receive incoming network data.
var dataBuffer = ByteBufPool.UnsafeDefault.Allocate();
if (dataBuffer == null)
{
logger.LogError("Failed to allocate ByteBuf at DoReceive().");
receivedDataQueue.Add(ByteBuf.NewErrorStatusByteBuf((int)SocketError.NoBufferSpaceAvailable));
return;
}
var context = new RequestContext(SocketOperation.Receive, dataBuffer);
var recvId = GenerateUniqueKey();
// Add the operation context to request table for completion callback.
while (!requestContexts.TryAdd(recvId, context))
{
// Generate another key, if the key is duplicated.
recvId = GenerateUniqueKey();
}
// Post a receive operation via native method.
var rioBuf = dataBuffer.GetInputRioBuf();
if (RioNative.PostRIOReceive(rioRqHandle, &rioBuf, 1, 0, recvId)) return;
requestContexts.TryRemove(recvId, out context);
context.Data.Release();
if (isCleanedUp)
{
logger.LogDebug("Socket is already disposed. DoReceive() do nothing.");
receivedDataQueue.Add(ByteBuf.NewErrorStatusByteBuf((int)SocketError.NetworkDown));
return;
}
// Log exception, if post receive operation failed.
var socketException = new SocketException();
logger.LogError("Failed to call DoReceive() with error code [{0}], error message: {1}",
socketException.ErrorCode, socketException.Message);
context.Data.Status = socketException.ErrorCode;
receivedDataQueue.Add(context.Data);
}
/// <summary>
/// This method is invoked by the IoCompleted method to process the receive completion.
/// </summary>
private void ProcessReceive(RequestContext context, int status, uint byteTransferred)
{
context.Data.Status = status;
var data = context.Data;
data.WriterIndex += (int)byteTransferred;
receivedDataQueue.Add(data);
if (status != (int) SocketError.Success)
{
logger.LogError("Socket receive operation failed with error {0}", status);
context.Data.Release();
return;
}
// Posts another receive operation
DoReceive();
}
/// <summary>
/// Posts a send operation to this socket.
/// </summary>
private unsafe void DoSend(long sendId, RequestContext context)
{
// Make a room from Request Queue of this socket for operation.
var errorStatus = AllocateRequest();
if ( errorStatus != 0)
{
logger.LogError("Cannot post send operation due to no room in Request Queue.");
sendStatusQueue.Add(errorStatus);
return;
}
// Add the operation context to request table for completion callback.
while (!requestContexts.TryAdd(sendId, context))
{
// Generate another key, if the key is duplicated.
sendId = GenerateUniqueKey();
}
// Post a send operation via native method.
var rioBuf = context.Data.GetOutputRioBuf();
if (RioNative.PostRIOSend(rioRqHandle, &rioBuf, 1, 0, sendId)) return;
requestContexts.TryRemove(sendId, out context);
context.Data.Release();
if (isCleanedUp)
{
logger.LogDebug("Socket is already disposed. PostSend() do nothing.");
receivedDataQueue.Add(ByteBuf.NewErrorStatusByteBuf((int)SocketError.NetworkDown));
return;
}
// Log exception, if post send operation failed.
var socketException = new SocketException();
logger.LogError("Failed to call PostRIOSend() with error code [{0}]. Error message: {1}",
socketException.ErrorCode, socketException.Message);
sendStatusQueue.Add(socketException.ErrorCode);
}
/// <summary>
/// This method is invoked by the IoCompleted method to process the send completion.
/// </summary>
private void ProcessSend(long requestId, RequestContext context, int status, uint byteTransferred)
{
sendStatusQueue.Add(status);
if (status != (int)SocketError.Success)
{
logger.LogError("Socket send operation failed with error {0}", status);
context.Data.Release();
return;
}
var data = context.Data;
data.ReaderIndex += (int)byteTransferred;
if (data.IsReadable())
{
// If some of the bytes in the message have NOT been sent,
// then we need to post another send operation.
context.Data = data;
DoSend(requestId, context);
}
else
{
// All the bytes in the data have been sent,
// release the buffer back to pool.
data.Release();
}
}
/// <summary>
/// Release room back to request queue.
/// </summary>
[MethodImpl(MethodImplOptions.Synchronized)]
private void ReleaseRequest()
{
rqUsed -= 1;
}
private static unsafe IPEndPoint CreateIpEndPoint(IntPtr addr)
{
var addrBuf = (byte*)addr;
var address = ((addrBuf[4] & 0x000000FF) |
(addrBuf[5] << 8 & 0x0000FF00) |
(addrBuf[6] << 16 & 0x00FF0000) |
(addrBuf[7] << 24)) & 0x00000000FFFFFFFF;
var ipAddr = new IPAddress(address);
var port = (addrBuf[2] << 8 & 0xFF00) | addrBuf[3];
return new IPEndPoint(ipAddr, port);
}
/// <summary>
/// Generates a unique key from a GUID.
/// </summary>
private static long GenerateUniqueKey()
{
Debug.Assert(IntPtr.Size == 8); // For x64 bits.
var buffer = Guid.NewGuid().ToByteArray();
return BitConverter.ToInt64(buffer, 0);
}
private static byte[] GetNativeSocketAddress(IPEndPoint ipEp, out int sockaddrLen)
{
var sockaddr = ipEp.Serialize();
sockaddrLen = sockaddr.Size;
var addrbuf = new byte[(sockaddrLen / IntPtr.Size + 2) * IntPtr.Size]; //sizeof DWORD
// Address Family serialization
addrbuf[0] = sockaddr[0];
addrbuf[1] = sockaddr[1];
// Port serialization
addrbuf[2] = sockaddr[2];
addrbuf[3] = sockaddr[3];
// IPv4 Address serialization
addrbuf[4] = sockaddr[4];
addrbuf[5] = sockaddr[5];
addrbuf[6] = sockaddr[6];
addrbuf[7] = sockaddr[7];
return addrbuf;
}
}
internal class RequestContext
{
public RequestContext(SocketOperation operation, ByteBuf data)
{
Operation = operation;
Data = data;
}
public SocketOperation Operation { get; private set; }
public ByteBuf Data { get; set; }
}
internal enum SocketOperation
{
None = 0,
Receive,
Send
}
}

Просмотреть файл

@ -0,0 +1,462 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Concurrent;
using System.IO;
using System.Net;
using System.Net.Sockets;
using Microsoft.Spark.CSharp.Services;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// SaeaSocketWrapper class is a wrapper of a socket that use SocketAsyncEventArgs class
/// to implement socket operations.
/// </summary>
internal class SaeaSocketWrapper : ISocketWrapper
{
private const int MaxDataCacheSize = 4096; // The max size of data caching in the queue.
private readonly ILoggerService logger = LoggerServiceFactory.GetLogger(typeof(SaeaSocketWrapper));
private readonly BlockingCollection<ByteBuf> receivedDataQueue =
new BlockingCollection<ByteBuf>(new ConcurrentQueue<ByteBuf>(), MaxDataCacheSize);
private readonly BlockingCollection<int> sendStatusQueue =
new BlockingCollection<int>(new ConcurrentQueue<int>(), MaxDataCacheSize);
private readonly ConcurrentQueue<SocketAsyncEventArgs> poolOfAcceptEvents =
new ConcurrentQueue<SocketAsyncEventArgs>();
private readonly ConcurrentQueue<SocketAsyncEventArgs> poolOfRecvSendEvents =
new ConcurrentQueue<SocketAsyncEventArgs>();
private readonly SaeaSocketWrapper parent;
private Socket innerSocket;
private bool isCleanedUp;
private bool isConnected;
/// <summary>
/// Default ctor that creates a new instance of SaeaSocketWrapper class.
/// The instance binds to loop-back address with port 0.
/// </summary>
public SaeaSocketWrapper()
{
innerSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var localEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
innerSocket.Bind(localEndPoint);
}
/// <summary>
/// Initializes a SaeaSocketWrapper instance for an accepted socket.
/// </summary>
private SaeaSocketWrapper(SaeaSocketWrapper parent, Socket acceptedSocket)
{
this.parent = parent;
innerSocket = acceptedSocket;
isConnected = true;
}
/// <summary>
/// Finalizer.
/// </summary>
~SaeaSocketWrapper()
{
Dispose(false);
}
/// <summary>
/// Indicates whether there are data that has been received from the network and is available to be read.
/// </summary>
public bool HasData
{
get
{
EnsureAccessible();
return receivedDataQueue.Count > 0;
}
}
/// <summary>
/// Returns the local endpoint.
/// </summary>
public EndPoint LocalEndPoint { get { return innerSocket.LocalEndPoint; } }
/// <summary>
/// Returns the remote endpoint
/// </summary>
public EndPoint RemoteEndPoint { get { return innerSocket.RemoteEndPoint; } }
/// <summary>
/// Accepts a incoming connection request.
/// </summary>
/// <returns>A ISocket instance used to send and receive data</returns>
public ISocketWrapper Accept()
{
var socket = innerSocket.Accept();
var clientSock = new SaeaSocketWrapper(this, socket);
logger.LogDebug("Accepted connection from {0} to {1}", clientSock.RemoteEndPoint, clientSock.LocalEndPoint);
DoReceive(clientSock);
return clientSock;
}
/// <summary>
/// Close the ISocket connections and releases all associated resources.
/// </summary>
public void Close()
{
Dispose(true);
}
/// <summary>
/// Establishes a connection to a remote host that is specified by an IP address and a port number
/// </summary>
/// <param name="remoteaddr">The IP address of the remote host</param>
/// <param name="port">The port number of the remote host</param>
public void Connect(IPAddress remoteaddr, int port)
{
var remoteEndPoint = new IPEndPoint(remoteaddr, port);
innerSocket.Connect(remoteEndPoint);
isConnected = true;
DoReceive(this);
}
/// <summary>
/// Releases all resources used by the current instance of the SaeaSocketWrapper class.
/// </summary>
public void Dispose()
{
Dispose(true);
}
/// <summary>
/// Returns a stream used to send and receive data.
/// </summary>
/// <returns>The underlying Stream instance that be used to send and receive data</returns>
public Stream GetStream()
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
return new SocketStream(this);
}
/// <summary>
/// Starts listening for incoming connections requests
/// </summary>
/// <param name="backlog">The maximum length of the pending connections queue. </param>
public void Listen(int backlog = 16)
{
innerSocket.Listen(backlog);
}
/// <summary>
/// Receives network data from this socket, and returns a ByteBuf that contains the received data.
/// </summary>
/// <returns>A ByteBuf object that contains received data.</returns>
public ByteBuf Receive()
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
var data = receivedDataQueue.Take();
if (data.Status == (int) SocketError.Success) return data;
// throw a SocketException if theres is an error.
data.Release();
Dispose(true);
var socketException = new SocketException(data.Status);
throw socketException;
}
/// <summary>
/// Sends data to this socket with a ByteBuf object that contains data to be sent.
/// </summary>
/// <param name="data">A ByteBuf object that contains data to be sent</param>
public void Send(ByteBuf data)
{
EnsureAccessible();
if (!isConnected)
{
throw new InvalidOperationException("The operation is not allowed on non-connected sockets.");
}
if (!data.IsReadable())
{
throw new ArgumentException("The parameter {0} must contain one or more elements.", "data");
}
var dataToken = new SockDataToken(this, data);
if (parent != null)
{
parent.DoSend(dataToken);
}
else
{
DoSend(dataToken);
}
var status = sendStatusQueue.Take();
if (status == (int) SocketError.Success) return;
// throw a SocketException if theres is an error.
dataToken.Reset();
Dispose(true);
var socketException = new SocketException(status);
throw socketException;
}
/// <summary>
/// Implementation of the Dispose pattern.
/// </summary>
private void Dispose(bool disposing)
{
// Mark this as disposed before changing anything else.
var cleanedUp = isCleanedUp;
isCleanedUp = true;
if (cleanedUp || !disposing) return;
if (innerSocket != null)
{
try
{
// Gracefully shut down socket first.
innerSocket.Shutdown(SocketShutdown.Both);
}
catch (Exception)
{
// Ignore exceptions from Shutdown function
}
finally
{
innerSocket.Dispose();
}
innerSocket = null;
}
SocketAsyncEventArgs eventArgs;
while (poolOfAcceptEvents.Count > 0)
{
poolOfAcceptEvents.TryDequeue(out eventArgs);
eventArgs.Dispose();
}
while (poolOfRecvSendEvents.Count > 0)
{
poolOfRecvSendEvents.TryDequeue(out eventArgs);
if (eventArgs.UserToken != null)
{
var dataToken = (SockDataToken) eventArgs.UserToken;
if (dataToken.ClientSocket != null)
{
dataToken.ClientSocket.Dispose();
}
dataToken.Reset();
}
eventArgs.Dispose();
}
while (receivedDataQueue.Count > 0)
{
ByteBuf data;
receivedDataQueue.TryTake(out data);
data.Release();
}
GC.SuppressFinalize(this);
isCleanedUp = true;
isConnected = false;
}
private void EnsureAccessible()
{
if (isCleanedUp)
{
throw new ObjectDisposedException(GetType().FullName);
}
}
/// <summary>
/// The callback is called whenever a receive or send operation completes. However,
/// it is not be called if the receive/send operation completes synchronously.
/// </summary>
private void IoCompleted(object sender, SocketAsyncEventArgs e)
{
switch (e.LastOperation)
{
case SocketAsyncOperation.Receive:
ProcessReceive(e);
break;
case SocketAsyncOperation.Send:
ProcessSend(e);
break;
default:
throw new ArgumentException(
"The last operation completed on the socket was not a receive or send");
}
}
/// <summary>
/// Post a receive operation
/// </summary>
private void DoReceive(SaeaSocketWrapper clientSocket)
{
// Prepares the SocketAsyncEventArgs for receive operation.
SocketAsyncEventArgs recvEventArg;
if (!poolOfRecvSendEvents.TryDequeue(out recvEventArg))
{
recvEventArg = new SocketAsyncEventArgs();
recvEventArg.Completed += IoCompleted;
}
// Allocate buffer for the receive operation.
var dataBuf = ByteBufPool.Default.Allocate();
recvEventArg.UserToken = new SockDataToken(clientSocket, dataBuf);
recvEventArg.SetBuffer(dataBuf.Array, dataBuf.WriterIndexOffset, dataBuf.WritableBytes);
// Post async receive operation on the socket
var socket = clientSocket.innerSocket;
recvEventArg.AcceptSocket = socket;
if (clientSocket.isCleanedUp || socket == null)
{
// do nothing if client socket already disposed.
dataBuf.Status = (int) SocketError.NetworkDown;
dataBuf.Release();
recvEventArg.UserToken = null;
clientSocket.receivedDataQueue.Add(dataBuf);
return;
}
var willRaiseEvent = socket.ReceiveAsync(recvEventArg);
if (!willRaiseEvent)
{
// The operation completed synchronously, we need to call ProcessReceive method directly
ProcessReceive(recvEventArg);
}
}
/// <summary>
/// This method is invoked by the IoCompleted method when an asynchronous receive
/// operation completes. If the remote host closed the connection, then the socket
/// is closed. Otherwise, we process the received data.
/// </summary>
private void ProcessReceive(SocketAsyncEventArgs recvEventArg)
{
var userToken = (SockDataToken) recvEventArg.UserToken;
var recvData = userToken.DetachData();
var clientSocket = userToken.ClientSocket;
//update write index according to the BytesTransferred
recvData.WriterIndex += recvEventArg.BytesTransferred;
recvData.Status = (int)recvEventArg.SocketError;
clientSocket.receivedDataQueue.Add(recvData);
if (recvEventArg.SocketError != SocketError.Success)
{
return;
}
// Recycle the EventArgs
recvEventArg.UserToken = null;
poolOfRecvSendEvents.Enqueue(recvEventArg);
// Post another receive operation
DoReceive(clientSocket);
}
/// <summary>
/// Posts a send operation with a SockDataToken which contains the data to be sent.
/// </summary>
private void DoSend(SockDataToken dataToken)
{
// Prepare the SocketAsyncEventArgs for send operation
SocketAsyncEventArgs sendEventArg;
if (!poolOfRecvSendEvents.TryDequeue(out sendEventArg))
{
sendEventArg = new SocketAsyncEventArgs();
sendEventArg.Completed += IoCompleted;
}
sendEventArg.UserToken = dataToken;
sendEventArg.AcceptSocket = dataToken.ClientSocket.innerSocket;
DoSend(sendEventArg);
}
/// <summary>
/// Posts a send operation with a send EventArgs
/// </summary>
private void DoSend(SocketAsyncEventArgs sendEventArg)
{
// Set buffer for the SocketAsyncEventArgs
var dataToken = (SockDataToken)sendEventArg.UserToken;
var data = dataToken.Data;
sendEventArg.SetBuffer(data.Array, data.ReaderIndexOffset, data.ReadableBytes);
//post asynchronous send operation
var socket = dataToken.ClientSocket.innerSocket;
sendEventArg.AcceptSocket = socket;
if (dataToken.ClientSocket.isCleanedUp || socket == null)
{
// do nothing if client socket already disposed.
var clientSocket = dataToken.ClientSocket;
dataToken.Reset();
clientSocket.sendStatusQueue.Add((int)SocketError.NetworkDown);
return;
}
var willRaiseEvent = socket.SendAsync(sendEventArg);
if (!willRaiseEvent)
{
// The operation completed synchronously, we need to call ProcessSend method directly
ProcessSend(sendEventArg);
}
}
/// <summary>
/// This method is called by IoCompleted() when an asynchronous send completes.
/// If all of the data has NOT been sent, then it calls PostSend to send more data.
/// </summary>
private void ProcessSend(SocketAsyncEventArgs sendEventArg)
{
var sendToken = (SockDataToken)sendEventArg.UserToken;
sendToken.ClientSocket.sendStatusQueue.Add((int)sendEventArg.SocketError);
if (sendEventArg.SocketError != SocketError.Success)
{
sendToken.Reset();
sendEventArg.UserToken = null;
if (parent != null)
{
parent.poolOfRecvSendEvents.Enqueue(sendEventArg);
}
return;
}
var data = sendToken.Data;
data.ReaderIndex += sendEventArg.BytesTransferred;
if (data.IsReadable())
{
// If some of the bytes in the message have NOT been sent,
// then we will need to post another send operation.
DoSend(sendEventArg);
}
else
{
// All the bytes in the message have been sent.
sendToken.Reset();
sendEventArg.UserToken = null;
if (parent != null)
{
parent.poolOfRecvSendEvents.Enqueue(sendEventArg);
return;
}
poolOfRecvSendEvents.Enqueue(sendEventArg);
}
}
}
}

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// SockDataToken class is used to associate with the SocketAsyncEventArgs object.
/// Primarily, it is a way to pass state to the event handler.
/// </summary>
internal class SockDataToken
{
/// <summary>
/// Initializes a SockDataToken instance with a client socket and a dataBuf
/// </summary>
/// <param name="clientSocket">The client socket</param>
/// <param name="dataBuf">A data buffer that holds the data</param>
public SockDataToken(SaeaSocketWrapper clientSocket, ByteBuf dataBuf)
{
ClientSocket = clientSocket;
Data = dataBuf;
}
/// <summary>
/// Reset this token
/// </summary>
public void Reset()
{
ClientSocket = null;
if (Data != null) Data.Release();
Data = null;
}
/// <summary>
/// Detach the data ownership.
/// </summary>
public ByteBuf DetachData()
{
var retData = Data;
Data = null;
return retData;
}
/// <summary>
/// Gets and sets the data
/// </summary>
public ByteBuf Data { get; private set; }
/// <summary>
/// Gets and sets the client socket.
/// </summary>
public SaeaSocketWrapper ClientSocket { get; private set; }
}
}

Просмотреть файл

@ -1,6 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.IO;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.Spark.CSharp.Configuration;
[assembly: InternalsVisibleTo("CSharpWorker")]
[assembly: InternalsVisibleTo("Tests.Common")]
[assembly: InternalsVisibleTo("AdapterTest")]
[assembly: InternalsVisibleTo("WorkerTest")]
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
@ -9,8 +19,46 @@ namespace Microsoft.Spark.CSharp.Network
/// The ISocket instance can be RioSocket object, if the configuration is set to RioSocket and
/// only the application is running on a Windows OS that supports Registered IO socket.
/// </summary>
public static class SocketFactory
internal static class SocketFactory
{
private const string RiosockDll = "Riosock.dll";
private static SocketWrapperType sockWrapperType = SocketWrapperType.None;
/// <summary>
/// Set socket wrapper type only for internal use (unit test)
/// </summary>
internal static SocketWrapperType SocketWrapperType
{
get
{
if (sockWrapperType != SocketWrapperType.None)
{
return sockWrapperType;
}
sockWrapperType = SocketWrapperType.Normal;
SocketWrapperType sockType;
if (!Enum.TryParse(Environment.GetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName), out sockType))
return sockWrapperType;
switch (sockType)
{
case SocketWrapperType.Rio:
if (IsRioSockSupported())
{
sockWrapperType = SocketWrapperType.Rio;
}
break;
case SocketWrapperType.Saea:
sockWrapperType = sockType;
break;
}
return sockWrapperType;
}
set { sockWrapperType = value; }
}
/// <summary>
/// Creates a ISocket instance based on the configuration and OS version.
/// </summary>
@ -21,7 +69,66 @@ namespace Microsoft.Spark.CSharp.Network
/// </returns>
public static ISocketWrapper CreateSocket()
{
return new DefaultSocketWrapper();
switch (SocketWrapperType)
{
case SocketWrapperType.Normal:
return new DefaultSocketWrapper();
case SocketWrapperType.Rio:
return new RioSocketWrapper();
case SocketWrapperType.Saea:
return new SaeaSocketWrapper();
default:
throw new ArgumentOutOfRangeException();
}
}
/// <summary>
/// Indicates whether current OS supports RIO socket.
/// </summary>
public static bool IsRioSockSupported()
{
// Check is running on Windows
var os = Environment.OSVersion;
var p = (int) os.Platform;
var isWindows = (p != 4) && (p != 6) && (p != 128);
if (!isWindows) return false;
// Check windows version, RIO is only supported on Win8 and above
var osVersion = os.Version;
if (osVersion.Major <= 6 && (osVersion.Major != 6 || osVersion.Minor < 2)) return false;
// Check whether Riosock.dll exists, the Riosock.dll should be in the same folder with current assembly.
var localDir = Path.GetDirectoryName(new Uri(Assembly.GetExecutingAssembly().CodeBase).LocalPath) ?? ".";
return File.Exists(Path.Combine(localDir, RiosockDll));
}
}
/// <summary>
/// SocketWrapperType defines the socket wrapper type be used in transport.
/// </summary>
internal enum SocketWrapperType
{
/// <summary>
/// None
/// </summary>
None,
/// <summary>
/// Indicates CSharp code use System.Net.Sockets.Socket as transport
/// </summary>
Normal,
/// <summary>
/// Indicates CSharp code use Windows RIO socket as transport
/// </summary>
Rio,
/// <summary>
/// Indicates CSharp code use System.Net.Sockets.Socket with SocketAsyncEventArgs as transport
/// </summary>
Saea
}
}

Просмотреть файл

@ -0,0 +1,184 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Diagnostics;
using System.IO;
namespace Microsoft.Spark.CSharp.Network
{
/// <summary>
/// Provides the underlying stream of data for network access.
/// Just like a NetworkStream.
/// </summary>
internal class SocketStream: Stream
{
private readonly ByteBufPool bufPool;
private readonly ISocketWrapper streamSocket;
private ByteBuf recvDataCache;
/// <summary>
/// Initializes a SocketStream with a SaeaSocketWrapper object.
/// </summary>
/// <param name="socket">a SaeaSocketWrapper object</param>
public SocketStream(SaeaSocketWrapper socket)
{
if (socket == null)
{
throw new ArgumentNullException("socket");
}
streamSocket = socket;
bufPool = ByteBufPool.Default;
}
/// <summary>
/// Initializes a SocketStream with a RioSocketWrapper object.
/// </summary>
/// <param name="socket">a RioSocketWrapper object</param>
public SocketStream(RioSocketWrapper socket)
{
if (socket == null)
{
throw new ArgumentNullException("socket");
}
streamSocket = socket;
bufPool = ByteBufPool.UnsafeDefault;
}
/// <summary>
/// Indicates that data can be read from the stream.
/// This property always returns <see langword='true'/>
/// </summary>
public override bool CanRead { get { return true; } }
/// <summary>
/// Indicates that the stream can seek a specific location in the stream.
/// This property always returns <see langword='false'/>
/// </summary>
public override bool CanSeek { get { return false; } }
/// <summary>
/// Indicates that data can be written to the stream.
/// This property always returns <see langword='true'/>
/// </summary>
public override bool CanWrite { get { return true; } }
/// <summary>
/// The length of data available on the stream.
/// Always throws <see cref='NotSupportedException'/>.
/// </summary>
public override long Length { get{ throw new NotSupportedException("This stream does not support seek operations."); } }
/// <summary>
/// Gets or sets the position in the stream.
/// Always throws <see cref='NotSupportedException'/>.
/// </summary>
public override long Position
{
get
{
throw new NotSupportedException("This stream does not support seek operations.");
}
set
{
throw new NotSupportedException("This stream does not support seek operations.");
}
}
/// <summary>
/// Flushes data from the stream. This is meaningless for us, so it does nothing.
/// </summary>
public override void Flush()
{
}
/// <summary>
/// Seeks a specific position in the stream. This method is not supported
/// by the SocketDataStream class.
/// </summary>
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException("This stream does not support seek operations.");
}
/// <summary>
/// Sets the length of the stream. This method is not supported by the SocketDataStream class.
/// </summary>
public override void SetLength(long value)
{
throw new NotSupportedException("This stream does not support seek operations.");
}
/// <summary>
/// Reads a byte from the stream and advances the position within the stream by one byte, or returns -1 if at the end of the stream.
/// </summary>
/// <returns>
/// The unsigned byte cast to an Int32, or -1 if at the end of the stream.
/// </returns>
public override int ReadByte()
{
if (!recvDataCache.IsReadable())
{
recvDataCache = streamSocket.Receive();
}
return recvDataCache.ReadByte();
}
/// <summary>
/// Reads data from the stream.
/// </summary>
/// <param name="buffer">Buffer to read into.</param>
/// <param name="offset">Offset into the buffer where we're to read.</param>
/// <param name="count">Number of bytes to read.</param>
/// <returns>Number of bytes we read.</returns>
public override int Read(byte[] buffer, int offset, int count)
{
int bytesRemaining = count;
if (recvDataCache == null || !recvDataCache.IsReadable())
{
recvDataCache = streamSocket.Receive();
}
while (recvDataCache.IsReadable() && bytesRemaining > 0)
{
var bytesToRead = Math.Min(bytesRemaining, recvDataCache.ReadableBytes);
var n = recvDataCache.ReadBytes(buffer, offset + count - bytesRemaining, bytesToRead);
if (!recvDataCache.IsReadable())
{
recvDataCache.Release();
if (streamSocket.HasData)
{
recvDataCache = streamSocket.Receive();
}
}
bytesRemaining -= n;
}
return count - bytesRemaining;
}
/// <summary>
/// Writes data to the stream.
/// </summary>
/// <param name="buffer">Buffer to write from.</param>
/// <param name="offset">Offset into the buffer from where we'll start writing.</param>
/// <param name="count">Number of bytes to write.</param>
public override void Write(byte[] buffer, int offset, int count)
{
var remainingBytes = count;
while (0 < remainingBytes)
{
var sendBuffer = bufPool.Allocate();
var sendCount = Math.Min(sendBuffer.WritableBytes, remainingBytes);
sendBuffer.WriteBytes(buffer, offset, sendCount);
streamSocket.Send(sendBuffer);
remainingBytes -= sendCount;
}
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -16,6 +16,7 @@
<ReferencePath>$(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages</ReferencePath>
<IsCodedUITest>False</IsCodedUITest>
<TestProjectType>UnitTest</TestProjectType>
<CppDll Condition="Exists('..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<DebugSymbols>true</DebugSymbols>
@ -65,6 +66,10 @@
<Otherwise />
</Choose>
<ItemGroup>
<Compile Include="ByteBufChunkListTest.cs" />
<Compile Include="ByteBufChunkTest.cs" />
<Compile Include="ByteBufPoolTest.cs" />
<Compile Include="ByteBufTest.cs" />
<Compile Include="ColumnTest.cs" />
<Compile Include="AccumulatorTest.cs" />
<Compile Include="BroadcastTest.cs" />
@ -81,6 +86,8 @@
<Compile Include="PayloadHelperTest.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="RowTest.cs" />
<Compile Include="SocketStreamTest.cs" />
<Compile Include="SocketWrapperTest.cs" />
<Compile Include="SerDeTest.cs" />
<Compile Include="HiveContextTest.cs" />
<Compile Include="StatusTrackerTest.cs" />
@ -89,9 +96,7 @@
<Compile Include="DStreamTest.cs" />
<Compile Include="Mocks\MockDStreamProxy.cs" />
<Compile Include="Mocks\MockStreamingContextProxy.cs" />
<Compile Include="SparkCLRTestEnvironment.cs">
<SubType>Code</SubType>
</Compile>
<Compile Include="SparkCLRTestEnvironment.cs" />
<Compile Include="DataFrameTest.cs" />
<Compile Include="Mocks\MockSparkCLRProxy.cs" />
<Compile Include="Mocks\MockConfigurationService.cs" />
@ -121,7 +126,18 @@
<Name>Tests.Common</Name>
</ProjectReference>
</ItemGroup>
<ItemGroup />
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
<TargetPath>Riosock.dll</TargetPath>
</ContentWithTargetPath>
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
<TargetPath>Riosock.pdb</TargetPath>
</ContentWithTargetPath>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>

Просмотреть файл

@ -0,0 +1,84 @@
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
[TestFixture]
public class ByteBufChunkListTest
{
private ByteBufPool bufPool;
[OneTimeSetUp]
public void CreatePooledBuffer()
{
bufPool = new ByteBufPool(1024, 2, false);
}
[Test]
public void TestAllocateAndFree()
{
var q50 = new ByteBufChunkList(null, 50, 200);
var q00 = new ByteBufChunkList(q50, 1, 50);
q00.PrevList = null;
q50.PrevList = q00;
// This chunk can only allocate 4 ByteBufs totally
var chunk1 = ByteBufChunk.NewChunk(bufPool, bufPool.SegmentSize, bufPool.ChunkSize, false);
var chunk2 = ByteBufChunk.NewChunk(bufPool, bufPool.SegmentSize, bufPool.ChunkSize, false);
var chunk3 = ByteBufChunk.NewChunk(bufPool, bufPool.SegmentSize, bufPool.ChunkSize, false);
//
// Verify Allocate()
//
q00.Add(chunk1);
// Verify the chunk is hosted in q50.
Assert.AreEqual(chunk1.ToString(), q00.ToString());
// Verify allocation success
ByteBuf byteBuf1;
Assert.IsTrue(q00.Allocate(out byteBuf1));
// Verify the chunk be moved to next list.
ByteBuf byteBuf2;
Assert.IsTrue(q00.Allocate(out byteBuf2));
Assert.AreEqual(chunk1.ToString(), q50.ToString()); // chunk1 is in q50 now
// Allocates from q50
ByteBuf byteBuf3;
Assert.IsTrue(q50.Allocate(out byteBuf3));
ByteBuf byteBuf4;
Assert.IsTrue(q50.Allocate(out byteBuf4));
// Verify failed allocation due to chunk 100% usage
ByteBuf byteBuf5;
Assert.AreEqual(100, chunk1.Usage);
Assert.IsFalse(q50.Allocate(out byteBuf5));
// Verify allocation success from chunk2
q50.Add(chunk2);
Assert.IsTrue(q50.Allocate(out byteBuf5));
Assert.AreEqual(chunk2, byteBuf5.ByteBufChunk);
//
// Verify Free()
//
// Build chunk1's chain for Free()
chunk1.Next = chunk3;
chunk3.Prev = chunk1;
// Verify Free() returns false due to chunk2's usage is 0
// that means the chunk need to be destroyed.
Assert.IsFalse(q50.Free(chunk2, byteBuf5));
q50.Add(chunk2);
// Free chunk1 from q50 now.
Assert.IsTrue(q50.Free(chunk1, byteBuf4));
Assert.IsTrue(q50.Free(chunk1, byteBuf3));
Assert.IsTrue(q50.Free(chunk1, byteBuf2));
Assert.AreEqual(chunk1.ToString(), q00.ToString()); // chunk1 is in q00 now
Assert.AreSame(q00, chunk1.Parent);
Assert.AreSame(chunk3, chunk2.Next);
// Free chunk1 from q00 now.
Assert.IsFalse(q00.Free(chunk1, byteBuf1));
Assert.AreEqual("none", q00.ToString()); // q00 is empty now
}
}
}

Просмотреть файл

@ -0,0 +1,86 @@
using System;
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
[TestFixture]
public class ByteBufChunkTest
{
private ByteBufPool managedBufPool;
private ByteBufPool unsafeBufPool;
[OneTimeSetUp]
public void CreatePooledBuffer()
{
managedBufPool = new ByteBufPool(1024, 2, false);
unsafeBufPool = new ByteBufPool(1024, 2, true);
}
[Test]
public void TestInvalidByteChunk()
{
// Input chunk size with 0
Assert.Throws<ArgumentNullException>(() => ByteBufChunk.NewChunk(managedBufPool, managedBufPool.SegmentSize, 0, false));
if (!SocketFactory.IsRioSockSupported()) return;
// Input chunk size with 0
Assert.Throws<ArgumentNullException>(() => ByteBufChunk.NewChunk(unsafeBufPool, unsafeBufPool.SegmentSize, 0, true));
// Input chunk size with negative value that caused HeapAlloc failed.
Assert.Throws<OutOfMemoryException>(() => ByteBufChunk.NewChunk(unsafeBufPool, unsafeBufPool.SegmentSize, -1, true));
}
private void AllocateFreeBufChunkTest(ByteBufChunk chunk)
{
// Verify allocation
ByteBuf byteBuf1;
Assert.IsTrue(chunk.Allocate(out byteBuf1));
Assert.AreEqual(25, chunk.Usage);
ByteBuf byteBuf2;
Assert.IsTrue(chunk.Allocate(out byteBuf2));
Assert.AreEqual(50, chunk.Usage);
ByteBuf byteBuf3;
Assert.IsTrue(chunk.Allocate(out byteBuf3));
Assert.AreEqual(75, chunk.Usage);
ByteBuf byteBuf4;
Assert.IsTrue(chunk.Allocate(out byteBuf4));
Assert.AreEqual(100, chunk.Usage);
ByteBuf byteBuf5;
Assert.IsFalse(chunk.Allocate(out byteBuf5)); // Usage is 100%, cannot allocate
Assert.IsNull(byteBuf5);
// Verify Free()
chunk.Free(byteBuf1);
Assert.AreEqual(75, chunk.Usage);
chunk.Free(byteBuf2);
Assert.AreEqual(50, chunk.Usage);
chunk.Free(byteBuf3);
Assert.AreEqual(25, chunk.Usage);
chunk.Free(byteBuf4);
Assert.AreEqual(0, chunk.Usage);
}
[Test]
public void TestAllocateFreeManagedBufChunk()
{
// This chunk can only allocate 4 ByteBufs totally
var chunk = ByteBufChunk.NewChunk(managedBufPool, managedBufPool.SegmentSize, managedBufPool.ChunkSize, false);
AllocateFreeBufChunkTest(chunk);
chunk.Dispose();
}
[Test]
public void TestAllocateFreeUnsafeBufChunk()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
// This chunk can only allocate 4 ByteBufs totally
var chunk = ByteBufChunk.NewChunk(unsafeBufPool, unsafeBufPool.SegmentSize, unsafeBufPool.ChunkSize, true);
Assert.AreNotEqual(IntPtr.Zero, chunk.BufId);
AllocateFreeBufChunkTest(chunk);
chunk.Dispose();
}
}
}

Просмотреть файл

@ -0,0 +1,158 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
[TestFixture]
public class ByteBufPoolTest
{
private ByteBufPool bufPool;
private ByteBufPool unsafeBufPool;
[OneTimeSetUp]
public void CreatePooledBuffer()
{
bufPool = new ByteBufPool(1024, ByteBufPool.DefaultChunkOrder, false);
if (SocketFactory.IsRioSockSupported())
{
unsafeBufPool = new ByteBufPool(1024, ByteBufPool.DefaultChunkOrder, true);
}
}
[Test]
public void TestManagedBufferAllocate()
{
var byteBuf = bufPool.Allocate();
var bufChunk = byteBuf.ByteBufChunk;
Assert.AreEqual(bufChunk.FreeBytes, bufChunk.Size - byteBuf.Capacity);
byteBuf.Release();
Assert.AreEqual(bufChunk.FreeBytes, bufChunk.Size);
}
[Test]
public void TestManagedBufferPoolGrow()
{
// Verify no chunks in 100% usage queue at beginning.
var chunkNumbers = bufPool.GetUsages();
Assert.AreEqual(0, chunkNumbers[5]); // The number of chunks in 100% usage queue should be 0 at beginning.
var bufs = new List<ByteBuf>();
for (var i = 0; i < 257; i++)
{
var byteBuf = bufPool.Allocate();
bufs.Add(byteBuf);
}
var firstChunk = bufs[255].ByteBufChunk;
var secondChunk = bufs[256].ByteBufChunk;
// Verify the buffer pool got grown.
Assert.AreNotSame(firstChunk, secondChunk);
// Verify the usage of the first buffer chunk
Assert.AreEqual(firstChunk.Usage, 100);
// Verify the usage of the second buffer chunk
Assert.AreEqual(secondChunk.Usage, 1);
// Verify the chunk exhaust and should be in 100% usage queue now.
chunkNumbers = bufPool.GetUsages();
Assert.AreNotEqual(0, chunkNumbers[5]); // The number of chunks in 100% usage queue should not be 0 now.
// Verify the ToString() shows as usage string.
var usageStr = bufPool.ToString();
Assert.IsTrue(usageStr.StartsWith("Chunk(s) at 0~25%"));
// Release buffers back to pool.
foreach (var byteBuf in bufs)
{
byteBuf.Release();
}
// Verify the first buffer chunk is disposed.
Assert.IsTrue(firstChunk.IsDisposed);
// Verify the usage of the second buffer chunk is 0.
Assert.AreEqual(secondChunk.Usage, 0);
}
[Test]
public void TestUnsafeBufferAllocate()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
var byteBuf = unsafeBufPool.Allocate();
var bufChunk = byteBuf.ByteBufChunk;
Assert.AreNotEqual(IntPtr.Zero, bufChunk.UnsafeArray);
Assert.AreNotEqual(bufChunk.BufId, IntPtr.Zero);
Assert.AreEqual(bufChunk.FreeBytes, bufChunk.Size - byteBuf.Capacity);
// Verify GetInputRioBuf()
var inputRioBuf = byteBuf.GetInputRioBuf();
Assert.AreNotEqual(null, inputRioBuf);
Assert.AreEqual(byteBuf.WritableBytes, inputRioBuf.Length);
// Verify GetOutputRioBuf()
const string writeStr = "Write bytes to ByteBuf.";
var writeBytes = Encoding.UTF8.GetBytes(writeStr);
byteBuf.WriteBytes(writeBytes, 0, writeBytes.Length);
var outputRioBuf = byteBuf.GetOutputRioBuf();
Assert.AreNotEqual(null, outputRioBuf);
Assert.AreEqual(byteBuf.ReadableBytes, outputRioBuf.Length);
byteBuf.Release();
Assert.AreEqual(bufChunk.FreeBytes, bufChunk.Size);
}
[Test]
public void TestUnsafeBufferPoolGrow()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
var bufs = new List<ByteBuf>();
for (var i = 0; i < 257; i++)
{
var byteBuf = unsafeBufPool.Allocate();
bufs.Add(byteBuf);
}
var firstChunk = bufs[255].ByteBufChunk;
var secondChunk = bufs[256].ByteBufChunk;
// Verify the buffer pool got grown.
Assert.AreNotSame(firstChunk, secondChunk);
Assert.AreNotEqual(firstChunk.BufId, secondChunk.BufId);
// Verify the usage of the first buffer chunk
Assert.AreEqual(firstChunk.Usage, 100);
// Verify the usage of the second buffer chunk
Assert.AreEqual(secondChunk.Usage, 1);
// Release buffers back to pool.
foreach (var byteBuf in bufs)
{
byteBuf.Release();
}
// Verify the first buffer chunk is disposed.
Assert.IsTrue(firstChunk.IsDisposed);
// Verify the usage of the second buffer chunk is 0.
Assert.AreEqual(secondChunk.Usage, 0);
}
[Test]
public void TestInvalidBufPool()
{
// Verify to new a ByteBufPool with a bigger chunkOrder which valid value is 0-14.
Assert.Throws<ArgumentException>(() => new ByteBufPool(1024, 15, false));
// Verify to new a ByteBugPool with a larger segment size to hit overflow of ByteBufChunk size.
Assert.Throws<ArgumentException>(() => new ByteBufPool(int.MaxValue, 8, false));
}
}
}

Просмотреть файл

@ -0,0 +1,161 @@
using System;
using System.Text;
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
[TestFixture]
public class ByteBufTest
{
private ByteBufPool managedBufPool;
private ByteBufPool unsafeBufPool;
[OneTimeSetUp]
public void CreatePooledBuffer()
{
managedBufPool = new ByteBufPool(1024, ByteBufPool.DefaultChunkOrder, false);
if (SocketFactory.IsRioSockSupported())
{
unsafeBufPool = new ByteBufPool(1024, ByteBufPool.DefaultChunkOrder, true);
}
}
private void WriteReadByteBufTest(ByteBuf byteBuf)
{
var initWriteIndex = byteBuf.WriterIndex;
var initReadIndex = byteBuf.WriterIndex;
Assert.AreEqual(initWriteIndex, initReadIndex);
Assert.AreEqual(byteBuf.Capacity, byteBuf.WritableBytes);
Assert.AreEqual(0, byteBuf.ReadableBytes);
Assert.IsFalse(byteBuf.IsReadable());
Assert.IsTrue(byteBuf.IsWritable());
// Verify WriteBytes() function
const string writeStr = "Write bytes to ByteBuf.";
var writeBytes = Encoding.UTF8.GetBytes(writeStr);
byteBuf.WriteBytes(writeBytes, 0, writeBytes.Length);
Assert.AreEqual(initWriteIndex + writeBytes.Length, byteBuf.WriterIndex);
Assert.AreEqual(byteBuf.Capacity - writeBytes.Length, byteBuf.WritableBytes);
Assert.AreEqual(writeBytes.Length, byteBuf.ReadableBytes);
Assert.AreEqual(initReadIndex, byteBuf.ReaderIndex);
Assert.IsTrue(byteBuf.IsReadable());
// Verify ReadBytes() function
var readBytes = new byte[writeBytes.Length];
var ret = byteBuf.ReadBytes(readBytes, 0, readBytes.Length);
Assert.AreEqual(writeBytes.Length, ret);
var readStr = Encoding.UTF8.GetString(readBytes, 0, ret);
Assert.AreEqual(writeStr, readStr);
Assert.AreEqual(initWriteIndex + writeBytes.Length, byteBuf.WriterIndex);
Assert.AreEqual(byteBuf.Capacity - writeBytes.Length, byteBuf.WritableBytes);
Assert.AreEqual(0, byteBuf.ReadableBytes);
Assert.AreEqual(initReadIndex + ret, byteBuf.ReaderIndex);
// Verify ReadByte() function
byteBuf.WriteBytes(writeBytes, 0, 1);
var retByte = byteBuf.ReadByte();
Assert.AreEqual(writeBytes[0], retByte);
// Verify clear() function
byteBuf.Clear();
Assert.AreEqual(0, byteBuf.ReaderIndex);
Assert.AreEqual(0, byteBuf.WriterIndex);
}
[Test]
public void TestWriteReadManagedBuf()
{
var byteBuf = managedBufPool.Allocate();
Assert.AreEqual(IntPtr.Zero, byteBuf.UnsafeArray); // Verify the pointer of UnsafeArray is IntPtr.Zero on managed buffer.
WriteReadByteBufTest(byteBuf);
byteBuf.Release();
}
[Test]
public void TestWriteReadUnsafeBuf()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
var byteBuf = unsafeBufPool.Allocate();
Assert.AreNotEqual(IntPtr.Zero, byteBuf.UnsafeArray); // Verify the point of UnsafeArray has value on unsafe buffer.
WriteReadByteBufTest(byteBuf);
byteBuf.Release();
}
[Test]
public void TestInvalidByteBuf()
{
// Test invalid parameter to new ByteBuf.
Assert.Throws<ArgumentOutOfRangeException>(() => new ByteBuf(null, -1, 1024));
Assert.Throws<ArgumentOutOfRangeException>(() => new ByteBuf(null, 0, -1));
Assert.Throws<ArgumentNullException>(() => new ByteBuf(null, 0, 1024));
var byteBuf = managedBufPool.Allocate();
Assert.Throws<ArgumentException>(() => new ByteBuf(byteBuf.ByteBufChunk, byteBuf.ByteBufChunk.Size - 1, 1024));
byteBuf.Release();
// Test function on disposed ByteBuf.
Assert.IsFalse(byteBuf.IsWritable());
Assert.Throws<ObjectDisposedException>(() => byteBuf.ReadByte());
// Release double - nothing to do
byteBuf.Release();
}
[Test]
public void TestInvalidReadBytes()
{
var byteBuf = managedBufPool.Allocate();
var readBytes = new byte[10];
// Verify ReadBytes with invalid parameters
Assert.Throws<ArgumentNullException>(() => byteBuf.ReadBytes(null, 0, 0));
Assert.Throws<ArgumentOutOfRangeException>(() => byteBuf.ReadBytes(readBytes, -1, 1024));
Assert.Throws<ArgumentOutOfRangeException>(() => byteBuf.ReadBytes(readBytes, 0, -1));
// Verify ReadBytes with invalid boundary
Assert.Throws<ArgumentException>(() => byteBuf.ReadBytes(readBytes, readBytes.Length - 1, readBytes.Length));
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.ReadBytes(readBytes, 0, readBytes.Length));
byteBuf.Release();
}
[Test]
public void TestInvalidWriteBytes()
{
var byteBuf = managedBufPool.Allocate();
var writeBytes = new byte[2048];
// Verify WriteBytes with invalid parameters
Assert.Throws<ArgumentNullException>(() => byteBuf.WriteBytes(null, 0, 0));
Assert.Throws<ArgumentOutOfRangeException>(() => byteBuf.WriteBytes(writeBytes, -1, 1024));
Assert.Throws<ArgumentOutOfRangeException>(() => byteBuf.WriteBytes(writeBytes, 0, -1));
// Verify WriteBytes with invalid boundary
Assert.Throws<ArgumentException>(() => byteBuf.WriteBytes(writeBytes, writeBytes.Length - 1, writeBytes.Length));
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.WriteBytes(writeBytes, 0, 2048));
}
[Test]
public void TestInvalidWriterReaderIndex()
{
var byteBuf = managedBufPool.Allocate();
// Verify to set writer/reader index with a negative value.
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.WriterIndex = -1);
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.ReaderIndex = -1);
// Verify to set writer/reader index with an overflow value.
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.ReaderIndex += 2);
Assert.Throws<IndexOutOfRangeException>(() => byteBuf.WriterIndex += (managedBufPool.SegmentSize + 1));
byteBuf.Release();
}
[Test]
public void TestGetRioBufFromManagedBuf()
{
var byteBuf = managedBufPool.Allocate();
// Verify to GetInputRioBuf()/GetOutputRioBuf() from a managed ByteBuf
Assert.Throws<InvalidOperationException>(() => byteBuf.GetInputRioBuf());
Assert.Throws<InvalidOperationException>(() => byteBuf.GetOutputRioBuf());
byteBuf.Release();
}
}
}

Просмотреть файл

@ -8,6 +8,7 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Network;
namespace AdapterTest.Mocks
{
@ -26,6 +27,11 @@ namespace AdapterTest.Mocks
return workerPath;
}
public SocketWrapperType GetCSharpSocketType()
{
return SocketWrapperType.Normal;
}
public int BackendPortNumber
{
get { throw new NotImplementedException(); }

Просмотреть файл

@ -0,0 +1,44 @@
using System;
using System.IO;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
[TestFixture]
public class SocketStreamTest
{
/// <summary>
/// Only test invalid case.
/// The positive case verified on SocketWrapperTest
/// </summary>
[Test]
public void TestInvalidSocketStream()
{
Assert.Throws<ArgumentNullException>(() => new SocketStream((SaeaSocketWrapper)null));
Assert.Throws<ArgumentNullException>(() => new SocketStream((RioSocketWrapper)null));
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, "Saea");
SocketFactory.SocketWrapperType = SocketWrapperType.None;
using (var socket = SocketFactory.CreateSocket())
using (var stream = new SocketStream((SaeaSocketWrapper)socket))
{
// Verify SocketStream
Assert.IsTrue(stream.CanRead);
Assert.IsTrue(stream.CanWrite);
Assert.IsFalse(stream.CanSeek);
long lengh = 10;
Assert.Throws<NotSupportedException>(() => lengh = stream.Length);
Assert.Throws<NotSupportedException>(() => stream.Position = lengh);
Assert.Throws<NotSupportedException>(() => lengh = stream.Position);
Assert.Throws<NotSupportedException>(() => stream.Seek(lengh, SeekOrigin.Begin));
Assert.Throws<NotSupportedException>(() => stream.SetLength(lengh));
}
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, string.Empty);
SocketFactory.SocketWrapperType = SocketWrapperType.None;
}
}
}

Просмотреть файл

@ -0,0 +1,204 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Network;
using NUnit.Framework;
namespace AdapterTest
{
/// <summary>
/// Validates SaeaSocketWrapper by creating a ISocketWrapper server to
/// simulate interactions between CSharpRDD and CSharpWorker
/// </summary>
[TestFixture]
public class SocketWrapperTest
{
private void SocketTest(ISocketWrapper serverSocket)
{
serverSocket.Listen();
if (serverSocket is RioSocketWrapper)
{
// Do nothing for second listen operation.
Assert.DoesNotThrow(() => serverSocket.Listen(int.MaxValue));
}
var port = ((IPEndPoint)serverSocket.LocalEndPoint).Port;
var clientMsg = "Hello Message from client";
var clientMsgBytes = Encoding.UTF8.GetBytes(clientMsg);
Task.Run(() =>
{
var bytes = new byte[1024];
using (var socket = serverSocket.Accept())
{
using (var s = socket.GetStream())
{
// Receive data
var bytesRec = s.Read(bytes, 0, bytes.Length);
// send echo message.
s.Write(bytes, 0, bytesRec);
s.Flush();
// Receive one byte
var oneByte = s.ReadByte();
// Send echo one byte
byte[] oneBytes = { (byte)oneByte };
s.Write(oneBytes, 0, oneBytes.Length);
Thread.SpinWait(0);
// Keep sending to ensure no memory leak
var longBytes = Encoding.UTF8.GetBytes(new string('x', 8192));
for (int i = 0; i < 1000; i++)
{
s.Write(longBytes, 0, longBytes.Length);
}
byte[] msg = Encoding.ASCII.GetBytes("This is a test<EOF>");
s.Write(msg, 0, msg.Length);
// Receive echo byte.
s.ReadByte();
}
}
});
var clientSock = SocketFactory.CreateSocket();
// Valid invalid operation
Assert.Throws<InvalidOperationException>(() => clientSock.GetStream());
Assert.Throws<InvalidOperationException>(() => clientSock.Receive());
Assert.Throws<InvalidOperationException>(() => clientSock.Send(null));
Assert.Throws<SocketException>(() => clientSock.Connect(IPAddress.Any, 1024));
clientSock.Connect(IPAddress.Loopback, port);
// Valid invalid operation
var byteBuf = ByteBufPool.Default.Allocate();
Assert.Throws<ArgumentException>(() => clientSock.Send(byteBuf));
byteBuf.Release();
Assert.Throws<SocketException>(() => clientSock.Listen());
if (clientSock is RioSocketWrapper)
{
Assert.Throws<InvalidOperationException>(() => clientSock.Accept());
}
using (var s = clientSock.GetStream())
{
// Send message
s.Write(clientMsgBytes, 0, clientMsgBytes.Length);
// Receive echo message
var bytes = new byte[1024];
var bytesRec = s.Read(bytes, 0, bytes.Length);
Assert.AreEqual(clientMsgBytes.Length, bytesRec);
var recvStr = Encoding.UTF8.GetString(bytes, 0, bytesRec);
Assert.AreEqual(clientMsg, recvStr);
// Send one byte
byte[] oneBytes = { 1 };
s.Write(oneBytes, 0, oneBytes.Length);
// Receive echo message
var oneByte = s.ReadByte();
Assert.AreEqual((byte)1, oneByte);
// Keep receiving to ensure no memory leak.
while (true)
{
bytesRec = s.Read(bytes, 0, bytes.Length);
recvStr = Encoding.UTF8.GetString(bytes, 0, bytesRec);
if (recvStr.IndexOf("<EOF>", StringComparison.OrdinalIgnoreCase) > -1)
{
break;
}
}
// send echo bytes
s.Write(oneBytes, 0, oneBytes.Length);
}
clientSock.Close();
// Verify invalid operation
Assert.Throws<ObjectDisposedException>(() => clientSock.Receive());
serverSocket.Close();
}
[Test]
public void TestSaeaSocket()
{
// Set Socket wrapper to Saea
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, "Saea");
SocketFactory.SocketWrapperType = SocketWrapperType.None;
var serverSocket = SocketFactory.CreateSocket();
Assert.IsTrue(serverSocket is SaeaSocketWrapper);
SocketTest(serverSocket);
// Reset socket wrapper type
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, string.Empty);
SocketFactory.SocketWrapperType = SocketWrapperType.None;
}
[Test]
public void TestRioSocket()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
// Set Socket wrapper to Rio
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, "Rio");
SocketFactory.SocketWrapperType = SocketWrapperType.None;
RioSocketWrapper.rioRqGrowthFactor = 1;
var serverSocket = SocketFactory.CreateSocket();
Assert.IsTrue(serverSocket is RioSocketWrapper);
SocketTest(serverSocket);
// Reset socket wrapper type
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, string.Empty);
SocketFactory.SocketWrapperType = SocketWrapperType.None;
RioNative.UnloadRio();
}
[Test]
public void TestUseThreadPoolForRioNative()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
RioNative.SetUseThreadPool(true);
RioNative.EnsureRioLoaded();
Assert.AreEqual(Environment.ProcessorCount, RioNative.GetWorkThreadNumber());
RioNative.UnloadRio();
RioNative.SetUseThreadPool(false);
}
[Test]
public void TestUseSingleThreadForRioNative()
{
if (!SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
RioNative.SetUseThreadPool(false);
RioNative.EnsureRioLoaded();
Assert.AreEqual(1, RioNative.GetWorkThreadNumber());
RioNative.UnloadRio();
}
}
}

Просмотреть файл

@ -2,16 +2,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AdapterTest.Mocks;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Interop;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Proxy;
using Microsoft.Spark.CSharp.Proxy.Ipc;
using NUnit.Framework;
namespace AdapterTest
@ -22,6 +15,7 @@ namespace AdapterTest
[OneTimeSetUp]
public static void Initialize()
{
Console.WriteLine("Running SparkCLRTestEnvironment Initialize()");
SparkCLREnvironment.SparkCLRProxy = new MockSparkCLRProxy();
SparkCLREnvironment.ConfigurationService = new MockConfigurationService();
SparkCLREnvironment.WeakObjectManager = new WeakObjectManagerImpl

Просмотреть файл

@ -7,6 +7,13 @@ set CMDHOME=%CMDHOME:~0,-1%
@REM Set msbuild location.
SET VisualStudioVersion=12.0
if EXIST "%VS140COMNTOOLS%" SET VisualStudioVersion=14.0
@REM Set Build OS
SET CppDll=HasCpp
SET VCBuildTool="%VS120COMNTOOLS:~0,-14%VC\bin\cl.exe"
if EXIST "%VS140COMNTOOLS%" SET VCBuildTool="%VS140COMNTOOLS:~0,-14%VC\bin\cl.exe"
if NOT EXIST %VCBuildTool% SET CppDll=NoCpp
SET MSBUILDEXEDIR=%programfiles(x86)%\MSBuild\%VisualStudioVersion%\Bin
if NOT EXIST "%MSBUILDEXEDIR%\." SET MSBUILDEXEDIR=%programfiles%\MSBuild\%VisualStudioVersion%\Bin
@ -39,7 +46,7 @@ SET CONFIGURATION=%STEP%
SET STEP=%CONFIGURATION%
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION% %MSBUILDOPT% "%PROJ%"
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION%;AllowUnsafeBlocks=true %MSBUILDOPT% "%PROJ%"
@if ERRORLEVEL 1 GOTO :ErrorStop
@echo BUILD ok for %CONFIGURATION% %PROJ%
@ -48,7 +55,7 @@ SET STEP=Release
SET CONFIGURATION=%STEP%
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION% %MSBUILDOPT% "%PROJ%"
"%MSBUILDEXE%" /p:Configuration=%CONFIGURATION%;AllowUnsafeBlocks=true %MSBUILDOPT% "%PROJ%"
@if ERRORLEVEL 1 GOTO :ErrorStop
@echo BUILD ok for %CONFIGURATION% %PROJ%

Просмотреть файл

@ -12,6 +12,6 @@
<add key="CSharpWorkerPath" value="C:\SparkCLR\csharp\Samples\Microsoft.Spark.CSharp\bin\Debug\CSharpWorker.exe"/>
<add key="CSharpBackendPortNumber" value="0"/>
-->
</appSettings>
</configuration>

Просмотреть файл

@ -23,6 +23,9 @@ namespace Microsoft.Spark.CSharp.PerfBenchmark
/// 6. predicate (MID)
/// 7. object (MID/Literal)
/// 8. language_code
///
/// Note: You can add an additional column with any size data, if you want to increase
/// the size for each line.
/// </summary>
class FreebaseDeletionsBenchmark
{

Просмотреть файл

@ -11,6 +11,7 @@
<AssemblyName>SparkCLRPerf</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<CppDll Condition="Exists('..\..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -21,6 +22,7 @@
<DefineConstants>DEBUG;TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
<Prefer32Bit>false</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -42,9 +44,21 @@
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<None Include="App.config" />
<None Include="App.config">
<SubType>Designer</SubType>
</None>
<None Include="data\deletionbenchmarktestdata.csv" />
</ItemGroup>
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<None Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
</None>
<None Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
</None>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\Adapter\Microsoft.Spark.CSharp\Adapter.csproj">
<Project>{ce999a96-f42b-4e80-b208-709d7f49a77c}</Project>

Просмотреть файл

@ -11,6 +11,7 @@
<AssemblyName>SparkCLRSamples</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<CppDll Condition="Exists('..\..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -78,7 +79,18 @@
</None>
<None Include="packages.config" />
</ItemGroup>
<ItemGroup />
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
<TargetPath>Riosock.dll</TargetPath>
</ContentWithTargetPath>
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
<TargetPath>Riosock.pdb</TargetPath>
</ContentWithTargetPath>
</ItemGroup>
<ItemGroup>
<Content Include="data\csvtestlog.txt">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>

Просмотреть файл

@ -13,7 +13,7 @@ namespace Microsoft.Spark.CSharp
/// TaskRunner is used to run Spark task assigned by JVM side. It uses a TCP socket to
/// communicate with JVM side. This socket may be reused to run multiple Spark tasks.
/// </summary>
public class TaskRunner
internal class TaskRunner
{
private static ILoggerService logger = null;
private ILoggerService Logger

Просмотреть файл

@ -38,9 +38,10 @@ namespace Microsoft.Spark.CSharp
// can't initialize logger early because in MultiThreadWorker mode, JVM will read C#'s stdout via
// pipe. When initialize logger, some unwanted info will be flushed to stdout. But we can still
// use stderr
Console.Error.WriteLine("input args: [{0}]", string.Join(" ", args));
Console.Error.WriteLine("input args: [{0}] SocketWrapper: [{1}]",
string.Join(" ", args), SocketFactory.SocketWrapperType);
if (args.Count() != 2)
if (args.Length != 2)
{
Console.Error.WriteLine("Wrong number of args: {0}, will exit", args.Count());
Environment.Exit(-1);
@ -48,6 +49,14 @@ namespace Microsoft.Spark.CSharp
if ("pyspark.daemon".Equals(args[1]))
{
if (SocketFactory.SocketWrapperType == SocketWrapperType.Rio)
{
// In daemon mode, the socket will be used as server.
// Use ThreadPool to retrieve RIO socket results has good performance
// than a single thread.
RioNative.SetUseThreadPool(true);
}
var multiThreadWorker = new MultiThreadWorker();
multiThreadWorker.Run();
}

Просмотреть файл

@ -11,6 +11,7 @@
<AssemblyName>CSharpWorker</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<CppDll Condition="Exists('..\..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
@ -50,7 +51,18 @@
<Compile Include="TaskRunner.cs" />
<Compile Include="Worker.cs" />
</ItemGroup>
<ItemGroup />
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
<TargetPath>Riosock.dll</TargetPath>
</ContentWithTargetPath>
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
<TargetPath>Riosock.pdb</TargetPath>
</ContentWithTargetPath>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\Adapter\Microsoft.Spark.CSharp\Adapter.csproj">
<Project>{ce999a96-f42b-4e80-b208-709d7f49a77c}</Project>

Просмотреть файл

@ -10,6 +10,7 @@ using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using System.Text;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Network;
@ -21,7 +22,9 @@ namespace WorkerTest
/// Validates MultiThreadWorker by creating a ISocketWrapper server to
/// simulate interactions between CSharpRDD and CSharpWorker
/// </summary>
[TestFixture]
[TestFixture("Normal")]
[TestFixture("Rio")]
[TestFixture("Saea")]
class MultiThreadWorkerTest
{
private int splitIndex = 0;
@ -31,6 +34,18 @@ namespace WorkerTest
private int numBroadcastVariables = 0;
private readonly byte[] command = SparkContext.BuildCommand(new CSharpWorkerFunc((pid, iter) => iter), SerializedMode.String, SerializedMode.String);
public MultiThreadWorkerTest(string sockType)
{
if (sockType.Equals("Rio") && !SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting TestFixture due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
// Set Socket wrapper for test
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, sockType);
SocketFactory.SocketWrapperType = SocketWrapperType.None;
}
// StringBuilder is not thread-safe, it shouldn't be used concurrently from different threads.
// http://stackoverflow.com/questions/12645351/stringbuilder-tostring-throw-an-index-out-of-range-exception
StringBuilder output = new StringBuilder();
@ -168,7 +183,10 @@ namespace WorkerTest
/// <param name="exitCode"></param>
private void AssertWorker(Process worker, int exitCode = 0, string errorMessage = null)
{
worker.WaitForExit(3000);
if (!worker.WaitForExit(3000))
{
worker.Kill();
}
string str;
lock (syncLock)
{

Просмотреть файл

@ -11,6 +11,7 @@ using System.Net;
using System.Reflection;
using System.Runtime.Serialization.Formatters.Binary;
using System.Text;
using Microsoft.Spark.CSharp.Configuration;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Sql;
using Microsoft.Spark.CSharp.Interop.Ipc;
@ -25,7 +26,9 @@ namespace WorkerTest
/// Validates CSharpWorker by creating a ISocketWrapper server to
/// simulate interactions between CSharpRDD and CSharpWorker
/// </summary>
[TestFixture]
[TestFixture("Normal")]
[TestFixture("Rio")]
[TestFixture("Saea")]
public class WorkerTest
{
private int splitIndex = 0;
@ -34,6 +37,29 @@ namespace WorkerTest
private int numberOfIncludesItems = 0;
private int numBroadcastVariables = 0;
private readonly byte[] command = SparkContext.BuildCommand(new CSharpWorkerFunc((pid, iter) => iter), SerializedMode.String, SerializedMode.String);
private readonly string socketWrapperType;
public WorkerTest(string sockType)
{
if (sockType.Equals("Rio") && !SocketFactory.IsRioSockSupported())
{
Assert.Ignore("Omitting TestFixture due to missing Riosock.dll. It might caused by no VC++ build tool or running on an OS that not supports Windows RIO socket.");
}
// Set Socket wrapper for test
socketWrapperType = sockType;
Environment.SetEnvironmentVariable(ConfigurationService.CSharpSocketTypeEnvName, socketWrapperType);
SocketFactory.SocketWrapperType = SocketWrapperType.None;
}
[OneTimeTearDown]
public void CleanUpSocketWrapper()
{
if (socketWrapperType.Equals("Rio") && SocketFactory.IsRioSockSupported())
{
RioNative.UnloadRio();
}
}
// StringBuilder is not thread-safe, it shouldn't be used concurrently from different threads.
// http://stackoverflow.com/questions/12645351/stringbuilder-tostring-throw-an-index-out-of-range-exception
@ -144,7 +170,10 @@ namespace WorkerTest
/// <param name="exitCode"></param>
private void AssertWorker(Process worker, int exitCode = 0, string assertMessage = null)
{
worker.WaitForExit(3000);
if (!worker.WaitForExit(3000))
{
worker.Kill();
}
Assert.IsTrue(worker.HasExited);
Assert.AreEqual(exitCode, worker.ExitCode);
string str;

Просмотреть файл

@ -1,6 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
@ -11,6 +10,13 @@
<AssemblyName>WorkerTest</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<ProjectTypeGuids>{3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
<ReferencePath>$(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages</ReferencePath>
<IsCodedUITest>False</IsCodedUITest>
<TestProjectType>UnitTest</TestProjectType>
<CppDll Condition="Exists('..\..\cpp\x64')">HasCpp</CppDll>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<DebugSymbols>true</DebugSymbols>
@ -71,6 +77,18 @@
<ItemGroup>
<Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" />
</ItemGroup>
<ItemGroup Condition=" '$(CppDll)' == 'HasCpp' ">
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.dll</Link>
<TargetPath>Riosock.dll</TargetPath>
</ContentWithTargetPath>
<ContentWithTargetPath Include="$(SolutionDir)..\cpp\x64\$(ConfigurationName)\Riosock.pdb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>Cpp\Riosock.pdb</Link>
<TargetPath>Riosock.pdb</TargetPath>
</ContentWithTargetPath>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>

Просмотреть файл

@ -1,7 +1,7 @@
#!/bin/bash
export FWDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export CppDll=NoCpp
export XBUILDOPT=/verbosity:minimal
if [ -z $builduri ];
@ -37,7 +37,7 @@ export CONFIGURATION=$STEP
export STEP=$CONFIGURATION
xbuild /p:Configuration=$CONFIGURATION $XBUILDOPT $PROJ
xbuild "/p:Configuration=$CONFIGURATION;AllowUnsafeBlocks=true" $XBUILDOPT $PROJ
export RC=$? && [ $RC -ne 0 ] && error_exit
echo "BUILD ok for $CONFIGURATION $PROJ"
@ -46,7 +46,7 @@ export STEP=Release
export CONFIGURATION=$STEP
xbuild /p:Configuration=$CONFIGURATION $XBUILDOPT $PROJ
xbuild "/p:Configuration=$CONFIGURATION;AllowUnsafeBlocks=true" $XBUILDOPT $PROJ
export RC=$? && [ $RC -ne 0 ] && error_exit
echo "BUILD ok for $CONFIGURATION $PROJ"

Просмотреть файл

@ -7,6 +7,7 @@ set CMDHOME=%CMDHOME:~0,-1%
@REM Set msbuild location.
SET VisualStudioVersion=12.0
if EXIST "%VS140COMNTOOLS%" SET VisualStudioVersion=14.0
SET MSBUILDEXEDIR=%programfiles(x86)%\MSBuild\%VisualStudioVersion%\Bin
if NOT EXIST "%MSBUILDEXEDIR%\." SET MSBUILDEXEDIR=%programfiles%\MSBuild\%VisualStudioVersion%\Bin

Просмотреть файл

@ -7,5 +7,6 @@
|Streaming (Kafka) |spark.mobius.streaming.kafka.fetchRate |Set the number of Kafka metadata fetch operation per batch |
|Streaming (Kafka) |spark.mobius.streaming.kafka.numReceivers |Set the number of threads used to materialize the RDD created by applying the user read function to the original KafkaRDD. |
|Streaming (UpdateStateByKey) |spark.mobius.streaming.parallelJobs |Sets 0-based max number of parallel jobs for UpdateStateByKey so that next N batches can start its tasks on time even if previous batch not completed yet. default: 0, recommended: 1. It's a special version of spark.streaming.concurrentJobs which does not observe UpdateStateByKey's state ordering properly |
|Worker |spark.mobius.CSharp.socketType |Sets the socket type that will be used in IPC for csharp code. default: Normal, if no any configuration. Normal means use default .Net Socket class for IPC; Rio, use Windows RIO socket for IPC; Saea, use .Net Socket class with SocketAsyncEventArgs class for IPC. Riosocket and SaeaSocket has better performance on dealing larger data transmission than traditional .Net Socket. You can switch the socket type when you has large data transmission (we can see the performance improvement for over 4KB per transmission in average) between JVM and CLR. |

Просмотреть файл

@ -10,6 +10,9 @@ The following environment variables should be set properly in the Developer Comm
* `JAVA_HOME`
**To Be Noticed**:
Mobius on Windows includes a C++ component - RIOSock.dll. If your environment does not have VC++ Build Toolset installed, the C++ component will be skipped to compile. Offically, the C++ component is always compiled on AppVeyor.
Please enable VC++ component from Visual Studio, or you can download [Visual C++ Build Tools](http://landinghub.visualstudio.com/visual-cpp-build-tools), if you want to build C++ components.
## Instructions

Просмотреть файл

@ -72,6 +72,11 @@ class CSharpRDD(
logInfo(s"workerFactoryId: $workerFactoryId")
}
if (!CSharpRDD.csharpWorkerSocketType.isEmpty) {
envVars.put("spark.mobius.CSharp.socketType", CSharpRDD.csharpWorkerSocketType)
logInfo(s"CSharpWorker socket type: $CSharpRDD.csharpWorkerSocketType")
}
val runner = new PythonRunner(
command, envVars, cSharpIncludes, cSharpWorker.getAbsolutePath, unUsedVersionIdentifier,
broadcastVars, accumulator, bufferSize, reuse_worker)
@ -200,6 +205,8 @@ object CSharpRDD {
// long running multi-process CSharpWorker mode is enabled only when configurated explicitly
var maxCSharpWorkerProcessCount: Int = SparkEnv.get.conf.getInt("spark.mobius.CSharpWorker.maxProcessCount", -1)
// socket type for CSharpWorker
var csharpWorkerSocketType: String = SparkEnv.get.conf.get("spark.mobius.CSharp.socketType", "")
def createRDDFromArray(
sc: SparkContext,