-
Notifications
You must be signed in to change notification settings - Fork 352
[Flutter SDK] Add configuration validation for LLM, STT, and TTS (#450) #456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e98124d
15138c5
7154800
1f85521
40829dd
01d9b3c
616c961
418f015
115e330
8804c3d
875a9f7
86851fe
1fc42e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import 'package:runanywhere/core/protocols/component/component_configuration.dart'; | ||
| import 'package:runanywhere/core/types/model_types.dart'; | ||
| import 'package:runanywhere/foundation/error_types/sdk_error.dart'; | ||
|
|
||
| /// Configuration for the LLM component. | ||
| /// | ||
| /// Mirrors the validation contract used by the Swift and Kotlin SDKs so | ||
| /// invalid parameters fail in Dart before crossing the FFI boundary. | ||
| class LLMConfiguration implements ComponentConfiguration { | ||
| final String? modelId; | ||
| final InferenceFramework? preferredFramework; | ||
| final int contextLength; | ||
| final double temperature; | ||
| final int maxTokens; | ||
| final String? systemPrompt; | ||
| final bool streamingEnabled; | ||
|
|
||
| const LLMConfiguration({ | ||
| this.modelId, | ||
| this.preferredFramework, | ||
| this.contextLength = 2048, | ||
| this.temperature = 0.7, | ||
| this.maxTokens = 100, | ||
| this.systemPrompt, | ||
| this.streamingEnabled = true, | ||
| }); | ||
|
|
||
| @override | ||
| void validate() { | ||
| if (contextLength <= 0 || contextLength > 32768) { | ||
| throw SDKError.validationFailed( | ||
| 'Context length must be between 1 and 32768', | ||
| ); | ||
| } | ||
|
|
||
| if (!temperature.isFinite || temperature < 0 || temperature > 2.0) { | ||
| throw SDKError.validationFailed( | ||
| 'Temperature must be between 0 and 2.0', | ||
| ); | ||
| } | ||
|
|
||
| if (maxTokens <= 0 || maxTokens > contextLength) { | ||
| throw SDKError.validationFailed( | ||
| 'Max tokens must be between 1 and context length', | ||
| ); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import 'package:runanywhere/core/protocols/component/component_configuration.dart'; | ||
| import 'package:runanywhere/core/types/model_types.dart'; | ||
| import 'package:runanywhere/foundation/error_types/sdk_error.dart'; | ||
|
|
||
| /// Configuration for the STT component. | ||
| /// | ||
| /// Mirrors the validation contract used by the Swift and Kotlin SDKs so | ||
| /// invalid parameters fail in Dart before crossing the FFI boundary. | ||
| class STTConfiguration implements ComponentConfiguration { | ||
| final String? modelId; | ||
| final InferenceFramework? preferredFramework; | ||
| final String language; | ||
| final int sampleRate; | ||
| final bool enablePunctuation; | ||
| final bool enableDiarization; | ||
| final List<String> vocabularyList; | ||
| final int maxAlternatives; | ||
| final bool enableTimestamps; | ||
|
|
||
| const STTConfiguration({ | ||
| this.modelId, | ||
| this.preferredFramework, | ||
| this.language = 'en-US', | ||
| this.sampleRate = 16000, | ||
| this.enablePunctuation = true, | ||
| this.enableDiarization = false, | ||
| this.vocabularyList = const <String>[], | ||
| this.maxAlternatives = 1, | ||
| this.enableTimestamps = true, | ||
| }); | ||
|
|
||
| @override | ||
| void validate() { | ||
| if (sampleRate <= 0 || sampleRate > 48000) { | ||
| throw SDKError.validationFailed( | ||
| 'Sample rate must be between 1 and 48000 Hz', | ||
| ); | ||
| } | ||
|
|
||
| if (maxAlternatives <= 0 || maxAlternatives > 10) { | ||
| throw SDKError.validationFailed( | ||
| 'Max alternatives must be between 1 and 10', | ||
| ); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import 'package:runanywhere/core/protocols/component/component_configuration.dart'; | ||
| import 'package:runanywhere/foundation/error_types/sdk_error.dart'; | ||
|
|
||
| /// Configuration for TTS synthesis. | ||
| class TTSConfiguration implements ComponentConfiguration { | ||
| final String voice; | ||
| final String language; | ||
| final double speakingRate; | ||
| final double pitch; | ||
| final double volume; | ||
| final String audioFormat; | ||
|
|
||
| const TTSConfiguration({ | ||
| this.voice = 'system', | ||
| this.language = 'en-US', | ||
| this.speakingRate = 0.5, | ||
| this.pitch = 1.0, | ||
| this.volume = 1.0, | ||
| this.audioFormat = 'pcm', | ||
| }); | ||
|
|
||
| @override | ||
| void validate() { | ||
| if (!speakingRate.isFinite || speakingRate < 0.5 || speakingRate > 2.0) { | ||
| throw SDKError.validationFailed( | ||
| 'Speaking rate must be between 0.5 and 2.0', | ||
| ); | ||
| } | ||
|
|
||
| if (!pitch.isFinite || pitch < 0.5 || pitch > 2.0) { | ||
| throw SDKError.validationFailed( | ||
| 'Pitch must be between 0.5 and 2.0', | ||
| ); | ||
| } | ||
|
|
||
| if (!volume.isFinite || volume < 0.0 || volume > 1.0) { | ||
| throw SDKError.validationFailed( | ||
| 'Volume must be between 0.0 and 1.0', | ||
| ); | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ import 'dart:ffi'; | |
| import 'dart:isolate'; // Keep for non-streaming generation | ||
|
|
||
| import 'package:ffi/ffi.dart'; | ||
|
|
||
| import 'package:runanywhere/features/llm/llm_configuration.dart'; | ||
| import 'package:runanywhere/foundation/error_types/sdk_error.dart'; | ||
| import 'package:runanywhere/foundation/logging/sdk_logger.dart'; | ||
| import 'package:runanywhere/native/ffi_types.dart'; | ||
| import 'package:runanywhere/native/platform_loader.dart'; | ||
|
|
@@ -48,6 +49,7 @@ class DartBridgeLLM { | |
|
|
||
| RacHandle? _handle; | ||
| String? _loadedModelId; | ||
| int? _loadedContextLength; | ||
| final _logger = SDKLogger('DartBridge.LLM'); | ||
|
|
||
| /// Active stream subscription for cancellation | ||
|
|
@@ -153,6 +155,7 @@ class DartBridgeLLM { | |
| String modelPath, | ||
| String modelId, | ||
| String modelName, | ||
| int? contextLength, | ||
| ) async { | ||
| final handle = getHandle(); | ||
|
|
||
|
|
@@ -181,6 +184,7 @@ class DartBridgeLLM { | |
| } | ||
|
|
||
| _loadedModelId = modelId; | ||
| _loadedContextLength = contextLength; | ||
| _logger.info('LLM model loaded: $modelId'); | ||
| } finally { | ||
| calloc.free(pathPtr); | ||
|
|
@@ -200,6 +204,7 @@ class DartBridgeLLM { | |
|
|
||
| cleanupFn(_handle!); | ||
| _loadedModelId = null; | ||
| _loadedContextLength = null; | ||
| _logger.info('LLM model unloaded'); | ||
| } catch (e) { | ||
| _logger.error('Failed to unload LLM model: $e'); | ||
|
|
@@ -247,6 +252,13 @@ class DartBridgeLLM { | |
| throw StateError('No LLM model loaded. Call loadModel() first.'); | ||
| } | ||
|
|
||
| _validateGenerationParameters( | ||
| contextLength: _requireLoadedContextLength(), | ||
| maxTokens: maxTokens, | ||
| temperature: temperature, | ||
| systemPrompt: systemPrompt, | ||
| ); | ||
|
|
||
| // Run FFI call in a separate isolate to avoid heap corruption | ||
| // from C++ background threads (Metal GPU operations) | ||
| final handleAddress = handle.address; | ||
|
|
@@ -290,6 +302,14 @@ class DartBridgeLLM { | |
| throw StateError('No LLM model loaded. Call loadModel() first.'); | ||
| } | ||
|
|
||
| _validateGenerationParameters( | ||
| contextLength: _requireLoadedContextLength(), | ||
| maxTokens: maxTokens, | ||
| temperature: temperature, | ||
| systemPrompt: systemPrompt, | ||
| streamingEnabled: true, | ||
| ); | ||
|
|
||
| // Create stream controller for emitting tokens to the caller | ||
| final controller = StreamController<String>(); | ||
|
|
||
|
|
@@ -367,6 +387,33 @@ class DartBridgeLLM { | |
| } | ||
| } | ||
|
|
||
| int _requireLoadedContextLength() { | ||
| final contextLength = _loadedContextLength; | ||
| if (contextLength != null && contextLength > 0) { | ||
| return contextLength; | ||
| } | ||
|
|
||
| throw SDKError.validationFailed( | ||
| 'Loaded model is missing context length metadata for maxTokens validation', | ||
| ); | ||
| } | ||
|
|
||
| void _validateGenerationParameters({ | ||
| required int contextLength, | ||
| required int maxTokens, | ||
| required double temperature, | ||
| String? systemPrompt, | ||
| bool streamingEnabled = false, | ||
| }) { | ||
| LLMConfiguration( | ||
| contextLength: contextLength, | ||
| maxTokens: maxTokens, | ||
| temperature: temperature, | ||
| systemPrompt: systemPrompt, | ||
| streamingEnabled: streamingEnabled, | ||
| ).validate(); | ||
| } | ||
|
Comment on lines
+401
to
+415
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded The Consider passing the actual model's context length here. Since Prompt To Fix With AIThis is a comment left during a code review.
Path: sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_llm.dart
Line: 383-396
Comment:
**Hardcoded `contextLength` defeats `maxTokens` validation**
The `contextLength` is hardcoded to `32768` (the maximum allowed), which means the `maxTokens <= contextLength` check in `LLMConfiguration.validate()` will never reject any value under 32768. A user could pass `maxTokens: 32768` even though the loaded model may have a much smaller context window (e.g. 2048 or 4096), sending an invalid value across the FFI boundary — exactly what this PR is trying to prevent.
Consider passing the actual model's context length here. Since `DartBridgeLLM` manages the C++ lifecycle, the real context length may be queryable from the native layer, or you could store it when the model is loaded and pass it through to validation.
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| // MARK: - Cleanup | ||
|
|
||
| /// Destroy the component and release resources. | ||
|
|
@@ -380,6 +427,7 @@ class DartBridgeLLM { | |
| destroyFn(_handle!); | ||
| _handle = null; | ||
| _loadedModelId = null; | ||
| _loadedContextLength = null; | ||
| _logger.debug('LLM component destroyed'); | ||
| } catch (e) { | ||
| _logger.error('Failed to destroy LLM component: $e'); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.