diff --git a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerModels.swift b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerModels.swift index 7f882f028..51f98e6f5 100644 --- a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerModels.swift +++ b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerModels.swift @@ -85,7 +85,12 @@ public struct OfflineDiarizerModels: Sendable { logger.info("Loading offline diarization models from \(modelsDirectory.path)") let loadStart = Date() - let inferenceComputeUnits: MLComputeUnits = .all + // Honor a caller-supplied compute-units choice (the `configuration:` + // parameter was previously accepted but ignored here). Defaults to `.all`, + // so existing callers are unaffected; callers that need CPU/GPU-only + // inference (e.g. to avoid cross-device ANE numeric variance in the + // embeddings) can now pass `MLModelConfiguration(computeUnits:)`. + let inferenceComputeUnits: MLComputeUnits = configuration?.computeUnits ?? .all let segmentationAndEmbeddingNames: [String] = [ ModelNames.OfflineDiarizer.segmentationPath, diff --git a/Tests/FluidAudioTests/CI/BasicInitializationTests.swift b/Tests/FluidAudioTests/CI/BasicInitializationTests.swift index fe05d44c1..bfa0ab25e 100644 --- a/Tests/FluidAudioTests/CI/BasicInitializationTests.swift +++ b/Tests/FluidAudioTests/CI/BasicInitializationTests.swift @@ -348,6 +348,30 @@ extension CoreMLDiarizerTests { XCTAssertEqual(models.embeddingModel.configuration.computeUnits, customConfig.computeUnits) } + + /// Tests that the OFFLINE diarizer model loader honors a user-specified + /// configuration's compute units. This is parity with + /// `testModelLoadingCustomConfig` above (the streaming `DiarizerModels` + /// loader already honored it); `OfflineDiarizerModels.load(configuration:)` + /// previously accepted the parameter but ignored it and always loaded `.all`. + func testOfflineModelLoadingCustomConfig() async throws { + + XCTExpectFailure("Download might fail in CI environment", strict: false) + + let customConfig = MLModelConfiguration() + customConfig.computeUnits = .cpuOnly + + let models = try await OfflineDiarizerModels.load(configuration: customConfig) + + // Segmentation, embedding, and PLDA-rho models load with the requested + // compute units; the fbank front-end intentionally stays on `.cpuOnly`. + XCTAssertEqual( + models.segmentationModel.configuration.computeUnits, customConfig.computeUnits) + XCTAssertEqual( + models.embeddingModel.configuration.computeUnits, customConfig.computeUnits) + XCTAssertEqual( + models.pldaRhoModel.configuration.computeUnits, customConfig.computeUnits) + } } // MARK: - CoreML Backend Specific Test