Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a72f6ea
feat(13-01): restructure polymorphic generation to nested sealed hier…
halotukozak Mar 24, 2026
cbf6804
feat(13-01): wire classNameLookup through ClientGenerator and CodeGen…
halotukozak Mar 24, 2026
e4030dd
Merge remote-tracking branch 'origin/master' into feat/nested-sealed-…
halotukozak Mar 30, 2026
ad24a54
Merge origin/master into feat/nested-sealed-hierarchies
halotukozak Mar 31, 2026
517e2ce
Merge origin/master into feat/nested-sealed-hierarchies
halotukozak Apr 1, 2026
bc1e7b9
fix(core): propagate classNameLookup in recursive toTypeName calls
halotukozak Apr 1, 2026
0ba6291
refactor: replace `ModelPackage` with `Hierarchy` for sealed schema h…
halotukozak Apr 1, 2026
9d4347f
refactor: replace lazy properties with constructor initialization in …
halotukozak Apr 1, 2026
3795864
refactor: simplify `resolveSerialName` logic and centralize `SCHEMA_P…
halotukozak Apr 1, 2026
f44924c
fix: update all tests for Hierarchy context and fix anyOfWithoutDiscr…
halotukozak Apr 1, 2026
ce1c3a8
refactor: simplify Hierarchy and extract shared constructor building …
halotukozak Apr 1, 2026
ed99cbf
refactor: simplify `Hierarchy` and `ModelGenerator` with cleaner null…
halotukozak Apr 1, 2026
be04037
refactor: enhance `Hierarchy` with caching and incremental schema upd…
halotukozak Apr 1, 2026
8570c44
refactor: replace `CacheGroup` with `MemoScope`, streamline `Hierarch…
halotukozak Apr 2, 2026
7bf5821
refactor: improve `ModelGenerator` and `Memo` with cleaner design and…
halotukozak Apr 2, 2026
ec4b207
refactor: update functional test to verify sealed class generation wi…
halotukozak Apr 2, 2026
0cda7c0
refactor: enhance `Hierarchy` with `anyOfParents` cache and doc impro…
halotukozak Apr 2, 2026
db8de77
merge: resolve master into feat/nested-sealed-hierarchies
halotukozak Apr 2, 2026
70c4f7d
refactor(core): replace sealed classes with sealed interfaces and int…
halotukozak Apr 4, 2026
cfd16f2
test: update assertion to check for sealed interface instead of seale…
halotukozak Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions core/src/main/kotlin/com/avsystem/justworks/core/Memo.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.avsystem.justworks.core

import arrow.atomic.Atomic
import arrow.atomic.update

@JvmInline
value class MemoScope private constructor(private val memos: Atomic<Set<Memo<*>>>) {
constructor() : this(Atomic(emptySet()))

fun register(memo: Memo<*>) {
memos.update { it + memo }
}

fun reset() {
memos.get().forEach { it.reset() }
}
}

class Memo<T>(private val compute: () -> T) {
private val holder = Atomic(lazy(compute))

operator fun getValue(thisRef: Any?, property: Any?): T = holder.get().value

fun reset() {
holder.set(lazy(compute))
}
}

fun <T> memoized(memoScope: MemoScope, compute: () -> T): Memo<T> = Memo(compute).also(memoScope::register)
2 changes: 2 additions & 0 deletions core/src/main/kotlin/com/avsystem/justworks/core/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ package com.avsystem.justworks.core

import kotlin.enums.enumEntries

internal const val SCHEMA_PREFIX = "#/components/schemas/"

