Implement Java codegen for struct-valued constants (#503)

This commit is contained in:
Ben Bader 2022-12-07 17:07:00 -07:00 коммит произвёл GitHub
Родитель 141b856f9a
Коммит f9c1099d55
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 184 добавлений и 42 удалений

Просмотреть файл

@ -59,6 +59,11 @@ struct Bonk
typedef map<string,Bonk> MapType
const Bonk A_BONK = {
"message": "foobar",
"type": 100,
}
struct Bools {
1: bool im_true,
2: bool im_false,
@ -96,6 +101,21 @@ struct Insanity
2: list<Xtruct> xtructs
}
const Insanity TOTAL_INSANITY = {
"userMap": {
myNumberz: 1234
},
"xtructs": [
{
"string_thing": "hello",
},
{
"i32_thing": 1,
"bool_thing": 0,
},
]
}
struct CrazyNesting {
1: string string_field,
2: optional set<Insanity> set_field,

Просмотреть файл

@ -318,3 +318,23 @@ union UnionWithResult {
2: i64 bigResult;
3: string error;
}
const Insanity TOTAL_INSANITY = {
"userMap": {
myNumberz: 1234
},
"xtructs": [
{
"string_thing": "hello",
},
{
"i32_thing": 1,
"bool_thing": 0,
},
]
}
const Bonk A_BONK = {
"message": "foobar",
"type": 100,
}

Просмотреть файл

@ -38,15 +38,18 @@ import com.microsoft.thrifty.schema.parser.IntValueElement
import com.microsoft.thrifty.schema.parser.ListValueElement
import com.microsoft.thrifty.schema.parser.LiteralValueElement
import com.microsoft.thrifty.schema.parser.MapValueElement
import com.squareup.javapoet.ClassName
import com.squareup.javapoet.CodeBlock
import com.squareup.javapoet.NameAllocator
import com.squareup.javapoet.ParameterizedTypeName
import com.squareup.javapoet.TypeName
import java.util.Locale
import java.util.NoSuchElementException
import java.util.concurrent.atomic.AtomicInteger
internal class ConstantBuilder(
private val typeResolver: TypeResolver,
private val fieldNamer: FieldNamer,
private val schema: Schema
) {
@ -133,8 +136,29 @@ internal class ConstantBuilder(
}
override fun visitStruct(structType: StructType) {
// TODO: this
throw UnsupportedOperationException("struct-type default values are not yet implemented")
val structTypeName = typeResolver.getJavaClass(structType) as ClassName
val builderTypeName = structTypeName.nestedClass("Builder")
val loweredStructName = structType.name.replaceFirstChar { it.lowercase(Locale.US) }
val builderName = "${loweredStructName}Builder${scope.getAndIncrement()}"
initializer.addStatement("\$1T \$2N = new \$1T()", builderTypeName, builderName)
val fieldsByName = structType.fields.associateBy { it.name }
val map = (value as MapValueElement).value
for ((keyElement, valueElement) in map) {
val key = (keyElement as LiteralValueElement).value
val field = fieldsByName[key] ?: error("Struct ${structType.name} has no field named '$key'")
val setterName = fieldNamer.getName(field)
val valueName = renderConstValue(initializer, allocator, scope, field.type, valueElement)
initializer.addStatement("\$N.\$N(\$L)", builderName, setterName, valueName)
}
if (needsDeclaration) {
initializer.addStatement("\$T \$N = \$N.build()", structTypeName, name, builderName)
} else {
initializer.addStatement("\$N = \$N.build()", name, builderName)
}
}
override fun visitTypedef(typedefType: TypedefType) {
@ -189,12 +213,12 @@ internal class ConstantBuilder(
return constantOrError("Invalid boolean constant")
}
return CodeBlock.builder().add(name).build()
return CodeBlock.of(name)
}
override fun visitByte(byteType: BuiltinType): CodeBlock {
return if (value is IntValueElement) {
CodeBlock.builder().add("(byte) \$L", getNumberLiteral(value)).build()
CodeBlock.of("(byte) \$L", getNumberLiteral(value))
} else {
constantOrError("Invalid byte constant")
}
@ -202,7 +226,7 @@ internal class ConstantBuilder(
override fun visitI16(i16Type: BuiltinType): CodeBlock {
return if (value is IntValueElement) {
CodeBlock.builder().add("(short) \$L", getNumberLiteral(value)).build()
CodeBlock.of("(short) \$L", getNumberLiteral(value))
} else {
constantOrError("Invalid i16 constant")
}
@ -210,7 +234,7 @@ internal class ConstantBuilder(
override fun visitI32(i32Type: BuiltinType): CodeBlock {
return if (value is IntValueElement) {
CodeBlock.builder().add("\$L", getNumberLiteral(value)).build()
CodeBlock.of("\$L", getNumberLiteral(value))
} else {
constantOrError("Invalid i32 constant")
}
@ -218,7 +242,7 @@ internal class ConstantBuilder(
override fun visitI64(i64Type: BuiltinType): CodeBlock {
return if (value is IntValueElement) {
CodeBlock.builder().add("\$LL", getNumberLiteral(value)).build()
CodeBlock.of("\$LL", getNumberLiteral(value))
} else {
constantOrError("Invalid i64 constant")
}
@ -234,7 +258,7 @@ internal class ConstantBuilder(
override fun visitString(stringType: BuiltinType): CodeBlock {
return if (value is LiteralValueElement) {
CodeBlock.builder().add("\$S", value.value).build()
CodeBlock.of("\$S", value.value)
} else {
constantOrError("Invalid string constant")
}
@ -253,7 +277,12 @@ internal class ConstantBuilder(
when (value) {
is IntValueElement -> enumType.findMemberById(value.value.toInt())
is IdentifierValueElement -> {
// TODO(ben): Figure out how to handle const references
try {
return constantOrError("this is gross, sorry")
} catch (e: IllegalStateException) {
// Not a constant
}
// Remove the enum name prefix, assuming it is present
val name = value.value.split(".").last()
enumType.findMemberByName(name)
@ -265,18 +294,14 @@ internal class ConstantBuilder(
"No enum member in ${enumType.name} with value $value")
}
return CodeBlock.builder()
.add("\$T.\$L", typeResolver.getJavaClass(enumType), member.name)
.build()
return CodeBlock.of("\$T.\$L", typeResolver.getJavaClass(enumType), member.name)
}
override fun visitList(listType: ListType): CodeBlock {
return if (value is ListValueElement) {
if (value.value.isEmpty()) {
val elementType = typeResolver.getJavaClass(listType.elementType)
CodeBlock.builder()
.add("\$T.<\$T>emptyList()", TypeNames.COLLECTIONS, elementType)
.build()
CodeBlock.of("\$T.<\$T>emptyList()", TypeNames.COLLECTIONS, elementType)
} else {
visitCollection(listType, "list", "unmodifiableList")
}
@ -289,9 +314,7 @@ internal class ConstantBuilder(
return if (value is ListValueElement) { // not a typo; ListValueElement covers lists and sets.
if (value.value.isEmpty()) {
val elementType = typeResolver.getJavaClass(setType.elementType)
CodeBlock.builder()
.add("\$T.<\$T>emptySet()", TypeNames.COLLECTIONS, elementType)
.build()
CodeBlock.of("\$T.<\$T>emptySet()", TypeNames.COLLECTIONS, elementType)
} else {
visitCollection(setType, "set", "unmodifiableSet")
}
@ -305,9 +328,7 @@ internal class ConstantBuilder(
if (value.value.isEmpty()) {
val keyType = typeResolver.getJavaClass(mapType.keyType)
val valueType = typeResolver.getJavaClass(mapType.valueType)
CodeBlock.builder()
.add("\$T.<\$T, \$T>emptyMap()", TypeNames.COLLECTIONS, keyType, valueType)
.build()
CodeBlock.of("\$T.<\$T, \$T>emptyMap()", TypeNames.COLLECTIONS, keyType, valueType)
} else {
visitCollection(mapType, "map", "unmodifiableMap")
}
@ -322,11 +343,14 @@ internal class ConstantBuilder(
method: String): CodeBlock {
val name = allocator.newName(tempName, scope.getAndIncrement())
generateFieldInitializer(block, allocator, scope, name, type, value, true)
return CodeBlock.builder().add("\$T.\$L(\$N)", TypeNames.COLLECTIONS, method, name).build()
return CodeBlock.of("\$T.\$L(\$N)", TypeNames.COLLECTIONS, method, name)
}
override fun visitStruct(structType: StructType): CodeBlock {
throw IllegalStateException("nested structs not implemented")
val loweredStructName = structType.name.replaceFirstChar { it.lowercase(Locale.getDefault()) }
val name = allocator.newName(loweredStructName, scope.getAndIncrement())
generateFieldInitializer(block, allocator, scope, name, type, value, true)
return CodeBlock.of(name)
}
override fun visitTypedef(typedefType: TypedefType): CodeBlock {
@ -363,7 +387,14 @@ internal class ConstantBuilder(
.firstOrNull() ?: throw IllegalStateException(message)
val packageName = c.getNamespaceFor(NamespaceScope.JAVA)
return CodeBlock.builder().add("$packageName.Constants.$name").build()
return CodeBlock.of("$packageName.Constants.$name")
}
private inline fun buildCodeBlock(fn: CodeBlock.Builder.() -> Unit): CodeBlock {
return CodeBlock.builder().let { builder ->
builder.fn()
builder.build()
}
}
}
}

Просмотреть файл

@ -82,7 +82,7 @@ class ThriftyCodeGenerator {
typeResolver.setSetClass(TypeNames.HASH_SET)
typeResolver.setMapClass(TypeNames.HASH_MAP)
constantBuilder = ConstantBuilder(typeResolver, schema)
constantBuilder = ConstantBuilder(typeResolver, fieldNamer, schema)
serviceBuilder = ServiceBuilder(typeResolver, constantBuilder, fieldNamer)
}
@ -852,6 +852,10 @@ class ThriftyCodeGenerator {
return toString.build()
}
/**
* Builds a `Constants` type containing all constants defined for one
* single package.
*/
private fun buildConst(constants: Collection<Constant>): TypeSpec {
val builder = TypeSpec.classBuilder("Constants")
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
@ -934,7 +938,7 @@ class ThriftyCodeGenerator {
tempName,
type,
constant.value,
true)
needsDeclaration = true)
staticInit.addStatement("\$N = \$T.\$L(\$N)",
constant.name,
TypeNames.COLLECTIONS,
@ -945,7 +949,19 @@ class ThriftyCodeGenerator {
}
override fun visitStruct(structType: StructType) {
throw UnsupportedOperationException("Struct-type constants are not supported")
val cve = constant.value
if (cve is MapValueElement && cve.value.isEmpty()) {
field.initializer("new \$T.Builder().build()", typeResolver.getJavaClass(constant.type))
} else {
constantBuilder.generateFieldInitializer(
staticInit,
allocator,
scope,
constant.name,
constant.type.trueType,
cve,
needsDeclaration = false)
}
}
override fun visitTypedef(typedefType: TypedefType) {

Просмотреть файл

@ -1888,13 +1888,16 @@ class KotlinCodeGenerator(
block.add("%T(\n", className)
for (field in structType.fields) {
// TODO: Once we allow default values, support them here
val fieldValue = fieldValues[field.name]
?: error("No value for struct field '$field.name'")
block.add("%L = ", names[field])
recursivelyRenderConstValue(block, field.type, fieldValue)
block.add(",\n")
if (fieldValue != null) {
block.add("%L = ", names[field])
recursivelyRenderConstValue(block, field.type, fieldValue)
block.add(",\n")
} else {
check(!field.required) { "Missing value for required field '${field.name}'" }
// TODO: if there's a default value, support it
block.add("%L = null,\n", names[field])
}
}
block.add("⇤)")
@ -1903,9 +1906,11 @@ class KotlinCodeGenerator(
block.add("%T().let·{\n", builderType)
for (field in structType.fields) {
// TODO: Once we allow default values, support them here
val fieldValue = fieldValues[field.name]
?: error("No value for struct field '$field.name'")
if (fieldValue == null) {
check(!field.required) { "Missing value for required field '${field.name}'" }
continue
}
block.add("it.%L(", names[field])
recursivelyRenderConstValue(block, field.type, fieldValue)

Просмотреть файл

@ -30,6 +30,7 @@ import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asTypeName
import io.kotest.matchers.ints.shouldBeLessThan
import io.kotest.matchers.should
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNot
@ -1368,6 +1369,33 @@ class KotlinCodeGeneratorTest {
lines[1] shouldContain "\\/\\/ Generated on: [0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}\\.\\S+".toRegex()
}
@Test
fun `constant reference`() {
val thrift = """
|namespace kt test.const_ref
|
|struct Node {
| 1: optional list<Node> n;
|}
|
|const Node D = {"n": [B, C]}
|const Node C = {"n": [A]}
|const Node B = {"n": [A]}
|const Node A = {}
""".trimMargin()
val files = generate(thrift) { filePerType() }
val constants = files.single { it.name.contains("Constants") }
val kotlinText = constants.toString()
val positionOfA = kotlinText.indexOf("val A: Node")
val positionOfD = kotlinText.indexOf("val D: Node")
positionOfA shouldBeLessThan positionOfD
files.shouldCompile()
}
private fun generate(thrift: String, config: (KotlinCodeGenerator.() -> KotlinCodeGenerator)? = null): List<FileSpec> {
val configOrDefault = config ?: { emitFileComment(false) }
return KotlinCodeGenerator()

Просмотреть файл

@ -116,6 +116,10 @@ class Constant private constructor (
return Builder(this)
}
override fun toString(): String {
return "Constant(name=$name, loc=${location.path})"
}
/**
* An object that can build [Constants][Constant].
*/
@ -448,10 +452,9 @@ class Constant private constructor (
Constant.validate(symbolTable, value, field.type)
}
// TODO: Relax this requirement and allow non-required or default-valued fields to be unspecified
check(allFields.isEmpty()) {
val missingFieldNames = allFields.keys.joinToString(", ")
"Expected all fields to be set; missing: $missingFieldNames"
check(allFields.none { it.value.required }) {
val missingRequiredFieldNames = allFields.filter { it.value.required }.map { it.key }.joinToString(", ")
"Some required fields are unset: $missingRequiredFieldNames"
}
} else {
super.validate(symbolTable, expected, valueElement)
@ -528,9 +531,14 @@ class Constant private constructor (
return when (cve) {
is IdentifierValueElement -> getScalarConstantReference()
is MapValueElement -> cve.value.values.flatMap { elem ->
val visitor = ConstantReferenceVisitor(elem, linker)
mapType.valueType.accept(visitor)
is MapValueElement -> {
cve.value.keys.flatMap { elem ->
val visitor = ConstantReferenceVisitor(elem, linker)
mapType.keyType.accept(visitor)
} + cve.value.values.flatMap { elem ->
val visitor = ConstantReferenceVisitor(elem, linker)
mapType.valueType.accept(visitor)
}.distinct()
}
else -> error("no")

Просмотреть файл

@ -1022,6 +1022,20 @@ class LoaderTest {
strs.referencedConstants.map { it.name} shouldBe listOf("A", "B")
}
@Test
fun constantInMapKey() {
val thrift = """
const string KEY = "foo"
const map<string, string> MAP = {
KEY: "bar"
}
""".trimIndent()
val schema = load(thrift)
val (key, map) = schema.constants
map.referencedConstants shouldBe listOf(key)
}
@Test
fun topologicallySortsConstants() {
val thrift = """