This commit is contained in:
loucass003
2026-03-26 17:34:49 +01:00
parent 86dbdbddf8
commit 1001b7f887
19 changed files with 189 additions and 193 deletions

View File

@@ -4,6 +4,7 @@ import kotlin.reflect.KClass
class EventDispatcher<T : Any>(private val keyOf: (T) -> KClass<*> = { it::class }) {
@Volatile var listeners: Map<KClass<*>, List<suspend (T) -> Unit>> = emptyMap()
@Volatile private var globalListeners: List<suspend (T) -> Unit> = emptyList()
fun register(key: KClass<*>, callback: suspend (T) -> Unit) {

View File

@@ -0,0 +1,147 @@
package dev.slimevr.firmware
import dev.slimevr.VRServer
import dev.slimevr.context.BasicBehaviour
import dev.slimevr.context.Context
import dev.slimevr.context.createContext
import dev.slimevr.serial.SerialServer
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import solarxr_protocol.datatypes.DeviceIdTable
import solarxr_protocol.rpc.FirmwarePart
import solarxr_protocol.rpc.FirmwareUpdateDeviceId
import solarxr_protocol.rpc.FirmwareUpdateStatus
import solarxr_protocol.rpc.SerialDevicePort
data class FirmwareJobStatus(
val portLocation: String,
val firmwareDeviceId: FirmwareUpdateDeviceId,
val status: FirmwareUpdateStatus,
val progress: Int = 0,
)
data class FirmwareManagerState(
val jobs: Map<String, FirmwareJobStatus>,
)
sealed interface FirmwareManagerActions {
data class UpdateJob(
val portLocation: String,
val firmwareDeviceId: FirmwareUpdateDeviceId,
val status: FirmwareUpdateStatus,
val progress: Int = 0,
) : FirmwareManagerActions
data class RemoveJob(val portLocation: String) : FirmwareManagerActions
}
typealias FirmwareManagerContext = Context<FirmwareManagerState, FirmwareManagerActions>
typealias FirmwareManagerBehaviour = BasicBehaviour<FirmwareManagerState, FirmwareManagerActions>
val FirmwareManagerBaseBehaviour = FirmwareManagerBehaviour(
reducer = { s, a ->
when (a) {
is FirmwareManagerActions.UpdateJob -> s.copy(
jobs = s.jobs +
(
a.portLocation to FirmwareJobStatus(
portLocation = a.portLocation,
firmwareDeviceId = a.firmwareDeviceId,
status = a.status,
progress = a.progress,
)
),
)
is FirmwareManagerActions.RemoveJob -> s.copy(jobs = s.jobs - a.portLocation)
}
},
observer = null,
)
data class FirmwareManager(
val context: FirmwareManagerContext,
val flash: suspend (portLocation: String, parts: List<FirmwarePart>, needManualReboot: Boolean, ssid: String?, password: String?, server: VRServer) -> Unit,
val otaFlash: suspend (deviceIp: String, firmwareDeviceId: FirmwareUpdateDeviceId, part: FirmwarePart, VRServer) -> Unit,
val cancelAll: suspend () -> Unit,
)
fun createFirmwareManager(
serialServer: SerialServer,
scope: CoroutineScope,
): FirmwareManager {
val behaviours = listOf(FirmwareManagerBaseBehaviour)
val context = createContext(
initialState = FirmwareManagerState(jobs = mapOf()),
reducers = behaviours.map { it.reducer },
scope = scope,
)
val runningJobs = mutableMapOf<String, Job>()
val flash: suspend (String, List<FirmwarePart>, Boolean, String?, String?, VRServer) -> Unit = { portLocation, parts, needManualReboot, ssid, password, server ->
runningJobs[portLocation]?.cancelAndJoin()
runningJobs[portLocation] = scope.launch {
doSerialFlash(
portLocation = portLocation,
parts = parts,
needManualReboot = needManualReboot,
ssid = ssid,
password = password,
serialServer = serialServer,
server = server,
onStatus = { status, progress ->
context.dispatch(
FirmwareManagerActions.UpdateJob(
portLocation = portLocation,
firmwareDeviceId = SerialDevicePort(port = portLocation),
status = status,
progress = progress,
),
)
},
scope = scope,
)
}
}
val otaFlash: suspend (String, FirmwareUpdateDeviceId, FirmwarePart, VRServer) -> Unit = { deviceIp, firmwareDeviceId, part, server ->
runningJobs[deviceIp]?.cancelAndJoin()
runningJobs[deviceIp] = scope.launch {
doOtaFlash(
deviceIp = deviceIp,
deviceId = (firmwareDeviceId as? DeviceIdTable)?.id ?: error("device id should exist"),
part = part,
server = server,
onStatus = { status, progress ->
context.dispatch(
FirmwareManagerActions.UpdateJob(
portLocation = deviceIp,
firmwareDeviceId = firmwareDeviceId,
status = status,
progress = progress,
),
)
},
)
}
}
val cancelAll: suspend () -> Unit = {
runningJobs.values.forEach { it.cancelAndJoin() }
runningJobs.clear()
}
val manager = FirmwareManager(
context = context,
flash = flash,
otaFlash = otaFlash,
cancelAll = cancelAll,
)
behaviours.map { it.observer }.forEach { it?.invoke(context) }
return manager
}

View File

@@ -31,11 +31,9 @@ private const val OTA_PORT = 8266
private const val OTA_PASSWORD = "SlimeVR-OTA"
private const val OTA_CHUNK_SIZE = 2048
private fun bytesToMd5(bytes: ByteArray): String =
MessageDigest.getInstance("MD5").digest(bytes).joinToString("") { "%02x".format(it) }
private fun bytesToMd5(bytes: ByteArray): String = MessageDigest.getInstance("MD5").digest(bytes).joinToString("") { "%02x".format(it) }
private suspend fun sendDatagram(socket: BoundDatagramSocket, message: String, target: InetSocketAddress) =
socket.send(Datagram(buildPacket { writeFully(message.toByteArray()) }, target))
private suspend fun sendDatagram(socket: BoundDatagramSocket, message: String, target: InetSocketAddress) = socket.send(Datagram(buildPacket { writeFully(message.toByteArray()) }, target))
/**
* Sends the OTA invitation over UDP and performs the optional AUTH challenge-response.
@@ -155,7 +153,7 @@ suspend fun doOtaFlash(
onStatus(FirmwareUpdateStatus.REBOOTING, 0)
// wait for the tracker with that MAC to connect to the server via UDP
// wait for the tracker with the correct id to come online
val connected = withTimeoutOrNull(60_000) {
server.context.state
.map { state -> state.devices.values.any { it.context.state.value.id.toUByte() == deviceId.id } }

View File

@@ -3,85 +3,23 @@ package dev.slimevr.firmware
import dev.llelievr.espflashkotlin.Flasher
import dev.llelievr.espflashkotlin.FlashingProgressListener
import dev.slimevr.VRServer
import dev.slimevr.context.BasicBehaviour
import dev.slimevr.context.Context
import dev.slimevr.context.createContext
import dev.slimevr.serial.SerialConnection
import dev.slimevr.serial.SerialServer
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeoutOrNull
import solarxr_protocol.datatypes.DeviceIdTable
import solarxr_protocol.rpc.FirmwarePart
import solarxr_protocol.rpc.FirmwareUpdateDeviceId
import solarxr_protocol.rpc.FirmwareUpdateStatus
import solarxr_protocol.rpc.SerialDevicePort
private val MAC_REGEX = Regex("mac: (([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2})", RegexOption.IGNORE_CASE)
data class FirmwareJobStatus(
val portLocation: String,
val firmwareDeviceId: FirmwareUpdateDeviceId,
val status: FirmwareUpdateStatus,
val progress: Int = 0,
)
data class FirmwareManagerState(
val jobs: Map<String, FirmwareJobStatus>,
)
sealed interface FirmwareManagerActions {
data class UpdateJob(
val portLocation: String,
val firmwareDeviceId: FirmwareUpdateDeviceId,
val status: FirmwareUpdateStatus,
val progress: Int = 0,
) : FirmwareManagerActions
data class RemoveJob(val portLocation: String) : FirmwareManagerActions
}
typealias FirmwareManagerContext = Context<FirmwareManagerState, FirmwareManagerActions>
typealias FirmwareManagerBehaviour = BasicBehaviour<FirmwareManagerState, FirmwareManagerActions>
val FirmwareManagerBaseBehaviour = FirmwareManagerBehaviour(
reducer = { s, a ->
when (a) {
is FirmwareManagerActions.UpdateJob -> s.copy(
jobs = s.jobs + (a.portLocation to FirmwareJobStatus(
portLocation = a.portLocation,
firmwareDeviceId = a.firmwareDeviceId,
status = a.status,
progress = a.progress,
)),
)
is FirmwareManagerActions.RemoveJob -> s.copy(jobs = s.jobs - a.portLocation)
}
},
observer = null,
)
data class FirmwareManager(
val context: FirmwareManagerContext,
val flash: suspend (portLocation: String, parts: List<FirmwarePart>, needManualReboot: Boolean, ssid: String?, password: String?, server: VRServer) -> Unit,
val otaFlash: suspend (deviceIp: String, firmwareDeviceId: FirmwareUpdateDeviceId, part: FirmwarePart, VRServer) -> Unit,
val cancelAll: suspend () -> Unit,
)
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun doSerialFlash(
portLocation: String,
@@ -256,79 +194,3 @@ internal suspend fun doSerialFlashPostFlash(
onStatus(FirmwareUpdateStatus.DONE, 0)
}
fun createFirmwareManager(
serialServer: SerialServer,
scope: CoroutineScope,
): FirmwareManager {
val behaviours = listOf(FirmwareManagerBaseBehaviour)
val context = createContext(
initialState = FirmwareManagerState(jobs = mapOf()),
reducers = behaviours.map { it.reducer },
scope = scope,
)
val runningJobs = mutableMapOf<String, Job>()
val flash: suspend (String, List<FirmwarePart>, Boolean, String?, String?, VRServer) -> Unit = { portLocation, parts, needManualReboot, ssid, password, server ->
runningJobs[portLocation]?.cancelAndJoin()
runningJobs[portLocation] = scope.launch {
doSerialFlash(
portLocation = portLocation,
parts = parts,
needManualReboot = needManualReboot,
ssid = ssid,
password = password,
serialServer = serialServer,
server = server,
onStatus = { status, progress ->
context.dispatch(FirmwareManagerActions.UpdateJob(
portLocation = portLocation,
firmwareDeviceId = SerialDevicePort(port = portLocation),
status = status,
progress = progress,
))
},
scope = scope,
)
}
}
val otaFlash: suspend (String, FirmwareUpdateDeviceId, FirmwarePart, VRServer) -> Unit = { deviceIp, firmwareDeviceId, part, server ->
runningJobs[deviceIp]?.cancelAndJoin()
runningJobs[deviceIp] = scope.launch {
doOtaFlash(
deviceIp = deviceIp,
deviceId = (firmwareDeviceId as? DeviceIdTable)?.id ?: error("device id should exist"),
part = part,
server = server,
onStatus = { status, progress ->
context.dispatch(
FirmwareManagerActions.UpdateJob(
portLocation = deviceIp,
firmwareDeviceId = firmwareDeviceId,
status = status,
progress = progress,
),
)
},
)
}
}
val cancelAll: suspend () -> Unit = {
runningJobs.values.forEach { it.cancelAndJoin() }
runningJobs.clear()
}
val manager = FirmwareManager(
context = context,
flash = flash,
otaFlash = otaFlash,
cancelAll = cancelAll,
)
behaviours.map { it.observer }.forEach { it?.invoke(context) }
return manager
}

View File

@@ -1,11 +1,11 @@
package dev.slimevr.solarxr
import com.google.flatbuffers.FlatBufferBuilder
import dev.slimevr.EventDispatcher
import dev.slimevr.VRServer
import dev.slimevr.context.Context
import dev.slimevr.context.CustomBehaviour
import dev.slimevr.context.createContext
import dev.slimevr.EventDispatcher
import io.ktor.util.moveToByteArray
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
@@ -27,7 +27,6 @@ sealed interface SolarXRConnectionActions {
typealias SolarXRConnectionContext = Context<SolarXRConnectionState, SolarXRConnectionActions>
typealias SolarXRConnectionBehaviour = CustomBehaviour<SolarXRConnectionState, SolarXRConnectionActions, SolarXRConnection>
data class SolarXRConnection(
val context: SolarXRConnectionContext,
val serverContext: VRServer,

View File

@@ -36,10 +36,14 @@ suspend fun createSolarXRWebsocketServer(serverContext: VRServer) {
routing {
webSocket {
val solarxrConnection =
createSolarXRConnection(serverContext, scope = this, onSend = {
AppLogger.solarxr.info("[WS] New connection")
val solarxrConnection = createSolarXRConnection(
serverContext,
scope = this,
onSend = {
send(Frame.Binary(fin = true, data = it))
})
},
)
for (frame in incoming) {
when (frame) {
@@ -49,7 +53,7 @@ suspend fun createSolarXRWebsocketServer(serverContext: VRServer) {
)
is Frame.Close -> {
AppLogger.solarxr.info("Connection closed")
AppLogger.solarxr.info("[WS] Connection closed")
}
else -> {}

View File

@@ -66,7 +66,7 @@ fun createDevice(
mcuType: McuType,
firmware: String? = null,
protocolVersion: Int,
serverContext: VRServer
serverContext: VRServer,
): Device {
val deviceState = DeviceState(
id = id,

View File

@@ -57,8 +57,6 @@ sealed interface UDPConnectionActions {
typealias UDPConnectionContext = Context<UDPConnectionState, UDPConnectionActions>
typealias UDPConnectionBehaviour = CustomBehaviour<UDPConnectionState, UDPConnectionActions, UDPConnection>
val PacketBehaviour = UDPConnectionBehaviour(
reducer = { s, a ->
when (a) {
@@ -175,7 +173,7 @@ val HandshakeBehaviour = UDPConnectionBehaviour(
val newDevice = createDevice(
id = deviceId,
scope = it.serverContext.context.scope,
address = it.context.state.value.address,
address = it.context.state.value.address,
macAddress = packet.data.macString,
boardType = packet.data.boardType,
protocolVersion = packet.data.protocolVersion,
@@ -299,13 +297,12 @@ val SensorRotationBehaviour = UDPConnectionBehaviour(
},
)
data class UDPConnection(
val context: UDPConnectionContext,
val serverContext: VRServer,
val packetEvents: UDPPacketDispatcher,
val packetChannel: Channel<PacketEvent<UDPPacket>>,
val send: (UDPPacket) -> Unit
val send: (UDPPacket) -> Unit,
) {
fun getDevice(): Device? {
val deviceId = context.state.value.deviceId
@@ -391,5 +388,3 @@ data class UDPConnection(
}
}
}

View File

@@ -90,7 +90,7 @@ data class Handshake(
companion object {
fun read(src: Source): Handshake = with(src) {
if (remaining == 0L) return Handshake()
val b = if (remaining >= 4) BoardType.fromValue(readInt().toUShort()) ?: BoardType.UNKNOWN else BoardType.UNKNOWN
val b = if (remaining >= 4) BoardType.fromValue(readInt().toUShort()) ?: BoardType.UNKNOWN else BoardType.UNKNOWN
val i = if (remaining >= 4) readInt() else 0
val m = if (remaining >= 4) McuType.fromValue(readInt().toUShort()) ?: McuType.Other else McuType.Other
if (remaining >= 12) {
@@ -304,49 +304,27 @@ data class ProtocolChange(val targetProtocol: Int = 0, val targetVersion: Int =
fun readPacket(type: PacketType, src: Source): UDPPacket = when (type) {
PacketType.HEARTBEAT -> Heartbeat
PacketType.HANDSHAKE -> Handshake.read(src)
PacketType.ROTATION -> Rotation.read(src)
PacketType.ACCEL -> Accel.read(src)
PacketType.PING_PONG -> PingPong.read(src)
PacketType.SERIAL -> Serial.read(src)
PacketType.BATTERY_LEVEL -> BatteryLevel.read(src)
PacketType.TAP -> Tap.read(src)
PacketType.ERROR -> ErrorPacket.read(src)
PacketType.SENSOR_INFO -> SensorInfo.read(src)
PacketType.ROTATION_2 -> Rotation2.read(src)
PacketType.ROTATION_DATA -> RotationData.read(src)
PacketType.MAGNETOMETER_ACCURACY -> MagnetometerAccuracy.read(src)
PacketType.SIGNAL_STRENGTH -> SignalStrength.read(src)
PacketType.TEMPERATURE -> Temperature.read(src)
PacketType.USER_ACTION -> UserActionPacket.read(src)
PacketType.FEATURE_FLAGS -> FeatureFlags.read(src)
PacketType.ROTATION_AND_ACCEL -> RotationAndAccel.read(src)
PacketType.ACK_CONFIG_CHANGE -> AckConfigChange.read(src)
PacketType.SET_CONFIG_FLAG -> SetConfigFlag()
PacketType.FLEX_DATA -> FlexData.read(src)
PacketType.POSITION -> PositionPacket.read(src)
PacketType.PROTOCOL_CHANGE -> ProtocolChange.read(src)
}

View File

@@ -71,8 +71,9 @@ suspend fun createUDPTrackerServer(
newContext.packetChannel.trySend(event)
}
}
if (took.inWholeMilliseconds > 2)
if (took.inWholeMilliseconds > 2) {
AppLogger.udp.warn("Packet processing took too long ${took.inWholeMilliseconds}")
}
}
}
}

View File

@@ -28,4 +28,4 @@ fun buildTestSerialServer(scope: CoroutineScope) = SerialServer.create(
fun buildTestVrServer(scope: CoroutineScope): VRServer {
val serialServer = buildTestSerialServer(scope)
return VRServer.create(scope, serialServer, createFirmwareManager(serialServer, scope))
}
}

View File

@@ -14,6 +14,8 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import solarxr_protocol.datatypes.hardware_info.BoardType
import solarxr_protocol.datatypes.hardware_info.McuType
import solarxr_protocol.rpc.FirmwareUpdateStatus
import kotlin.test.Test
import kotlin.test.assertEquals
@@ -296,7 +298,17 @@ class DoSerialFlashTest {
delay(200)
server.onDataReceived("COM1", "looking for the server")
delay(300)
val device = createDevice(backgroundScope, vrServer.nextHandle(), address = "192.168.1.100", macAddress = "AA:BB:CC:DD:EE:FF", DeviceOrigin.UDP, vrServer)
val device = createDevice(
backgroundScope,
id = vrServer.nextHandle(),
address = "192.168.1.100",
macAddress = "AA:BB:CC:DD:EE:FF",
origin = DeviceOrigin.UDP,
protocolVersion = 0,
serverContext = vrServer,
boardType = BoardType.SLIMEVR,
mcuType = McuType.ESP8266
)
vrServer.context.dispatch(VRServerActions.NewDevice(device.context.state.value.id, device))
}

View File

@@ -13,13 +13,12 @@ import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
private fun serialJob(port: String, status: FirmwareUpdateStatus, progress: Int = 0) =
FirmwareManagerActions.UpdateJob(
portLocation = port,
firmwareDeviceId = SerialDevicePort(port = port),
status = status,
progress = progress,
)
private fun serialJob(port: String, status: FirmwareUpdateStatus, progress: Int = 0) = FirmwareManagerActions.UpdateJob(
portLocation = port,
firmwareDeviceId = SerialDevicePort(port = port),
status = status,
progress = progress,
)
class FirmwareManagerReducerTest {
private fun makeContext(scope: kotlinx.coroutines.CoroutineScope) = createContext(
@@ -86,4 +85,4 @@ class FirmwareManagerReducerTest {
assertEquals(1, context.state.value.jobs.size)
}
}
}

View File

@@ -51,4 +51,4 @@ class SerialConnectionReducerTest {
val result = reducer(state(connected = true), SerialConnectionActions.Disconnected)
assertFalse(result.connected)
}
}
}

View File

@@ -83,5 +83,5 @@ class DataFeedTest {
assertEquals(0, sendCount)
}
//TODO: need more tests for the content of a datafeed + check if the masks work
// TODO: need more tests for the content of a datafeed + check if the masks work
}