Kotlin generate union as sealed class (#253)
This commit is contained in:
Родитель
e222e48e69
Коммит
d1fc321541
|
@ -456,4 +456,17 @@ struct MapsOfCollections {
|
|||
1: map<set<i32>, set<string>> mapOfSets;
|
||||
2: map<list<double>, list<i64>> mapOfLists;
|
||||
3: map<map<i32, i32>, map<i8, i8>> mapOfMaps;
|
||||
}
|
||||
|
||||
union TestUnion {
|
||||
1: i32 AnInt;
|
||||
2: i64 ALong;
|
||||
3: string Text;
|
||||
4: Bonk aBonk;
|
||||
}
|
||||
|
||||
union EmptyUnion {};
|
||||
|
||||
struct HasEmptyUnion {
|
||||
1: EmptyUnion theEmptyUnion;
|
||||
}
|
|
@ -232,7 +232,7 @@ class KotlinCodeGenerator(
|
|||
schema.typedefs.forEach { typedefsByNamespace.put(it.kotlinNamespace, generateTypeAlias(it)) }
|
||||
schema.enums.forEach { specsByNamespace.put(it.kotlinNamespace, generateEnumClass(it)) }
|
||||
schema.structs.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
|
||||
schema.unions.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
|
||||
schema.unions.forEach { specsByNamespace.put(it.kotlinNamespace, generateSealedClass(schema, it)) }
|
||||
schema.exceptions.forEach { specsByNamespace.put(it.kotlinNamespace, generateDataClass(schema, it)) }
|
||||
|
||||
val constantNameAllocators = mutableMapOf<String, NameAllocator>()
|
||||
|
@ -390,8 +390,8 @@ class KotlinCodeGenerator(
|
|||
if (struct.isDeprecated) addAnnotation(makeDeprecated())
|
||||
if (struct.hasJavadoc) addKdoc("%L", struct.documentation)
|
||||
if (struct.isException) superclass(Exception::class)
|
||||
|
||||
if (parcelize) {
|
||||
|
||||
addAnnotation(makeParcelable())
|
||||
addAnnotation(suppressLint("ParcelCreator")) // Android Studio bug with Parcelize
|
||||
}
|
||||
|
@ -499,6 +499,76 @@ class KotlinCodeGenerator(
|
|||
.build()
|
||||
}
|
||||
|
||||
|
||||
internal fun generateSealedClass(schema: Schema, struct: StructType): TypeSpec {
|
||||
val structClassName = ClassName(struct.kotlinNamespace, struct.name)
|
||||
val typeBuilder = TypeSpec.classBuilder(structClassName).apply {
|
||||
addGeneratedAnnotation()
|
||||
|
||||
addModifiers(KModifier.SEALED)
|
||||
|
||||
if (struct.isDeprecated) addAnnotation(makeDeprecated())
|
||||
if (struct.hasJavadoc) addKdoc("%L", struct.documentation)
|
||||
}
|
||||
|
||||
val nameAllocator = nameAllocators[struct]
|
||||
for (field in struct.fields) {
|
||||
val name = nameAllocator.get(field)
|
||||
val type = field.type.typeName
|
||||
|
||||
|
||||
val propertySpec = PropertySpec.varBuilder("value", type.asNullable())
|
||||
.initializer("value")
|
||||
|
||||
val propConstructor = FunSpec.constructorBuilder()
|
||||
.addParameter("value", type.asNullable())
|
||||
|
||||
val dataProp = TypeSpec.classBuilder(name)
|
||||
.addModifiers(KModifier.DATA)
|
||||
.superclass(structClassName)
|
||||
.addProperty(propertySpec.build())
|
||||
.primaryConstructor(propConstructor.build())
|
||||
.build()
|
||||
typeBuilder.addType(dataProp)
|
||||
}
|
||||
var builderTypeName : ClassName? = null
|
||||
var adapterInterfaceTypeName = KtAdapter::class
|
||||
.asTypeName()
|
||||
.parameterizedBy(struct.typeName)
|
||||
if (!builderlessDataClasses) {
|
||||
builderTypeName = ClassName(struct.kotlinNamespace, struct.name, "Builder")
|
||||
|
||||
typeBuilder.addType(generateBuilderForSealed(schema, struct))
|
||||
adapterInterfaceTypeName = Adapter::class.asTypeName().parameterizedBy(
|
||||
struct.typeName, builderTypeName)
|
||||
}
|
||||
|
||||
val adapterTypeName = ClassName(struct.kotlinNamespace, struct.name, "${struct.name}Adapter")
|
||||
|
||||
typeBuilder.addType(generateAdapterForSealed(struct, adapterTypeName, adapterInterfaceTypeName, builderTypeName))
|
||||
|
||||
val companionBuilder = TypeSpec.companionObjectBuilder()
|
||||
companionBuilder.addProperty(PropertySpec.builder("ADAPTER", adapterInterfaceTypeName)
|
||||
.initializer("%T()", adapterTypeName)
|
||||
.jvmField()
|
||||
.build())
|
||||
|
||||
|
||||
if (shouldImplementStruct) {
|
||||
typeBuilder
|
||||
.addSuperinterface(Struct::class)
|
||||
.addFunction(FunSpec.builder("write")
|
||||
.addModifiers(KModifier.OVERRIDE)
|
||||
.addParameter("protocol", Protocol::class)
|
||||
.addStatement("%L.write(protocol, this)", nameAllocator.get(Tags.ADAPTER))
|
||||
.build())
|
||||
}
|
||||
|
||||
return typeBuilder
|
||||
.addType(companionBuilder.build())
|
||||
.build()
|
||||
}
|
||||
|
||||
// endregion Structs
|
||||
|
||||
// region Redaction/obfuscation
|
||||
|
@ -647,6 +717,86 @@ class KotlinCodeGenerator(
|
|||
.build()
|
||||
}
|
||||
|
||||
|
||||
internal fun generateBuilderForSealed(schema: Schema, struct: StructType): TypeSpec {
|
||||
val structTypeName = ClassName(struct.kotlinNamespace, struct.name)
|
||||
val spec = TypeSpec.classBuilder("Builder")
|
||||
.addSuperinterface(StructBuilder::class.asTypeName().parameterizedBy(structTypeName))
|
||||
val buildFunSpec = FunSpec.builder("build")
|
||||
.addModifiers(KModifier.OVERRIDE)
|
||||
.returns(structTypeName)
|
||||
.beginControlFlow("return when")
|
||||
|
||||
val resetFunSpec = FunSpec.builder("reset")
|
||||
.addModifiers(KModifier.OVERRIDE)
|
||||
|
||||
val copyCtor = FunSpec.constructorBuilder()
|
||||
.addParameter("source", structTypeName)
|
||||
.callThisConstructor()
|
||||
.beginControlFlow("when(source)")
|
||||
|
||||
val defaultCtor = FunSpec.constructorBuilder()
|
||||
|
||||
val nameAllocator = nameAllocators[struct]
|
||||
for (field in struct.fields) {
|
||||
val name = nameAllocator.get(field)
|
||||
val type = field.type.typeName
|
||||
|
||||
// Add a private var
|
||||
|
||||
val defaultValueBlock = field.defaultValue?.let {
|
||||
renderConstValue(schema, field.type, it)
|
||||
} ?: CodeBlock.of("null")
|
||||
|
||||
val propertySpec = PropertySpec.varBuilder(name, type.asNullable(), KModifier.PRIVATE)
|
||||
|
||||
// Add a builder fun
|
||||
var content = "return apply {\n"
|
||||
for (field2 in struct.fields) {
|
||||
val name2 = nameAllocator.get(field2)
|
||||
if (name == name2) {
|
||||
content += " this.$name2 = value\n"
|
||||
} else {
|
||||
content += " this.$name2 = null\n"
|
||||
}
|
||||
}
|
||||
content += "}"
|
||||
val builderFunSpec = FunSpec.builder(name)
|
||||
.addParameter("value", type)
|
||||
.addStatement(content)
|
||||
|
||||
// Add initialization in default ctor
|
||||
defaultCtor.addStatement("this.$name = %L", defaultValueBlock)
|
||||
|
||||
// Add initialization in copy ctor
|
||||
copyCtor.addStatement("is $name -> this.$name = source.value")
|
||||
|
||||
// reset field
|
||||
resetFunSpec.addStatement("this.$name = %L", defaultValueBlock)
|
||||
|
||||
// Add field to build-method ctor-invocation arg builder
|
||||
// TODO: Add newlines and indents if numFields > 1
|
||||
buildFunSpec.addStatement("$name != null -> ${struct.name}.$name($name)")
|
||||
|
||||
// Finish off the property and builder fun
|
||||
spec.addProperty(propertySpec.build())
|
||||
spec.addFunction(builderFunSpec.build())
|
||||
}
|
||||
|
||||
buildFunSpec.addStatement("else -> throw IllegalStateException(\"unpossible\")")
|
||||
buildFunSpec.endControlFlow()
|
||||
|
||||
copyCtor.endControlFlow()
|
||||
|
||||
return spec
|
||||
.addFunction(defaultCtor.build())
|
||||
.addFunction(copyCtor.build())
|
||||
.addFunction(buildFunSpec.build())
|
||||
.addFunction(resetFunSpec.build())
|
||||
.build()
|
||||
}
|
||||
|
||||
|
||||
// endregion Builders
|
||||
|
||||
// region Adapters
|
||||
|
@ -819,6 +969,146 @@ class KotlinCodeGenerator(
|
|||
.build()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generates an adapter for the given struct type.
|
||||
*
|
||||
* The kind of adapter generated depends on whether a [builderType] is
|
||||
* provided. If so, a conventional [com.microsoft.thrifty.Adapter] gets
|
||||
* created, making use of the given [builderType]. If not, a so-called
|
||||
* "builderless" [com.microsoft.thrifty.kotlin.Adapter] is the result.
|
||||
*/
|
||||
internal fun generateAdapterForSealed(
|
||||
struct: StructType,
|
||||
adapterName: ClassName,
|
||||
adapterInterfaceName: TypeName,
|
||||
builderType: ClassName?): TypeSpec {
|
||||
val adapter = TypeSpec.classBuilder(adapterName)
|
||||
.addModifiers(KModifier.PRIVATE)
|
||||
.addSuperinterface(adapterInterfaceName)
|
||||
|
||||
val reader = FunSpec.builder("read").apply {
|
||||
addModifiers(KModifier.OVERRIDE)
|
||||
returns(struct.typeName)
|
||||
addParameter("protocol", Protocol::class)
|
||||
|
||||
if (builderType != null) {
|
||||
addParameter("builder", builderType)
|
||||
}
|
||||
}
|
||||
|
||||
val writer = FunSpec.builder("write")
|
||||
.addModifiers(KModifier.OVERRIDE)
|
||||
.addParameter("protocol", Protocol::class)
|
||||
.addParameter("struct", struct.typeName)
|
||||
|
||||
// Writer first, b/c it is easier
|
||||
|
||||
val nameAllocator = nameAllocators[struct]
|
||||
|
||||
writer.addStatement("protocol.writeStructBegin(%S)", struct.name)
|
||||
for (field in struct.fields) {
|
||||
val name = nameAllocator.get(field)
|
||||
val fieldType = field.type
|
||||
|
||||
if (!field.required) {
|
||||
writer.beginControlFlow("if (struct is $name)")
|
||||
}
|
||||
|
||||
writer.addStatement("protocol.writeFieldBegin(%S, %L, %T.%L)",
|
||||
field.name,
|
||||
field.id,
|
||||
TType::class,
|
||||
fieldType.typeCodeName)
|
||||
|
||||
generateWriteCall(writer, "struct.value!!", fieldType)
|
||||
|
||||
writer.addStatement("protocol.writeFieldEnd()")
|
||||
|
||||
if (!field.required) {
|
||||
writer.endControlFlow()
|
||||
}
|
||||
}
|
||||
writer.addStatement("protocol.writeFieldStop()")
|
||||
writer.addStatement("protocol.writeStructEnd()")
|
||||
|
||||
// Reader next
|
||||
|
||||
reader.addStatement("protocol.readStructBegin()")
|
||||
if (builderType == null) {
|
||||
reader.addStatement("var result : ${struct.name}? = null")
|
||||
}
|
||||
reader.beginControlFlow("while (true)")
|
||||
|
||||
reader.addStatement("val fieldMeta = protocol.readFieldBegin()")
|
||||
|
||||
reader.beginControlFlow("if (fieldMeta.typeId == %T.STOP)", TType::class)
|
||||
reader.addStatement("break")
|
||||
reader.endControlFlow()
|
||||
|
||||
|
||||
if (struct.fields.isNotEmpty()) {
|
||||
reader.beginControlFlow("when (fieldMeta.fieldId.toInt())")
|
||||
|
||||
for (field in struct.fields) {
|
||||
val name = nameAllocator.get(field)
|
||||
val fieldType = field.type
|
||||
|
||||
reader.addCode {
|
||||
addStatement("${field.id} -> {%>")
|
||||
beginControlFlow("if (fieldMeta.typeId == %T.%L)", TType::class, fieldType.typeCodeName)
|
||||
|
||||
generateReadCall(this, name, fieldType)
|
||||
|
||||
if (builderType != null) {
|
||||
addStatement("builder.$name($name)")
|
||||
} else {
|
||||
addStatement("result = $name($name)")
|
||||
}
|
||||
|
||||
nextControlFlow("else")
|
||||
addStatement("%T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
|
||||
endControlFlow()
|
||||
addStatement("%<}")
|
||||
}
|
||||
}
|
||||
|
||||
reader.addStatement("else -> %T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
|
||||
reader.endControlFlow() // when (fieldMeta.fieldId.toInt())
|
||||
} else {
|
||||
reader.addStatement("%T.skip(protocol, fieldMeta.typeId)", ProtocolUtil::class)
|
||||
}
|
||||
|
||||
reader.addStatement("protocol.readFieldEnd()")
|
||||
reader.endControlFlow() // while (true)
|
||||
reader.addStatement("protocol.readStructEnd()")
|
||||
|
||||
if (builderType != null) {
|
||||
reader.addStatement("return builder.build()")
|
||||
} else {
|
||||
reader.addCode {
|
||||
beginControlFlow("if (null == result)")
|
||||
addStatement("throw IllegalStateException(\"unreadable\")")
|
||||
nextControlFlow("else")
|
||||
addStatement("return result")
|
||||
endControlFlow()
|
||||
}
|
||||
}
|
||||
|
||||
if (builderType != null) {
|
||||
adapter.addFunction(FunSpec.builder("read")
|
||||
.addModifiers(KModifier.OVERRIDE)
|
||||
.addParameter("protocol", Protocol::class)
|
||||
.addStatement("return read(protocol, %T())", builderType)
|
||||
.build())
|
||||
}
|
||||
|
||||
return adapter
|
||||
.addFunction(reader.build())
|
||||
.addFunction(writer.build())
|
||||
.build()
|
||||
}
|
||||
|
||||
private fun generateWriteCall(writer: FunSpec.Builder, name: String, type: ThriftType) {
|
||||
|
||||
// Assumptions:
|
||||
|
|
|
@ -30,17 +30,21 @@ import com.squareup.kotlinpoet.KModifier
|
|||
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
|
||||
import com.squareup.kotlinpoet.TypeSpec
|
||||
import com.squareup.kotlinpoet.asTypeName
|
||||
import io.kotlintest.shouldNot
|
||||
import io.kotlintest.matchers.string.contain
|
||||
import io.kotlintest.should
|
||||
import io.kotlintest.shouldBe
|
||||
import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.junit.rules.TemporaryFolder
|
||||
|
||||
class KotlinCodeGeneratorTest {
|
||||
@get:Rule val tempDir = TemporaryFolder()
|
||||
@get:Rule
|
||||
val tempDir = TemporaryFolder()
|
||||
|
||||
@Test fun `struct to data class`() {
|
||||
@Test
|
||||
fun `struct to data class`() {
|
||||
val schema = load("""
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -82,7 +86,8 @@ class KotlinCodeGeneratorTest {
|
|||
files.forEach { println("$it") }
|
||||
}
|
||||
|
||||
@Test fun `output styles work as advertised`() {
|
||||
@Test
|
||||
fun `output styles work as advertised`() {
|
||||
val thrift = """
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -101,13 +106,14 @@ class KotlinCodeGeneratorTest {
|
|||
|
||||
// Default should be one file per namespace
|
||||
gen.outputStyle shouldBe KotlinCodeGenerator.OutputStyle.FILE_PER_NAMESPACE
|
||||
gen.generate(schema).size shouldBe 1
|
||||
gen.generate(schema).size shouldBe 1
|
||||
|
||||
gen.outputStyle = KotlinCodeGenerator.OutputStyle.FILE_PER_TYPE
|
||||
gen.generate(schema).size shouldBe 2
|
||||
}
|
||||
|
||||
@Test fun `file-per-type puts constants into a file named 'Constants'`() {
|
||||
@Test
|
||||
fun `file-per-type puts constants into a file named 'Constants'`() {
|
||||
val thrift = """
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -123,7 +129,8 @@ class KotlinCodeGeneratorTest {
|
|||
specs.single().name shouldBe "Constants" // ".kt" suffix is appended when the file is written out
|
||||
}
|
||||
|
||||
@Test fun `empty structs get default equals, hashcode, and toString methods`() {
|
||||
@Test
|
||||
fun `empty structs get default equals, hashcode, and toString methods`() {
|
||||
val thrift = """
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -140,10 +147,11 @@ class KotlinCodeGeneratorTest {
|
|||
struct.modifiers.any { it == KModifier.DATA } shouldBe false
|
||||
struct.funSpecs.any { it.name == "toString" } shouldBe true
|
||||
struct.funSpecs.any { it.name == "hashCode" } shouldBe true
|
||||
struct.funSpecs.any { it.name == "equals" } shouldBe true
|
||||
struct.funSpecs.any { it.name == "equals" } shouldBe true
|
||||
}
|
||||
|
||||
@Test fun `Non-empty structs are data classes`() {
|
||||
@Test
|
||||
fun `Non-empty structs are data classes`() {
|
||||
val thrift = """
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -160,10 +168,11 @@ class KotlinCodeGeneratorTest {
|
|||
struct.modifiers.any { it == KModifier.DATA } shouldBe true
|
||||
struct.funSpecs.any { it.name == "toString" } shouldBe false
|
||||
struct.funSpecs.any { it.name == "hashCode" } shouldBe false
|
||||
struct.funSpecs.any { it.name == "equals" } shouldBe false
|
||||
struct.funSpecs.any { it.name == "equals" } shouldBe false
|
||||
}
|
||||
|
||||
@Test fun `exceptions with reserved field names get renamed fields`() {
|
||||
@Test
|
||||
fun `exceptions with reserved field names get renamed fields`() {
|
||||
val thrift = """
|
||||
namespace kt com.test
|
||||
|
||||
|
@ -176,7 +185,8 @@ class KotlinCodeGeneratorTest {
|
|||
xception.propertySpecs.single().name shouldBe "message_"
|
||||
}
|
||||
|
||||
@Test fun services() {
|
||||
@Test
|
||||
fun services() {
|
||||
val thrift = """
|
||||
namespace kt test.services
|
||||
|
||||
|
@ -193,7 +203,8 @@ class KotlinCodeGeneratorTest {
|
|||
generate(thrift).forEach { println(it) }
|
||||
}
|
||||
|
||||
@Test fun `typedefs become typealiases`() {
|
||||
@Test
|
||||
fun `typedefs become typealiases`() {
|
||||
val thrift = """
|
||||
namespace kt test.typedefs
|
||||
|
||||
|
@ -207,7 +218,8 @@ class KotlinCodeGeneratorTest {
|
|||
generate(thrift).forEach { println(it) }
|
||||
}
|
||||
|
||||
@Test fun `services that return typedefs`() {
|
||||
@Test
|
||||
fun `services that return typedefs`() {
|
||||
val thrift = """
|
||||
namespace kt test.typedefs
|
||||
|
||||
|
@ -226,7 +238,8 @@ class KotlinCodeGeneratorTest {
|
|||
.parameterizedBy(ClassName("test.typedefs", "TheNumber"))
|
||||
}
|
||||
|
||||
@Test fun `constants that are typedefs`() {
|
||||
@Test
|
||||
fun `constants that are typedefs`() {
|
||||
val thrift = """
|
||||
|namespace kt test.typedefs
|
||||
|
|
||||
|
@ -248,7 +261,8 @@ class KotlinCodeGeneratorTest {
|
|||
""".trimMargin()
|
||||
}
|
||||
|
||||
@Test fun `Parcelize annotations for structs and enums`() {
|
||||
@Test
|
||||
fun `Parcelize annotations for structs and enums`() {
|
||||
val thrift = """
|
||||
|namespace kt test.parcelize
|
||||
|
|
||||
|
@ -263,7 +277,7 @@ class KotlinCodeGeneratorTest {
|
|||
|
||||
val file = generate(thrift) { parcelize() }.single()
|
||||
val struct = file.members.single { it is TypeSpec && it.name == "Foo" } as TypeSpec
|
||||
val anEnum = file.members.single { it is TypeSpec && it.name == "AnEnum"} as TypeSpec
|
||||
val anEnum = file.members.single { it is TypeSpec && it.name == "AnEnum" } as TypeSpec
|
||||
val svc = file.members.single { it is TypeSpec && it.name == "SvcClient" } as TypeSpec
|
||||
|
||||
val parcelize = ClassName("kotlinx.android.parcel", "Parcelize")
|
||||
|
@ -273,7 +287,8 @@ class KotlinCodeGeneratorTest {
|
|||
svc.annotations.any { it.type == parcelize } shouldBe false
|
||||
}
|
||||
|
||||
@Test fun `Custom map-type constants`() {
|
||||
@Test
|
||||
fun `Custom map-type constants`() {
|
||||
val thrift = """
|
||||
|namespace kt test.map_consts
|
||||
|
|
||||
|
@ -301,7 +316,8 @@ class KotlinCodeGeneratorTest {
|
|||
""".trimMargin()
|
||||
}
|
||||
|
||||
@Test fun `suspend-fun service clients`() {
|
||||
@Test
|
||||
fun `suspend-fun service clients`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|
@ -334,6 +350,521 @@ class KotlinCodeGeneratorTest {
|
|||
""".trimMargin())
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `union generate sealed`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { coroutineServiceClients() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
|sealed class Union : Struct {
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union properties as data`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { coroutineServiceClients() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
|
|
||||
| data class Foo(var value: Int?) : Union()
|
||||
|
|
||||
| data class Bar(var value: Long?) : Union()
|
||||
|
|
||||
| data class Baz(var value: String?) : Union()
|
||||
|
|
||||
| data class NotFoo(var value: Int?) : Union()
|
||||
|
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union has builder`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { coroutineServiceClients() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
| class Builder : StructBuilder<Union> {
|
||||
| private var Foo: Int?
|
||||
|
|
||||
| private var Bar: Long?
|
||||
|
|
||||
| private var Baz: String?
|
||||
|
|
||||
| private var NotFoo: Int?
|
||||
|
|
||||
| constructor() {
|
||||
| this.Foo = null
|
||||
| this.Bar = null
|
||||
| this.Baz = null
|
||||
| this.NotFoo = null
|
||||
| }
|
||||
|
|
||||
| constructor(source: Union) : this() {
|
||||
| when(source) {
|
||||
| is Foo -> this.Foo = source.value
|
||||
| is Bar -> this.Bar = source.value
|
||||
| is Baz -> this.Baz = source.value
|
||||
| is NotFoo -> this.NotFoo = source.value
|
||||
| }
|
||||
| }
|
||||
|
|
||||
| fun Foo(value: Int) = apply {
|
||||
| this.Foo = value
|
||||
| this.Bar = null
|
||||
| this.Baz = null
|
||||
| this.NotFoo = null
|
||||
| }
|
||||
|
|
||||
| fun Bar(value: Long) = apply {
|
||||
| this.Foo = null
|
||||
| this.Bar = value
|
||||
| this.Baz = null
|
||||
| this.NotFoo = null
|
||||
| }
|
||||
|
|
||||
| fun Baz(value: String) = apply {
|
||||
| this.Foo = null
|
||||
| this.Bar = null
|
||||
| this.Baz = value
|
||||
| this.NotFoo = null
|
||||
| }
|
||||
|
|
||||
| fun NotFoo(value: Int) = apply {
|
||||
| this.Foo = null
|
||||
| this.Bar = null
|
||||
| this.Baz = null
|
||||
| this.NotFoo = value
|
||||
| }
|
||||
|
|
||||
| override fun build(): Union = when {
|
||||
| Foo != null -> Union.Foo(Foo)
|
||||
| Bar != null -> Union.Bar(Bar)
|
||||
| Baz != null -> Union.Baz(Baz)
|
||||
| NotFoo != null -> Union.NotFoo(NotFoo)
|
||||
| else -> throw IllegalStateException("unpossible")
|
||||
| }
|
||||
|
|
||||
| override fun reset() {
|
||||
| this.Foo = null
|
||||
| this.Bar = null
|
||||
| this.Baz = null
|
||||
| this.NotFoo = null
|
||||
| }
|
||||
| }
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union wont generate builder when disabled`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { builderlessDataClasses() }
|
||||
|
||||
file.single().toString() shouldNot contain("""
|
||||
| class Builder
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
fun `union wont generate struct when disabled`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) //{ shouldImplementStruct() }
|
||||
|
||||
file.single().toString() shouldNot contain("""
|
||||
| : Struct
|
||||
""".trimMargin())
|
||||
|
||||
file.single().toString() shouldNot contain("""
|
||||
| write
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union generate write function`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) //{ shouldImplementStruct() }
|
||||
|
||||
file.single().toString() should contain(""" |
|
||||
| override fun write(protocol: Protocol, struct: Union) {
|
||||
| protocol.writeStructBegin("Union")
|
||||
| if (struct is Foo) {
|
||||
| protocol.writeFieldBegin("Foo", 1, TType.I32)
|
||||
| protocol.writeI32(struct.value!!)
|
||||
| protocol.writeFieldEnd()
|
||||
| }
|
||||
| if (struct is Bar) {
|
||||
| protocol.writeFieldBegin("Bar", 2, TType.I64)
|
||||
| protocol.writeI64(struct.value!!)
|
||||
| protocol.writeFieldEnd()
|
||||
| }
|
||||
| if (struct is Baz) {
|
||||
| protocol.writeFieldBegin("Baz", 3, TType.STRING)
|
||||
| protocol.writeString(struct.value!!)
|
||||
| protocol.writeFieldEnd()
|
||||
| }
|
||||
| if (struct is NotFoo) {
|
||||
| protocol.writeFieldBegin("NotFoo", 4, TType.I32)
|
||||
| protocol.writeI32(struct.value!!)
|
||||
| protocol.writeFieldEnd()
|
||||
| }
|
||||
| protocol.writeFieldStop()
|
||||
| protocol.writeStructEnd()
|
||||
| }
|
||||
| }
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union generate read function`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) //{ shouldImplementStruct() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
| override fun read(protocol: Protocol) = read(protocol, Builder())
|
||||
|
|
||||
| override fun read(protocol: Protocol, builder: Builder): Union {
|
||||
| protocol.readStructBegin()
|
||||
| while (true) {
|
||||
| val fieldMeta = protocol.readFieldBegin()
|
||||
| if (fieldMeta.typeId == TType.STOP) {
|
||||
| break
|
||||
| }
|
||||
| when (fieldMeta.fieldId.toInt()) {
|
||||
| 1 -> {
|
||||
| if (fieldMeta.typeId == TType.I32) {
|
||||
| val Foo = protocol.readI32()
|
||||
| builder.Foo(Foo)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 2 -> {
|
||||
| if (fieldMeta.typeId == TType.I64) {
|
||||
| val Bar = protocol.readI64()
|
||||
| builder.Bar(Bar)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 3 -> {
|
||||
| if (fieldMeta.typeId == TType.STRING) {
|
||||
| val Baz = protocol.readString()
|
||||
| builder.Baz(Baz)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 4 -> {
|
||||
| if (fieldMeta.typeId == TType.I32) {
|
||||
| val NotFoo = protocol.readI32()
|
||||
| builder.NotFoo(NotFoo)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| protocol.readFieldEnd()
|
||||
| }
|
||||
| protocol.readStructEnd()
|
||||
| return builder.build()
|
||||
| }
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union generate read function without builder`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { builderlessDataClasses() }
|
||||
|
||||
|
||||
file.single().toString() should contain("""
|
||||
| override fun read(protocol: Protocol): Union {
|
||||
| protocol.readStructBegin()
|
||||
| var result : Union? = null
|
||||
| while (true) {
|
||||
| val fieldMeta = protocol.readFieldBegin()
|
||||
| if (fieldMeta.typeId == TType.STOP) {
|
||||
| break
|
||||
| }
|
||||
| when (fieldMeta.fieldId.toInt()) {
|
||||
| 1 -> {
|
||||
| if (fieldMeta.typeId == TType.I32) {
|
||||
| val Foo = protocol.readI32()
|
||||
| result = Foo(Foo)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 2 -> {
|
||||
| if (fieldMeta.typeId == TType.I64) {
|
||||
| val Bar = protocol.readI64()
|
||||
| result = Bar(Bar)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 3 -> {
|
||||
| if (fieldMeta.typeId == TType.STRING) {
|
||||
| val Baz = protocol.readString()
|
||||
| result = Baz(Baz)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| 4 -> {
|
||||
| if (fieldMeta.typeId == TType.I32) {
|
||||
| val NotFoo = protocol.readI32()
|
||||
| result = NotFoo(NotFoo)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| protocol.readFieldEnd()
|
||||
| }
|
||||
| protocol.readStructEnd()
|
||||
| if (null == result) {
|
||||
| throw IllegalStateException("unreadable")
|
||||
| } else {
|
||||
| return result
|
||||
| }
|
||||
| }
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union generate Adapter with builder`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift)
|
||||
|
||||
file.single().toString() should contain("""
|
||||
| private class UnionAdapter : Adapter<Union, Builder> {
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `union generate Adapter`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
| 1: i32 Foo;
|
||||
| 2: i64 Bar;
|
||||
| 3: string Baz;
|
||||
| 4: i32 NotFoo;
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { builderlessDataClasses() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
| private class UnionAdapter : Adapter<Union> {
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `empty union generate sealed`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|union Union {
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { coroutineServiceClients() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
|sealed class Union : Struct {
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `struct with union`() {
|
||||
val thrift = """
|
||||
|namespace kt test.coro
|
||||
|
|
||||
|struct Bonk {
|
||||
| 1: string message;
|
||||
| 2: i32 type;
|
||||
|}
|
||||
|
|
||||
|union UnionStruct {
|
||||
| 1: Bonk Struct
|
||||
|}
|
||||
""".trimMargin()
|
||||
|
||||
val file = generate(thrift) { coroutineServiceClients() }
|
||||
|
||||
file.single().toString() should contain("""
|
||||
|sealed class UnionStruct : Struct {
|
||||
| override fun write(protocol: Protocol) {
|
||||
| ADAPTER.write(protocol, this)
|
||||
| }
|
||||
|
|
||||
| data class Struct(var value: Bonk?) : UnionStruct()
|
||||
|
|
||||
| class Builder : StructBuilder<UnionStruct> {
|
||||
| private var Struct: Bonk?
|
||||
|
|
||||
| constructor() {
|
||||
| this.Struct = null
|
||||
| }
|
||||
|
|
||||
| constructor(source: UnionStruct) : this() {
|
||||
| when(source) {
|
||||
| is Struct -> this.Struct = source.value
|
||||
| }
|
||||
| }
|
||||
|
|
||||
| fun Struct(value: Bonk) = apply {
|
||||
| this.Struct = value
|
||||
| }
|
||||
|
|
||||
| override fun build(): UnionStruct = when {
|
||||
| Struct != null -> UnionStruct.Struct(Struct)
|
||||
| else -> throw IllegalStateException("unpossible")
|
||||
| }
|
||||
|
|
||||
| override fun reset() {
|
||||
| this.Struct = null
|
||||
| }
|
||||
| }
|
||||
|
|
||||
| private class UnionStructAdapter : Adapter<UnionStruct, Builder> {
|
||||
| override fun read(protocol: Protocol) = read(protocol, Builder())
|
||||
|
|
||||
| override fun read(protocol: Protocol, builder: Builder): UnionStruct {
|
||||
| protocol.readStructBegin()
|
||||
| while (true) {
|
||||
| val fieldMeta = protocol.readFieldBegin()
|
||||
| if (fieldMeta.typeId == TType.STOP) {
|
||||
| break
|
||||
| }
|
||||
| when (fieldMeta.fieldId.toInt()) {
|
||||
| 1 -> {
|
||||
| if (fieldMeta.typeId == TType.STRUCT) {
|
||||
| val Struct = Bonk.ADAPTER.read(protocol)
|
||||
| builder.Struct(Struct)
|
||||
| } else {
|
||||
| ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| }
|
||||
| else -> ProtocolUtil.skip(protocol, fieldMeta.typeId)
|
||||
| }
|
||||
| protocol.readFieldEnd()
|
||||
| }
|
||||
| protocol.readStructEnd()
|
||||
| return builder.build()
|
||||
| }
|
||||
|
|
||||
| override fun write(protocol: Protocol, struct: UnionStruct) {
|
||||
| protocol.writeStructBegin("UnionStruct")
|
||||
| if (struct is Struct) {
|
||||
| protocol.writeFieldBegin("Struct", 1, TType.STRUCT)
|
||||
| Bonk.ADAPTER.write(protocol, struct.value!!)
|
||||
| protocol.writeFieldEnd()
|
||||
| }
|
||||
| protocol.writeFieldStop()
|
||||
| protocol.writeStructEnd()
|
||||
| }
|
||||
| }
|
||||
""".trimMargin())
|
||||
}
|
||||
|
||||
private fun generate(thrift: String, config: (KotlinCodeGenerator.() -> KotlinCodeGenerator)? = null): List<FileSpec> {
|
||||
val configOrDefault = config ?: { this }
|
||||
return KotlinCodeGenerator()
|
||||
|
|
Загрузка…
Ссылка в новой задаче