Add iOS SocketTransport implementation (#542)
This is rather crude, but it seems to work. Protocol and Transport are blocking APIs; in Android (and Java generally) that's okay, but Apple strongly pushes you towards async networking in iOS. So strongly in fact that all high-level APIs that are not yet deprecated are async-only, including Network.framework. This framework is also the only non-deprecated game in town when it comes to TLS. In order to bridge the gap between Network.framework and our blocking APIs, this PR makes extensive use of dispatch semaphores - essentially, we block the calling thread until a completion handler signals the semaphore. In the next major version of Thrifty, we should see about making the core APIs suspend, with blocking shims for migration. In that version, we can drop this charade and just use ktor or something. Until that glorious day, we get NwSocket.
This commit is contained in:
Родитель
436e7f2ec7
Коммит
188faba9e2
|
@ -14,7 +14,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
jvm-version: [17]
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
os: [ubuntu-latest, macos-13]
|
||||
env:
|
||||
JDK_VERSION: ${{ matrix.jvm-version }}
|
||||
GRADLE_OPTS: -Dorg.gradle.daemon=false
|
||||
|
@ -34,6 +34,7 @@ jobs:
|
|||
path: |
|
||||
~/.gradle/caches/
|
||||
~/.gradle/wrapper/
|
||||
~/.konan/
|
||||
.build-cache/
|
||||
key: cache-gradle-${{ matrix.os }}-${{ matrix.jvm-version }}-${{ hashFiles('settings.gradle') }}-${{ hashFiles('**/build.gradle') }}
|
||||
restore-keys: |
|
||||
|
@ -56,6 +57,16 @@ jobs:
|
|||
shell: bash
|
||||
run: ./gradlew check codeCoverageReport --parallel --no-daemon
|
||||
|
||||
- name: Boot simulator
|
||||
if: matrix.os == 'macos-13'
|
||||
shell: bash
|
||||
run: xcrun simctl boot 'iPhone 14 Pro Max' || true
|
||||
|
||||
- name: Run simulator tests
|
||||
if: matrix.os == 'macos-13'
|
||||
shell: bash
|
||||
run: ./gradlew -PiosDevice="iPhone 14 Pro Max" iosTest
|
||||
|
||||
- name: Upload coverage stats
|
||||
if: success() && matrix.os == 'ubuntu-latest' && matrix.jvm-version == '8'
|
||||
uses: codecov/codecov-action@v3
|
||||
|
|
|
@ -23,3 +23,5 @@ org.gradle.jvmargs=-XX:MaxMetaspaceSize=512m
|
|||
|
||||
# Build cache is helpful
|
||||
org.gradle.caching=true
|
||||
|
||||
kotlin.mpp.enableCInteropCommonization=true
|
||||
|
|
|
@ -47,6 +47,10 @@ kotlin {
|
|||
baseName = "Thrifty"
|
||||
}
|
||||
}
|
||||
|
||||
compilations.main.cinterops {
|
||||
KT62102Workaround {}
|
||||
}
|
||||
}
|
||||
|
||||
iosX64 {
|
||||
|
@ -55,6 +59,10 @@ kotlin {
|
|||
baseName = "Thrifty"
|
||||
}
|
||||
}
|
||||
|
||||
compilations.main.cinterops {
|
||||
KT62102Workaround {}
|
||||
}
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
|
@ -98,6 +106,12 @@ kotlin {
|
|||
|
||||
iosTest {
|
||||
dependsOn commonTest
|
||||
|
||||
dependencies {
|
||||
implementation libs.kotlin.test.common
|
||||
implementation libs.kotest.assertions.common
|
||||
implementation libs.kotest.assertions.core
|
||||
}
|
||||
}
|
||||
|
||||
iosArm64Main {
|
||||
|
@ -122,6 +136,21 @@ jvmTest {
|
|||
useJUnitPlatform()
|
||||
}
|
||||
|
||||
tasks.register("iosTest") {
|
||||
def device = project.findProperty("iosDevice")?.toString() ?: "iPhone 15 Pro Max"
|
||||
dependsOn 'linkDebugTestIosX64'
|
||||
group = JavaBasePlugin.VERIFICATION_GROUP
|
||||
description = "Runs tests for target 'ios' on an iOS simulator"
|
||||
|
||||
doLast {
|
||||
def binary = kotlin.targets.iosX64.binaries.getTest('DEBUG').outputFile
|
||||
println("muh binary: ${binary.absolutePath}")
|
||||
exec {
|
||||
commandLine 'xcrun', 'simctl', 'spawn', device, binary.absolutePath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// What have I gotten myself in to
|
||||
configurations {
|
||||
jvmApiElements {
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Thrifty
|
||||
*
|
||||
* Copyright (c) Microsoft Corporation
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
|
||||
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
*
|
||||
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package com.microsoft.thrifty.transport
|
||||
|
||||
expect class SocketTransport internal constructor(
|
||||
builder: Builder
|
||||
) : Transport {
|
||||
class Builder(host: String, port: Int) {
|
||||
/**
|
||||
* The number of milliseconds to wait for a connection to be established.
|
||||
*/
|
||||
fun connectTimeout(connectTimeout: Int): Builder
|
||||
|
||||
/**
|
||||
* The number of milliseconds a read operation should wait for completion.
|
||||
*/
|
||||
fun readTimeout(readTimeout: Int): Builder
|
||||
|
||||
/**
|
||||
* Enable TLS for this connection.
|
||||
*/
|
||||
fun enableTls(enableTls: Boolean): Builder
|
||||
|
||||
fun build(): SocketTransport
|
||||
}
|
||||
|
||||
@Throws(okio.IOException::class)
|
||||
fun connect()
|
||||
}
|
|
@ -0,0 +1,344 @@
|
|||
/*
|
||||
* Thrifty
|
||||
*
|
||||
* Copyright (c) Microsoft Corporation
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
|
||||
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
*
|
||||
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package com.microsoft.thrifty.transport
|
||||
|
||||
import KT62102Workaround.dispatch_get_target_default_queue
|
||||
import KT62102Workaround.nw_connection_send_with_default_context
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.cinterop.ExperimentalForeignApi
|
||||
import kotlinx.cinterop.Pinned
|
||||
import kotlinx.cinterop.addressOf
|
||||
import kotlinx.cinterop.convert
|
||||
import kotlinx.cinterop.usePinned
|
||||
import okio.Closeable
|
||||
import okio.IOException
|
||||
import platform.Network.nw_connection_cancel
|
||||
import platform.Network.nw_connection_create
|
||||
import platform.Network.nw_connection_receive
|
||||
import platform.Network.nw_connection_set_queue
|
||||
import platform.Network.nw_connection_set_state_changed_handler
|
||||
import platform.Network.nw_connection_start
|
||||
import platform.Network.nw_connection_state_cancelled
|
||||
import platform.Network.nw_connection_state_failed
|
||||
import platform.Network.nw_connection_state_invalid
|
||||
import platform.Network.nw_connection_state_preparing
|
||||
import platform.Network.nw_connection_state_ready
|
||||
import platform.Network.nw_connection_state_t
|
||||
import platform.Network.nw_connection_state_waiting
|
||||
import platform.Network.nw_connection_t
|
||||
import platform.Network.nw_endpoint_create_host
|
||||
import platform.Network.nw_error_domain_dns
|
||||
import platform.Network.nw_error_domain_invalid
|
||||
import platform.Network.nw_error_domain_posix
|
||||
import platform.Network.nw_error_domain_t
|
||||
import platform.Network.nw_error_domain_tls
|
||||
import platform.Network.nw_error_get_error_code
|
||||
import platform.Network.nw_error_get_error_domain
|
||||
import platform.Network.nw_error_t
|
||||
import platform.Network.nw_parameters_copy_default_protocol_stack
|
||||
import platform.Network.nw_parameters_create
|
||||
import platform.Network.nw_protocol_stack_prepend_application_protocol
|
||||
import platform.Network.nw_protocol_stack_set_transport_protocol
|
||||
import platform.Network.nw_tcp_create_options
|
||||
import platform.Network.nw_tcp_options_set_connection_timeout
|
||||
import platform.Network.nw_tcp_options_set_no_delay
|
||||
import platform.Network.nw_tls_create_options
|
||||
import platform.darwin.DISPATCH_TIME_FOREVER
|
||||
import platform.darwin.DISPATCH_TIME_NOW
|
||||
import platform.darwin.dispatch_data_apply
|
||||
import platform.darwin.dispatch_data_create
|
||||
import platform.darwin.dispatch_get_global_queue
|
||||
import platform.darwin.dispatch_semaphore_create
|
||||
import platform.darwin.dispatch_semaphore_signal
|
||||
import platform.darwin.dispatch_semaphore_t
|
||||
import platform.darwin.dispatch_semaphore_wait
|
||||
import platform.darwin.dispatch_time
|
||||
import platform.darwin.dispatch_time_t
|
||||
import platform.posix.QOS_CLASS_DEFAULT
|
||||
import platform.posix.intptr_t
|
||||
import platform.posix.memcpy
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
|
||||
@OptIn(ExperimentalForeignApi::class)
|
||||
class NwSocket(
|
||||
private val conn: nw_connection_t,
|
||||
private val readWriteTimeoutMillis: Long,
|
||||
) : Closeable {
|
||||
private val isConnected = atomic(true)
|
||||
private val lastError = atomic<nw_error_t>(null)
|
||||
|
||||
init {
|
||||
nw_connection_set_state_changed_handler(conn, this::handleStateChange)
|
||||
}
|
||||
|
||||
fun read(buffer: ByteArray, offset: Int = 0, count: Int = buffer.size): Int {
|
||||
require(offset >= 0)
|
||||
require(count >= 0)
|
||||
require(offset + count <= buffer.size)
|
||||
|
||||
check(isConnected.value) { "Socket not connected" }
|
||||
|
||||
buffer.usePinned { pinned ->
|
||||
var totalRead = 0
|
||||
while (totalRead < count) {
|
||||
val numRead = readOneChunk(pinned, offset + totalRead, count - totalRead)
|
||||
|
||||
if (numRead == 0) {
|
||||
break
|
||||
}
|
||||
|
||||
totalRead += numRead
|
||||
}
|
||||
|
||||
return totalRead
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun readOneChunk(pinned: Pinned<ByteArray>, offset: Int, count: Int): Int {
|
||||
val sem = dispatch_semaphore_create(0)
|
||||
var networkError: nw_error_t = null
|
||||
var numRead = 0
|
||||
|
||||
nw_connection_receive(
|
||||
connection = conn,
|
||||
minimum_incomplete_length = 0.convert(),
|
||||
maximum_length = count.convert()
|
||||
) { contents, _, _, error ->
|
||||
dispatch_data_apply(contents) { _, _, dataPtr, size ->
|
||||
memcpy(pinned.addressOf(offset + numRead), dataPtr, size)
|
||||
numRead += size.toInt()
|
||||
true // keep going
|
||||
}
|
||||
|
||||
networkError = error
|
||||
|
||||
dispatch_semaphore_signal(sem)
|
||||
}
|
||||
|
||||
if (!sem.waitWithTimeout(readWriteTimeoutMillis)) {
|
||||
throw IOException("Timed out waiting for read")
|
||||
}
|
||||
|
||||
networkError?.throwError()
|
||||
|
||||
return numRead
|
||||
}
|
||||
|
||||
fun write(buffer: ByteArray, offset: Int = 0, count: Int = buffer.size) {
|
||||
require(offset >= 0)
|
||||
require(count >= 0)
|
||||
require(offset + count <= buffer.size)
|
||||
|
||||
check(isConnected.value) { "Socket not connected" }
|
||||
|
||||
buffer.usePinned { pinned ->
|
||||
val sem = dispatch_semaphore_create(0)
|
||||
val toWrite = dispatch_data_create(
|
||||
buffer = pinned.addressOf(offset),
|
||||
size = count.convert(),
|
||||
queue = dispatch_get_target_default_queue(), // Our own method, see KT62102Workaround
|
||||
destructor = ::noopDispatchBlock
|
||||
)
|
||||
|
||||
var err: nw_error_t = null
|
||||
nw_connection_send_with_default_context(
|
||||
connection = conn,
|
||||
content = toWrite,
|
||||
is_complete = false
|
||||
) { networkError ->
|
||||
err = networkError
|
||||
dispatch_semaphore_signal(sem)
|
||||
}
|
||||
|
||||
if (!sem.waitWithTimeout(readWriteTimeoutMillis)) {
|
||||
throw IOException("Timed out waiting for write")
|
||||
}
|
||||
|
||||
if (err != null) {
|
||||
err.throwError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun flush() {
|
||||
// no-op?
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
nw_connection_cancel(conn)
|
||||
}
|
||||
|
||||
private fun handleStateChange(state: nw_connection_state_t, networkError: nw_error_t) {
|
||||
// If there isn't a last-error value already, set this one.
|
||||
lastError.compareAndSet(null, networkError)
|
||||
|
||||
when (state) {
|
||||
nw_connection_state_invalid -> { }
|
||||
|
||||
nw_connection_state_waiting -> { }
|
||||
|
||||
nw_connection_state_preparing -> { }
|
||||
|
||||
nw_connection_state_ready -> {
|
||||
isConnected.value = true
|
||||
}
|
||||
|
||||
nw_connection_state_failed -> {
|
||||
isConnected.value = false
|
||||
nw_connection_set_state_changed_handler(conn, null)
|
||||
}
|
||||
|
||||
nw_connection_state_cancelled -> {
|
||||
isConnected.value = false
|
||||
nw_connection_set_state_changed_handler(conn, null)
|
||||
}
|
||||
|
||||
else -> {
|
||||
println("Unexpected nw_connection_state_t value: $state")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val INTPTR_ZERO = 0.convert<intptr_t>()
|
||||
|
||||
fun connect(
|
||||
host: String,
|
||||
port: Int,
|
||||
enableTls: Boolean,
|
||||
sendTimeoutMillis: Long = 0,
|
||||
connectTimeoutMillis: Long = 0,
|
||||
): NwSocket {
|
||||
// Network.framework, at the C level, is a little tedious to use.
|
||||
// Rather than a sockaddr_t and a socket descriptor, there are relatively
|
||||
// more "things". We've got to set up an endpoint, then connection parameters,
|
||||
// then TCP options, then TLS options, and finally a connection.
|
||||
// The remainder of the weirdness is ours, since we're using semaphores
|
||||
// to make this asynchronous API into a synchronous one.
|
||||
require(connectTimeoutMillis >= 0L) { "negative connect timeouts are not supported" }
|
||||
require(sendTimeoutMillis >= 0L) { "negative send timeouts are not supported" }
|
||||
|
||||
val endpoint = nw_endpoint_create_host(host, "$port") ?: error("Invalid host/port: $host:$port")
|
||||
|
||||
val parameters = nw_parameters_create()
|
||||
val stack = nw_parameters_copy_default_protocol_stack(parameters)
|
||||
|
||||
val tcpOptions = nw_tcp_create_options()
|
||||
if (connectTimeoutMillis != 0L) {
|
||||
nw_tcp_options_set_connection_timeout(
|
||||
tcpOptions,
|
||||
maxOf(1, connectTimeoutMillis / 1000).convert()
|
||||
)
|
||||
}
|
||||
nw_tcp_options_set_no_delay(tcpOptions, true)
|
||||
nw_protocol_stack_set_transport_protocol(stack, tcpOptions)
|
||||
|
||||
if (enableTls) {
|
||||
val tlsOptions = nw_tls_create_options()
|
||||
nw_protocol_stack_prepend_application_protocol(stack, tlsOptions)
|
||||
}
|
||||
|
||||
val connection = nw_connection_create(endpoint, parameters) ?: error("Unable to create connection")
|
||||
val globalQueue = dispatch_get_global_queue(QOS_CLASS_DEFAULT.convert(), 0.convert())
|
||||
nw_connection_set_queue(connection, globalQueue)
|
||||
|
||||
val sem = dispatch_semaphore_create(0)
|
||||
val didConnect = atomic(false)
|
||||
val connectionError = atomic<nw_error_t>(null)
|
||||
|
||||
nw_connection_set_state_changed_handler(connection) { state, error ->
|
||||
if (error != null) {
|
||||
connectionError.value = error
|
||||
}
|
||||
|
||||
if (state == nw_connection_state_ready) {
|
||||
didConnect.value = true
|
||||
}
|
||||
|
||||
if (state in setOf(nw_connection_state_ready, nw_connection_state_failed, nw_connection_state_cancelled)) {
|
||||
dispatch_semaphore_signal(sem)
|
||||
}
|
||||
}
|
||||
|
||||
nw_connection_start(connection)
|
||||
val finishedInTime = sem.waitWithTimeout(connectTimeoutMillis)
|
||||
|
||||
if (connectionError.value != null) {
|
||||
nw_connection_cancel(connection)
|
||||
connectionError.value.throwError("Error connecting to $host:$port")
|
||||
}
|
||||
|
||||
if (!finishedInTime) {
|
||||
nw_connection_cancel(connection)
|
||||
throw IOException("Timed out connecting to $host:$port")
|
||||
}
|
||||
|
||||
if (didConnect.value) {
|
||||
return NwSocket(connection, sendTimeoutMillis)
|
||||
}
|
||||
|
||||
throw IOException("Failed to connect, but got no error")
|
||||
}
|
||||
|
||||
/**
|
||||
* A function, usable as a [platform.darwin.dispatch_block_t], that does nothing.
|
||||
*
|
||||
* When used with [dispatch_data_create], this block causes the data
|
||||
* *not* to be copied. This is what we want, since we're using semaphores
|
||||
* to wait for write completion, and we can guarantee that our memory
|
||||
* outlives the dispatch_data_t that wraps it.
|
||||
*/
|
||||
private fun noopDispatchBlock() {}
|
||||
|
||||
/**
|
||||
* Returns true if the semaphore was signaled, false if it timed out.
|
||||
*/
|
||||
private fun dispatch_semaphore_t.waitWithTimeout(timeoutMillis: Long): Boolean {
|
||||
return dispatch_semaphore_wait(this, computeTimeout(timeoutMillis)) == INTPTR_ZERO
|
||||
}
|
||||
|
||||
private fun computeTimeout(timeoutMillis: Long): dispatch_time_t {
|
||||
return if (timeoutMillis == 0L) {
|
||||
DISPATCH_TIME_FOREVER
|
||||
} else {
|
||||
val nanos = timeoutMillis.milliseconds.inWholeNanoseconds
|
||||
dispatch_time(DISPATCH_TIME_NOW, nanos)
|
||||
}
|
||||
}
|
||||
|
||||
private fun nw_error_t.throwError(message: String? = null): Nothing {
|
||||
val domain = nw_error_get_error_domain(this)
|
||||
val code = nw_error_get_error_code(this)
|
||||
val errorBody = message ?: "Network error"
|
||||
throw IOException("$errorBody: $this (domain=${domain.name} code=$code)")
|
||||
}
|
||||
|
||||
private val nw_error_domain_t.name: String
|
||||
get() = when (this) {
|
||||
nw_error_domain_dns -> "dns"
|
||||
nw_error_domain_tls -> "tls"
|
||||
nw_error_domain_posix -> "posix"
|
||||
nw_error_domain_invalid -> "invalid"
|
||||
else -> "$this"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Thrifty
|
||||
*
|
||||
* Copyright (c) Microsoft Corporation
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
|
||||
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
*
|
||||
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package com.microsoft.thrifty.transport
|
||||
|
||||
import okio.IOException
|
||||
import platform.Network.nw_connection_t
|
||||
|
||||
actual class SocketTransport actual constructor(
|
||||
builder: Builder,
|
||||
) : Transport {
|
||||
private val host = builder.host
|
||||
private val port = builder.port
|
||||
private val connectTimeout: Long = builder.connectTimeout.toLong()
|
||||
private val readTimeout: Long = builder.readTimeout.toLong()
|
||||
private val tls = builder.useTransportSecurity
|
||||
|
||||
private var socket: NwSocket? = null
|
||||
|
||||
/**
|
||||
* Just here for testing
|
||||
*/
|
||||
internal constructor(connection: nw_connection_t) : this(Builder("", 0)) {
|
||||
socket = NwSocket(connection, 0L)
|
||||
}
|
||||
|
||||
actual class Builder actual constructor(
|
||||
val host: String,
|
||||
val port: Int,
|
||||
) {
|
||||
var connectTimeout: Int = 0
|
||||
var readTimeout: Int = 0
|
||||
var useTransportSecurity: Boolean = false
|
||||
|
||||
actual fun connectTimeout(connectTimeout: Int): Builder {
|
||||
this.connectTimeout = maxOf(connectTimeout, 0)
|
||||
return this
|
||||
}
|
||||
|
||||
actual fun readTimeout(readTimeout: Int): Builder {
|
||||
this.readTimeout = maxOf(readTimeout, 0)
|
||||
return this
|
||||
}
|
||||
|
||||
actual fun enableTls(enableTls: Boolean): Builder {
|
||||
this.useTransportSecurity = enableTls
|
||||
return this
|
||||
}
|
||||
|
||||
actual fun build(): SocketTransport {
|
||||
return SocketTransport(this)
|
||||
}
|
||||
}
|
||||
|
||||
override fun read(buffer: ByteArray, offset: Int, count: Int): Int {
|
||||
return socket!!.read(buffer, offset, count)
|
||||
}
|
||||
|
||||
override fun write(data: ByteArray) {
|
||||
write(data, 0, data.size)
|
||||
}
|
||||
|
||||
override fun write(buffer: ByteArray, offset: Int, count: Int) {
|
||||
require(offset >= 0)
|
||||
require(count >= 0)
|
||||
require(count <= buffer.size - offset)
|
||||
socket!!.write(buffer, offset, count)
|
||||
}
|
||||
|
||||
override fun flush() {
|
||||
// no-op?
|
||||
socket?.flush()
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
socket?.close()
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
actual fun connect() {
|
||||
socket = NwSocket.connect(
|
||||
host = host,
|
||||
port = port,
|
||||
enableTls = tls,
|
||||
sendTimeoutMillis = readTimeout,
|
||||
connectTimeoutMillis = connectTimeout
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Thrifty
|
||||
*
|
||||
* Copyright (c) Microsoft Corporation
|
||||
*
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the License);
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
|
||||
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
|
||||
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
*
|
||||
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package com.microsoft.thrifty.transport
|
||||
|
||||
import com.microsoft.thrifty.protocol.BinaryProtocol
|
||||
import com.microsoft.thrifty.protocol.Xtruct
|
||||
import io.kotest.matchers.shouldBe
|
||||
import kotlinx.cinterop.ExperimentalForeignApi
|
||||
import kotlinx.cinterop.convert
|
||||
import okio.use
|
||||
import platform.Network.nw_connection_set_queue
|
||||
import platform.Network.nw_connection_set_state_changed_handler
|
||||
import platform.Network.nw_connection_start
|
||||
import platform.Network.nw_connection_state_cancelled
|
||||
import platform.Network.nw_connection_state_failed
|
||||
import platform.Network.nw_connection_state_ready
|
||||
import platform.Network.nw_listener_cancel
|
||||
import platform.Network.nw_listener_create
|
||||
import platform.Network.nw_listener_get_port
|
||||
import platform.Network.nw_listener_set_new_connection_handler
|
||||
import platform.Network.nw_listener_set_queue
|
||||
import platform.Network.nw_listener_set_state_changed_handler
|
||||
import platform.Network.nw_listener_start
|
||||
import platform.Network.nw_listener_state_cancelled
|
||||
import platform.Network.nw_listener_state_failed
|
||||
import platform.Network.nw_listener_state_ready
|
||||
import platform.Network.nw_parameters_copy_default_protocol_stack
|
||||
import platform.Network.nw_parameters_create
|
||||
import platform.Network.nw_protocol_stack_set_transport_protocol
|
||||
import platform.Network.nw_tcp_create_options
|
||||
import platform.Network.nw_tcp_options_set_connection_timeout
|
||||
import platform.Network.nw_tcp_options_set_enable_keepalive
|
||||
import platform.darwin.DISPATCH_TIME_FOREVER
|
||||
import platform.darwin.dispatch_async
|
||||
import platform.darwin.dispatch_get_global_queue
|
||||
import platform.darwin.dispatch_queue_create
|
||||
import platform.darwin.dispatch_semaphore_create
|
||||
import platform.darwin.dispatch_semaphore_signal
|
||||
import platform.darwin.dispatch_semaphore_wait
|
||||
import platform.posix.QOS_CLASS_DEFAULT
|
||||
import kotlin.test.Test
|
||||
|
||||
@OptIn(ExperimentalForeignApi::class)
|
||||
class NwSocketTest {
|
||||
@Test
|
||||
fun canRoundTripStructs() {
|
||||
val xtruct = Xtruct.Builder()
|
||||
.bool_thing(true)
|
||||
.byte_thing(1)
|
||||
.i32_thing(2)
|
||||
.i64_thing(3)
|
||||
.double_thing(4.0)
|
||||
.string_thing("five")
|
||||
.build()
|
||||
|
||||
val globalQueue = dispatch_get_global_queue(QOS_CLASS_DEFAULT.convert(), 0.convert())
|
||||
|
||||
// For some reason, NW_PARAMETERS_DISABLE_PROTOCOL wasn't actually disabling TLS
|
||||
// on the listener; we'd see "handshake failed" errors. Who even knows.
|
||||
// Manually creating parameters, and not even touching TLS, seems to work.
|
||||
//val parameters = nw_parameters_create_secure_tcp(NW_PARAMETERS_DISABLE_PROTOCOL, NW_PARAMETERS_DEFAULT_CONFIGURATION)
|
||||
|
||||
val tcpOptions = nw_tcp_create_options()
|
||||
nw_tcp_options_set_enable_keepalive(tcpOptions, true)
|
||||
nw_tcp_options_set_connection_timeout(tcpOptions, 60.convert())
|
||||
|
||||
val parameters = nw_parameters_create()
|
||||
val stack = nw_parameters_copy_default_protocol_stack(parameters)
|
||||
nw_protocol_stack_set_transport_protocol(stack, tcpOptions)
|
||||
|
||||
val serverListener = nw_listener_create(parameters)
|
||||
nw_listener_set_queue(serverListener, globalQueue)
|
||||
nw_listener_set_new_connection_handler(serverListener) { connection ->
|
||||
nw_connection_set_state_changed_handler(connection) { state, err ->
|
||||
if (state == nw_connection_state_ready) {
|
||||
val transport = SocketTransport(connection)
|
||||
val protocol = BinaryProtocol(transport)
|
||||
xtruct.write(protocol)
|
||||
} else if (state in listOf(
|
||||
nw_connection_state_failed,
|
||||
nw_connection_state_cancelled
|
||||
)
|
||||
) {
|
||||
println("server: I AM NOT READY")
|
||||
}
|
||||
}
|
||||
|
||||
nw_connection_set_queue(connection, globalQueue)
|
||||
nw_connection_start(connection)
|
||||
}
|
||||
|
||||
val readySem = dispatch_semaphore_create(0)
|
||||
var ready = false
|
||||
nw_listener_set_state_changed_handler(serverListener) { state, err ->
|
||||
if (state == nw_listener_state_ready) {
|
||||
ready = true
|
||||
}
|
||||
|
||||
if (state in listOf(
|
||||
nw_listener_state_ready,
|
||||
nw_listener_state_failed,
|
||||
nw_listener_state_cancelled
|
||||
)
|
||||
) {
|
||||
dispatch_semaphore_signal(readySem)
|
||||
}
|
||||
}
|
||||
nw_listener_start(serverListener)
|
||||
dispatch_semaphore_wait(readySem, DISPATCH_TIME_FOREVER)
|
||||
|
||||
if (!ready) {
|
||||
nw_listener_cancel(serverListener)
|
||||
throw AssertionError("Failed to set up a listener")
|
||||
}
|
||||
|
||||
val clientSem = dispatch_semaphore_create(0)
|
||||
val clientQueue = dispatch_queue_create("client", null)
|
||||
var matched = false
|
||||
dispatch_async(clientQueue) {
|
||||
try {
|
||||
val port = nw_listener_get_port(serverListener)
|
||||
SocketTransport.Builder("127.0.0.1", port.toInt()).readTimeout(100).build()
|
||||
.use { transport ->
|
||||
transport.connect()
|
||||
val protocol = BinaryProtocol(transport)
|
||||
val readXtruct = Xtruct.ADAPTER.read(protocol)
|
||||
|
||||
if (readXtruct == xtruct) {
|
||||
// Assertion errors don't make it out of dispatch queues,
|
||||
// so we'll just set a flag and check it later.
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
nw_listener_cancel(serverListener)
|
||||
dispatch_semaphore_signal(clientSem)
|
||||
}
|
||||
}
|
||||
dispatch_semaphore_wait(clientSem, DISPATCH_TIME_FOREVER)
|
||||
|
||||
matched shouldBe true
|
||||
}
|
||||
}
|
|
@ -26,48 +26,63 @@ import java.io.OutputStream
|
|||
import java.net.InetSocketAddress
|
||||
import java.net.Socket
|
||||
import javax.net.SocketFactory
|
||||
import javax.net.ssl.SSLSocketFactory
|
||||
|
||||
class SocketTransport internal constructor(
|
||||
actual class SocketTransport actual constructor(
|
||||
builder: Builder
|
||||
) : Transport {
|
||||
private val host = builder.host
|
||||
private val port = builder.port
|
||||
private val readTimeout = builder.readTimeout
|
||||
private val connectTimeout = builder.connectTimeout
|
||||
private val socketFactory = builder.socketFactory ?: SocketFactory.getDefault()
|
||||
private val socketFactory = builder.socketFactory ?: builder.getDefaultSocketFactory()
|
||||
|
||||
private var socket: Socket? = null
|
||||
private var inputStream: InputStream? = null
|
||||
private var outputStream: OutputStream? = null
|
||||
|
||||
class Builder(host: String, port: Int) {
|
||||
actual class Builder actual constructor(host: String, port: Int) {
|
||||
internal val host: String
|
||||
internal val port: Int
|
||||
internal var readTimeout = 0
|
||||
internal var connectTimeout = 0
|
||||
internal var socketFactory: SocketFactory? = null
|
||||
internal var enableTls = false
|
||||
|
||||
fun readTimeout(readTimeout: Int): Builder {
|
||||
actual fun readTimeout(readTimeout: Int): Builder {
|
||||
require(readTimeout >= 0) { "readTimeout cannot be negative" }
|
||||
this.readTimeout = readTimeout
|
||||
return this
|
||||
}
|
||||
|
||||
fun connectTimeout(connectTimeout: Int): Builder {
|
||||
actual fun connectTimeout(connectTimeout: Int): Builder {
|
||||
require(connectTimeout >= 0) { "connectTimeout cannot be negative" }
|
||||
this.connectTimeout = connectTimeout
|
||||
return this
|
||||
}
|
||||
|
||||
actual fun enableTls(enableTls: Boolean): Builder {
|
||||
this.enableTls = enableTls
|
||||
return this
|
||||
}
|
||||
|
||||
fun socketFactory(socketFactory: SocketFactory?): Builder {
|
||||
this.socketFactory = requireNotNull(socketFactory) { "socketFactory" }
|
||||
return this
|
||||
}
|
||||
|
||||
fun build(): SocketTransport {
|
||||
actual fun build(): SocketTransport {
|
||||
return SocketTransport(this)
|
||||
}
|
||||
|
||||
fun getDefaultSocketFactory(): SocketFactory {
|
||||
return if (enableTls) {
|
||||
SSLSocketFactory.getDefault()
|
||||
} else {
|
||||
SocketFactory.getDefault()
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
require(host.isNotBlank()) { "host must not be null or empty" }
|
||||
require(port in 0..0xFFFF) { "Invalid port number: $port" }
|
||||
|
@ -98,7 +113,7 @@ class SocketTransport internal constructor(
|
|||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
fun connect() {
|
||||
actual fun connect() {
|
||||
if (socket == null) {
|
||||
socket = socketFactory.createSocket()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
language=Objective-C
|
||||
---
|
||||
#import <Network/connection.h>
|
||||
|
||||
void nw_connection_send_with_default_context(
|
||||
nw_connection_t connection,
|
||||
_Nullable dispatch_data_t content,
|
||||
bool is_complete,
|
||||
nw_connection_send_completion_t completion
|
||||
) {
|
||||
nw_connection_send(connection, content, NW_CONNECTION_DEFAULT_MESSAGE_CONTEXT, is_complete, completion);
|
||||
}
|
||||
|
||||
// Not related to KT-62102, but this is a good place to put it.
|
||||
//
|
||||
// As of kt 1.9.10, DISPATCH_DATA_DESTRUCTOR_DEFAULT is erroneously mapped
|
||||
// as a COpaquePointer, and not as a dispatch_block_t, rendering it unusable
|
||||
// with dispatch_data_create. This function works around that deficiency.
|
||||
|
||||
dispatch_block_t dispatch_data_default_destructor() {
|
||||
return DISPATCH_DATA_DESTRUCTOR_DEFAULT;
|
||||
}
|
||||
|
||||
dispatch_queue_t dispatch_get_target_default_queue() {
|
||||
return DISPATCH_TARGET_QUEUE_DEFAULT;
|
||||
}
|
Загрузка…
Ссылка в новой задаче