inline fun <reified T : Enum<T>> String.toEnumOrNull(): T? = enumEntries<T>().find { it.name.equals(this, true) }
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@ object CodeGenerator {
modelPackage: String,
apiPackage: String,
outputDir: File,
): Result = context(ModelPackage(modelPackage), ApiPackage(apiPackage)) {
val modelRegistry = NameRegistry()
val apiRegistry = NameRegistry()
): Result {
val hierarchy = Hierarchy(ModelPackage(modelPackage)).apply { addSchemas(spec.schemas) }

val (modelFiles, resolvedSpec) = ModelGenerator.generateWithResolvedSpec(spec, modelRegistry)
val (modelFiles, resolvedSpec) = context(hierarchy, NameRegistry()) {
ModelGenerator.generateWithResolvedSpec(spec)
}

modelFiles.forEach { it.writeTo(outputDir) }

val hasPolymorphicTypes = modelFiles.any { it.name == SERIALIZERS_MODULE.simpleName }

val clientFiles = ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes, apiRegistry)
val clientFiles = context(hierarchy, ApiPackage(apiPackage), NameRegistry()) {
ClientGenerator.generate(resolvedSpec, hasPolymorphicTypes)
}

clientFiles.forEach { it.writeTo(outputDir) }

Expand Down
80 changes: 80 additions & 0 deletions core/src/main/kotlin/com/avsystem/justworks/core/gen/Hierarchy.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.avsystem.justworks.core.gen

import com.avsystem.justworks.core.MemoScope
import com.avsystem.justworks.core.memoized
import com.avsystem.justworks.core.model.SchemaModel
import com.avsystem.justworks.core.model.TypeRef
import com.squareup.kotlinpoet.ClassName

internal class Hierarchy(val modelPackage: ModelPackage) {
private val schemas = mutableSetOf<SchemaModel>()

private val memoScope = MemoScope()

/**
* Updates the underlying schemas and invalidates all cached derived views.
* This is necessary when schemas are updated (e.g., after inlining types).
*/
fun addSchemas(newSchemas: List<SchemaModel>) {
schemas += newSchemas
memoScope.reset()
}

/** All schemas indexed by name for quick lookup. */
val schemasById: Map<String, SchemaModel> by memoized(memoScope) {
schemas.associateBy { it.name }
}

/** Schemas that define polymorphic variants via oneOf or anyOf. */
private val polymorphicSchemas: List<SchemaModel> by memoized(memoScope) {
schemas.filterNot { it.variants().isNullOrEmpty() }
}

/** Maps parent schema name to its variant schema names (for both oneOf and anyOf). */
val sealedHierarchies: Map<String, List<String>> by memoized(memoScope) {
polymorphicSchemas
.associate { schema ->
schema.name to schema
.variants()
?.filterIsInstance<TypeRef.Reference>()
?.map { it.schemaName }
.orEmpty()
}
}

/** Parent schema names that use anyOf without a discriminator (JsonContentPolymorphicSerializer pattern). */
val anyOfWithoutDiscriminator: Set<String> by memoized(memoScope) {
polymorphicSchemas
.asSequence()
.filter { !it.anyOf.isNullOrEmpty() && it.discriminator == null }
.map { it.name }
.toSet()
}

/** Inverse of [sealedHierarchies] for anyOf-without-discriminator: variant name to its parent names. */
val anyOfParents: Map<String, Set<String>> by memoized(memoScope) {
sealedHierarchies
.asSequence()
.filter { (parent, _) -> parent in anyOfWithoutDiscriminator }
.flatMap { (parent, variants) -> variants.map { it to parent } }
.groupBy({ it.first }, { it.second })
.mapValues { (_, parents) -> parents.toSet() }
}

/** Maps schema name to its [ClassName], using nested class for discriminated hierarchy variants. */
private val lookup: Map<String, ClassName> by memoized(memoScope) {
sealedHierarchies
.asSequence()
.filterNot { (parent, _) -> parent in anyOfWithoutDiscriminator }
.flatMap { (parent, variants) ->
val parentClass = ClassName(modelPackage, parent)
variants.map { variant -> variant to parentClass.nestedClass(variant.toPascalCase()) } +
(parent to parentClass)
}.toMap()
}

/** Resolves a schema name to its [ClassName], falling back to a flat top-level class. */
operator fun get(name: String): ClassName = lookup[name] ?: ClassName(modelPackage, name)
}

