Skip to content

Commit

Permalink
support routers which inject beans
Browse files Browse the repository at this point in the history
  • Loading branch information
wakingrufus committed Aug 9, 2024
1 parent 7904bd9 commit b82671a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 18 deletions.
33 changes: 29 additions & 4 deletions docs/webmvc.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,39 @@ open class TestKotlinApplication : SpringFunkApplication {
## Router DSL
The Router DSL exposes the Spring built-in `RouterFunctionDsl`.

### Example
#### Example
```kotlin
open class TestKotlinApplication : SpringFunkApplication {
override fun dsl(): SpringDslContainer.() -> Unit = {
webmvc {
router {
GET("/dsl") {
ServerResponse.ok().body(Dto("Hello World"))
routes {
route {
GET("/dsl") {
ServerResponse.ok().body(Dto("Hello World"))
}
}
}
}
}
}
```

### Routers with Bean Injection

In order to use bean injection in your routers, declare a separate router function.
Then register this function using `ref()` to inject, similar to the beans DSL.

#### Example
```kotlin
fun helloWorldApi(serviceClass: ServiceClass) = router {
GET("/hello", serviceClass::get)
}
open class TestKotlinApplication : SpringFunkApplication {
override fun dsl(): SpringDslContainer.() -> Unit = {
webmvc {
routes {
router {
helloWorldApi(ref())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.github.wakingrufus.funk.webmvc

import com.github.wakingrufus.funk.core.SpringDslMarker
import org.springframework.context.support.BeanDefinitionDsl
import org.springframework.web.servlet.function.RouterFunction
import org.springframework.web.servlet.function.RouterFunctionDsl
import org.springframework.web.servlet.function.ServerResponse

class RoutesDsl {
private val routes = mutableListOf<BeanDefinitionDsl.BeanSupplierContext.() -> RouterFunction<ServerResponse>>()

@SpringDslMarker
fun router(f: BeanDefinitionDsl.BeanSupplierContext.() -> RouterFunction<ServerResponse>) {
routes.add(f)
}

@SpringDslMarker
fun route(router: RouterFunctionDsl.() -> Unit) {
routes.add { org.springframework.web.servlet.function.router(router) }
}

fun merge(f: BeanDefinitionDsl.BeanSupplierContext): RouterFunction<ServerResponse> =
routes.map { it.invoke(f) }.reduce(RouterFunction<ServerResponse>::and)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ package com.github.wakingrufus.funk.webmvc
import com.github.wakingrufus.funk.core.SpringDsl
import com.github.wakingrufus.funk.core.SpringDslContainer
import com.github.wakingrufus.funk.core.SpringDslMarker
import org.springframework.web.servlet.function.RouterFunction
import org.springframework.web.servlet.function.RouterFunctionDsl
import org.springframework.web.servlet.function.ServerResponse

@SpringDslMarker
class WebmvcDsl : SpringDsl {
internal var routerDsl: RouterFunction<ServerResponse>? = null
internal var routes: RoutesDsl? = null
internal var enableWebmvcDsl: EnableWebMvcDsl? = null
internal var converterDsl: WebMvcConverterDsl? = null

Expand All @@ -22,8 +19,8 @@ class WebmvcDsl : SpringDsl {
}

@SpringDslMarker
fun router(config: RouterFunctionDsl.() -> Unit) {
routerDsl = org.springframework.web.servlet.function.router(config)
fun routes(config: RoutesDsl.() -> Unit) {
routes = RoutesDsl().apply(config)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.springframework.boot.web.servlet.filter.OrderedHiddenHttpMethodFilter
import org.springframework.boot.web.servlet.server.CookieSameSiteSupplier
import org.springframework.context.ApplicationContextInitializer
import org.springframework.context.support.GenericApplicationContext
import org.springframework.context.support.beans
import org.springframework.context.support.registerBean
import org.springframework.web.filter.RequestContextFilter
import org.springframework.web.servlet.function.RouterFunction
Expand Down Expand Up @@ -55,8 +56,12 @@ class WebmvcInitializer : ApplicationContextInitializer<GenericApplicationContex
standardJacksonObjectMapperBuilderCustomizer(context, jacksonProps)
}
}
routerDsl?.run {
context.registerBean<RouterFunction<ServerResponse>> { this }
routes?.also { r ->
beans {
bean {
r.merge(this)
}
}.initialize(context)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.github.wakingrufus.funk.webmvc

import com.github.wakingrufus.funk.base.SpringFunkApplication
import com.github.wakingrufus.funk.beans.beans
import com.github.wakingrufus.funk.core.SpringDslContainer
import io.github.oshai.kotlinlogging.KotlinLogging
import org.assertj.core.api.Assertions.assertThat
Expand All @@ -11,7 +12,9 @@ import org.springframework.boot.test.web.client.TestRestTemplate
import org.springframework.boot.test.web.client.getForEntity
import org.springframework.context.ApplicationContext
import org.springframework.test.context.ContextConfiguration
import org.springframework.web.servlet.function.ServerRequest
import org.springframework.web.servlet.function.ServerResponse
import org.springframework.web.servlet.function.router
import java.net.URI

@SpringBootTest(
Expand All @@ -34,19 +37,34 @@ internal class DslIntegrationTest {
assertThat(response.statusCode.value()).isEqualTo(200)
context.beanDefinitionNames.forEach { log.info { it } }
}

}

internal class FunkApplication : SpringFunkApplication {
override fun dsl(): SpringDslContainer.() -> Unit = {
beans {
bean<ServiceClass>()
}
webmvc {
enableWebMvc {
jetty()
}
router {
GET("/dsl") {
ServerResponse.ok().build()

routes {
router {
helloWorldApi(ref())
}
}
}
}
}

class ServiceClass {
fun get(req: ServerRequest): ServerResponse {
return ServerResponse.ok().body("Hello, World")
}
}

fun helloWorldApi(serviceClass: ServiceClass) = router {
GET("/dsl", serviceClass::get)
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ internal class JacksonFunkApplication : SpringFunkApplication {
converters {
jackson()
}
router {
GET("/dsl") {
ServerResponse.ok().body(Dto("Hello World"))
routes {
route {
GET("/dsl") {
ServerResponse.ok().body(Dto("Hello World"))
}
}
}
}
Expand Down

0 comments on commit b82671a

Please sign in to comment.