Add HttpTransport, async clients to iOS (#543)

This PR makes the newly-added HttpTransport into expect/actual types, with a new implementation for iOS. This in turn required quite a bit of work to set up testing, and uncovered what seems to be a kotlinc type-inference bug that I've worked around here by changing the kotlin code generator a little bit.
This commit is contained in:
Ben Bader 2023-11-01 10:06:46 -06:00 коммит произвёл GitHub
Родитель 188faba9e2
Коммит ee13d65420
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 943 добавлений и 29 удалений

2
.github/workflows/pre-merge.yml поставляемый
Просмотреть файл

@ -55,7 +55,7 @@ jobs:
- name: Build and test
shell: bash
run: ./gradlew check codeCoverageReport --parallel --no-daemon
run: ./gradlew check codeCoverageReport --parallel --no-daemon -x iosX64Test
- name: Boot simulator
if: matrix.os == 'macos-13'

Просмотреть файл

@ -1,4 +1,5 @@
[versions]
dokka = "1.7.20"
junit = "5.10.0"
kotest = "5.5.4"
kotlin = "1.9.10"
@ -8,7 +9,7 @@ okio = "3.3.0"
antlr = "org.antlr:antlr4:4.9.3"
apacheThrift = "org.apache.thrift:libthrift:0.19.0"
clikt = "com.github.ajalt.clikt:clikt:3.1.0"
dokka = "org.jetbrains.dokka:dokka-gradle-plugin:1.7.20"
dokka = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref = "dokka" }
guava = "com.google.guava:guava:31.1-jre"
javaPoet = "com.squareup:javapoet:1.13.0"
kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin" }
@ -37,7 +38,7 @@ kotlin = [ "kotlin-stdlib", "kotlin-reflect" ]
testing = ["junit", "hamcrest", "kotest-assertions-core", "kotest-assertions-coreJvm"]
[plugins]
dokka = "org.jetbrains.dokka:1.7.20"
dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" }
gradlePluginPublish = "com.gradle.plugin-publish:1.2.1"
kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" }
kotlin-mpp = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }

Просмотреть файл

