diff --git a/server/core/src/main/java/dev/slimevr/firmware/serial.kt b/server/core/src/main/java/dev/slimevr/firmware/serial.kt index 2d7f7cdc9..acd262342 100644 --- a/server/core/src/main/java/dev/slimevr/firmware/serial.kt +++ b/server/core/src/main/java/dev/slimevr/firmware/serial.kt @@ -183,8 +183,12 @@ internal suspend fun doSerialFlashPostFlash( // wait for the tracker with that MAC to connect to the server via UDP val connected = withTimeoutOrNull(60_000) { server.context.state - .map { state -> state.devices.values.any { it.context.state.value.macAddress?.uppercase() == macAddress - && it.context.state.value.status != TrackerStatus.DISCONNECTED } } + .map { state -> + state.devices.values.any { + it.context.state.value.macAddress?.uppercase() == macAddress && + it.context.state.value.status != TrackerStatus.DISCONNECTED + } + } .filter { it } .first() } diff --git a/server/core/src/main/java/dev/slimevr/tracker/device.kt b/server/core/src/main/java/dev/slimevr/tracker/device.kt index 6f41ec1e6..4d8a3058c 100644 --- a/server/core/src/main/java/dev/slimevr/tracker/device.kt +++ b/server/core/src/main/java/dev/slimevr/tracker/device.kt @@ -64,9 +64,6 @@ fun createDevice( address: String, macAddress: String? = null, origin: DeviceOrigin, - boardType: BoardType, - mcuType: McuType, - firmware: String? = null, protocolVersion: Int, serverContext: VRServer, ): Device { @@ -78,13 +75,13 @@ fun createDevice( origin = origin, address = address, macAddress = macAddress, - boardType = boardType, - firmware = firmware, - mcuType = mcuType, protocolVersion = protocolVersion, ping = null, signalStrength = null, status = TrackerStatus.DISCONNECTED, + mcuType = McuType.Other, + boardType = BoardType.UNKNOWN, + firmware = null ) val behaviours = listOf(DeviceStatsBehaviour) diff --git a/server/core/src/main/java/dev/slimevr/tracker/udp/connection.kt b/server/core/src/main/java/dev/slimevr/tracker/udp/connection.kt index 9a4b85c6f..02d92c8ea 100644 --- a/server/core/src/main/java/dev/slimevr/tracker/udp/connection.kt +++ b/server/core/src/main/java/dev/slimevr/tracker/udp/connection.kt @@ -24,6 +24,7 @@ import kotlinx.coroutines.isActive import kotlinx.coroutines.launch import kotlinx.io.Buffer import kotlinx.io.readByteArray +import solarxr_protocol.datatypes.TrackerStatus import java.net.DatagramPacket import java.net.DatagramSocket import java.net.InetAddress @@ -52,11 +53,14 @@ sealed interface UDPConnectionActions { data class Handshake(val deviceId: Int) : UDPConnectionActions data class LastPacket(val packetNum: Long? = null, val time: Long) : UDPConnectionActions data class AssignTracker(val trackerId: TrackerIdNum) : UDPConnectionActions + data object Disconnected : UDPConnectionActions } typealias UDPConnectionContext = Context typealias UDPConnectionBehaviour = CustomBehaviour +private const val CONNECTION_TIMEOUT_MS = 5000L + val PacketBehaviour = UDPConnectionBehaviour( reducer = { s, a -> when (a) { @@ -78,7 +82,7 @@ val PacketBehaviour = UDPConnectionBehaviour( val state = it.context.state.value val now = System.currentTimeMillis() - if (now - state.lastPacket > 5000 && packet.packetNumber == 0L) { + if (now - state.lastPacket > CONNECTION_TIMEOUT_MS && packet.packetNumber == 0L) { it.context.dispatch( UDPConnectionActions.LastPacket( packetNum = 0, @@ -160,6 +164,10 @@ val HandshakeBehaviour = UDPConnectionBehaviour( deviceId = a.deviceId, ) + is UDPConnectionActions.Disconnected -> s.copy( + didHandshake = false, + ) + else -> s } }, @@ -167,32 +175,66 @@ val HandshakeBehaviour = UDPConnectionBehaviour( it.packetEvents.onPacket { packet -> val state = it.context.state.value - if (state.deviceId == null) { + val device = if (state.deviceId == null) { val deviceId = it.serverContext.nextHandle() - val newDevice = createDevice( id = deviceId, scope = it.serverContext.context.scope, - address = it.context.state.value.address, + address = state.address, macAddress = packet.data.macString, - boardType = packet.data.boardType, - protocolVersion = packet.data.protocolVersion, - mcuType = packet.data.mcuType, - firmware = packet.data.firmware, origin = DeviceOrigin.UDP, + protocolVersion = packet.data.protocolVersion, serverContext = it.serverContext, ) - - it.serverContext.context.dispatch( - VRServerActions.NewDevice( - deviceId = deviceId, - context = newDevice, - ), - ) + it.serverContext.context.dispatch(VRServerActions.NewDevice(deviceId = deviceId, context = newDevice)) it.context.dispatch(UDPConnectionActions.Handshake(deviceId)) - it.send(Handshake()) + newDevice } else { - it.send(Handshake()) + it.context.dispatch(UDPConnectionActions.Handshake(state.deviceId)) + it.getDevice() ?: run { + AppLogger.udp.warn("Reconnect handshake but device ${state.deviceId} not found") + it.send(Handshake()) + return@onPacket + } + } + + // Apply handshake fields to device, always, for both first connect and reconnect + device.context.dispatch( + DeviceActions.Update { + copy( + macAddress = packet.data.macString ?: macAddress, + boardType = packet.data.boardType, + mcuType = packet.data.mcuType, + firmware = packet.data.firmware ?: firmware, + protocolVersion = packet.data.protocolVersion, + ) + }, + ) + + it.send(Handshake()) + } + }, +) + +val TimeoutBehaviour = UDPConnectionBehaviour( + observer = { + it.context.scope.launch { + while (isActive) { + val state = it.context.state.value + if (!state.didHandshake) { + delay(500) + continue + } + val timeUntilTimeout = CONNECTION_TIMEOUT_MS - (System.currentTimeMillis() - state.lastPacket) + if (timeUntilTimeout <= 0) { + AppLogger.udp.info("Connection timed out for ${state.id}") + it.context.dispatch(UDPConnectionActions.Disconnected) + it.getDevice()?.context?.dispatch( + DeviceActions.Update { copy(status = TrackerStatus.DISCONNECTED) }, + ) + } else { + delay(timeUntilTimeout + 1) + } } } }, @@ -251,7 +293,6 @@ val SensorInfoBehaviour = UDPConnectionBehaviour( val action = TrackerActions.Update { copy( sensorType = event.data.imuType, - status = event.data.status, ) } @@ -340,6 +381,7 @@ data class UDPConnection( val behaviours = listOf( PacketBehaviour, HandshakeBehaviour, + TimeoutBehaviour, PingBehaviour, DeviceStatsBehaviour, SensorInfoBehaviour, @@ -386,6 +428,12 @@ data class UDPConnection( // Dedicated coroutine per connection so the receive loop is never blocked by packet processing scope.launch { for (event in packetChannel) { + // We skip any packet from the tracker that are not handshake packets + // if we didn't do a handshake with the server + // this prevents from receiving packets if the server does not know about the + // tracker yet. This usually happen when you restart the server with already + // connected trackers + if (!context.state.value.didHandshake && event.data !is PreHandshakePacket) continue dispatcher.emit(event) } } diff --git a/server/core/src/main/java/dev/slimevr/tracker/udp/packets.kt b/server/core/src/main/java/dev/slimevr/tracker/udp/packets.kt index 3e3a3e4ed..f6d73dabd 100644 --- a/server/core/src/main/java/dev/slimevr/tracker/udp/packets.kt +++ b/server/core/src/main/java/dev/slimevr/tracker/udp/packets.kt @@ -68,6 +68,9 @@ sealed interface UDPPacket { fun write(dst: Sink) {} } +/** Packets that are processed before the handshake is complete */ +sealed interface PreHandshakePacket : UDPPacket + sealed interface SensorSpecificPacket : UDPPacket { val sensorId: Int } @@ -81,7 +84,7 @@ data class Handshake( val protocolVersion: Int = 0, val firmware: String? = null, val macString: String? = null, -) : UDPPacket { +) : PreHandshakePacket { override fun write(dst: Sink) { dst.writeByte(PacketType.HANDSHAKE.id.toByte()) dst.write("Hey OVR =D 5".toByteArray(Charsets.US_ASCII)) @@ -126,7 +129,7 @@ data class Accel(val acceleration: Vector3 = Vector3.NULL, override val sensorId } } -data class PingPong(val pingId: Int = 0) : UDPPacket { +data class PingPong(val pingId: Int = 0) : PreHandshakePacket { override fun write(dst: Sink) { dst.writeInt(pingId) } diff --git a/server/desktop/src/main/java/dev/slimevr/desktop/ipc/protocol.kt b/server/desktop/src/main/java/dev/slimevr/desktop/ipc/protocol.kt index dd987e968..e59c3afb4 100644 --- a/server/desktop/src/main/java/dev/slimevr/desktop/ipc/protocol.kt +++ b/server/desktop/src/main/java/dev/slimevr/desktop/ipc/protocol.kt @@ -8,6 +8,7 @@ import dev.slimevr.desktop.platform.TrackerAdded import dev.slimevr.desktop.platform.Version import dev.slimevr.solarxr.createSolarXRConnection import dev.slimevr.solarxr.onSolarXRMessage +import dev.slimevr.tracker.DeviceActions import dev.slimevr.tracker.DeviceOrigin import dev.slimevr.tracker.TrackerActions import dev.slimevr.tracker.createDevice @@ -21,9 +22,7 @@ import kotlinx.coroutines.flow.runningFold import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import solarxr_protocol.datatypes.hardware_info.BoardType import solarxr_protocol.datatypes.hardware_info.ImuType -import solarxr_protocol.datatypes.hardware_info.McuType import java.nio.ByteBuffer const val PROTOCOL_VERSION = 5 @@ -104,32 +103,47 @@ suspend fun handleFeederConnection( val msg = ProtobufMessage.ADAPTER.decode(bytes) if (msg.tracker_added != null) { - val deviceId = server.nextHandle() - val device = createDevice( - scope = this, - id = deviceId, - address = msg.tracker_added.tracker_serial, - origin = DeviceOrigin.FEEDER, - boardType = BoardType.UNKNOWN, - firmware = msg.version.toString(), - protocolVersion = 0, - mcuType = McuType.Other, - macAddress = msg.tracker_added.tracker_serial, // FIXME: prob not correct - serverContext = server, - ) - server.context.dispatch(VRServerActions.NewDevice(deviceId, device)) + val serial = msg.tracker_added.tracker_serial + val protocolVersion = msg.version?.protocol_version ?: 0 + val firmware = msg.version?.toString() - val trackerId = server.nextHandle() - val tracker = createTracker( - scope = this, - id = trackerId, - deviceId = deviceId, - sensorType = ImuType.MPU9250, // TODO: prob need to make sensor type optional - hardwareId = msg.tracker_added.tracker_serial, - origin = DeviceOrigin.FEEDER, - serverContext = server, + // Check for existing tracker with same hardwareId (reconnect case) + val existingTracker = server.context.state.value.trackers.values + .find { it.context.state.value.hardwareId == serial } + + val device = if (existingTracker != null) { + server.getDevice(existingTracker.context.state.value.deviceId) ?: error("could not find existing device") + } else { + val deviceId = server.nextHandle() + val newDevice = createDevice( + scope = this, + id = deviceId, + address = serial, + macAddress = serial, // FIXME: prob not correct + origin = DeviceOrigin.FEEDER, + protocolVersion = protocolVersion, + serverContext = server, + ) + server.context.dispatch(VRServerActions.NewDevice(deviceId, newDevice)) + + val trackerId = server.nextHandle() + val tracker = createTracker( + scope = this, + id = trackerId, + deviceId = deviceId, + sensorType = ImuType.MPU9250, // TODO: prob need to make sensor type optional + hardwareId = serial, + origin = DeviceOrigin.FEEDER, + serverContext = server, + ) + server.context.dispatch(VRServerActions.NewTracker(trackerId, tracker)) + + newDevice + } + + device.context.dispatch( + DeviceActions.Update { copy(firmware = firmware, protocolVersion = protocolVersion) }, ) - server.context.dispatch(VRServerActions.NewTracker(trackerId, tracker)) } if (msg.position != null) {