Kotlin generate union as sealed class (#253)

This commit is contained in:
Thibault Duperron 2019-02-04 19:13:09 +01:00 коммит произвёл Ben Bader
Родитель e222e48e69
Коммит d1fc321541
3 изменённых файлов: 854 добавлений и 20 удалений

Просмотреть файл

@ -456,4 +456,17 @@ struct MapsOfCollections {
1: map<set<i32>, set<string>> mapOfSets;
2: map<list<double>, list<i64>> mapOfLists;
3: map<map<i32, i32>, map<i8, i8>> mapOfMaps;
}
union TestUnion {
1: i32 AnInt;
2: i64 ALong;
3: string Text;
4: Bonk aBonk;
}
union EmptyUnion {};
struct HasEmptyUnion {
1: EmptyUnion theEmptyUnion;
}

Просмотреть файл

@ -232,7 +232,7 @@ class KotlinCodeGenerator(
schema.typedefs.forEach { typedefsByNamespace.put(it.kotlinNamespace, generateTypeAlias(it)) }
schema.enums.forEach { specsByNamespace.put(it.kotlinNamespace, generateEnumClass(it)) }
schema.structs.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
schema.unions.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
schema.unions.forEach { specsByNamespace.put(it.kotlinNamespace, generateSealedClass(schema, it)) }
schema.exceptions.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
val constantNameAllocators = mutableMapOf<String, NameAllocator>()
@ -390,8 +390,8 @@ class KotlinCodeGenerator(
if (struct.isDeprecated) addAnnotation(makeDeprecated())
if (struct.hasJavadoc) addKdoc("%L", struct.documentation)
if (struct.isException) superclass(Exception::class)
if (parcelize) {
addAnnotation(makeParcelable())
addAnnotation(suppressLint("ParcelCreator")) // Android Studio bug with Parcelize
}
@ -499,6 +499,76 @@ class KotlinCodeGenerator(
.build()
}
internal fun generateSealedClass(schema: Schema, struct: StructType): TypeSpec {
val structClassName = ClassName(struct.kotlinNamespace, struct.name)
val typeBuilder = TypeSpec.classBuilder(structClassName).apply {
addGeneratedAnnotation()
addModifiers(KModifier.SEALED)
if (struct.isDeprecated) addAnnotation(makeDeprecated())
if (struct.hasJavadoc) addKdoc("%L", struct.documentation)
}
val nameAllocator = nameAllocators[struct]
for (field in struct.fields) {
val name = nameAllocator.get(field)
val type = field.type.typeName
val propertySpec = PropertySpec.varBuilder("value", type.asNullable())
.initializer("value")
val propConstructor = FunSpec.constructorBuilder()
.addParameter("value", type.asNullable())
val dataProp = TypeSpec.classBuilder(name)
.addModifiers(KModifier.DATA)
.superclass(structClassName)
.addProperty(propertySpec.build())
.primaryConstructor(propConstructor.build())
.build()
typeBuilder.addType(dataProp)
}
var builderTypeName : ClassName? = null
var adapterInterfaceTypeName = KtAdapter::class
.asTypeName()
.parameterizedBy(struct.typeName)
if (!builderlessDataClasses) {
builderTypeName = ClassName(struct.kotlinNamespace, struct.name, "Builder")
typeBuilder.addType(generateBuilderForSealed(schema, struct))
adapterInterfaceTypeName = Adapter::class.asTypeName().parameterizedBy(
struct.typeName, builderTypeName)
}
val adapterTypeName = ClassName(struct.kotlinNamespace, struct.name, "${struct.name}Adapter")
typeBuilder.addType(generateAdapterForSealed(struct, adapterTypeName, adapterInterfaceTypeName, builderTypeName))
val companionBuilder = TypeSpec.companionObjectBuilder()
companionBuilder.addProperty(PropertySpec.builder("ADAPTER", adapterInterfaceTypeName)
.initializer("%T()", adapterTypeName)
.jvmField()
.build())
if (shouldImplementStruct) {
typeBuilder
.addSuperinterface(Struct::class)
.addFunction(FunSpec.builder("write")
.addModifiers(KModifier.OVERRIDE)
.addParameter("protocol", Protocol::class)
.addStatement("%L.write(protocol, this)", nameAllocator.get(Tags.ADAPTER))
.build())
}
return typeBuilder
.addType(companionBuilder.build())
.build()
}
// endregion Structs
// region Redaction/obfuscation
@ -647,6 +717,86 @@ class KotlinCodeGenerator(
.build()
}
internal fun generateBuilderForSealed(schema: Schema, struct: StructType): TypeSpec {
val structTypeName = ClassName(struct.kotlinNamespace, struct.name)
val spec = TypeSpec.classBuilder("Builder")
.addSuperinterface(StructBuilder::class.asTypeName().parameterizedBy(structTypeName))
val buildFunSpec = FunSpec.builder("build")
.addModifiers(KModifier.OVERRIDE)
.returns(structTypeName)
.beginControlFlow("return when")
val resetFunSpec = FunSpec.builder("reset")
.addModifiers(KModifier.OVERRIDE)
val copyCtor = FunSpec.constructorBuilder()
.addParameter("source", structTypeName)
.callThisConstructor()
.beginControlFlow("when(source)")
val defaultCtor = FunSpec.constructorBuilder()
val nameAllocator = nameAllocators[struct]
for (field in struct.fields) {
val name = nameAllocator.get(field)
val type = field.type.typeName
// Add a private var
val defaultValueBlock = field.defaultValue?.let {
renderConstValue(schema, field.type, it)
} ?: CodeBlock.of("null")
val propertySpec = PropertySpec.varBuilder(name, type.asNullable(), KModifier.PRIVATE)
// Add a builder fun
var content = "return apply {\n"
for (field2 in struct.fields) {
val name2 = nameAllocator.get(field2)
if (name == name2) {
content += " this.$name2 = value\n"
} else {
content += " this.$name2 = null\n"
}
}
content += "}"
val builderFunSpec = FunSpec.builder(name)
.addParameter("value", type)
.addStatement(content)
// Add initialization in default ctor
defaultCtor.addStatement("this.$name = %L", defaultValueBlock)
// Add initialization in copy ctor
copyCtor.addStatement("is $name -> this.$name = source.value")
// reset field
resetFunSpec.addStatement("this.$name = %L", defaultValueBlock)
// Add field to build-method ctor-invocation arg builder
// TODO: Add newlines and indents if numFields > 1
buildFunSpec.addStatement("$name != null -> ${struct.name}.$name($name)")
// Finish off the property and builder fun
spec.addProperty(propertySpec.build())
spec.addFunction(builderFunSpec.build())
}
buildFunSpec.addStatement("else -> throw IllegalStateException(\"unpossible\")")
buildFunSpec.endControlFlow()
copyCtor.endControlFlow()
return spec
.addFunction(defaultCtor.build())
.addFunction(copyCtor.build())
.addFunction(buildFunSpec.build())
.addFunction(resetFunSpec.build())
.build()
}
// endregion Builders
// region Adapters
@ -819,6 +969,146 @@ class KotlinCodeGenerator(
.build()
}
/**
* Generates an adapter for the given struct type.
*
* The kind of adapter generated depends on whether a [builderType] is
* provided. If so, a conventional [com.microsoft.thrifty.Adapter] gets
* created, making use of the given [builderType]. If not, a so-called
* "builderless" [com.microsoft.thrifty.kotlin.Adapter] is the result.
*/
internal fun generateAdapterForSealed(
struct: StructType,
adapterName: ClassName,
adapterInterfaceName: TypeName,
builderType: ClassName?): TypeSpec {
val adapter = TypeSpec.classBuilder(adapterName)
.addModifiers(KModifier.PRIVATE)
.addSuperinterface(adapterInterfaceName)
val reader = FunSpec.builder("read").apply {
addModifiers(KModifier.OVERRIDE)
returns(struct.typeName)
addParameter("protocol", Protocol::class)
if (builderType != null) {
addParameter("builder", builderType)
}
}
val writer = FunSpec.builder("write")
.addModifiers(KModifier.OVERRIDE)
.addParameter("protocol", Protocol::class)
.addParameter("struct", struct.typeName)
// Writer first, b/c it is easier
val nameAllocator = nameAllocators[struct]
writer.addStatement("protocol.writeStructBegin(%S)", struct.name)
for (field in struct.fields) {
val name = nameAllocator.get(field)
val fieldType = field.type
if (!field.required) {
writer.beginControlFlow("if (struct is $name)")
}
writer.addStatement("protocol.writeFieldBegin(%S, %L, %T.%L)",
field.name,
field.id,
TType::class,
fieldType.typeCodeName)
generateWriteCall(writer, "struct.value!!", fieldType)
writer.addStatement("protocol.writeFieldEnd()")
if (!field.required) {
writer.endControlFlow()
}
}
writer.addStatement("protocol.writeFieldStop()")
writer.addStatement("protocol.writeStructEnd()")
// Reader next
reader.addStatement("protocol.readStructBegin()")
if (builderType == null) {
reader.addStatement("var result : ${struct.name}? = null")
}
reader.beginControlFlow("while (true)")
reader.addStatement("val fieldMeta = protocol.readFieldBegin()")
reader.beginControlFlow("if (fieldMeta.typeId == %T.STOP)", TType::class)
reader.addStatement("break")
reader.endControlFlow()
if (struct.fields.isNotEmpty()) {
reader.beginControlFlow("when (fieldMeta.fieldId.toInt())")
for (field in struct.fields) {
val name = nameAllocator.get(field)
val fieldType = field.type
reader.addCode {
addStatement("${field.id} -> {%>")
beginControlFlow("if (fieldMeta.typeId == %T.%L)", TType::class, fieldType.typeCodeName)
generateReadCall(this, name, fieldType)
if (builderType != null) {
addStatement("builder.$name($name)")
} else {
addStatement("result = $name($name)")
}
nextControlFlow("else")
addStatement("%T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
endControlFlow()
addStatement("%<}")
}
}
reader.addStatement("else -> %T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
reader.endControlFlow() // when (fieldMeta.fieldId.toInt())
} else {
reader.addStatement("%T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
}
reader.addStatement("protocol.readFieldEnd()")
reader.endControlFlow() // while (true)
reader.addStatement("protocol.readStructEnd()")
if (builderType != null) {
reader.addStatement("return builder.build()")
} else {
reader.addCode {
beginControlFlow("if (null == result)")
addStatement("throw IllegalStateException(\"unreadable\")")
nextControlFlow("else")
addStatement("return result")
endControlFlow()
}
}
if (builderType != null) {
adapter.addFunction(FunSpec.builder("read")
.addModifiers(KModifier.OVERRIDE)
.addParameter("protocol", Protocol::class)
.addStatement("return read(protocol, %T())", builderType)
.build())
}
return adapter
.addFunction(reader.build())
.addFunction(writer.build())
.build()
}
private fun generateWriteCall(writer: FunSpec.Builder, name: String, type: ThriftType) {
// Assumptions:

Просмотреть файл

@ -30,17 +30,21 @@ import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asTypeName
import io.kotlintest.shouldNot
import io.kotlintest.matchers.string.contain
import io.kotlintest.should
import io.kotlintest.shouldBe
import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
class KotlinCodeGeneratorTest {
@get:Rule val tempDir = TemporaryFolder()
@get:Rule
val tempDir = TemporaryFolder()
@Test fun `struct to data class`() {
@Test
fun `struct to data class`() {
val schema = load("""
namespace kt com.test
@ -82,7 +86,8 @@ class KotlinCodeGeneratorTest {
files.forEach { println("$it") }
}
@Test fun `output styles work as advertised`() {
@Test
fun `output styles work as advertised`() {
val thrift = """
namespace kt com.test
@ -101,13 +106,14 @@ class KotlinCodeGeneratorTest {
// Default should be one file per namespace
gen.outputStyle shouldBe KotlinCodeGenerator.OutputStyle.FILE_PER_NAMESPACE
gen.generate(schema).size shouldBe 1
gen.generate(schema).size shouldBe 1
gen.outputStyle = KotlinCodeGenerator.OutputStyle.FILE_PER_TYPE
gen.generate(schema).size shouldBe 2
}
@Test fun `file-per-type puts constants into a file named 'Constants'`() {
@Test
fun `file-per-type puts constants into a file named 'Constants'`() {
val thrift = """
namespace kt com.test
@ -123,7 +129,8 @@ class KotlinCodeGeneratorTest {
specs.single().name shouldBe "Constants" // ".kt" suffix is appended when the file is written out
}
@Test fun `empty structs get default equals, hashcode, and toString methods`() {
@Test
fun `empty structs get default equals, hashcode, and toString methods`() {
val thrift = """
namespace kt com.test
@ -140,10 +147,11 @@ class KotlinCodeGeneratorTest {
struct.modifiers.any { it == KModifier.DATA } shouldBe false
struct.funSpecs.any { it.name == "toString" } shouldBe true
struct.funSpecs.any { it.name == "hashCode" } shouldBe true
struct.funSpecs.any { it.name == "equals" } shouldBe true
struct.funSpecs.any { it.name == "equals" } shouldBe true
}
@Test fun `Non-empty structs are data classes`() {
@Test
fun `Non-empty structs are data classes`() {
val thrift = """
namespace kt com.test
@ -160,10 +168,11 @@ class KotlinCodeGeneratorTest {
struct.modifiers.any { it == KModifier.DATA } shouldBe true
struct.funSpecs.any { it.name == "toString" } shouldBe false
struct.funSpecs.any { it.name == "hashCode" } shouldBe false
struct.funSpecs.any { it.name == "equals" } shouldBe false
struct.funSpecs.any { it.name == "equals" } shouldBe false
}
@Test fun `exceptions with reserved field names get renamed fields`() {
@Test
fun `exceptions with reserved field names get renamed fields`() {
val thrift = """
namespace kt com.test
@ -176,7 +185,8 @@ class KotlinCodeGeneratorTest {
xception.propertySpecs.single().name shouldBe "message_"
}
@Test fun services() {
@Test
fun services() {
val thrift = """
namespace kt test.services
@ -193,7 +203,8 @@ class KotlinCodeGeneratorTest {
generate(thrift).forEach { println(it) }
}
@Test fun `typedefs become typealiases`() {
@Test
fun `typedefs become typealiases`() {
val thrift = """
namespace kt test.typedefs
@ -207,7 +218,8 @@ class KotlinCodeGeneratorTest {
generate(thrift).forEach { println(it) }
}
@Test fun `services that return typedefs`() {
@Test
fun `services that return typedefs`() {
val thrift = """
namespace kt test.typedefs
@ -226,7 +238,8 @@ class KotlinCodeGeneratorTest {
.parameterizedBy(ClassName("test.typedefs", "TheNumber"))
}
@Test fun `constants that are typedefs`() {
@Test
fun `constants that are typedefs`() {
val thrift = """
|namespace kt test.typedefs
|
@ -248,7 +261,8 @@ class KotlinCodeGeneratorTest {
""".trimMargin()
}
@Test fun `Parcelize annotations for structs and enums`() {
@Test
fun `Parcelize annotations for structs and enums`() {
val thrift = """
|namespace kt test.parcelize
|
@ -263,7 +277,7 @@ class KotlinCodeGeneratorTest {
val file = generate(thrift) { parcelize() }.single()
val struct = file.members.single { it is TypeSpec && it.name == "Foo" } as TypeSpec
val anEnum = file.members.single { it is TypeSpec && it.name == "AnEnum"} as TypeSpec
val anEnum = file.members.single { it is TypeSpec && it.name == "AnEnum" } as TypeSpec
val svc = file.members.single { it is TypeSpec && it.name == "SvcClient" } as TypeSpec
val parcelize = ClassName("kotlinx.android.parcel", "Parcelize")
@ -273,7 +287,8 @@ class KotlinCodeGeneratorTest {
svc.annotations.any { it.type == parcelize } shouldBe false
}
@Test fun `Custom map-type constants`() {
@Test
fun `Custom map-type constants`() {
val thrift = """
|namespace kt test.map_consts
|
@ -301,7 +316,8 @@ class KotlinCodeGeneratorTest {
""".trimMargin()
}
@Test fun `suspend-fun service clients`() {
@Test
fun `suspend-fun service clients`() {
val thrift = """
|namespace kt test.coro
|
@ -334,6 +350,521 @@ class KotlinCodeGeneratorTest {
""".trimMargin())
}
@Test
fun `union generate sealed`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { coroutineServiceClients() }
file.single().toString() should contain("""
|sealed class Union : Struct {
""".trimMargin())
}
@Test
fun `union properties as data`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { coroutineServiceClients() }
file.single().toString() should contain("""
|
| data class Foo(var value: Int?) : Union()
|
| data class Bar(var value: Long?) : Union()
|
| data class Baz(var value: String?) : Union()
|
| data class NotFoo(var value: Int?) : Union()
|
""".trimMargin())
}
@Test
fun `union has builder`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { coroutineServiceClients() }
file.single().toString() should contain("""
| class Builder : StructBuilder<Union> {
| private var Foo: Int?
|
| private var Bar: Long?
|
| private var Baz: String?
|
| private var NotFoo: Int?
|
| constructor() {
| this.Foo = null
| this.Bar = null
| this.Baz = null
| this.NotFoo = null
| }
|
| constructor(source: Union) : this() {
| when(source) {
| is Foo -> this.Foo = source.value
| is Bar -> this.Bar = source.value
| is Baz -> this.Baz = source.value
| is NotFoo -> this.NotFoo = source.value
| }
| }
|
| fun Foo(value: Int) = apply {
| this.Foo = value
| this.Bar = null
| this.Baz = null
| this.NotFoo = null
| }
|
| fun Bar(value: Long) = apply {
| this.Foo = null
| this.Bar = value
| this.Baz = null
| this.NotFoo = null
| }
|
| fun Baz(value: String) = apply {
| this.Foo = null
| this.Bar = null
| this.Baz = value
| this.NotFoo = null
| }
|
| fun NotFoo(value: Int) = apply {
| this.Foo = null
| this.Bar = null
| this.Baz = null
| this.NotFoo = value
| }
|
| override fun build(): Union = when {
| Foo != null -> Union.Foo(Foo)
| Bar != null -> Union.Bar(Bar)
| Baz != null -> Union.Baz(Baz)
| NotFoo != null -> Union.NotFoo(NotFoo)
| else -> throw IllegalStateException("unpossible")
| }
|
| override fun reset() {
| this.Foo = null
| this.Bar = null
| this.Baz = null
| this.NotFoo = null
| }
| }
""".trimMargin())
}
@Test
fun `union wont generate builder when disabled`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { builderlessDataClasses() }
file.single().toString() shouldNot contain("""
| class Builder
""".trimMargin())
}
@Ignore
@Test
fun `union wont generate struct when disabled`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) //{ shouldImplementStruct() }
file.single().toString() shouldNot contain("""
| : Struct
""".trimMargin())
file.single().toString() shouldNot contain("""
| write
""".trimMargin())
}
@Test
fun `union generate write function`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) //{ shouldImplementStruct() }
file.single().toString() should contain(""" |
| override fun write(protocol: Protocol, struct: Union) {
| protocol.writeStructBegin("Union")
| if (struct is Foo) {
| protocol.writeFieldBegin("Foo", 1, TType.I32)
| protocol.writeI32(struct.value!!)
| protocol.writeFieldEnd()
| }
| if (struct is Bar) {
| protocol.writeFieldBegin("Bar", 2, TType.I64)
| protocol.writeI64(struct.value!!)
| protocol.writeFieldEnd()
| }
| if (struct is Baz) {
| protocol.writeFieldBegin("Baz", 3, TType.STRING)
| protocol.writeString(struct.value!!)
| protocol.writeFieldEnd()
| }
| if (struct is NotFoo) {
| protocol.writeFieldBegin("NotFoo", 4, TType.I32)
| protocol.writeI32(struct.value!!)
| protocol.writeFieldEnd()
| }
| protocol.writeFieldStop()
| protocol.writeStructEnd()
| }
| }
""".trimMargin())
}
@Test
fun `union generate read function`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) //{ shouldImplementStruct() }
file.single().toString() should contain("""
| override fun read(protocol: Protocol) = read(protocol, Builder())
|
| override fun read(protocol: Protocol, builder: Builder): Union {
| protocol.readStructBegin()
| while (true) {
| val fieldMeta = protocol.readFieldBegin()
| if (fieldMeta.typeId == TType.STOP) {
| break
| }
| when (fieldMeta.fieldId.toInt()) {
| 1 -> {
| if (fieldMeta.typeId == TType.I32) {
| val Foo = protocol.readI32()
| builder.Foo(Foo)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 2 -> {
| if (fieldMeta.typeId == TType.I64) {
| val Bar = protocol.readI64()
| builder.Bar(Bar)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 3 -> {
| if (fieldMeta.typeId == TType.STRING) {
| val Baz = protocol.readString()
| builder.Baz(Baz)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 4 -> {
| if (fieldMeta.typeId == TType.I32) {
| val NotFoo = protocol.readI32()
| builder.NotFoo(NotFoo)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| protocol.readFieldEnd()
| }
| protocol.readStructEnd()
| return builder.build()
| }
""".trimMargin())
}
@Test
fun `union generate read function without builder`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { builderlessDataClasses() }
file.single().toString() should contain("""
| override fun read(protocol: Protocol): Union {
| protocol.readStructBegin()
| var result : Union? = null
| while (true) {
| val fieldMeta = protocol.readFieldBegin()
| if (fieldMeta.typeId == TType.STOP) {
| break
| }
| when (fieldMeta.fieldId.toInt()) {
| 1 -> {
| if (fieldMeta.typeId == TType.I32) {
| val Foo = protocol.readI32()
| result = Foo(Foo)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 2 -> {
| if (fieldMeta.typeId == TType.I64) {
| val Bar = protocol.readI64()
| result = Bar(Bar)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 3 -> {
| if (fieldMeta.typeId == TType.STRING) {
| val Baz = protocol.readString()
| result = Baz(Baz)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| 4 -> {
| if (fieldMeta.typeId == TType.I32) {
| val NotFoo = protocol.readI32()
| result = NotFoo(NotFoo)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| protocol.readFieldEnd()
| }
| protocol.readStructEnd()
| if (null == result) {
| throw IllegalStateException("unreadable")
| } else {
| return result
| }
| }
""".trimMargin())
}
@Test
fun `union generate Adapter with builder`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift)
file.single().toString() should contain("""
| private class UnionAdapter : Adapter<Union, Builder> {
""".trimMargin())
}
@Test
fun `union generate Adapter`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
| 1: i32 Foo;
| 2: i64 Bar;
| 3: string Baz;
| 4: i32 NotFoo;
|}
""".trimMargin()
val file = generate(thrift) { builderlessDataClasses() }
file.single().toString() should contain("""
| private class UnionAdapter : Adapter<Union> {
""".trimMargin())
}
@Test
fun `empty union generate sealed`() {
val thrift = """
|namespace kt test.coro
|
|union Union {
|}
""".trimMargin()
val file = generate(thrift) { coroutineServiceClients() }
file.single().toString() should contain("""
|sealed class Union : Struct {
""".trimMargin())
}
@Test
fun `struct with union`() {
val thrift = """
|namespace kt test.coro
|
|struct Bonk {
| 1: string message;
| 2: i32 type;
|}
|
|union UnionStruct {
| 1: Bonk Struct
|}
""".trimMargin()
val file = generate(thrift) { coroutineServiceClients() }
file.single().toString() should contain("""
|sealed class UnionStruct : Struct {
| override fun write(protocol: Protocol) {
| ADAPTER.write(protocol, this)
| }
|
| data class Struct(var value: Bonk?) : UnionStruct()
|
| class Builder : StructBuilder<UnionStruct> {
| private var Struct: Bonk?
|
| constructor() {
| this.Struct = null
| }
|
| constructor(source: UnionStruct) : this() {
| when(source) {
| is Struct -> this.Struct = source.value
| }
| }
|
| fun Struct(value: Bonk) = apply {
| this.Struct = value
| }
|
| override fun build(): UnionStruct = when {
| Struct != null -> UnionStruct.Struct(Struct)
| else -> throw IllegalStateException("unpossible")
| }
|
| override fun reset() {
| this.Struct = null
| }
| }
|
| private class UnionStructAdapter : Adapter<UnionStruct, Builder> {
| override fun read(protocol: Protocol) = read(protocol, Builder())
|
| override fun read(protocol: Protocol, builder: Builder): UnionStruct {
| protocol.readStructBegin()
| while (true) {
| val fieldMeta = protocol.readFieldBegin()
| if (fieldMeta.typeId == TType.STOP) {
| break
| }
| when (fieldMeta.fieldId.toInt()) {
| 1 -> {
| if (fieldMeta.typeId == TType.STRUCT) {
| val Struct = Bonk.ADAPTER.read(protocol)
| builder.Struct(Struct)
| } else {
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| }
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
| }
| protocol.readFieldEnd()
| }
| protocol.readStructEnd()
| return builder.build()
| }
|
| override fun write(protocol: Protocol, struct: UnionStruct) {
| protocol.writeStructBegin("UnionStruct")
| if (struct is Struct) {
| protocol.writeFieldBegin("Struct", 1, TType.STRUCT)
| Bonk.ADAPTER.write(protocol, struct.value!!)
| protocol.writeFieldEnd()
| }
| protocol.writeFieldStop()
| protocol.writeStructEnd()
| }
| }
""".trimMargin())
}
private fun generate(thrift: String, config: (KotlinCodeGenerator.() -> KotlinCodeGenerator)? = null): List<FileSpec> {
val configOrDefault = config ?: { this }
return KotlinCodeGenerator()