Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ public struct OfflineDiarizerModels: Sendable {
logger.info("Loading offline diarization models from \(modelsDirectory.path)")

let loadStart = Date()
let inferenceComputeUnits: MLComputeUnits = .all
// Honor a caller-supplied compute-units choice (the `configuration:`
// parameter was previously accepted but ignored here). Defaults to `.all`,
// so existing callers are unaffected; callers that need CPU/GPU-only
// inference (e.g. to avoid cross-device ANE numeric variance in the
// embeddings) can now pass `MLModelConfiguration(computeUnits:)`.
let inferenceComputeUnits: MLComputeUnits = configuration?.computeUnits ?? .all

let segmentationAndEmbeddingNames: [String] = [
ModelNames.OfflineDiarizer.segmentationPath,
Expand Down
24 changes: 24 additions & 0 deletions Tests/FluidAudioTests/CI/BasicInitializationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,30 @@ extension CoreMLDiarizerTests {

XCTAssertEqual(models.embeddingModel.configuration.computeUnits, customConfig.computeUnits)
}

/// Tests that the OFFLINE diarizer model loader honors a user-specified
/// configuration's compute units. This is parity with
/// `testModelLoadingCustomConfig` above (the streaming `DiarizerModels`
/// loader already honored it); `OfflineDiarizerModels.load(configuration:)`
/// previously accepted the parameter but ignored it and always loaded `.all`.
func testOfflineModelLoadingCustomConfig() async throws {

XCTExpectFailure("Download might fail in CI environment", strict: false)

let customConfig = MLModelConfiguration()
customConfig.computeUnits = .cpuOnly

let models = try await OfflineDiarizerModels.load(configuration: customConfig)

// Segmentation, embedding, and PLDA-rho models load with the requested
// compute units; the fbank front-end intentionally stays on `.cpuOnly`.
XCTAssertEqual(
models.segmentationModel.configuration.computeUnits, customConfig.computeUnits)
XCTAssertEqual(
models.embeddingModel.configuration.computeUnits, customConfig.computeUnits)
XCTAssertEqual(
models.pldaRhoModel.configuration.computeUnits, customConfig.computeUnits)
}
}

// MARK: - CoreML Backend Specific Test
Expand Down