Skip to content

Commit

Permalink
Merge pull request #115 from Cu3PO42/feat/inherited-sockets
Browse files Browse the repository at this point in the history
Add support for inherited sockets
  • Loading branch information
flemming-n-larsen authored Dec 21, 2024
2 parents a45c776 + d2b21a2 commit 936d197
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 21 deletions.
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ kotlin = "2.1.0"
gson = "2.11.0"
gson-extras = "1.3.0"
jansi = "2.4.1"
java-websocket = "1.5.7"
java-websocket = "1.6.0"
json = "20240303"
jsonschema2pojo = "1.2.2"
kotlinx-serialization-json = "1.7.3"
Expand Down
30 changes: 24 additions & 6 deletions server/src/main/kotlin/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.fusesource.jansi.AnsiConsole
import picocli.CommandLine
import picocli.CommandLine.*
import picocli.CommandLine.Model.CommandSpec
import java.nio.channels.ServerSocketChannel
import java.util.*
import kotlin.system.exitProcess

Expand Down Expand Up @@ -51,6 +52,8 @@ class Server : Runnable {
private const val MIN_PORT = 1000
private const val MAX_PORT = 65535

const val INHERIT = "inherit"

@Option(names = ["-v", "--version"], description = ["Display version info"])
private var isVersionInfoRequested = false

Expand All @@ -59,10 +62,16 @@ class Server : Runnable {

@Option(
names = ["-p", "--port"],
type = [Int::class],
description = ["Port number (default: $DEFAULT_PORT)"]
type = [String::class],
description = ["Port number (default: $DEFAULT_PORT) or '$INHERIT' to use socket activation (if supported by the system)"]
)
var port: Int = DEFAULT_PORT
private var port: String = DEFAULT_PORT.toString()

val useInheritedChannel: Boolean
get() = port.equals(INHERIT, ignoreCase = true)

val portNumber: Int
get() = if (useInheritedChannel) getInheritedPort() else port.toIntOrNull() ?: DEFAULT_PORT

@Option(
names = ["-g", "--games"],
Expand Down Expand Up @@ -99,6 +108,11 @@ class Server : Runnable {
var tps: Int = DEFAULT_TURNS_PER_SECOND

val cmdLine = CommandLine(Server())

private fun getInheritedPort(): Int {
val channel = System.inheritedChannel() as? ServerSocketChannel
return channel?.socket()?.localPort ?: -1
}
}

@Spec
Expand Down Expand Up @@ -145,7 +159,7 @@ class Server : Runnable {
}

private fun validatePort() {
if (port !in MIN_PORT..MAX_PORT) {
if (!useInheritedChannel && portNumber !in MIN_PORT..MAX_PORT) {
reportInvalidPort()
exitProcess(1) // general error
}
Expand All @@ -154,13 +168,17 @@ class Server : Runnable {
private fun reportInvalidPort() {
System.err.println(
"""
Port must be between $MIN_PORT and $MAX_PORT.
Default port $DEFAULT_PORT will be used for HTTP.
Port must be either 'inherit' or a number between $MIN_PORT and $MAX_PORT.
Default port is $DEFAULT_PORT used for http.
""".trimIndent()
)
}

private fun startExitInputMonitorThread() {
// When inheriting a channel, it is passed as FD3, i.e. stdin. In this case, it does not
// make sense to monitor for an exit command.
if (useInheritedChannel) return

Thread {
monitorStandardInputForExit()
}.apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,46 @@ package dev.robocode.tankroyale.server.dev.robocode.tankroyale.server.connection

import dev.robocode.tankroyale.server.Server
import org.java_websocket.WebSocket
import org.slf4j.LoggerFactory
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.channels.ServerSocketChannel
import kotlin.system.exitProcess

class MultiServerWebSocketObserver(observer: IClientWebSocketObserver) {

private val loopbackServerWebSocketObserver = ServerWebSocketObserver(InetSocketAddress(Server.port), observer)
private val localhostServerWebSocketObserver = ServerWebSocketObserver(InetSocketAddress(InetAddress.getLocalHost(), Server.port), observer)
private val log = LoggerFactory.getLogger(this::class.java)

private val webSocketServer = if (Server.useInheritedChannel) {
validateInheritedChannel()
arrayOf(ServerWebSocketObserver(observer))
} else {
arrayOf(
ServerWebSocketObserver(InetSocketAddress(Server.portNumber), observer),
ServerWebSocketObserver(InetSocketAddress(InetAddress.getLocalHost(), Server.portNumber), observer)
)
}

fun start() {
loopbackServerWebSocketObserver.run()
localhostServerWebSocketObserver.run()
webSocketServer.forEach { it.run() }
}

fun broadcast(clientSockets: Collection<WebSocket>, message: String) {
localhostServerWebSocketObserver.broadcast(message, clientSockets)
webSocketServer.forEach { it.broadcast(message, clientSockets) }
}

private fun validateInheritedChannel() {
val inheritedChannel = System.inheritedChannel()
if (inheritedChannel == null) {
log.error(
"The '${Server.INHERIT}' mode require that a valid socket is passed to this server via file descriptor 3 (fd 3). " +
"Make sure the socket was passed correctly"
)
exitProcess(2)
}
if (inheritedChannel !is ServerSocketChannel) {
log.error("The '${Server.INHERIT}' mode expects a server socket")
exitProcess(2)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,35 @@ import org.java_websocket.handshake.ClientHandshake
import org.java_websocket.server.WebSocketServer
import org.slf4j.LoggerFactory
import java.net.InetSocketAddress
import java.nio.channels.ServerSocketChannel

class ServerWebSocketObserver(
address: InetSocketAddress,
private val observer: IClientWebSocketObserver
) : WebSocketServer(address) {

class ServerWebSocketObserver : WebSocketServer {
private val log = LoggerFactory.getLogger(this::class.java)
private val observer: IClientWebSocketObserver
private val useInheritedChannel: Boolean

constructor(
address: InetSocketAddress,
observer: IClientWebSocketObserver
) : super(address) {
this.observer = observer
this.useInheritedChannel = false
isTcpNoDelay = true
}

init {
// Disable Nagle's algorithm
constructor(
observer: IClientWebSocketObserver
) : super(
(System.inheritedChannel() as? ServerSocketChannel)
?: throw IllegalStateException("No inherited server socket channel available")
) {
this.observer = observer
this.useInheritedChannel = true
isTcpNoDelay = true
}

override fun onStart() {
log.debug("onStart()")
log.debug("onStart(){}", if (useInheritedChannel) " with inherited channel" else "")
}

override fun onOpen(clientSocket: WebSocket, handshake: ClientHandshake) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class GameServer(

/** Starts this server */
fun start() {
log.info("Starting server on port ${Server.port} with supporting game type(s): ${gameTypes.joinToString()}")
log.info("Starting server on port ${Server.portNumber} with supporting game type(s): ${gameTypes.joinToString()}")
connectionHandler.start()
}

Expand Down

0 comments on commit 936d197

Please sign in to comment.