- Fix streaming STT to accept audio chunks from gRPC stream instead of local microphone - Add proper PCM audio buffer conversion for 16-bit, 16kHz, mono audio - Add StreamingResultHandler for safe callback handling - Properly manage streaming session state and cleanup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
388 lines
13 KiB
Swift
388 lines
13 KiB
Swift
import Foundation
|
|
import Speech
|
|
import AVFoundation
|
|
|
|
// MARK: - Result Types
|
|
|
|
/// Transcription result
|
|
public struct TranscriptionResult: Sendable {
|
|
public let text: String
|
|
public let segments: [TranscriptionSegmentResult]
|
|
public let detectedLanguage: String
|
|
public let confidence: Float
|
|
}
|
|
|
|
/// Individual transcription segment
|
|
public struct TranscriptionSegmentResult: Sendable {
|
|
public let text: String
|
|
public let startTime: Float
|
|
public let endTime: Float
|
|
public let confidence: Float
|
|
}
|
|
|
|
/// Streaming transcription update
|
|
public struct StreamingTranscriptionUpdate: Sendable {
|
|
public let partialText: String
|
|
public let isFinal: Bool
|
|
public let finalText: String?
|
|
public let segments: [TranscriptionSegmentResult]
|
|
}
|
|
|
|
/// Transcription configuration
|
|
public struct TranscriptionConfig: Sendable {
|
|
public var languageCode: String?
|
|
public var enablePunctuation: Bool
|
|
public var enableTimestamps: Bool
|
|
|
|
public static let `default` = TranscriptionConfig(
|
|
languageCode: nil,
|
|
enablePunctuation: true,
|
|
enableTimestamps: false
|
|
)
|
|
|
|
public init(
|
|
languageCode: String? = nil,
|
|
enablePunctuation: Bool = true,
|
|
enableTimestamps: Bool = false
|
|
) {
|
|
self.languageCode = languageCode
|
|
self.enablePunctuation = enablePunctuation
|
|
self.enableTimestamps = enableTimestamps
|
|
}
|
|
}
|
|
|
|
// MARK: - Errors
|
|
|
|
public enum SpeechToTextError: Error, CustomStringConvertible, Sendable {
|
|
case notAvailable
|
|
case authorizationDenied
|
|
case modelNotReady(String)
|
|
case transcriptionFailed(String)
|
|
case invalidAudioFormat
|
|
case audioProcessingFailed(String)
|
|
case unsupportedMimeType(String)
|
|
|
|
public var description: String {
|
|
switch self {
|
|
case .notAvailable: return "Speech recognition not available on this system"
|
|
case .authorizationDenied: return "Speech recognition authorization denied"
|
|
case .modelNotReady(let reason): return "Speech model not ready: \(reason)"
|
|
case .transcriptionFailed(let reason): return "Transcription failed: \(reason)"
|
|
case .invalidAudioFormat: return "Invalid audio format"
|
|
case .audioProcessingFailed(let reason): return "Audio processing failed: \(reason)"
|
|
case .unsupportedMimeType(let type): return "Unsupported audio MIME type: \(type)"
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Service Actor
|
|
|
|
public actor SpeechToTextService {
|
|
|
|
/// Service availability status
|
|
public private(set) var isAvailable: Bool = false
|
|
|
|
/// Streaming session state
|
|
private var isStreamingActive: Bool = false
|
|
private var streamingRequest: SFSpeechAudioBufferRecognitionRequest?
|
|
private var streamingRecognizer: SFSpeechRecognizer?
|
|
private var streamingTask: SFSpeechRecognitionTask?
|
|
private var streamingContinuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation?
|
|
|
|
public init() async {
|
|
await checkAvailability()
|
|
}
|
|
|
|
// MARK: - Public API
|
|
|
|
/// Transcribe audio data (file-based)
|
|
public func transcribe(
|
|
audioData: Data,
|
|
mimeType: String,
|
|
config: TranscriptionConfig = .default
|
|
) async throws -> TranscriptionResult {
|
|
guard isAvailable else {
|
|
throw SpeechToTextError.notAvailable
|
|
}
|
|
|
|
// Convert audio data to file URL for processing
|
|
let tempURL = try createTempAudioFile(data: audioData, mimeType: mimeType)
|
|
defer { try? FileManager.default.removeItem(at: tempURL) }
|
|
|
|
return try await transcribeWithSFSpeechRecognizer(url: tempURL, config: config)
|
|
}
|
|
|
|
/// Stream transcription from audio chunks sent via gRPC
|
|
public func streamTranscribe(
|
|
config: TranscriptionConfig = .default
|
|
) -> AsyncThrowingStream<StreamingTranscriptionUpdate, Error> {
|
|
AsyncThrowingStream { continuation in
|
|
Task {
|
|
guard await self.isAvailable else {
|
|
continuation.finish(throwing: SpeechToTextError.notAvailable)
|
|
return
|
|
}
|
|
|
|
do {
|
|
try await self.startStreamingSession(config: config, continuation: continuation)
|
|
} catch {
|
|
continuation.finish(throwing: error)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Feed audio chunk for streaming transcription (PCM audio data)
|
|
public func feedAudioChunk(_ chunk: Data) async throws {
|
|
guard isStreamingActive, let request = streamingRequest else {
|
|
throw SpeechToTextError.transcriptionFailed("No active streaming session")
|
|
}
|
|
|
|
// Convert raw PCM data to audio buffer
|
|
// Assuming 16-bit PCM, mono, 16kHz (common format for speech)
|
|
let audioFormat = AVAudioFormat(
|
|
commonFormat: .pcmFormatInt16,
|
|
sampleRate: 16000,
|
|
channels: 1,
|
|
interleaved: true
|
|
)!
|
|
|
|
let frameCount = UInt32(chunk.count / 2) // 2 bytes per Int16 sample
|
|
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFormat, frameCapacity: frameCount) else {
|
|
throw SpeechToTextError.audioProcessingFailed("Failed to create audio buffer")
|
|
}
|
|
|
|
buffer.frameLength = frameCount
|
|
|
|
// Copy data into buffer
|
|
chunk.withUnsafeBytes { rawPtr in
|
|
if let int16Ptr = rawPtr.baseAddress?.assumingMemoryBound(to: Int16.self) {
|
|
buffer.int16ChannelData?[0].update(from: int16Ptr, count: Int(frameCount))
|
|
}
|
|
}
|
|
|
|
request.append(buffer)
|
|
}
|
|
|
|
/// End streaming session
|
|
public func endStreamingSession() async {
|
|
streamingRequest?.endAudio()
|
|
isStreamingActive = false
|
|
streamingRequest = nil
|
|
streamingTask = nil
|
|
streamingRecognizer = nil
|
|
streamingContinuation = nil
|
|
}
|
|
|
|
/// Get status information
|
|
public func getStatus() -> String {
|
|
if isAvailable {
|
|
return "SFSpeechRecognizer available"
|
|
} else {
|
|
return "Speech recognition not available"
|
|
}
|
|
}
|
|
|
|
// MARK: - Private Implementation
|
|
|
|
private func checkAvailability() async {
|
|
// Check SFSpeechRecognizer availability
|
|
let status = SFSpeechRecognizer.authorizationStatus()
|
|
switch status {
|
|
case .authorized:
|
|
isAvailable = SFSpeechRecognizer.supportedLocales().count > 0
|
|
case .notDetermined:
|
|
// Request authorization
|
|
isAvailable = await withCheckedContinuation { continuation in
|
|
SFSpeechRecognizer.requestAuthorization { newStatus in
|
|
continuation.resume(returning: newStatus == .authorized)
|
|
}
|
|
}
|
|
default:
|
|
isAvailable = false
|
|
}
|
|
}
|
|
|
|
/// Create temporary audio file from data
|
|
private func createTempAudioFile(data: Data, mimeType: String) throws -> URL {
|
|
let ext = extensionForMimeType(mimeType)
|
|
let tempDir = FileManager.default.temporaryDirectory
|
|
let fileName = UUID().uuidString + "." + ext
|
|
let fileURL = tempDir.appendingPathComponent(fileName)
|
|
|
|
try data.write(to: fileURL)
|
|
return fileURL
|
|
}
|
|
|
|
/// Get file extension for MIME type
|
|
private func extensionForMimeType(_ mimeType: String) -> String {
|
|
switch mimeType.lowercased() {
|
|
case "audio/wav", "audio/wave", "audio/x-wav":
|
|
return "wav"
|
|
case "audio/mp3", "audio/mpeg":
|
|
return "mp3"
|
|
case "audio/m4a", "audio/mp4", "audio/x-m4a":
|
|
return "m4a"
|
|
case "audio/aac":
|
|
return "aac"
|
|
case "audio/flac":
|
|
return "flac"
|
|
default:
|
|
return "wav"
|
|
}
|
|
}
|
|
|
|
/// Transcribe using SFSpeechRecognizer
|
|
private func transcribeWithSFSpeechRecognizer(
|
|
url: URL,
|
|
config: TranscriptionConfig
|
|
) async throws -> TranscriptionResult {
|
|
let locale = Locale(identifier: config.languageCode ?? "en-US")
|
|
guard let recognizer = SFSpeechRecognizer(locale: locale) else {
|
|
throw SpeechToTextError.notAvailable
|
|
}
|
|
|
|
guard recognizer.isAvailable else {
|
|
throw SpeechToTextError.notAvailable
|
|
}
|
|
|
|
let request = SFSpeechURLRecognitionRequest(url: url)
|
|
request.shouldReportPartialResults = false
|
|
|
|
return try await withCheckedThrowingContinuation { continuation in
|
|
var hasResumed = false
|
|
|
|
recognizer.recognitionTask(with: request) { result, error in
|
|
guard !hasResumed else { return }
|
|
|
|
if let error = error {
|
|
hasResumed = true
|
|
continuation.resume(throwing: SpeechToTextError.transcriptionFailed(error.localizedDescription))
|
|
return
|
|
}
|
|
|
|
guard let result = result, result.isFinal else { return }
|
|
|
|
hasResumed = true
|
|
|
|
let transcription = result.bestTranscription
|
|
var segments: [TranscriptionSegmentResult] = []
|
|
|
|
if config.enableTimestamps {
|
|
for segment in transcription.segments {
|
|
segments.append(TranscriptionSegmentResult(
|
|
text: segment.substring,
|
|
startTime: Float(segment.timestamp),
|
|
endTime: Float(segment.timestamp + segment.duration),
|
|
confidence: segment.confidence
|
|
))
|
|
}
|
|
}
|
|
|
|
let transcriptionResult = TranscriptionResult(
|
|
text: transcription.formattedString,
|
|
segments: segments,
|
|
detectedLanguage: config.languageCode ?? "en-US",
|
|
confidence: segments.isEmpty ? 1.0 : segments.reduce(0) { $0 + $1.confidence } / Float(segments.count)
|
|
)
|
|
|
|
continuation.resume(returning: transcriptionResult)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Start streaming session for gRPC audio chunks
|
|
private func startStreamingSession(
|
|
config: TranscriptionConfig,
|
|
continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation
|
|
) async throws {
|
|
let locale = Locale(identifier: config.languageCode ?? "en-US")
|
|
guard let recognizer = SFSpeechRecognizer(locale: locale) else {
|
|
throw SpeechToTextError.notAvailable
|
|
}
|
|
|
|
guard recognizer.isAvailable else {
|
|
throw SpeechToTextError.notAvailable
|
|
}
|
|
|
|
// Set up streaming state
|
|
isStreamingActive = true
|
|
streamingRecognizer = recognizer
|
|
streamingContinuation = continuation
|
|
|
|
let request = SFSpeechAudioBufferRecognitionRequest()
|
|
request.shouldReportPartialResults = true
|
|
streamingRequest = request
|
|
|
|
// Create wrapper to handle results safely
|
|
let service = self
|
|
let resultHandler = StreamingResultHandler(
|
|
config: config,
|
|
continuation: continuation,
|
|
onFinish: {
|
|
Task { await service.endStreamingSession() }
|
|
}
|
|
)
|
|
|
|
streamingTask = recognizer.recognitionTask(with: request) { result, error in
|
|
resultHandler.handleResult(result: result, error: error)
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - Streaming Result Handler
|
|
|
|
/// Wrapper to safely handle streaming recognition results
|
|
private final class StreamingResultHandler: @unchecked Sendable {
|
|
private let config: TranscriptionConfig
|
|
private let continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation
|
|
private let onFinish: () -> Void
|
|
|
|
init(
|
|
config: TranscriptionConfig,
|
|
continuation: AsyncThrowingStream<StreamingTranscriptionUpdate, Error>.Continuation,
|
|
onFinish: @escaping () -> Void
|
|
) {
|
|
self.config = config
|
|
self.continuation = continuation
|
|
self.onFinish = onFinish
|
|
}
|
|
|
|
func handleResult(result: SFSpeechRecognitionResult?, error: Error?) {
|
|
if let error = error {
|
|
continuation.finish(throwing: SpeechToTextError.transcriptionFailed(error.localizedDescription))
|
|
onFinish()
|
|
return
|
|
}
|
|
|
|
guard let result = result else { return }
|
|
|
|
let transcription = result.bestTranscription
|
|
var segments: [TranscriptionSegmentResult] = []
|
|
|
|
if config.enableTimestamps {
|
|
for segment in transcription.segments {
|
|
segments.append(TranscriptionSegmentResult(
|
|
text: segment.substring,
|
|
startTime: Float(segment.timestamp),
|
|
endTime: Float(segment.timestamp + segment.duration),
|
|
confidence: segment.confidence
|
|
))
|
|
}
|
|
}
|
|
|
|
let update = StreamingTranscriptionUpdate(
|
|
partialText: transcription.formattedString,
|
|
isFinal: result.isFinal,
|
|
finalText: result.isFinal ? transcription.formattedString : nil,
|
|
segments: segments
|
|
)
|
|
continuation.yield(update)
|
|
|
|
if result.isFinal {
|
|
continuation.finish()
|
|
onFinish()
|
|
}
|
|
}
|
|
}
|