private fun SchemaModel.variants() = oneOf ?: anyOf
17 changes: 14 additions & 3 deletions core/src/main/kotlin/com/avsystem/justworks/core/gen/Utils.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.avsystem.justworks.core.gen

import com.avsystem.justworks.core.SCHEMA_PREFIX
import com.avsystem.justworks.core.model.PrimitiveType
import com.avsystem.justworks.core.model.PropertyModel
import com.avsystem.justworks.core.model.SchemaModel
import com.avsystem.justworks.core.model.TypeRef
import com.squareup.kotlinpoet.BOOLEAN
import com.squareup.kotlinpoet.BYTE_ARRAY
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.DOUBLE
import com.squareup.kotlinpoet.FLOAT
import com.squareup.kotlinpoet.INT
Expand All @@ -28,7 +29,7 @@ internal val TypeRef.requiredProperties: Set<String>
is TypeRef.Array, is TypeRef.Map, is TypeRef.Primitive, is TypeRef.Reference, TypeRef.Unknown -> emptySet()
}

context(modelPackage: ModelPackage)
context(hierarchy: Hierarchy)
internal fun TypeRef.toTypeName(): TypeName = when (this) {
is TypeRef.Primitive -> {
when (type) {
Expand All @@ -54,7 +55,7 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) {
}

is TypeRef.Reference -> {
ClassName(modelPackage, schemaName)
hierarchy[schemaName]
}

is TypeRef.Inline -> {
Expand All @@ -67,3 +68,13 @@ internal fun TypeRef.toTypeName(): TypeName = when (this) {
}

internal fun TypeRef.isBinaryUpload(): Boolean = this is TypeRef.Primitive && this.type == PrimitiveType.BYTE_ARRAY

/**
* Resolves the @SerialName value for a variant within a oneOf schema.
*/
internal fun SchemaModel.resolveSerialName(variantSchemaName: String): String = discriminator
?.mapping
?.firstNotNullOfOrNull { (serialName, refPath) ->
serialName.takeIf { refPath.removePrefix(SCHEMA_PREFIX) == variantSchemaName }
}
?: variantSchemaName
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.avsystem.justworks.core.gen.GENERATED_SERIALIZERS_MODULE
import com.avsystem.justworks.core.gen.HTTP_CLIENT
import com.avsystem.justworks.core.gen.HTTP_ERROR
import com.avsystem.justworks.core.gen.HTTP_SUCCESS
import com.avsystem.justworks.core.gen.ModelPackage
import com.avsystem.justworks.core.gen.Hierarchy
import com.avsystem.justworks.core.gen.NameRegistry
import com.avsystem.justworks.core.gen.RAISE
import com.avsystem.justworks.core.gen.TOKEN
Expand Down Expand Up @@ -51,29 +51,22 @@ internal object ClientGenerator {
private const val DEFAULT_TAG = "Default"
private const val API_SUFFIX = "Api"

context(_: ModelPackage, _: ApiPackage)
fun generate(
spec: ApiSpec,
hasPolymorphicTypes: Boolean,
nameRegistry: NameRegistry,
): List<FileSpec> {
context(_: Hierarchy, _: ApiPackage, _: NameRegistry)
fun generate(spec: ApiSpec, hasPolymorphicTypes: Boolean): List<FileSpec> {
val grouped = spec.endpoints.groupBy { it.tags.firstOrNull() ?: DEFAULT_TAG }
return grouped.map { (tag, endpoints) ->
generateClientFile(tag, endpoints, hasPolymorphicTypes, nameRegistry)
}
return grouped.map { (tag, endpoints) -> generateClientFile(tag, endpoints, hasPolymorphicTypes) }
}

context(modelPackage: ModelPackage, apiPackage: ApiPackage)
context(hierarchy: Hierarchy, apiPackage: ApiPackage, nameRegistry: NameRegistry)
private fun generateClientFile(
tag: String,
endpoints: List<Endpoint>,
hasPolymorphicTypes: Boolean,
nameRegistry: NameRegistry,
): FileSpec {
val className = ClassName(apiPackage, nameRegistry.register("${tag.toPascalCase()}$API_SUFFIX"))

val clientInitializer = if (hasPolymorphicTypes) {
val generatedSerializersModule = MemberName(modelPackage, GENERATED_SERIALIZERS_MODULE)
val generatedSerializersModule = MemberName(hierarchy.modelPackage, GENERATED_SERIALIZERS_MODULE)
CodeBlock.of("${CREATE_HTTP_CLIENT}(%M)", generatedSerializersModule)
} else {
CodeBlock.of("${CREATE_HTTP_CLIENT}()")
Expand Down Expand Up @@ -101,17 +94,18 @@ internal object ClientGenerator {
.primaryConstructor(primaryConstructor)
.addProperty(httpClientProperty)

val methodRegistry = NameRegistry()
classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it, methodRegistry) })
context(NameRegistry()) {
classBuilder.addFunctions(endpoints.map { generateEndpointFunction(it) })
}

return FileSpec
.builder(className)
.addType(classBuilder.build())
.build()
}

context(_: ModelPackage)
private fun generateEndpointFunction(endpoint: Endpoint, methodRegistry: NameRegistry): FunSpec {
context(_: Hierarchy, methodRegistry: NameRegistry)
private fun generateEndpointFunction(endpoint: Endpoint): FunSpec {
val functionName = methodRegistry.register(endpoint.operationId.toCamelCase())
val returnBodyType = resolveReturnType(endpoint)
val returnType = HTTP_SUCCESS.parameterizedBy(returnBodyType)
Expand Down Expand Up @@ -166,7 +160,7 @@ internal object ClientGenerator {
return funBuilder.build()
}

context(_: ModelPackage)
context(_: Hierarchy)
private fun resolveReturnType(endpoint: Endpoint): TypeName = endpoint.responses.entries
.asSequence()
.filter { it.key.startsWith("2") }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.avsystem.justworks.core.gen.client
import com.avsystem.justworks.core.gen.BODY
import com.avsystem.justworks.core.gen.CHANNEL_PROVIDER
import com.avsystem.justworks.core.gen.CONTENT_TYPE_CLASS
import com.avsystem.justworks.core.gen.ModelPackage
import com.avsystem.justworks.core.gen.Hierarchy
import com.avsystem.justworks.core.gen.isBinaryUpload
import com.avsystem.justworks.core.gen.properties
import com.avsystem.justworks.core.gen.requiredProperties
Expand All @@ -16,7 +16,7 @@ import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.STRING

internal object ParametersGenerator {
context(_: ModelPackage)
context(_: Hierarchy)
fun buildMultipartParameters(requestBody: RequestBody): List<ParameterSpec> =
requestBody.schema.properties.flatMap { prop ->
val name = prop.name.toCamelCase()
Expand All @@ -33,13 +33,13 @@ internal object ParametersGenerator {
}
}

context(_: ModelPackage)
context(_: Hierarchy)
fun buildFormParameters(requestBody: RequestBody): List<ParameterSpec> = requestBody.schema.properties.map { prop ->
val isRequired = requestBody.required && prop.name in requestBody.schema.requiredProperties
buildNullableParameter(prop.type, prop.name, isRequired)
}

context(_: ModelPackage)
context(_: Hierarchy)
fun buildNullableParameter(
typeRef: TypeRef,
name: String,
Expand All @@ -50,7 +50,7 @@ internal object ParametersGenerator {
return builder.build()
}

context(_: ModelPackage)
context(_: Hierarchy)
fun buildBodyParams(requestBody: RequestBody) = when (requestBody.contentType) {
ContentType.MULTIPART_FORM_DATA -> buildMultipartParameters(requestBody)
ContentType.FORM_URL_ENCODED -> buildFormParameters(requestBody)
Expand Down
Loading
Loading