Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSValueParameter
import com.google.devtools.ksp.symbol.Visibility
import com.slack.circuit.codegen.CodegenMode.KOTLIN_INJECT_ANVIL
import com.slack.circuit.codegen.CodegenMode.METRO
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FileSpec
Expand Down Expand Up @@ -425,7 +427,7 @@ private class CircuitSymbolProcessor(
val creatorOrConstructor: KSFunctionDeclaration?
val targetClass: KSClassDeclaration
if (isAssisted) {
if (codegenMode == KOTLIN_INJECT_ANVIL) {
if (codegenMode == KOTLIN_INJECT_ANVIL || codegenMode == METRO) {
creatorOrConstructor = injectableConstructor
targetClass = declaration
} else {
Expand Down Expand Up @@ -499,31 +501,74 @@ private class CircuitSymbolProcessor(
codegenMode.runtime.getProviderBlock(CodeBlock.of("provider"))
} else if (isAssisted) {
// Inject the target class's assisted factory that we'll call its create() on.
if (codegenMode == KOTLIN_INJECT_ANVIL) {
val factoryLambda =
LambdaTypeName.get(
receiver = null,
parameters =
assistedKSParams.map { ksParam ->
ParameterSpec.builder(
ksParam.name!!.getShortName(),
ksParam.type.toTypeName(),
)
.build()
},
returnType = targetClass.toClassName(),
when (codegenMode) {
KOTLIN_INJECT_ANVIL -> {
val factoryLambda =
LambdaTypeName.get(
receiver = null,
parameters =
assistedKSParams.map { ksParam ->
ParameterSpec.builder(
ksParam.name!!.getShortName(),
ksParam.type.toTypeName(),
)
.build()
},
returnType = targetClass.toClassName(),
)
constructorParams.add(ParameterSpec.builder("factory", factoryLambda).build())
CodeBlock.of("factory(%L)", assistedParams)
}
METRO -> {
val assistedFactory =
declaration.declarations.filterIsInstance<KSClassDeclaration>().find { nestedClass
->
nestedClass.isAnnotationPresentWithLeniency(
codegenMode.runtime.assistedFactory!!
)
}
requireNotNull(assistedFactory) {
"No assisted factory found for ${declaration.qualifiedName?.asString()}"
}
val constructorAssistedParameters =
assistedKSParams.map { it.toAssistedParameterType("factory") }
val assistedFactoryCreate =
assistedFactory.getAllFunctions().find { assistedFactoryFunction ->
val assistedFunctionParameters =
assistedFactoryFunction.parameters
.filter { it.isAnnotationPresentWithLeniency(codegenMode.runtime.assisted) }
.map { it.toAssistedParameterType("factory") }
val assistedParamsMatch =
constructorAssistedParameters == assistedFunctionParameters
val numberOfFunctionParamsMatch =
constructorAssistedParameters.size == assistedFactoryFunction.parameters.size

assistedParamsMatch && numberOfFunctionParamsMatch
}
requireNotNull(assistedFactoryCreate) {
"No assisted factory create function found " +
"for ${declaration.qualifiedName?.asString()}"
}

constructorParams.add(
ParameterSpec.builder("factory", assistedFactory.toClassName()).build()
)
constructorParams.add(ParameterSpec.builder("factory", factoryLambda).build())
CodeBlock.of("factory(%L)", assistedParams)
} else {
constructorParams.add(
ParameterSpec.builder("factory", declaration.toClassName()).build()
)
CodeBlock.of(
"factory.%L(%L)",
creatorOrConstructor!!.simpleName.getShortName(),
assistedParams,
)
CodeBlock.of(
"factory.%L(%L)",
assistedFactoryCreate.simpleName.getShortName(),
assistedParams,
)
}
else -> {
constructorParams.add(
ParameterSpec.builder("factory", declaration.toClassName()).build()
)
CodeBlock.of(
"factory.%L(%L)",
creatorOrConstructor!!.simpleName.getShortName(),
assistedParams,
)
}
}
} else {
// Simple constructor call, no injection.
Expand All @@ -537,6 +582,9 @@ private class CircuitSymbolProcessor(

private data class AssistedType(val factoryName: String, val type: TypeName, val name: String)

private fun KSValueParameter.toAssistedParameterType(factoryName: String) =
AssistedType(factoryName = factoryName, name = name!!.getShortName(), type = type.toTypeName())

/**
* Returns a [CodeBlock] representation of all named assisted parameters on this
* [KSFunctionDeclaration] to be used in generated invocation code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ class CircuitSymbolProcessorTest {
"""
.trimIndent(),
)
private val metroAnnotation =
kotlin(
"Inject.kt",
"""
package dev.zacsweers.metro

annotation class Inject
"""
.trimIndent(),
)
private val metroAssistedAnnotation =
kotlin(
"Assisted.kt",
"""
package dev.zacsweers.metro

annotation class Assisted(val value: String = "")
"""
.trimIndent(),
)
private val metroAssistedFactoryAnnotation =
kotlin(
"AssistedFactory.kt",
"""
package dev.zacsweers.metro

annotation class AssistedFactory
"""
.trimIndent(),
)
private val screens =
kotlin(
"Screens.kt",
Expand Down Expand Up @@ -1376,10 +1406,176 @@ class CircuitSymbolProcessorTest {
)
}

@Test
fun presenterClass_simpleInjection_metro() {
assertGeneratedFile(
sourceFile =
kotlin(
"FavoritesPresenter.kt",
"""
package test

import com.slack.circuit.codegen.annotations.CircuitInject
import com.slack.circuit.runtime.presenter.Presenter
import androidx.compose.runtime.Composable
import dev.zacsweers.metro.Inject

@Inject
@CircuitInject(FavoritesScreen::class, AppScope::class)
class FavoritesPresenter : Presenter<FavoritesScreen.State> {
@Composable
override fun present(): FavoritesScreen.State {
throw NotImplementedError()
}
}
"""
.trimIndent(),
),
generatedFilePath = "test/FavoritesPresenterFactory.kt",
expectedContent =
"""
package test

import com.slack.circuit.runtime.CircuitContext
import com.slack.circuit.runtime.Navigator
import com.slack.circuit.runtime.presenter.Presenter
import com.slack.circuit.runtime.screen.Screen
import dev.zacsweers.metro.ContributesIntoSet
import dev.zacsweers.metro.Inject
import dev.zacsweers.metro.Provider

@Inject
@ContributesIntoSet(AppScope::class)
public class FavoritesPresenterFactory(
private val provider: Provider<FavoritesPresenter>,
) : Presenter.Factory {
override fun create(
screen: Screen,
navigator: Navigator,
context: CircuitContext,
): Presenter<*>? = when (screen) {
is FavoritesScreen -> provider()
else -> null
}
}
"""
.trimIndent(),
codegenMode = CodegenMode.METRO,
)
}

@Test
fun presenterClass_assistedInjection_metro() {
assertGeneratedFile(
sourceFile =
kotlin(
"FavoritesPresenter.kt",
"""
package test

import com.slack.circuit.codegen.annotations.CircuitInject
import com.slack.circuit.runtime.Navigator
import com.slack.circuit.runtime.presenter.Presenter
import androidx.compose.runtime.Composable
import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.AssistedFactory
import dev.zacsweers.metro.Inject

@Inject
@CircuitInject(FavoritesScreen::class, AppScope::class)
class FavoritesPresenter(
@Assisted private val navigator: Navigator
) : Presenter<FavoritesScreen.State> {
@AssistedFactory
fun interface Factory {
fun create(@Assisted navigator: Navigator): FavoritesPresenter
}
@Composable
override fun present(): FavoritesScreen.State {
throw NotImplementedError()
}
}
"""
.trimIndent(),
),
generatedFilePath = "test/FavoritesPresenterFactory.kt",
expectedContent =
"""
package test

import com.slack.circuit.runtime.CircuitContext
import com.slack.circuit.runtime.Navigator
import com.slack.circuit.runtime.presenter.Presenter
import com.slack.circuit.runtime.screen.Screen
import dev.zacsweers.metro.ContributesIntoSet
import dev.zacsweers.metro.Inject

@Inject
@ContributesIntoSet(AppScope::class)
public class FavoritesPresenterFactory(
private val factory: FavoritesPresenter.Factory,
) : Presenter.Factory {
override fun create(
screen: Screen,
navigator: Navigator,
context: CircuitContext,
): Presenter<*>? = when (screen) {
is FavoritesScreen -> factory.create(navigator = navigator)
else -> null
}
}
"""
.trimIndent(),
codegenMode = CodegenMode.METRO,
)
}

@Test
fun invalidAssistedInjection_metro() {
assertProcessingError(
sourceFile =
kotlin(
"FavoritesPresenter.kt",
"""
package test

import com.slack.circuit.codegen.annotations.CircuitInject
import com.slack.circuit.runtime.Navigator
import com.slack.circuit.runtime.presenter.Presenter
import androidx.compose.runtime.Composable
import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.AssistedFactory
import dev.zacsweers.metro.Inject

@Inject
@CircuitInject(FavoritesScreen::class, AppScope::class)
class FavoritesPresenter(
@Assisted private val navigator: Navigator
) : Presenter<FavoritesScreen.State> {

// No AssistedFactory

@Composable
override fun present(): FavoritesScreen.State {
throw NotImplementedError()
}
}
"""
.trimIndent(),
),
codegenMode = CodegenMode.METRO,
) { messages ->
assertThat(messages).contains("No assisted factory found")
}
}

private enum class CodegenMode {
ANVIL,
HILT,
KOTLIN_INJECT_ANVIL,
METRO,
}

private fun assertGeneratedFile(
Expand Down Expand Up @@ -1422,6 +1618,13 @@ class CircuitSymbolProcessorTest {
CodegenMode.KOTLIN_INJECT_ANVIL -> {
listOf(appScope, kotlinInjectAnnotation)
}
CodegenMode.METRO ->
listOf(
appScope,
metroAnnotation,
metroAssistedAnnotation,
metroAssistedFactoryAnnotation,
)
}
inheritClassPath = true
symbolProcessorProviders += CircuitSymbolProcessorProvider()
Expand Down