Implement Java codegen for struct-valued constants (#503)
This commit is contained in:
Родитель
141b856f9a
Коммит
f9c1099d55
|
@ -59,6 +59,11 @@ struct Bonk
|
|||
|
||||
typedef map<string,Bonk> MapType
|
||||
|
||||
const Bonk A_BONK = {
|
||||
"message": "foobar",
|
||||
"type": 100,
|
||||
}
|
||||
|
||||
struct Bools {
|
||||
1: bool im_true,
|
||||
2: bool im_false,
|
||||
|
@ -96,6 +101,21 @@ struct Insanity
|
|||
2: list<Xtruct> xtructs
|
||||
}
|
||||
|
||||
const Insanity TOTAL_INSANITY = {
|
||||
"userMap": {
|
||||
myNumberz: 1234
|
||||
},
|
||||
"xtructs": [
|
||||
{
|
||||
"string_thing": "hello",
|
||||
},
|
||||
{
|
||||
"i32_thing": 1,
|
||||
"bool_thing": 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
struct CrazyNesting {
|
||||
1: string string_field,
|
||||
2: optional set<Insanity> set_field,
|
||||
|
|
|
@ -318,3 +318,23 @@ union UnionWithResult {
|
|||
2: i64 bigResult;
|
||||
3: string error;
|
||||
}
|
||||
|
||||
const Insanity TOTAL_INSANITY = {
|
||||
"userMap": {
|
||||
myNumberz: 1234
|
||||
},
|
||||
"xtructs": [
|
||||
{
|
||||
"string_thing": "hello",
|
||||
},
|
||||
{
|
||||
"i32_thing": 1,
|
||||
"bool_thing": 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const Bonk A_BONK = {
|
||||
"message": "foobar",
|
||||
"type": 100,
|
||||
}
|
||||
|
|
|
@ -38,15 +38,18 @@ import com.microsoft.thrifty.schema.parser.IntValueElement
|
|||
import com.microsoft.thrifty.schema.parser.ListValueElement
|
||||
import com.microsoft.thrifty.schema.parser.LiteralValueElement
|
||||
import com.microsoft.thrifty.schema.parser.MapValueElement
|
||||
import com.squareup.javapoet.ClassName
|
||||
import com.squareup.javapoet.CodeBlock
|
||||
import com.squareup.javapoet.NameAllocator
|
||||
import com.squareup.javapoet.ParameterizedTypeName
|
||||
import com.squareup.javapoet.TypeName
|
||||
import java.util.Locale
|
||||
import java.util.NoSuchElementException
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
internal class ConstantBuilder(
|
||||
private val typeResolver: TypeResolver,
|
||||
private val fieldNamer: FieldNamer,
|
||||
private val schema: Schema
|
||||
) {
|
||||
|
||||
|
@ -133,8 +136,29 @@ internal class ConstantBuilder(
|
|||
}
|
||||
|
||||
override fun visitStruct(structType: StructType) {
|
||||
// TODO: this
|
||||
throw UnsupportedOperationException("struct-type default values are not yet implemented")
|
||||
val structTypeName = typeResolver.getJavaClass(structType) as ClassName
|
||||
val builderTypeName = structTypeName.nestedClass("Builder")
|
||||
val loweredStructName = structType.name.replaceFirstChar { it.lowercase(Locale.US) }
|
||||
val builderName = "${loweredStructName}Builder${scope.getAndIncrement()}"
|
||||
|
||||
initializer.addStatement("\$1T \$2N = new \$1T()", builderTypeName, builderName)
|
||||
|
||||
val fieldsByName = structType.fields.associateBy { it.name }
|
||||
|
||||
val map = (value as MapValueElement).value
|
||||
for ((keyElement, valueElement) in map) {
|
||||
val key = (keyElement as LiteralValueElement).value
|
||||
val field = fieldsByName[key] ?: error("Struct ${structType.name} has no field named '$key'")
|
||||
val setterName = fieldNamer.getName(field)
|
||||
val valueName = renderConstValue(initializer, allocator, scope, field.type, valueElement)
|
||||
initializer.addStatement("\$N.\$N(\$L)", builderName, setterName, valueName)
|
||||
}
|
||||
|
||||
if (needsDeclaration) {
|
||||
initializer.addStatement("\$T \$N = \$N.build()", structTypeName, name, builderName)
|
||||
} else {
|
||||
initializer.addStatement("\$N = \$N.build()", name, builderName)
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitTypedef(typedefType: TypedefType) {
|
||||
|
@ -189,12 +213,12 @@ internal class ConstantBuilder(
|
|||
return constantOrError("Invalid boolean constant")
|
||||
}
|
||||
|
||||
return CodeBlock.builder().add(name).build()
|
||||
return CodeBlock.of(name)
|
||||
}
|
||||
|
||||
override fun visitByte(byteType: BuiltinType): CodeBlock {
|
||||
return if (value is IntValueElement) {
|
||||
CodeBlock.builder().add("(byte) \$L", getNumberLiteral(value)).build()
|
||||
CodeBlock.of("(byte) \$L", getNumberLiteral(value))
|
||||
} else {
|
||||
constantOrError("Invalid byte constant")
|
||||
}
|
||||
|
@ -202,7 +226,7 @@ internal class ConstantBuilder(
|
|||
|
||||
override fun visitI16(i16Type: BuiltinType): CodeBlock {
|
||||
return if (value is IntValueElement) {
|
||||
CodeBlock.builder().add("(short) \$L", getNumberLiteral(value)).build()
|
||||
CodeBlock.of("(short) \$L", getNumberLiteral(value))
|
||||
} else {
|
||||
constantOrError("Invalid i16 constant")
|
||||
}
|
||||
|
@ -210,7 +234,7 @@ internal class ConstantBuilder(
|
|||
|
||||
override fun visitI32(i32Type: BuiltinType): CodeBlock {
|
||||
return if (value is IntValueElement) {
|
||||
CodeBlock.builder().add("\$L", getNumberLiteral(value)).build()
|
||||
CodeBlock.of("\$L", getNumberLiteral(value))
|
||||
} else {
|
||||
constantOrError("Invalid i32 constant")
|
||||
}
|
||||
|
@ -218,7 +242,7 @@ internal class ConstantBuilder(
|
|||
|
||||
override fun visitI64(i64Type: BuiltinType): CodeBlock {
|
||||
return if (value is IntValueElement) {
|
||||
CodeBlock.builder().add("\$LL", getNumberLiteral(value)).build()
|
||||
CodeBlock.of("\$LL", getNumberLiteral(value))
|
||||
} else {
|
||||
constantOrError("Invalid i64 constant")
|
||||
}
|
||||
|
@ -234,7 +258,7 @@ internal class ConstantBuilder(
|
|||
|
||||
override fun visitString(stringType: BuiltinType): CodeBlock {
|
||||
return if (value is LiteralValueElement) {
|
||||
CodeBlock.builder().add("\$S", value.value).build()
|
||||
CodeBlock.of("\$S", value.value)
|
||||
} else {
|
||||
constantOrError("Invalid string constant")
|
||||
}
|
||||
|
@ -253,7 +277,12 @@ internal class ConstantBuilder(
|
|||
when (value) {
|
||||
is IntValueElement -> enumType.findMemberById(value.value.toInt())
|
||||
is IdentifierValueElement -> {
|
||||
// TODO(ben): Figure out how to handle const references
|
||||
try {
|
||||
return constantOrError("this is gross, sorry")
|
||||
} catch (e: IllegalStateException) {
|
||||
// Not a constant
|
||||
}
|
||||
|
||||
// Remove the enum name prefix, assuming it is present
|
||||
val name = value.value.split(".").last()
|
||||
enumType.findMemberByName(name)
|
||||
|
@ -265,18 +294,14 @@ internal class ConstantBuilder(
|
|||
"No enum member in ${enumType.name} with value $value")
|
||||
}
|
||||
|
||||
return CodeBlock.builder()
|
||||
.add("\$T.\$L", typeResolver.getJavaClass(enumType), member.name)
|
||||
.build()
|
||||
return CodeBlock.of("\$T.\$L", typeResolver.getJavaClass(enumType), member.name)
|
||||
}
|
||||
|
||||
override fun visitList(listType: ListType): CodeBlock {
|
||||
return if (value is ListValueElement) {
|
||||
if (value.value.isEmpty()) {
|
||||
val elementType = typeResolver.getJavaClass(listType.elementType)
|
||||
CodeBlock.builder()
|
||||
.add("\$T.<\$T>emptyList()", TypeNames.COLLECTIONS, elementType)
|
||||
.build()
|
||||
CodeBlock.of("\$T.<\$T>emptyList()", TypeNames.COLLECTIONS, elementType)
|
||||
} else {
|
||||
visitCollection(listType, "list", "unmodifiableList")
|
||||
}
|
||||
|
@ -289,9 +314,7 @@ internal class ConstantBuilder(
|
|||
return if (value is ListValueElement) { // not a typo; ListValueElement covers lists and sets.
|
||||
if (value.value.isEmpty()) {
|
||||
val elementType = typeResolver.getJavaClass(setType.elementType)
|
||||
CodeBlock.builder()
|
||||
.add("\$T.<\$T>emptySet()", TypeNames.COLLECTIONS, elementType)
|
||||
.build()
|
||||
CodeBlock.of("\$T.<\$T>emptySet()", TypeNames.COLLECTIONS, elementType)
|
||||
} else {
|
||||
visitCollection(setType, "set", "unmodifiableSet")
|
||||
}
|
||||
|
@ -305,9 +328,7 @@ internal class ConstantBuilder(
|
|||
if (value.value.isEmpty()) {
|
||||
val keyType = typeResolver.getJavaClass(mapType.keyType)
|
||||
val valueType = typeResolver.getJavaClass(mapType.valueType)
|
||||
CodeBlock.builder()
|
||||
.add("\$T.<\$T, \$T>emptyMap()", TypeNames.COLLECTIONS, keyType, valueType)
|
||||
.build()
|
||||
CodeBlock.of("\$T.<\$T, \$T>emptyMap()", TypeNames.COLLECTIONS, keyType, valueType)
|
||||
} else {
|
||||
visitCollection(mapType, "map", "unmodifiableMap")
|
||||
}
|
||||
|
@ -322,11 +343,14 @@ internal class ConstantBuilder(
|
|||
method: String): CodeBlock {
|
||||
val name = allocator.newName(tempName, scope.getAndIncrement())
|
||||
generateFieldInitializer(block, allocator, scope, name, type, value, true)
|
||||
return CodeBlock.builder().add("\$T.\$L(\$N)", TypeNames.COLLECTIONS, method, name).build()
|
||||
return CodeBlock.of("\$T.\$L(\$N)", TypeNames.COLLECTIONS, method, name)
|
||||
}
|
||||
|
||||
override fun visitStruct(structType: StructType): CodeBlock {
|
||||
throw IllegalStateException("nested structs not implemented")
|
||||
val loweredStructName = structType.name.replaceFirstChar { it.lowercase(Locale.getDefault()) }
|
||||
val name = allocator.newName(loweredStructName, scope.getAndIncrement())
|
||||
generateFieldInitializer(block, allocator, scope, name, type, value, true)
|
||||
return CodeBlock.of(name)
|
||||
}
|
||||
|
||||
override fun visitTypedef(typedefType: TypedefType): CodeBlock {
|
||||
|
@ -363,7 +387,14 @@ internal class ConstantBuilder(
|
|||
.firstOrNull() ?: throw IllegalStateException(message)
|
||||
|
||||
val packageName = c.getNamespaceFor(NamespaceScope.JAVA)
|
||||
return CodeBlock.builder().add("$packageName.Constants.$name").build()
|
||||
return CodeBlock.of("$packageName.Constants.$name")
|
||||
}
|
||||
|
||||
private inline fun buildCodeBlock(fn: CodeBlock.Builder.() -> Unit): CodeBlock {
|
||||
return CodeBlock.builder().let { builder ->
|
||||
builder.fn()
|
||||
builder.build()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ class ThriftyCodeGenerator {
|
|||
typeResolver.setSetClass(TypeNames.HASH_SET)
|
||||
typeResolver.setMapClass(TypeNames.HASH_MAP)
|
||||
|
||||
constantBuilder = ConstantBuilder(typeResolver, schema)
|
||||
constantBuilder = ConstantBuilder(typeResolver, fieldNamer, schema)
|
||||
serviceBuilder = ServiceBuilder(typeResolver, constantBuilder, fieldNamer)
|
||||
}
|
||||
|
||||
|
@ -852,6 +852,10 @@ class ThriftyCodeGenerator {
|
|||
return toString.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a `Constants` type containing all constants defined for one
|
||||
* single package.
|
||||
*/
|
||||
private fun buildConst(constants: Collection<Constant>): TypeSpec {
|
||||
val builder = TypeSpec.classBuilder("Constants")
|
||||
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
|
||||
|
@ -934,7 +938,7 @@ class ThriftyCodeGenerator {
|
|||
tempName,
|
||||
type,
|
||||
constant.value,
|
||||
true)
|
||||
needsDeclaration = true)
|
||||
staticInit.addStatement("\$N = \$T.\$L(\$N)",
|
||||
constant.name,
|
||||
TypeNames.COLLECTIONS,
|
||||
|
@ -945,7 +949,19 @@ class ThriftyCodeGenerator {
|
|||
}
|
||||
|
||||
override fun visitStruct(structType: StructType) {
|
||||
throw UnsupportedOperationException("Struct-type constants are not supported")
|
||||
val cve = constant.value
|
||||
if (cve is MapValueElement && cve.value.isEmpty()) {
|
||||
field.initializer("new \$T.Builder().build()", typeResolver.getJavaClass(constant.type))
|
||||
} else {
|
||||
constantBuilder.generateFieldInitializer(
|
||||
staticInit,
|
||||
allocator,
|
||||
scope,
|
||||
constant.name,
|
||||
constant.type.trueType,
|
||||
cve,
|
||||
needsDeclaration = false)
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitTypedef(typedefType: TypedefType) {
|
||||
|
|
|
@ -1888,13 +1888,16 @@ class KotlinCodeGenerator(
|
|||
block.add("%T(\n⇥", className)
|
||||
|
||||
for (field in structType.fields) {
|
||||
// TODO: Once we allow default values, support them here
|
||||
val fieldValue = fieldValues[field.name]
|
||||
?: error("No value for struct field '$field.name'")
|
||||
|
||||
block.add("%L = ", names[field])
|
||||
recursivelyRenderConstValue(block, field.type, fieldValue)
|
||||
block.add(",\n")
|
||||
if (fieldValue != null) {
|
||||
block.add("%L = ", names[field])
|
||||
recursivelyRenderConstValue(block, field.type, fieldValue)
|
||||
block.add(",\n")
|
||||
} else {
|
||||
check(!field.required) { "Missing value for required field '${field.name}'" }
|
||||
// TODO: if there's a default value, support it
|
||||
block.add("%L = null,\n", names[field])
|
||||
}
|
||||
}
|
||||
|
||||
block.add("⇤)")
|
||||
|
@ -1903,9 +1906,11 @@ class KotlinCodeGenerator(
|
|||
block.add("%T().let·{\n⇥", builderType)
|
||||
|
||||
for (field in structType.fields) {
|
||||
// TODO: Once we allow default values, support them here
|
||||
val fieldValue = fieldValues[field.name]
|
||||
?: error("No value for struct field '$field.name'")
|
||||
if (fieldValue == null) {
|
||||
check(!field.required) { "Missing value for required field '${field.name}'" }
|
||||
continue
|
||||
}
|
||||
|
||||
block.add("it.%L(", names[field])
|
||||
recursivelyRenderConstValue(block, field.type, fieldValue)
|
||||
|
|
|
@ -30,6 +30,7 @@ import com.squareup.kotlinpoet.KModifier
|
|||
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
|
||||
import com.squareup.kotlinpoet.TypeSpec
|
||||
import com.squareup.kotlinpoet.asTypeName
|
||||
import io.kotest.matchers.ints.shouldBeLessThan
|
||||
import io.kotest.matchers.should
|
||||
import io.kotest.matchers.shouldBe
|
||||
import io.kotest.matchers.shouldNot
|
||||
|
@ -1368,6 +1369,33 @@ class KotlinCodeGeneratorTest {
|
|||
lines[1] shouldContain "\\/\\/ Generated on: [0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}\\.\\S+".toRegex()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `constant reference`() {
|
||||
val thrift = """
|
||||
|namespace kt test.const_ref
|
||||
|
|
||||
|struct Node {
|
||||
| 1: optional list<Node> n;
|
||||
|}
|
||||
|
|
||||
|const Node D = {"n": [B, C]}
|
||||
|const Node C = {"n": [A]}
|
||||
|const Node B = {"n": [A]}
|
||||
|const Node A = {}
|
||||
""".trimMargin()
|
||||
|
||||
val files = generate(thrift) { filePerType() }
|
||||
val constants = files.single { it.name.contains("Constants") }
|
||||
val kotlinText = constants.toString()
|
||||
|
||||
val positionOfA = kotlinText.indexOf("val A: Node")
|
||||
val positionOfD = kotlinText.indexOf("val D: Node")
|
||||
|
||||
positionOfA shouldBeLessThan positionOfD
|
||||
|
||||
files.shouldCompile()
|
||||
}
|
||||
|
||||
private fun generate(thrift: String, config: (KotlinCodeGenerator.() -> KotlinCodeGenerator)? = null): List<FileSpec> {
|
||||
val configOrDefault = config ?: { emitFileComment(false) }
|
||||
return KotlinCodeGenerator()
|
||||
|
|
|
@ -116,6 +116,10 @@ class Constant private constructor (
|
|||
return Builder(this)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "Constant(name=$name, loc=${location.path})"
|
||||
}
|
||||
|
||||
/**
|
||||
* An object that can build [Constants][Constant].
|
||||
*/
|
||||
|
@ -448,10 +452,9 @@ class Constant private constructor (
|
|||
Constant.validate(symbolTable, value, field.type)
|
||||
}
|
||||
|
||||
// TODO: Relax this requirement and allow non-required or default-valued fields to be unspecified
|
||||
check(allFields.isEmpty()) {
|
||||
val missingFieldNames = allFields.keys.joinToString(", ")
|
||||
"Expected all fields to be set; missing: $missingFieldNames"
|
||||
check(allFields.none { it.value.required }) {
|
||||
val missingRequiredFieldNames = allFields.filter { it.value.required }.map { it.key }.joinToString(", ")
|
||||
"Some required fields are unset: $missingRequiredFieldNames"
|
||||
}
|
||||
} else {
|
||||
super.validate(symbolTable, expected, valueElement)
|
||||
|
@ -528,9 +531,14 @@ class Constant private constructor (
|
|||
return when (cve) {
|
||||
is IdentifierValueElement -> getScalarConstantReference()
|
||||
|
||||
is MapValueElement -> cve.value.values.flatMap { elem ->
|
||||
val visitor = ConstantReferenceVisitor(elem, linker)
|
||||
mapType.valueType.accept(visitor)
|
||||
is MapValueElement -> {
|
||||
cve.value.keys.flatMap { elem ->
|
||||
val visitor = ConstantReferenceVisitor(elem, linker)
|
||||
mapType.keyType.accept(visitor)
|
||||
} + cve.value.values.flatMap { elem ->
|
||||
val visitor = ConstantReferenceVisitor(elem, linker)
|
||||
mapType.valueType.accept(visitor)
|
||||
}.distinct()
|
||||
}
|
||||
|
||||
else -> error("no")
|
||||
|
|
|
@ -1022,6 +1022,20 @@ class LoaderTest {
|
|||
strs.referencedConstants.map { it.name} shouldBe listOf("A", "B")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun constantInMapKey() {
|
||||
val thrift = """
|
||||
const string KEY = "foo"
|
||||
const map<string, string> MAP = {
|
||||
KEY: "bar"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
val schema = load(thrift)
|
||||
val (key, map) = schema.constants
|
||||
map.referencedConstants shouldBe listOf(key)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun topologicallySortsConstants() {
|
||||
val thrift = """
|
||||
|
|
Загрузка…
Ссылка в новой задаче