swift-apple-intelligence-grpc/Sources/AppleIntelligenceCore/Services/SpeechToTextService.swift
Mathias Beaulieu-Duncan f7b8fbfa36 Fix STT streaming to receive audio from gRPC client
- Fix streaming STT to accept audio chunks from gRPC stream instead of local microphone
- Add proper PCM audio buffer conversion for 16-bit, 16kHz, mono audio
- Add StreamingResultHandler for safe callback handling
- Properly manage streaming session state and cleanup

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-31 03:40:46 -05:00

388 lines
13 KiB
Swift

import Foundation
import Speech
import AVFoundation
// MARK: - Result Types
/// Transcription result
public struct TranscriptionResult: Sendable {
public let text: String
public let segments: [TranscriptionSegmentResult]
public let detectedLanguage: String
public let confidence: Float
}
/// Individual transcription segment
public struct TranscriptionSegmentResult: Sendable {
public let text: String
public let startTime: Float
public let endTime: Float
public let confidence: Float
}
/// Streaming transcription update
public struct StreamingTranscriptionUpdate: Sendable {
public let partialText: String
public let isFinal: Bool
public let finalText: String?
public let segments: [TranscriptionSegmentResult]
}
/// Transcription configuration
public struct TranscriptionConfig: Sendable {
public var languageCode: String?
public var enablePunctuation: Bool
public var enableTimestamps: Bool
public static let `default` = TranscriptionConfig(
languageCode: nil,
enablePunctuation: true,
enableTimestamps: false
)
public init(
languageCode: String? = nil,
enablePunctuation: Bool = true,
enableTimestamps: Bool = false
) {
self.languageCode = languageCode
self.enablePunctuation = enablePunctuation
self.enableTimestamps = enableTimestamps
}
}
// MARK: - Errors
public enum SpeechToTextError: Error, CustomStringConvertible, Sendable {
case notAvailable
case authorizationDenied
case modelNotReady(String)
case transcriptionFailed(String)
case invalidAudioFormat
case audioProcessingFailed(String)
case unsupportedMimeType(String)
public var description: String {
switch self {
case .notAvailable: return "Speech recognition not available on this system"
case .authorizationDenied: return "Speech recognition authorization denied"
case .modelNotReady(let reason): return "Speech model not ready: \(reason)"
case .transcriptionFailed(let reason): return "Transcription failed: \(reason)"
case .invalidAudioFormat: return "Invalid audio format"
case .audioProcessingFailed(let reason): return "Audio processing failed: \(reason)"
case .unsupportedMimeType(let type): return "Unsupported audio MIME type: \(type)"
}
}
}
// MARK: - Service Actor
public actor SpeechToTextService {
/// Service availability status
public private(set) var isAvailable: Bool = false
/// Streaming session state
private var isStreamingActive: Bool = false
private var streamingRequest: SFSpeechAudioBufferRecognitionRequest?
private var streamingRecognizer: SFSpeechRecognizer?
private var streamingTask: SFSpeechRecognitionTask?
private var streamingContinuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation?
public init() async {
await checkAvailability()
}
// MARK: - Public API
/// Transcribe audio data (file-based)
public func transcribe(
audioData: Data,
mimeType: String,
config: TranscriptionConfig = .default
) async throws -> TranscriptionResult {
guard isAvailable else {
throw SpeechToTextError.notAvailable
}
// Convert audio data to file URL for processing
let tempURL = try createTempAudioFile(data: audioData, mimeType: mimeType)
defer { try? FileManager.default.removeItem(at: tempURL) }
return try await transcribeWithSFSpeechRecognizer(url: tempURL, config: config)
}
/// Stream transcription from audio chunks sent via gRPC
public func streamTranscribe(
config: TranscriptionConfig = .default
) -> AsyncThrowingStream<StreamingTranscriptionUpdate, Error> {
AsyncThrowingStream { continuation in
Task {
guard await self.isAvailable else {
continuation.finish(throwing: SpeechToTextError.notAvailable)
return
}
do {
try await self.startStreamingSession(config: config, continuation: continuation)
} catch {
continuation.finish(throwing: error)
}
}
}
}
/// Feed audio chunk for streaming transcription (PCM audio data)
public func feedAudioChunk(_ chunk: Data) async throws {
guard isStreamingActive, let request = streamingRequest else {
throw SpeechToTextError.transcriptionFailed("No active streaming session")
}
// Convert raw PCM data to audio buffer
// Assuming 16-bit PCM, mono, 16kHz (common format for speech)
let audioFormat = AVAudioFormat(
commonFormat: .pcmFormatInt16,
sampleRate: 16000,
channels: 1,
interleaved: true
)!
let frameCount = UInt32(chunk.count / 2) // 2 bytes per Int16 sample
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFormat, frameCapacity: frameCount) else {
throw SpeechToTextError.audioProcessingFailed("Failed to create audio buffer")
}
buffer.frameLength = frameCount
// Copy data into buffer
chunk.withUnsafeBytes { rawPtr in
if let int16Ptr = rawPtr.baseAddress?.assumingMemoryBound(to: Int16.self) {
buffer.int16ChannelData?[0].update(from: int16Ptr, count: Int(frameCount))
}
}
request.append(buffer)
}
/// End streaming session
public func endStreamingSession() async {
streamingRequest?.endAudio()
isStreamingActive = false
streamingRequest = nil
streamingTask = nil
streamingRecognizer = nil
streamingContinuation = nil
}
/// Get status information
public func getStatus() -> String {
if isAvailable {
return "SFSpeechRecognizer available"
} else {
return "Speech recognition not available"
}
}
// MARK: - Private Implementation
private func checkAvailability() async {
// Check SFSpeechRecognizer availability
let status = SFSpeechRecognizer.authorizationStatus()
switch status {
case .authorized:
isAvailable = SFSpeechRecognizer.supportedLocales().count > 0
case .notDetermined:
// Request authorization
isAvailable = await withCheckedContinuation { continuation in
SFSpeechRecognizer.requestAuthorization { newStatus in
continuation.resume(returning: newStatus == .authorized)
}
}
default:
isAvailable = false
}
}
/// Create temporary audio file from data
private func createTempAudioFile(data: Data, mimeType: String) throws -> URL {
let ext = extensionForMimeType(mimeType)
let tempDir = FileManager.default.temporaryDirectory
let fileName = UUID().uuidString + "." + ext
let fileURL = tempDir.appendingPathComponent(fileName)
try data.write(to: fileURL)
return fileURL
}
/// Get file extension for MIME type
private func extensionForMimeType(_ mimeType: String) -> String {
switch mimeType.lowercased() {
case "audio/wav", "audio/wave", "audio/x-wav":
return "wav"
case "audio/mp3", "audio/mpeg":
return "mp3"
case "audio/m4a", "audio/mp4", "audio/x-m4a":
return "m4a"
case "audio/aac":
return "aac"
case "audio/flac":
return "flac"
default:
return "wav"
}
}
/// Transcribe using SFSpeechRecognizer
private func transcribeWithSFSpeechRecognizer(
url: URL,
config: TranscriptionConfig
) async throws -> TranscriptionResult {
let locale = Locale(identifier: config.languageCode ?? "en-US")
guard let recognizer = SFSpeechRecognizer(locale: locale) else {
throw SpeechToTextError.notAvailable
}
guard recognizer.isAvailable else {
throw SpeechToTextError.notAvailable
}
let request = SFSpeechURLRecognitionRequest(url: url)
request.shouldReportPartialResults = false
return try await withCheckedThrowingContinuation { continuation in
var hasResumed = false
recognizer.recognitionTask(with: request) { result, error in
guard !hasResumed else { return }
if let error = error {
hasResumed = true
continuation.resume(throwing: SpeechToTextError.transcriptionFailed(error.localizedDescription))
return
}
guard let result = result, result.isFinal else { return }
hasResumed = true
let transcription = result.bestTranscription
var segments: [TranscriptionSegmentResult] = []
if config.enableTimestamps {
for segment in transcription.segments {
segments.append(TranscriptionSegmentResult(
text: segment.substring,
startTime: Float(segment.timestamp),
endTime: Float(segment.timestamp + segment.duration),
confidence: segment.confidence
))
}
}
let transcriptionResult = TranscriptionResult(
text: transcription.formattedString,
segments: segments,
detectedLanguage: config.languageCode ?? "en-US",
confidence: segments.isEmpty ? 1.0 : segments.reduce(0) { $0 + $1.confidence } / Float(segments.count)
)
continuation.resume(returning: transcriptionResult)
}
}
}
/// Start streaming session for gRPC audio chunks
private func startStreamingSession(
config: TranscriptionConfig,
continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation
) async throws {
let locale = Locale(identifier: config.languageCode ?? "en-US")
guard let recognizer = SFSpeechRecognizer(locale: locale) else {
throw SpeechToTextError.notAvailable
}
guard recognizer.isAvailable else {
throw SpeechToTextError.notAvailable
}
// Set up streaming state
isStreamingActive = true
streamingRecognizer = recognizer
streamingContinuation = continuation
let request = SFSpeechAudioBufferRecognitionRequest()
request.shouldReportPartialResults = true
streamingRequest = request
// Create wrapper to handle results safely
let service = self
let resultHandler = StreamingResultHandler(
config: config,
continuation: continuation,
onFinish: {
Task { await service.endStreamingSession() }
}
)
streamingTask = recognizer.recognitionTask(with: request) { result, error in
resultHandler.handleResult(result: result, error: error)
}
}
}
// MARK: - Streaming Result Handler
/// Wrapper to safely handle streaming recognition results
private final class StreamingResultHandler: @unchecked Sendable {
private let config: TranscriptionConfig
private let continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation
private let onFinish: () -> Void
init(
config: TranscriptionConfig,
continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation,
onFinish: @escaping () -> Void
) {
self.config = config
self.continuation = continuation
self.onFinish = onFinish
}
func handleResult(result: SFSpeechRecognitionResult?, error: Error?) {
if let error = error {
continuation.finish(throwing: SpeechToTextError.transcriptionFailed(error.localizedDescription))
onFinish()
return
}
guard let result = result else { return }
let transcription = result.bestTranscription
var segments: [TranscriptionSegmentResult] = []
if config.enableTimestamps {
for segment in transcription.segments {
segments.append(TranscriptionSegmentResult(
text: segment.substring,
startTime: Float(segment.timestamp),
endTime: Float(segment.timestamp + segment.duration),
confidence: segment.confidence
))
}
}
let update = StreamingTranscriptionUpdate(
partialText: transcription.formattedString,
isFinal: result.isFinal,
finalText: result.isFinal ? transcription.formattedString : nil,
segments: segments
)
continuation.yield(update)
if result.isFinal {
continuation.finish()
onFinish()
}
}
}