@ -946,6 +946,7 @@ class KotlinCodeGenerator(
}
val structTypeName = ClassName(struct.kotlinNamespace, struct.name)
val resultVarName = nameAllocator.newName("resultVar", "resultVar")
val spec = TypeSpec.classBuilder("Builder")
.addSuperinterface(StructBuilder::class.asTypeName().parameterizedBy(structTypeName))
.addProperty(PropertySpec.builder(builderVarName, structTypeName.copy(nullable = true), KModifier.PRIVATE)
@ -961,11 +962,20 @@ class KotlinCodeGenerator(
.addFunction(FunSpec.builder("build")
.addModifiers(KModifier.OVERRIDE)
.returns(structTypeName)
.addStatement(
"return %N ?: %M(%S)",
builderVarName,
MemberName("kotlin", "error"),
"Invalid union; at least one value is required")
// We're doing some convoluted stuff to work around an apparent kotlinc bug
// where type inference and/or smart-casting fails. iOS tests got a compiler
// error about the builder var being of type "UnionType.Builder" which clearly
// it isn't. This reformulation seems to work.
.addStatement("val %N = %N", resultVarName, builderVarName)
.beginControlFlow("if (%N == null)", resultVarName)
.addStatement("%M(%S)", MemberName("kotlin", "error"), "Invalid union; at least one value is required")
.endControlFlow()
.addStatement("return %N!!", resultVarName)
// .addStatement(
// "return %N ?: %M(%S)",
// builderVarName,
// MemberName("kotlin", "error"),
// "Invalid union; at least one value is required")
.build())
.addFunction(FunSpec.builder("reset")
.addModifiers(KModifier.OVERRIDE)

Просмотреть файл

@ -616,8 +616,13 @@ class KotlinCodeGeneratorTest {
| this.value_ = source
| }
|
| public override fun build(): Union = value_ ?:
| public override fun build(): Union {
| val resultVar = value_
| if (resultVar == null) {
| error("Invalid union; at least one value is required")
| }
| return resultVar!!
| }
|
| public override fun reset(): Unit {
| value_ = null
@ -1488,6 +1493,27 @@ class KotlinCodeGeneratorTest {
files.shouldCompile()
}
@Test
fun `union with builder should compile`() {
val thrift = """
|namespace kt test.union
|
|union Union {
| 1: i32 result;
| 2: i64 bigResult;
| 3: string error;
|}
""".trimMargin()
val files = generate(thrift) {
withDataClassBuilders()
coroutineServiceClients()
}
files.shouldCompile()
println(files)
}
private fun generate(thrift: String, config: (KotlinCodeGenerator.() -> KotlinCodeGenerator)? = null): List<FileSpec> {
val configOrDefault = config ?: { emitFileComment(false) }
return KotlinCodeGenerator()

Просмотреть файл

@ -0,0 +1,340 @@
/*
* TestThrift.thrift, modified to use a separate package name.
*
* Any changes to this file *MUST* be mirrored in thrifty-test-server's copy!
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
* Contains some contributions under the Thrift Software License.
* Please see doc/old-thrift-license.txt in the Thrift distribution for
* details.
*/
namespace kt com.microsoft.thrifty.runtime.kgen.coro
/**
* Docstring!
*/
enum Numberz
{
ONE = 1,
TWO,
THREE,
FIVE = 5,
SIX,
EIGHT = 8
}
const double ActualDouble = 42
const Numberz myNumberz = Numberz.ONE;
typedef i64 UserId
struct Bonk
{
1: string message,
2: i32 type
}
struct Xtruct
{
1: string string_thing,
4: byte byte_thing,
9: i32 i32_thing,
11: i64 i64_thing,
13: double double_thing,
15: bool bool_thing
}
struct Xtruct2
{
1: byte byte_thing, // used to be byte, hence the name
2: Xtruct struct_thing,
3: i32 i32_thing
}
struct Insanity
{
1: map<Numberz, UserId> userMap,
2: list<Xtruct> xtructs
}
exception Xception {
1: i32 errorCode,
2: string message
}
exception Xception2 {
1: i32 errorCode,
2: Xtruct struct_thing
}
union TheEmptyUnion {}
union NonEmptyUnion {
1: i32 AnInt;
2: i64 ALong;
3: string AString;
4: Bonk ABonk;
}
struct HasUnion {
1: required NonEmptyUnion TheUnion;
}
union UnionWithDefault {
1: string Text;
2: i32 Int;
3: double Real = 3.14
}
union UnionWithRedactions {
1: string text;
2: string obfuscated_text (obfuscated = "true");
3: string redacted_text (redacted = "true");
4: list<i32> nums;
5: list<i32> obfuscated_nums (obfuscated = "true");
6: list<i32> redacted_nums (redacted = "true");
7: set<double> dubs;
8: set<double> obfuscated_dubs (obfuscated = "true");
9: set<double> redacted_dubs (redacted = "true");
10: map<i8, i8> bytes;
11: map<i8, i8> obfuscated_bytes (obfuscated = "true");
12: map<i8, i8> redacted_bytes (redacted = "true");
}
service ThriftTest
{
/**
* Prints "testVoid()" and returns nothing.
*/
void testVoid(),
/**
* Prints 'testString("%s")' with thing as '%s'
* @param string thing - the string to print
* @return string - returns the string 'thing'
*/
string testString(1: string thing),
/**
* Prints 'testBool("%s")' where '%s' with thing as 'true' or 'false'
* @param bool thing - the bool data to print
* @return bool - returns the bool 'thing'
*/
bool testBool(1: bool thing),
/**
* Prints 'testByte("%d")' with thing as '%d'
* The types i8 and byte are synonyms, use of i8 is encouraged, byte still exists for the sake of compatibility.
* @param byte thing - the i8/byte to print
* @return i8 - returns the i8/byte 'thing'
*/
byte testByte(1: byte thing),
/**
* Prints 'testI32("%d")' with thing as '%d'
* @param i32 thing - the i32 to print
* @return i32 - returns the i32 'thing'
*/
i32 testI32(1: i32 thing),
/**
* Prints 'testI64("%d")' with thing as '%d'
* @param i64 thing - the i64 to print
* @return i64 - returns the i64 'thing'
*/
i64 testI64(1: i64 thing),
/**
* Prints 'testDouble("%f")' with thing as '%f'
* @param double thing - the double to print
* @return double - returns the double 'thing'
*/
double testDouble(1: double thing),
/**
* Prints 'testBinary("%s")' where '%s' is a hex-formatted string of thing's data
* @param binary thing - the binary data to print
* @return binary - returns the binary 'thing'
*/
binary testBinary(1: binary thing),
/**
* Prints 'testStruct("{%s}")' where thing has been formatted into a string of comma separated values
* @param Xtruct thing - the Xtruct to print
* @return Xtruct - returns the Xtruct 'thing'
*/
Xtruct testStruct(1: Xtruct thing),
/**
* Prints 'testNest("{%s}")' where thing has been formatted into a string of the nested struct
* @param Xtruct2 thing - the Xtruct2 to print
* @return Xtruct2 - returns the Xtruct2 'thing'
*/
Xtruct2 testNest(1: Xtruct2 thing),
/**
* Prints 'testMap("{%s")' where thing has been formatted into a string of 'key => value' pairs
* separated by commas and new lines
* @param map<i32,i32> thing - the map<i32,i32> to print
* @return map<i32,i32> - returns the map<i32,i32> 'thing'
*/
map<i32,i32> testMap(1: map<i32,i32> thing),
/**
* Prints 'testStringMap("{%s}")' where thing has been formatted into a string of 'key => value' pairs
* separated by commas and new lines
* @param map<string,string> thing - the map<string,string> to print
* @return map<string,string> - returns the map<string,string> 'thing'
*/
map<string,string> testStringMap(1: map<string,string> thing),
/**
* Prints 'testSet("{%s}")' where thing has been formatted into a string of values
* separated by commas and new lines
* @param set<i32> thing - the set<i32> to print
* @return set<i32> - returns the set<i32> 'thing'
*/
set<i32> testSet(1: set<i32> thing),
/**
* Prints 'testList("{%s}")' where thing has been formatted into a string of values
* separated by commas and new lines
* @param list<i32> thing - the list<i32> to print
* @return list<i32> - returns the list<i32> 'thing'
*/
list<i32> testList(1: list<i32> thing),
/**
* Prints 'testEnum("%d")' where thing has been formatted into it's numeric value
* @param Numberz thing - the Numberz to print
* @return Numberz - returns the Numberz 'thing'
*/
Numberz testEnum(1: Numberz thing),
/**
* Prints 'testTypedef("%d")' with thing as '%d'
* @param UserId thing - the UserId to print
* @return UserId - returns the UserId 'thing'
*/
UserId testTypedef(1: UserId thing),
/**
* Prints 'testMapMap("%d")' with hello as '%d'
* @param i32 hello - the i32 to print
* @return map<i32,map<i32,i32>> - returns a dictionary with these values:
* {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => 2, 3 => 3, 4 => 4, }, }
*/
map<i32,map<i32,i32>> testMapMap(1: i32 hello),
/**
* So you think you've got this all worked, out eh?
*
* Creates a the returned map with these values and prints it out:
* { 1 => { 2 => argument,
* 3 => argument,
* },
* 2 => { 6 => <empty Insanity struct>, },
* }
* @return map<UserId, map<Numberz,Insanity>> - a map with the above values
*/
map<UserId, map<Numberz,Insanity>> testInsanity(1: Insanity argument),
/**
* Prints 'testMulti()'
* @param byte arg0 -
* @param i32 arg1 -
* @param i64 arg2 -
* @param map<i16, string> arg3 -
* @param Numberz arg4 -
* @param UserId arg5 -
* @return Xtruct - returns an Xtruct with string_thing = "Hello2, byte_thing = arg0, i32_thing = arg1
* and i64_thing = arg2
*/
Xtruct testMulti(1: byte arg0, 2: i32 arg1, 3: i64 arg2, 4: map<i16, string> arg3, 5: Numberz arg4, 6: UserId arg5),
/**
* Print 'testException(%s)' with arg as '%s'
* @param string arg - a string indication what type of exception to throw
* if arg == "Xception" throw Xception with errorCode = 1001 and message = arg
* elsen if arg == "TException" throw TException
* else do not throw anything
*/
void testException(1: string arg) throws(1: Xception err1),
/**
* Print 'testMultiException(%s, %s)' with arg0 as '%s' and arg1 as '%s'
* @param string arg - a string indication what type of exception to throw
* if arg0 == "Xception" throw Xception with errorCode = 1001 and message = "This is an Xception"
* elsen if arg0 == "Xception2" throw Xception2 with errorCode = 2002 and struct_thing.string_thing = "This is an Xception2"
* else do not throw anything
* @return Xtruct - an Xtruct with string_thing = arg1
*/
Xtruct testMultiException(1: string arg0, 2: string arg1) throws(1: Xception err1, 2: Xception2 err2)
/**
* Print 'testOneway(%d): Sleeping...' with secondsToSleep as '%d'
* sleep 'secondsToSleep'
* Print 'testOneway(%d): done sleeping!' with secondsToSleep as '%d'
* @param i32 secondsToSleep - the number of seconds to sleep
*/
oneway void testOneway(1:i32 secondsToSleep)
/**
* Prints 'testUnionArgument()' and returns the argument unmodified, wrapped in a
* HasUnion struct.
**/
HasUnion testUnionArgument(1: NonEmptyUnion arg0)
/**
* Returns the argument unaltered.
*/
UnionWithDefault testUnionWithDefault(1: UnionWithDefault theArg)
}
// Builderless unions should handle fields named "result".
// see https://github.com/microsoft/thrifty/issues/404
union UnionWithResult {
1: i32 result;
2: i64 bigResult;
3: string error;
}
const Insanity TOTAL_INSANITY = {
"userMap": {
myNumberz: 1234
},
"xtructs": [
{
"string_thing": "hello",
},
{
"i32_thing": 1,
"bool_thing": 0,
},
]
}
const Bonk A_BONK = {
"message": "foobar",
"type": 100,
}

Просмотреть файл

@ -38,6 +38,24 @@ tasks.withType(KotlinCompile).configureEach {
}
}
def compileTestThrift = tasks.register("compileTestThrift", JavaExec) { t ->
t.inputs.file("$projectDir/ClientThriftTest.thrift")
t.outputs.dir("$projectDir/build/generated-src/thrifty-kotlin/kotlin")
t.outputs.cacheIf("This task is always cacheable based on its inputs") { true }
t.classpath = project(":thrifty-compiler").sourceSets.main.runtimeClasspath
mainClass = "com.microsoft.thrifty.compiler.ThriftyCompiler"
args = [
"--out=$projectDir/build/generated-src/thrifty-kotlin/kotlin",
"--kt-struct-builders",
"--service-type=coroutine",
"$projectDir/ClientThriftTest.thrift"
]
}
kotlin {
jvm()
@ -51,6 +69,10 @@ kotlin {
compilations.main.cinterops {
KT62102Workaround {}
}
compilations.test.compileTaskProvider.configure {
dependsOn compileTestThrift
}
}
iosX64 {
@ -63,6 +85,10 @@ kotlin {
compilations.main.cinterops {
KT62102Workaround {}
}
compilations.test.compileTaskProvider.configure {
dependsOn compileTestThrift
}
}
sourceSets {
@ -107,6 +133,8 @@ kotlin {
iosTest {
dependsOn commonTest
kotlin.srcDir("$buildDir/generated-src/thrifty-kotlin/kotlin")
dependencies {
implementation libs.kotlin.test.common
implementation libs.kotest.assertions.common
@ -136,19 +164,72 @@ jvmTest {
useJUnitPlatform()
}
tasks.register("iosTest") {
def device = project.findProperty("iosDevice")?.toString() ?: "iPhone 15 Pro Max"
abstract class IosTestTask extends DefaultTask {
@InputFiles
abstract Property<FileCollection> getServerClasspath()
@InputFile
abstract RegularFileProperty getTestBinary()
@Input
abstract Property<String> getDevice()
private Process serverProcess
private int port
@TaskAction
def run() {
try {
serverProcess = startServer()
logger.quiet("Server listening on $port")
runTests()
} finally {
serverProcess?.destroyForcibly()
}
}
private void runTests() {
def thePort = port
project.exec {
environment("SIMCTL_CHILD_THRIFTY_HTTP_SERVER_PORT", thePort.toString())
commandLine 'xcrun', 'simctl', 'spawn', device.get(), testBinary.get().getAsFile().absolutePath
}
}
private Process startServer() {
Process process = new ProcessBuilder()
.command(["java", "-cp", serverClasspath.get().getAsPath(), "com.microsoft.thrifty.mains.HttpServerMain"])
.start()
def reader = new InputStreamReader(process.getInputStream())
while (true) {
String portLine = reader.readLine()
if (portLine == null) {
throw new RuntimeException("Server failed to start")
}
if (portLine.startsWith("port ")) {
port = Integer.parseInt(portLine.substring(5))
break
}
}
return process
}
}
tasks.register("iosTest", IosTestTask) {
dependsOn 'compileTestThrift'
dependsOn 'linkDebugTestIosX64'
group = JavaBasePlugin.VERIFICATION_GROUP
description = "Runs tests for target 'ios' on an iOS simulator"
doLast {
def binary = kotlin.targets.iosX64.binaries.getTest('DEBUG').outputFile
println("muh binary: ${binary.absolutePath}")
exec {
commandLine 'xcrun', 'simctl', 'spawn', device, binary.absolutePath
}
}
device = project.findProperty("iosDevice")?.toString() ?: "iPhone 15 Pro Max"
testBinary = kotlin.targets.iosX64.binaries.getTest('DEBUG').outputFile
serverClasspath = project(":thrifty-test-server").sourceSets.main.runtimeClasspath
}
// What have I gotten myself in to

Просмотреть файл

@ -0,0 +1,28 @@
/*
* Thrifty
*
* Copyright (c) Microsoft Corporation
*
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
*
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
*/
package com.microsoft.thrifty.transport
expect class HttpTransport(url: String) : Transport {
fun setConnectTimeout(timeout: Int)
fun setReadTimeout(timeout: Int)
fun setCustomHeaders(headers: Map<String, String>)
fun setCustomHeader(key: String, value: String)
}

Просмотреть файл

@ -20,10 +20,22 @@
*/
package com.microsoft.thrifty.service
import KT62102Workaround.dispatch_attr_serial
import com.microsoft.thrifty.Struct
import com.microsoft.thrifty.ThriftException
import com.microsoft.thrifty.protocol.Protocol
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.convert
import okio.Closeable
import okio.IOException
import platform.darwin.DISPATCH_QUEUE_SERIAL
import platform.darwin.dispatch_async
import platform.darwin.dispatch_get_global_queue
import platform.darwin.dispatch_queue_create
import platform.darwin.dispatch_suspend
import platform.posix.QOS_CLASS_USER_INITIATED
import kotlin.coroutines.cancellation.CancellationException
/**
* Implements a basic service client that executes methods asynchronously.
@ -35,11 +47,17 @@ import okio.Closeable
* configure your [Protocol] and [com.microsoft.thrifty.transport.Transport]
* objects appropriately.
*/
@Suppress("UNCHECKED_CAST")
@OptIn(ExperimentalForeignApi::class)
actual open class AsyncClientBase protected actual constructor(
protocol: Protocol,
private val listener: Listener
) : ClientBase(protocol), Closeable {
private val closed = atomic(false)
private var queue = dispatch_queue_create("client-queue", dispatch_attr_serial())
private val pendingCalls = mutableSetOf<MethodCall<*>>()
/**
* Exposes important events in the client's lifecycle.
*/
@ -78,10 +96,84 @@ actual open class AsyncClientBase protected actual constructor(
* @param methodCall the remote method call to be invoked
*/
protected actual fun enqueue(methodCall: MethodCall<*>) {
TODO()
check(!closed.value) { "Client has been closed" }
pendingCalls.add(methodCall)
dispatch_async(queue) {
pendingCalls.remove(methodCall)
if (closed.value) {
methodCall.callback?.onError(CancellationException("Client has been closed"))
return@dispatch_async
}
var result: Any? = null
var error: Exception? = null
try {
result = invokeRequest(methodCall)
} catch (e: IOException) {
fail(methodCall, e)
close(e)
return@dispatch_async
} catch (e: RuntimeException) {
fail(methodCall, e)
close(e)
return@dispatch_async
} catch (e: ServerException) {
error = e.thriftException
} catch (e: Exception) {
if (e is Struct) {
error = e
} else {
throw AssertionError("wat")
}
}
if (error != null) {
fail(methodCall, error)
} else {
complete(methodCall, result)
}
}
}
override fun close() {
TODO()
override fun close() = close(error = null)
private fun close(error: Exception?) {
if (closed.getAndSet(true)) {
return
}
dispatch_suspend(queue)
queue = null
for (call in pendingCalls) {
val e = error ?: CancellationException("Client has been closed")
fail(call, e)
}
dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED.convert(), 0.convert())) {
if (error != null) {
listener.onError(error)
} else {
listener.onTransportClosed()
}
}
}
@Suppress("UNCHECKED_CAST")
private fun complete(call: MethodCall<*>, result: Any?) {
val q = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED.convert(), 0.convert())
dispatch_async(q) {
val callback = call.callback as ServiceMethodCallback<Any?>?
callback?.onSuccess(result)
}
}
private fun fail(call: MethodCall<*>, exception: Exception) {
val q = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED.convert(), 0.convert())
dispatch_async(q) {
call.callback?.onError(exception)
}
}
}

Просмотреть файл

@ -0,0 +1,215 @@
/*
* Thrifty
*
* Copyright (c) Microsoft Corporation
*
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
*
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
*/
package com.microsoft.thrifty.transport
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.addressOf
import kotlinx.cinterop.convert
import kotlinx.cinterop.usePinned
import okio.IOException
import platform.Foundation.NSCondition
import platform.Foundation.NSError
import platform.Foundation.NSMakeRange
import platform.Foundation.NSMutableData
import platform.Foundation.NSMutableURLRequest
import platform.Foundation.NSTimeInterval
import platform.Foundation.NSURL
import platform.Foundation.NSURLResponse
import platform.Foundation.NSURLSession
import platform.Foundation.NSURLSessionTask
import platform.Foundation.appendBytes
import platform.Foundation.dataTaskWithRequest
import platform.Foundation.getBytes
import platform.Foundation.setHTTPBody
import platform.Foundation.setHTTPMethod
import platform.Foundation.setValue
@OptIn(ExperimentalForeignApi::class)
actual class HttpTransport actual constructor(url: String) : Transport {
private val url = NSURL.URLWithString(url)!!
private val customHeaders = mutableMapOf<String, String>()
private var readTimeout: NSTimeInterval = 0.0
private var connectTimeout: NSTimeInterval = 0.0
private var writing: Boolean = true
// When writing, [data] will act as a send buffer, sent on [flush].
// When reading, will hold response bytes that can be read out
// by calls to [read], and [consumed] will track how many bytes have
// been read out.
private var data: NSMutableData = NSMutableData()
private var consumed = 0UL
// This is used to signal when the response has been received.
private val condition = NSCondition()
private var response: NSURLResponse? = null
private var responseErr: NSError? = null
private var task: NSURLSessionTask? = null
override fun close() {
condition.locked {
if (task != null) {
task!!.cancel()
task = null
}
}
}
override fun read(buffer: ByteArray, offset: Int, count: Int): Int {
require(!writing) { "Cannot read before calling flush()" }
require(count > 0) { "Cannot read a negative or zero number of bytes" }
require(offset >= 0) { "Cannot read into a negative offset" }
require(offset < buffer.size) { "Offset is outside of buffer bounds" }
require(offset + count <= buffer.size) { "Not enough room in buffer for requested read" }
condition.waitFor { response != null || responseErr != null }
if (responseErr != null) {
throw IOException("Response error: $responseErr")
}
val remaining = data.length() - consumed
val toCopy = minOf(remaining, count.convert())
buffer.usePinned { pinned ->
data.getBytes(pinned.addressOf(offset), NSMakeRange(consumed.convert(), toCopy.convert()))
}
// If we copied bytes, move the pointer.
if (toCopy > 0U) {
consumed += toCopy
}
return toCopy.convert()
}
override fun write(buffer: ByteArray, offset: Int, count: Int) {
require(offset >= 0) { "offset < 0: $offset" }
require(count >= 0) { "count < 0: $count" }
require(offset + count <= buffer.size) { "offset + count > buffer.size: $offset + $count > ${buffer.size}" }
if (!writing) {
// Maybe there's still data in the buffer to be read,
// but if our user is writing, then let's just go with it.
condition.locked {
if (task != null) {
task!!.cancel()
task = null
}
data.setLength(0U)
response = null
responseErr = null
consumed = 0U
writing = true
}
}
buffer.usePinned { pinned ->
data.appendBytes(pinned.addressOf(offset), count.convert())
}
}
override fun flush() {
require(writing) { "Cannot flush after calling read()" }
writing = false
val urlRequest = NSMutableURLRequest(url)
urlRequest.setHTTPMethod("POST")
urlRequest.setValue(value = "application/x-thrift", forHTTPHeaderField = "Content-Type")
urlRequest.setValue(value = "application/x-thrift", forHTTPHeaderField = "Accept")
urlRequest.setValue(value = "Java/THttpClient", forHTTPHeaderField = "User-Agent")
for ((key, value) in customHeaders) {
urlRequest.setValue(value, forHTTPHeaderField = key)
}
if (readTimeout != 0.0) {
urlRequest.setTimeoutInterval(readTimeout)
}
urlRequest.setHTTPBody(data)
val session = NSURLSession.sharedSession()
val task = session.dataTaskWithRequest(urlRequest) { data, response, error ->
if (data != null) {
this.data = data.mutableCopy() as NSMutableData
} else {
this.data.setLength(0U)
}
consumed = 0U
condition.locked {
this.response = response
this.responseErr = error
condition.signal()
}
}
condition.locked {
this.task = task
}
task.resume()
}
actual fun setConnectTimeout(timeout: Int) {
this.connectTimeout = millisToTimeInterval(timeout.toLong())
}
actual fun setReadTimeout(timeout: Int) {
this.readTimeout = millisToTimeInterval(timeout.toLong())
}
actual fun setCustomHeaders(headers: Map<String, String>) {
customHeaders.clear()
customHeaders.putAll(headers)
}
actual fun setCustomHeader(key: String, value: String) {
customHeaders[key] = value
}
}
fun millisToTimeInterval(millis: Long): NSTimeInterval {
// NSTimeInterval is a double-precision floating point number representing
// seconds. So to go from millis to NSTimeInterval, we divide by 1000.0.
return millis / 1000.0
}
inline fun NSCondition.locked(block: () -> Unit) {
lock()
try {
block()
} finally {
unlock()
}
}
inline fun NSCondition.waitFor(crossinline condition: () -> Boolean) {
locked {
while (!condition()) {
wait()
}
}
}

Просмотреть файл

@ -135,7 +135,9 @@ class NwSocket(
}
if (!sem.waitWithTimeout(readWriteTimeoutMillis)) {
throw IOException("Timed out waiting for read")
val e = IOException("Timed out waiting for read")
println(e.stackTraceToString())
throw e
}
networkError?.throwError()

Просмотреть файл

@ -0,0 +1,71 @@
/*
* Thrifty
*
* Copyright (c) Microsoft Corporation
*
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
*
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
*/
package com.microsoft.thrifty.transport
import com.microsoft.thrifty.runtime.kgen.coro.Insanity
import com.microsoft.thrifty.runtime.kgen.coro.Numberz
import com.microsoft.thrifty.runtime.kgen.coro.ThriftTestClient
import com.microsoft.thrifty.runtime.kgen.coro.UserId
import com.microsoft.thrifty.protocol.BinaryProtocol
import com.microsoft.thrifty.service.AsyncClientBase
import com.microsoft.thrifty.service.ServiceMethodCallback
import io.kotest.matchers.maps.beEmpty
import io.kotest.matchers.shouldNot
import kotlinx.coroutines.runBlocking
import platform.Foundation.NSProcessInfo
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.fail
class HttpTransportTest {
private var port: Int = -1
@BeforeTest
fun setUp() {
val portVar = NSProcessInfo.processInfo.environment["THRIFTY_HTTP_SERVER_PORT"]
requireNotNull(portVar)
port = (portVar as String).toInt()
}
@Test
fun testHttpTransport() = runBlocking {
val transport = HttpTransport("http://localhost:$port/test/service")
val protocol = BinaryProtocol(transport)
val client = ThriftTestClient(protocol, object : AsyncClientBase.Listener {
override fun onTransportClosed() {
println("transport closed")
}
override fun onError(error: Throwable) {
fail("error: $error")
}
})
val insanity = Insanity.Builder()
.build()
val result = client.testInsanity(insanity)
result shouldNot beEmpty()
transport.close()
}
}

Просмотреть файл

@ -64,7 +64,7 @@ import java.net.URL
*
* @see [THRIFT-970](https://issues.apache.org/jira/browse/THRIFT-970)
*/
open class HttpTransport(url: String) : Transport {
actual open class HttpTransport actual constructor(url: String) : Transport {
private val url: URL = URL(url)
private var currentState: Transport = Writing()
private var connectTimeout: Int? = null
@ -146,20 +146,20 @@ open class HttpTransport(url: String) : Transport {
connection.doOutput = true
}
fun setConnectTimeout(timeout: Int) {
actual fun setConnectTimeout(timeout: Int) {
connectTimeout = timeout
}
fun setReadTimeout(timeout: Int) {
actual fun setReadTimeout(timeout: Int) {
readTimeout = timeout
}
fun setCustomHeaders(headers: Map<String, String>) {
actual fun setCustomHeaders(headers: Map<String, String>) {
customHeaders.clear()
customHeaders.putAll(headers)
}
fun setCustomHeader(key: String, value: String) {
actual fun setCustomHeader(key: String, value: String) {
customHeaders[key] = value
}

Просмотреть файл

@ -24,3 +24,7 @@ dispatch_block_t dispatch_data_default_destructor() {
dispatch_queue_t dispatch_get_target_default_queue() {
return DISPATCH_TARGET_QUEUE_DEFAULT;
}
dispatch_queue_attr_t dispatch_attr_serial() {
return DISPATCH_QUEUE_SERIAL;
}

Просмотреть файл

@ -0,0 +1,39 @@
/*
* Thrifty
*
* Copyright (c) Microsoft Corporation
*
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
*
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
*/
package com.microsoft.thrifty.mains;
import com.microsoft.thrifty.testing.HttpServer;
import com.microsoft.thrifty.testing.ServerProtocol;
import com.microsoft.thrifty.testing.ServerTransport;
public class HttpServerMain {
public static void main(String[] args) {
try (HttpServer server = new HttpServer()) {
server.run(ServerProtocol.BINARY, ServerTransport.HTTP);
System.out.println("port " + server.port());
server.await();
} catch (Exception e) {
e.printStackTrace();
}
}
}

Просмотреть файл

@ -20,13 +20,14 @@
*/
package com.microsoft.thrifty.testing;
import java.io.Closeable;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.core.StandardContext;
import org.apache.catalina.startup.Tomcat;
import static com.microsoft.thrifty.testing.TestServer.getProtocolFactory;
public class HttpServer implements TestServerInterface {
public class HttpServer implements TestServerInterface, Closeable {
private Tomcat tomcat;
@Override
@ -66,4 +67,8 @@ public class HttpServer implements TestServerInterface {
throw new RuntimeException(e);
}
}
public void await() {
tomcat.getServer().await();
}
}