Skip to content

Commit c6a40a7

Browse files
committed
feat(llm): add provider-agnostic parameter mapping system
Implements flexible LLM parameter transformation to support provider-specific naming conventions (e.g., max_tokens -> max_new_tokens for HuggingFace)
1 parent 89225dc commit c6a40a7

File tree

7 files changed

+749
-8
lines changed

7 files changed

+749
-8
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
tool_calls_var,
3131
)
3232
from nemoguardrails.integrations.langchain.message_utils import dicts_to_messages
33+
from nemoguardrails.llm.parameter_mapping import get_llm_provider, transform_llm_params
3334
from nemoguardrails.logging.callbacks import logging_callbacks
3435
from nemoguardrails.logging.explain import LLMCallInfo
3536

@@ -97,9 +98,23 @@ async def llm_call(
9798
_setup_llm_call_info(llm, model_name, model_provider)
9899
all_callbacks = _prepare_callbacks(custom_callback_handlers)
99100

100-
generation_llm: Union[BaseLanguageModel, Runnable] = (
101-
llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm
102-
)
101+
if llm_params or stop:
102+
params_to_transform = llm_params.copy() if llm_params else {}
103+
if stop is not None:
104+
params_to_transform["stop"] = stop
105+
106+
inferred_model_name = model_name or _infer_model_name(llm)
107+
inferred_provider = model_provider or get_llm_provider(llm)
108+
transformed_params = transform_llm_params(
109+
params_to_transform,
110+
provider=inferred_provider,
111+
model_name=inferred_model_name,
112+
)
113+
generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind(
114+
**transformed_params
115+
)
116+
else:
117+
generation_llm: Union[BaseLanguageModel, Runnable] = llm
103118

104119
if isinstance(prompt, str):
105120
response = await _invoke_with_string_prompt(
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Module for transforming LLM parameters between internal and provider-specific formats."""
17+
18+
import logging
19+
from typing import Any, Dict, Optional
20+
21+
from langchain.base_language import BaseLanguageModel
22+
23+
log = logging.getLogger(__name__)
24+
25+
_llm_parameter_mappings = {}
26+
27+
PROVIDER_PARAMETER_MAPPINGS = {
28+
"huggingface": {
29+
"max_tokens": "max_new_tokens",
30+
},
31+
"google_vertexai": {
32+
"max_tokens": "max_output_tokens",
33+
},
34+
}
35+
36+
37+
def register_llm_parameter_mapping(
38+
provider: str, model_name: str, parameter_mapping: Dict[str, Optional[str]]
39+
) -> None:
40+
"""Register a parameter mapping for a specific provider and model combination.
41+
42+
Args:
43+
provider: The LLM provider name
44+
model_name: The model name
45+
parameter_mapping: The parameter mapping dictionary
46+
"""
47+
key = (provider, model_name)
48+
_llm_parameter_mappings[key] = parameter_mapping
49+
log.debug("Registered parameter mapping for %s/%s", provider, model_name)
50+
51+
52+
def get_llm_parameter_mapping(
53+
provider: str, model_name: str
54+
) -> Optional[Dict[str, Optional[str]]]:
55+
"""Get the registered parameter mapping for a provider and model combination.
56+
57+
Args:
58+
provider: The LLM provider name
59+
model_name: The model name
60+
61+
Returns:
62+
The parameter mapping if registered, None otherwise
63+
"""
64+
return _llm_parameter_mappings.get((provider, model_name))
65+
66+
67+
def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]:
68+
"""Infer provider name from the LLM's module path.
69+
70+
This function extracts the provider name from LangChain package naming conventions:
71+
- langchain_openai -> openai
72+
- langchain_anthropic -> anthropic
73+
- langchain_google_genai -> google_genai
74+
- langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints
75+
- langchain_community.chat_models.ollama -> ollama
76+
77+
Args:
78+
llm: The LLM instance
79+
80+
Returns:
81+
The inferred provider name, or None if it cannot be determined
82+
"""
83+
module = type(llm).__module__
84+
85+
if module.startswith("langchain_"):
86+
package = module.split(".")[0]
87+
provider = package.replace("langchain_", "")
88+
89+
if provider == "community":
90+
parts = module.split(".")
91+
if len(parts) >= 3:
92+
provider = parts[-1]
93+
log.debug(
94+
"Inferred provider '%s' from community module %s", provider, module
95+
)
96+
return provider
97+
else:
98+
log.debug("Inferred provider '%s' from module %s", provider, module)
99+
return provider
100+
101+
log.debug("Could not infer provider from module %s", module)
102+
return None
103+
104+
105+
def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
106+
"""Get the provider name for an LLM instance by inferring from module path.
107+
108+
This function extracts the provider name from LangChain package naming conventions.
109+
See _infer_provider_from_module for details on the inference logic.
110+
111+
Args:
112+
llm: The LLM instance
113+
114+
Returns:
115+
The provider name if it can be inferred, None otherwise
116+
"""
117+
return _infer_provider_from_module(llm)
118+
119+
120+
def transform_llm_params(
121+
llm_params: Dict[str, Any],
122+
provider: Optional[str] = None,
123+
model_name: Optional[str] = None,
124+
parameter_mapping: Optional[Dict[str, Optional[str]]] = None,
125+
) -> Dict[str, Any]:
126+
"""Transform LLM parameters using provider-specific or custom mappings.
127+
128+
Args:
129+
llm_params: The original parameters dictionary
130+
provider: Optional provider name
131+
model_name: Optional model name
132+
parameter_mapping: Custom mapping dictionary. If None, uses built-in provider mappings.
133+
Key is the internal parameter name, value is the provider parameter name.
134+
If value is None, the parameter is dropped.
135+
136+
Returns:
137+
Transformed parameters dictionary
138+
"""
139+
if not llm_params:
140+
return llm_params
141+
142+
if parameter_mapping is not None:
143+
return _apply_mapping(llm_params, parameter_mapping)
144+
145+
has_instance_mapping = (provider, model_name) in _llm_parameter_mappings
146+
has_builtin_mapping = provider in PROVIDER_PARAMETER_MAPPINGS
147+
148+
if not has_instance_mapping and not has_builtin_mapping:
149+
return llm_params
150+
151+
mapping = None
152+
if has_instance_mapping:
153+
mapping = _llm_parameter_mappings.get((provider, model_name))
154+
log.debug("Using registered parameter mapping for %s/%s", provider, model_name)
155+
if not mapping and has_builtin_mapping:
156+
mapping = PROVIDER_PARAMETER_MAPPINGS[provider]
157+
log.debug("Using built-in parameter mapping for provider: %s", provider)
158+
159+
return _apply_mapping(llm_params, mapping) if mapping else llm_params
160+
161+
162+
def _apply_mapping(
163+
llm_params: Dict[str, Any], mapping: Dict[str, Optional[str]]
164+
) -> Dict[str, Any]:
165+
"""Apply parameter mapping transformation.
166+
167+
Args:
168+
llm_params: The original parameters dictionary
169+
mapping: The parameter mapping dictionary
170+
171+
Returns:
172+
Transformed parameters dictionary
173+
"""
174+
transformed_params = {}
175+
176+
for param_name, param_value in llm_params.items():
177+
if param_name in mapping:
178+
mapped_name = mapping[param_name]
179+
if mapped_name is not None:
180+
transformed_params[mapped_name] = param_value
181+
log.debug("Mapped parameter %s -> %s", param_name, mapped_name)
182+
else:
183+
log.debug("Dropped parameter %s", param_name)
184+
else:
185+
transformed_params[param_name] = param_value
186+
187+
return transformed_params

nemoguardrails/rails/llm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class Model(BaseModel):
123123
description="Configuration parameters for reasoning LLMs.",
124124
)
125125
parameters: Dict[str, Any] = Field(default_factory=dict)
126+
parameter_mapping: Optional[Dict[str, Optional[str]]] = Field(
127+
default=None,
128+
description="Optional parameter mapping to transform parameter names for provider-specific requirements. "
129+
"Keys are internal parameter names, values are provider parameter names. "
130+
"Set value to null to drop a parameter.",
131+
)
126132

127133
mode: Literal["chat", "text"] = Field(
128134
default="chat",

nemoguardrails/rails/llm/llmrails.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ModelInitializationError,
7575
init_llm_model,
7676
)
77+
from nemoguardrails.llm.parameter_mapping import register_llm_parameter_mapping
7778
from nemoguardrails.logging.explain import ExplainInfo
7879
from nemoguardrails.logging.processing_log import compute_generation_log
7980
from nemoguardrails.logging.stats import LLMStats
@@ -443,11 +444,21 @@ def _init_llms(self):
443444
if self.llm:
444445
# If an LLM was provided via constructor, use it as the main LLM
445446
# Log a warning if a main LLM is also specified in the config
446-
if any(model.type == "main" for model in self.config.models):
447+
main_model = next(
448+
(model for model in self.config.models if model.type == "main"), None
449+
)
450+
if main_model:
447451
log.warning(
448452
"Both an LLM was provided via constructor and a main LLM is specified in the config. "
449453
"The LLM provided via constructor will be used and the main LLM from config will be ignored."
450454
)
455+
# Still register parameter mapping from config if available
456+
if main_model.parameter_mapping and main_model.model:
457+
register_llm_parameter_mapping(
458+
main_model.engine,
459+
main_model.model,
460+
main_model.parameter_mapping,
461+
)
451462
self.runtime.register_action_param("llm", self.llm)
452463

453464
self._configure_main_llm_streaming(self.llm)
@@ -465,6 +476,12 @@ def _init_llms(self):
465476
mode="chat",
466477
kwargs=kwargs,
467478
)
479+
if main_model.parameter_mapping and main_model.model:
480+
register_llm_parameter_mapping(
481+
main_model.engine,
482+
main_model.model,
483+
main_model.parameter_mapping,
484+
)
468485
self.runtime.register_action_param("llm", self.llm)
469486

470487
self._configure_main_llm_streaming(
@@ -500,6 +517,12 @@ def _init_llms(self):
500517
kwargs=kwargs,
501518
)
502519

520+
if llm_config.parameter_mapping and llm_config.model:
521+
register_llm_parameter_mapping(
522+
llm_config.engine,
523+
llm_config.model,
524+
llm_config.parameter_mapping,
525+
)
503526
if llm_config.type == "main":
504527
# If a main LLM was already injected, skip creating another
505528
# one. Otherwise, create and register it.

0 commit comments

Comments
 (0)