diff --git a/.github/workflows/pre-merge.yml b/.github/workflows/pre-merge.yml index 5604b67..391eac6 100644 --- a/.github/workflows/pre-merge.yml +++ b/.github/workflows/pre-merge.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: jvm-version: [17] - os: [ubuntu-latest, windows-latest] + os: [ubuntu-latest, macos-13] env: JDK_VERSION: ${{ matrix.jvm-version }} GRADLE_OPTS: -Dorg.gradle.daemon=false @@ -34,6 +34,7 @@ jobs: path: | ~/.gradle/caches/ ~/.gradle/wrapper/ + ~/.konan/ .build-cache/ key: cache-gradle-${{ matrix.os }}-${{ matrix.jvm-version }}-${{ hashFiles('settings.gradle') }}-${{ hashFiles('**/build.gradle') }} restore-keys: | @@ -56,6 +57,16 @@ jobs: shell: bash run: ./gradlew check codeCoverageReport --parallel --no-daemon + - name: Boot simulator + if: matrix.os == 'macos-13' + shell: bash + run: xcrun simctl boot 'iPhone 14 Pro Max' || true + + - name: Run simulator tests + if: matrix.os == 'macos-13' + shell: bash + run: ./gradlew -PiosDevice="iPhone 14 Pro Max" iosTest + - name: Upload coverage stats if: success() && matrix.os == 'ubuntu-latest' && matrix.jvm-version == '8' uses: codecov/codecov-action@v3 diff --git a/gradle.properties b/gradle.properties index 8a460c7..f622ac9 100644 --- a/gradle.properties +++ b/gradle.properties @@ -23,3 +23,5 @@ org.gradle.jvmargs=-XX:MaxMetaspaceSize=512m # Build cache is helpful org.gradle.caching=true + +kotlin.mpp.enableCInteropCommonization=true diff --git a/thrifty-runtime/build.gradle b/thrifty-runtime/build.gradle index 4e4c237..672e038 100644 --- a/thrifty-runtime/build.gradle +++ b/thrifty-runtime/build.gradle @@ -47,6 +47,10 @@ kotlin { baseName = "Thrifty" } } + + compilations.main.cinterops { + KT62102Workaround {} + } } iosX64 { @@ -55,6 +59,10 @@ kotlin { baseName = "Thrifty" } } + + compilations.main.cinterops { + KT62102Workaround {} + } } sourceSets { @@ -98,6 +106,12 @@ kotlin { iosTest { dependsOn commonTest + + dependencies { + implementation libs.kotlin.test.common + implementation libs.kotest.assertions.common + implementation libs.kotest.assertions.core + } } iosArm64Main { @@ -122,6 +136,21 @@ jvmTest { useJUnitPlatform() } +tasks.register("iosTest") { + def device = project.findProperty("iosDevice")?.toString() ?: "iPhone 15 Pro Max" + dependsOn 'linkDebugTestIosX64' + group = JavaBasePlugin.VERIFICATION_GROUP + description = "Runs tests for target 'ios' on an iOS simulator" + + doLast { + def binary = kotlin.targets.iosX64.binaries.getTest('DEBUG').outputFile + println("muh binary: ${binary.absolutePath}") + exec { + commandLine 'xcrun', 'simctl', 'spawn', device, binary.absolutePath + } + } +} + // What have I gotten myself in to configurations { jvmApiElements { diff --git a/thrifty-runtime/src/commonMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt b/thrifty-runtime/src/commonMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt new file mode 100644 index 0000000..c55103f --- /dev/null +++ b/thrifty-runtime/src/commonMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt @@ -0,0 +1,47 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.transport + +expect class SocketTransport internal constructor( + builder: Builder +) : Transport { + class Builder(host: String, port: Int) { + /** + * The number of milliseconds to wait for a connection to be established. + */ + fun connectTimeout(connectTimeout: Int): Builder + + /** + * The number of milliseconds a read operation should wait for completion. + */ + fun readTimeout(readTimeout: Int): Builder + + /** + * Enable TLS for this connection. + */ + fun enableTls(enableTls: Boolean): Builder + + fun build(): SocketTransport + } + + @Throws(okio.IOException::class) + fun connect() +} diff --git a/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/NwSocket.kt b/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/NwSocket.kt new file mode 100644 index 0000000..565c69f --- /dev/null +++ b/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/NwSocket.kt @@ -0,0 +1,344 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.transport + +import KT62102Workaround.dispatch_get_target_default_queue +import KT62102Workaround.nw_connection_send_with_default_context +import kotlinx.atomicfu.atomic +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.Pinned +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.convert +import kotlinx.cinterop.usePinned +import okio.Closeable +import okio.IOException +import platform.Network.nw_connection_cancel +import platform.Network.nw_connection_create +import platform.Network.nw_connection_receive +import platform.Network.nw_connection_set_queue +import platform.Network.nw_connection_set_state_changed_handler +import platform.Network.nw_connection_start +import platform.Network.nw_connection_state_cancelled +import platform.Network.nw_connection_state_failed +import platform.Network.nw_connection_state_invalid +import platform.Network.nw_connection_state_preparing +import platform.Network.nw_connection_state_ready +import platform.Network.nw_connection_state_t +import platform.Network.nw_connection_state_waiting +import platform.Network.nw_connection_t +import platform.Network.nw_endpoint_create_host +import platform.Network.nw_error_domain_dns +import platform.Network.nw_error_domain_invalid +import platform.Network.nw_error_domain_posix +import platform.Network.nw_error_domain_t +import platform.Network.nw_error_domain_tls +import platform.Network.nw_error_get_error_code +import platform.Network.nw_error_get_error_domain +import platform.Network.nw_error_t +import platform.Network.nw_parameters_copy_default_protocol_stack +import platform.Network.nw_parameters_create +import platform.Network.nw_protocol_stack_prepend_application_protocol +import platform.Network.nw_protocol_stack_set_transport_protocol +import platform.Network.nw_tcp_create_options +import platform.Network.nw_tcp_options_set_connection_timeout +import platform.Network.nw_tcp_options_set_no_delay +import platform.Network.nw_tls_create_options +import platform.darwin.DISPATCH_TIME_FOREVER +import platform.darwin.DISPATCH_TIME_NOW +import platform.darwin.dispatch_data_apply +import platform.darwin.dispatch_data_create +import platform.darwin.dispatch_get_global_queue +import platform.darwin.dispatch_semaphore_create +import platform.darwin.dispatch_semaphore_signal +import platform.darwin.dispatch_semaphore_t +import platform.darwin.dispatch_semaphore_wait +import platform.darwin.dispatch_time +import platform.darwin.dispatch_time_t +import platform.posix.QOS_CLASS_DEFAULT +import platform.posix.intptr_t +import platform.posix.memcpy +import kotlin.time.Duration.Companion.milliseconds + +@OptIn(ExperimentalForeignApi::class) +class NwSocket( + private val conn: nw_connection_t, + private val readWriteTimeoutMillis: Long, +) : Closeable { + private val isConnected = atomic(true) + private val lastError = atomic(null) + + init { + nw_connection_set_state_changed_handler(conn, this::handleStateChange) + } + + fun read(buffer: ByteArray, offset: Int = 0, count: Int = buffer.size): Int { + require(offset >= 0) + require(count >= 0) + require(offset + count <= buffer.size) + + check(isConnected.value) { "Socket not connected" } + + buffer.usePinned { pinned -> + var totalRead = 0 + while (totalRead < count) { + val numRead = readOneChunk(pinned, offset + totalRead, count - totalRead) + + if (numRead == 0) { + break + } + + totalRead += numRead + } + + return totalRead + } + } + + @Throws(IOException::class) + private fun readOneChunk(pinned: Pinned, offset: Int, count: Int): Int { + val sem = dispatch_semaphore_create(0) + var networkError: nw_error_t = null + var numRead = 0 + + nw_connection_receive( + connection = conn, + minimum_incomplete_length = 0.convert(), + maximum_length = count.convert() + ) { contents, _, _, error -> + dispatch_data_apply(contents) { _, _, dataPtr, size -> + memcpy(pinned.addressOf(offset + numRead), dataPtr, size) + numRead += size.toInt() + true // keep going + } + + networkError = error + + dispatch_semaphore_signal(sem) + } + + if (!sem.waitWithTimeout(readWriteTimeoutMillis)) { + throw IOException("Timed out waiting for read") + } + + networkError?.throwError() + + return numRead + } + + fun write(buffer: ByteArray, offset: Int = 0, count: Int = buffer.size) { + require(offset >= 0) + require(count >= 0) + require(offset + count <= buffer.size) + + check(isConnected.value) { "Socket not connected" } + + buffer.usePinned { pinned -> + val sem = dispatch_semaphore_create(0) + val toWrite = dispatch_data_create( + buffer = pinned.addressOf(offset), + size = count.convert(), + queue = dispatch_get_target_default_queue(), // Our own method, see KT62102Workaround + destructor = ::noopDispatchBlock + ) + + var err: nw_error_t = null + nw_connection_send_with_default_context( + connection = conn, + content = toWrite, + is_complete = false + ) { networkError -> + err = networkError + dispatch_semaphore_signal(sem) + } + + if (!sem.waitWithTimeout(readWriteTimeoutMillis)) { + throw IOException("Timed out waiting for write") + } + + if (err != null) { + err.throwError() + } + } + } + + fun flush() { + // no-op? + } + + override fun close() { + nw_connection_cancel(conn) + } + + private fun handleStateChange(state: nw_connection_state_t, networkError: nw_error_t) { + // If there isn't a last-error value already, set this one. + lastError.compareAndSet(null, networkError) + + when (state) { + nw_connection_state_invalid -> { } + + nw_connection_state_waiting -> { } + + nw_connection_state_preparing -> { } + + nw_connection_state_ready -> { + isConnected.value = true + } + + nw_connection_state_failed -> { + isConnected.value = false + nw_connection_set_state_changed_handler(conn, null) + } + + nw_connection_state_cancelled -> { + isConnected.value = false + nw_connection_set_state_changed_handler(conn, null) + } + + else -> { + println("Unexpected nw_connection_state_t value: $state") + } + } + } + + companion object { + private val INTPTR_ZERO = 0.convert() + + fun connect( + host: String, + port: Int, + enableTls: Boolean, + sendTimeoutMillis: Long = 0, + connectTimeoutMillis: Long = 0, + ): NwSocket { + // Network.framework, at the C level, is a little tedious to use. + // Rather than a sockaddr_t and a socket descriptor, there are relatively + // more "things". We've got to set up an endpoint, then connection parameters, + // then TCP options, then TLS options, and finally a connection. + // The remainder of the weirdness is ours, since we're using semaphores + // to make this asynchronous API into a synchronous one. + require(connectTimeoutMillis >= 0L) { "negative connect timeouts are not supported" } + require(sendTimeoutMillis >= 0L) { "negative send timeouts are not supported" } + + val endpoint = nw_endpoint_create_host(host, "$port") ?: error("Invalid host/port: $host:$port") + + val parameters = nw_parameters_create() + val stack = nw_parameters_copy_default_protocol_stack(parameters) + + val tcpOptions = nw_tcp_create_options() + if (connectTimeoutMillis != 0L) { + nw_tcp_options_set_connection_timeout( + tcpOptions, + maxOf(1, connectTimeoutMillis / 1000).convert() + ) + } + nw_tcp_options_set_no_delay(tcpOptions, true) + nw_protocol_stack_set_transport_protocol(stack, tcpOptions) + + if (enableTls) { + val tlsOptions = nw_tls_create_options() + nw_protocol_stack_prepend_application_protocol(stack, tlsOptions) + } + + val connection = nw_connection_create(endpoint, parameters) ?: error("Unable to create connection") + val globalQueue = dispatch_get_global_queue(QOS_CLASS_DEFAULT.convert(), 0.convert()) + nw_connection_set_queue(connection, globalQueue) + + val sem = dispatch_semaphore_create(0) + val didConnect = atomic(false) + val connectionError = atomic(null) + + nw_connection_set_state_changed_handler(connection) { state, error -> + if (error != null) { + connectionError.value = error + } + + if (state == nw_connection_state_ready) { + didConnect.value = true + } + + if (state in setOf(nw_connection_state_ready, nw_connection_state_failed, nw_connection_state_cancelled)) { + dispatch_semaphore_signal(sem) + } + } + + nw_connection_start(connection) + val finishedInTime = sem.waitWithTimeout(connectTimeoutMillis) + + if (connectionError.value != null) { + nw_connection_cancel(connection) + connectionError.value.throwError("Error connecting to $host:$port") + } + + if (!finishedInTime) { + nw_connection_cancel(connection) + throw IOException("Timed out connecting to $host:$port") + } + + if (didConnect.value) { + return NwSocket(connection, sendTimeoutMillis) + } + + throw IOException("Failed to connect, but got no error") + } + + /** + * A function, usable as a [platform.darwin.dispatch_block_t], that does nothing. + * + * When used with [dispatch_data_create], this block causes the data + * *not* to be copied. This is what we want, since we're using semaphores + * to wait for write completion, and we can guarantee that our memory + * outlives the dispatch_data_t that wraps it. + */ + private fun noopDispatchBlock() {} + + /** + * Returns true if the semaphore was signaled, false if it timed out. + */ + private fun dispatch_semaphore_t.waitWithTimeout(timeoutMillis: Long): Boolean { + return dispatch_semaphore_wait(this, computeTimeout(timeoutMillis)) == INTPTR_ZERO + } + + private fun computeTimeout(timeoutMillis: Long): dispatch_time_t { + return if (timeoutMillis == 0L) { + DISPATCH_TIME_FOREVER + } else { + val nanos = timeoutMillis.milliseconds.inWholeNanoseconds + dispatch_time(DISPATCH_TIME_NOW, nanos) + } + } + + private fun nw_error_t.throwError(message: String? = null): Nothing { + val domain = nw_error_get_error_domain(this) + val code = nw_error_get_error_code(this) + val errorBody = message ?: "Network error" + throw IOException("$errorBody: $this (domain=${domain.name} code=$code)") + } + + private val nw_error_domain_t.name: String + get() = when (this) { + nw_error_domain_dns -> "dns" + nw_error_domain_tls -> "tls" + nw_error_domain_posix -> "posix" + nw_error_domain_invalid -> "invalid" + else -> "$this" + } + } +} diff --git a/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt b/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt new file mode 100644 index 0000000..63305d3 --- /dev/null +++ b/thrifty-runtime/src/iosMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt @@ -0,0 +1,106 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.transport + +import okio.IOException +import platform.Network.nw_connection_t + +actual class SocketTransport actual constructor( + builder: Builder, +) : Transport { + private val host = builder.host + private val port = builder.port + private val connectTimeout: Long = builder.connectTimeout.toLong() + private val readTimeout: Long = builder.readTimeout.toLong() + private val tls = builder.useTransportSecurity + + private var socket: NwSocket? = null + + /** + * Just here for testing + */ + internal constructor(connection: nw_connection_t) : this(Builder("", 0)) { + socket = NwSocket(connection, 0L) + } + + actual class Builder actual constructor( + val host: String, + val port: Int, + ) { + var connectTimeout: Int = 0 + var readTimeout: Int = 0 + var useTransportSecurity: Boolean = false + + actual fun connectTimeout(connectTimeout: Int): Builder { + this.connectTimeout = maxOf(connectTimeout, 0) + return this + } + + actual fun readTimeout(readTimeout: Int): Builder { + this.readTimeout = maxOf(readTimeout, 0) + return this + } + + actual fun enableTls(enableTls: Boolean): Builder { + this.useTransportSecurity = enableTls + return this + } + + actual fun build(): SocketTransport { + return SocketTransport(this) + } + } + + override fun read(buffer: ByteArray, offset: Int, count: Int): Int { + return socket!!.read(buffer, offset, count) + } + + override fun write(data: ByteArray) { + write(data, 0, data.size) + } + + override fun write(buffer: ByteArray, offset: Int, count: Int) { + require(offset >= 0) + require(count >= 0) + require(count <= buffer.size - offset) + socket!!.write(buffer, offset, count) + } + + override fun flush() { + // no-op? + socket?.flush() + } + + override fun close() { + socket?.close() + } + + @Throws(IOException::class) + actual fun connect() { + socket = NwSocket.connect( + host = host, + port = port, + enableTls = tls, + sendTimeoutMillis = readTimeout, + connectTimeoutMillis = connectTimeout + ) + } +} diff --git a/thrifty-runtime/src/iosTest/kotlin/com/microsoft/thrifty/transport/NwSocketTest.kt b/thrifty-runtime/src/iosTest/kotlin/com/microsoft/thrifty/transport/NwSocketTest.kt new file mode 100644 index 0000000..ef35053 --- /dev/null +++ b/thrifty-runtime/src/iosTest/kotlin/com/microsoft/thrifty/transport/NwSocketTest.kt @@ -0,0 +1,161 @@ +/* + * Thrifty + * + * Copyright (c) Microsoft Corporation + * + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE, + * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT. + * + * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License. + */ +package com.microsoft.thrifty.transport + +import com.microsoft.thrifty.protocol.BinaryProtocol +import com.microsoft.thrifty.protocol.Xtruct +import io.kotest.matchers.shouldBe +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.convert +import okio.use +import platform.Network.nw_connection_set_queue +import platform.Network.nw_connection_set_state_changed_handler +import platform.Network.nw_connection_start +import platform.Network.nw_connection_state_cancelled +import platform.Network.nw_connection_state_failed +import platform.Network.nw_connection_state_ready +import platform.Network.nw_listener_cancel +import platform.Network.nw_listener_create +import platform.Network.nw_listener_get_port +import platform.Network.nw_listener_set_new_connection_handler +import platform.Network.nw_listener_set_queue +import platform.Network.nw_listener_set_state_changed_handler +import platform.Network.nw_listener_start +import platform.Network.nw_listener_state_cancelled +import platform.Network.nw_listener_state_failed +import platform.Network.nw_listener_state_ready +import platform.Network.nw_parameters_copy_default_protocol_stack +import platform.Network.nw_parameters_create +import platform.Network.nw_protocol_stack_set_transport_protocol +import platform.Network.nw_tcp_create_options +import platform.Network.nw_tcp_options_set_connection_timeout +import platform.Network.nw_tcp_options_set_enable_keepalive +import platform.darwin.DISPATCH_TIME_FOREVER +import platform.darwin.dispatch_async +import platform.darwin.dispatch_get_global_queue +import platform.darwin.dispatch_queue_create +import platform.darwin.dispatch_semaphore_create +import platform.darwin.dispatch_semaphore_signal +import platform.darwin.dispatch_semaphore_wait +import platform.posix.QOS_CLASS_DEFAULT +import kotlin.test.Test + +@OptIn(ExperimentalForeignApi::class) +class NwSocketTest { + @Test + fun canRoundTripStructs() { + val xtruct = Xtruct.Builder() + .bool_thing(true) + .byte_thing(1) + .i32_thing(2) + .i64_thing(3) + .double_thing(4.0) + .string_thing("five") + .build() + + val globalQueue = dispatch_get_global_queue(QOS_CLASS_DEFAULT.convert(), 0.convert()) + + // For some reason, NW_PARAMETERS_DISABLE_PROTOCOL wasn't actually disabling TLS + // on the listener; we'd see "handshake failed" errors. Who even knows. + // Manually creating parameters, and not even touching TLS, seems to work. + //val parameters = nw_parameters_create_secure_tcp(NW_PARAMETERS_DISABLE_PROTOCOL, NW_PARAMETERS_DEFAULT_CONFIGURATION) + + val tcpOptions = nw_tcp_create_options() + nw_tcp_options_set_enable_keepalive(tcpOptions, true) + nw_tcp_options_set_connection_timeout(tcpOptions, 60.convert()) + + val parameters = nw_parameters_create() + val stack = nw_parameters_copy_default_protocol_stack(parameters) + nw_protocol_stack_set_transport_protocol(stack, tcpOptions) + + val serverListener = nw_listener_create(parameters) + nw_listener_set_queue(serverListener, globalQueue) + nw_listener_set_new_connection_handler(serverListener) { connection -> + nw_connection_set_state_changed_handler(connection) { state, err -> + if (state == nw_connection_state_ready) { + val transport = SocketTransport(connection) + val protocol = BinaryProtocol(transport) + xtruct.write(protocol) + } else if (state in listOf( + nw_connection_state_failed, + nw_connection_state_cancelled + ) + ) { + println("server: I AM NOT READY") + } + } + + nw_connection_set_queue(connection, globalQueue) + nw_connection_start(connection) + } + + val readySem = dispatch_semaphore_create(0) + var ready = false + nw_listener_set_state_changed_handler(serverListener) { state, err -> + if (state == nw_listener_state_ready) { + ready = true + } + + if (state in listOf( + nw_listener_state_ready, + nw_listener_state_failed, + nw_listener_state_cancelled + ) + ) { + dispatch_semaphore_signal(readySem) + } + } + nw_listener_start(serverListener) + dispatch_semaphore_wait(readySem, DISPATCH_TIME_FOREVER) + + if (!ready) { + nw_listener_cancel(serverListener) + throw AssertionError("Failed to set up a listener") + } + + val clientSem = dispatch_semaphore_create(0) + val clientQueue = dispatch_queue_create("client", null) + var matched = false + dispatch_async(clientQueue) { + try { + val port = nw_listener_get_port(serverListener) + SocketTransport.Builder("127.0.0.1", port.toInt()).readTimeout(100).build() + .use { transport -> + transport.connect() + val protocol = BinaryProtocol(transport) + val readXtruct = Xtruct.ADAPTER.read(protocol) + + if (readXtruct == xtruct) { + // Assertion errors don't make it out of dispatch queues, + // so we'll just set a flag and check it later. + matched = true + } + } + } finally { + nw_listener_cancel(serverListener) + dispatch_semaphore_signal(clientSem) + } + } + dispatch_semaphore_wait(clientSem, DISPATCH_TIME_FOREVER) + + matched shouldBe true + } +} diff --git a/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt b/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt index 6399c5a..df1960f 100644 --- a/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt +++ b/thrifty-runtime/src/jvmMain/kotlin/com/microsoft/thrifty/transport/SocketTransport.kt @@ -26,48 +26,63 @@ import java.io.OutputStream import java.net.InetSocketAddress import java.net.Socket import javax.net.SocketFactory +import javax.net.ssl.SSLSocketFactory -class SocketTransport internal constructor( +actual class SocketTransport actual constructor( builder: Builder ) : Transport { private val host = builder.host private val port = builder.port private val readTimeout = builder.readTimeout private val connectTimeout = builder.connectTimeout - private val socketFactory = builder.socketFactory ?: SocketFactory.getDefault() + private val socketFactory = builder.socketFactory ?: builder.getDefaultSocketFactory() private var socket: Socket? = null private var inputStream: InputStream? = null private var outputStream: OutputStream? = null - class Builder(host: String, port: Int) { + actual class Builder actual constructor(host: String, port: Int) { internal val host: String internal val port: Int internal var readTimeout = 0 internal var connectTimeout = 0 internal var socketFactory: SocketFactory? = null + internal var enableTls = false - fun readTimeout(readTimeout: Int): Builder { + actual fun readTimeout(readTimeout: Int): Builder { require(readTimeout >= 0) { "readTimeout cannot be negative" } this.readTimeout = readTimeout return this } - fun connectTimeout(connectTimeout: Int): Builder { + actual fun connectTimeout(connectTimeout: Int): Builder { require(connectTimeout >= 0) { "connectTimeout cannot be negative" } this.connectTimeout = connectTimeout return this } + actual fun enableTls(enableTls: Boolean): Builder { + this.enableTls = enableTls + return this + } + fun socketFactory(socketFactory: SocketFactory?): Builder { this.socketFactory = requireNotNull(socketFactory) { "socketFactory" } return this } - fun build(): SocketTransport { + actual fun build(): SocketTransport { return SocketTransport(this) } + fun getDefaultSocketFactory(): SocketFactory { + return if (enableTls) { + SSLSocketFactory.getDefault() + } else { + SocketFactory.getDefault() + } + } + init { require(host.isNotBlank()) { "host must not be null or empty" } require(port in 0..0xFFFF) { "Invalid port number: $port" } @@ -98,7 +113,7 @@ class SocketTransport internal constructor( } @Throws(IOException::class) - fun connect() { + actual fun connect() { if (socket == null) { socket = socketFactory.createSocket() } diff --git a/thrifty-runtime/src/nativeInterop/cinterop/KT62102Workaround.def b/thrifty-runtime/src/nativeInterop/cinterop/KT62102Workaround.def new file mode 100644 index 0000000..d1451a9 --- /dev/null +++ b/thrifty-runtime/src/nativeInterop/cinterop/KT62102Workaround.def @@ -0,0 +1,26 @@ +language=Objective-C +--- +#import + +void nw_connection_send_with_default_context( + nw_connection_t connection, + _Nullable dispatch_data_t content, + bool is_complete, + nw_connection_send_completion_t completion +) { + nw_connection_send(connection, content, NW_CONNECTION_DEFAULT_MESSAGE_CONTEXT, is_complete, completion); +} + +// Not related to KT-62102, but this is a good place to put it. +// +// As of kt 1.9.10, DISPATCH_DATA_DESTRUCTOR_DEFAULT is erroneously mapped +// as a COpaquePointer, and not as a dispatch_block_t, rendering it unusable +// with dispatch_data_create. This function works around that deficiency. + +dispatch_block_t dispatch_data_default_destructor() { + return DISPATCH_DATA_DESTRUCTOR_DEFAULT; +} + +dispatch_queue_t dispatch_get_target_default_queue() { + return DISPATCH_TARGET_QUEUE_DEFAULT; +}