|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Module for transforming LLM parameters between internal and provider-specific formats.""" |
| 17 | + |
| 18 | +import logging |
| 19 | +from typing import Any, Dict, Optional |
| 20 | + |
| 21 | +from langchain.base_language import BaseLanguageModel |
| 22 | + |
| 23 | +log = logging.getLogger(__name__) |
| 24 | + |
| 25 | +_llm_parameter_mappings = {} |
| 26 | + |
| 27 | +PROVIDER_PARAMETER_MAPPINGS = { |
| 28 | + "huggingface": { |
| 29 | + "max_tokens": "max_new_tokens", |
| 30 | + }, |
| 31 | + "google_vertexai": { |
| 32 | + "max_tokens": "max_output_tokens", |
| 33 | + }, |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +def register_llm_parameter_mapping( |
| 38 | + provider: str, model_name: str, parameter_mapping: Dict[str, Optional[str]] |
| 39 | +) -> None: |
| 40 | + """Register a parameter mapping for a specific provider and model combination. |
| 41 | +
|
| 42 | + Args: |
| 43 | + provider: The LLM provider name |
| 44 | + model_name: The model name |
| 45 | + parameter_mapping: The parameter mapping dictionary |
| 46 | + """ |
| 47 | + key = (provider, model_name) |
| 48 | + _llm_parameter_mappings[key] = parameter_mapping |
| 49 | + log.debug("Registered parameter mapping for %s/%s", provider, model_name) |
| 50 | + |
| 51 | + |
| 52 | +def get_llm_parameter_mapping( |
| 53 | + provider: str, model_name: str |
| 54 | +) -> Optional[Dict[str, Optional[str]]]: |
| 55 | + """Get the registered parameter mapping for a provider and model combination. |
| 56 | +
|
| 57 | + Args: |
| 58 | + provider: The LLM provider name |
| 59 | + model_name: The model name |
| 60 | +
|
| 61 | + Returns: |
| 62 | + The parameter mapping if registered, None otherwise |
| 63 | + """ |
| 64 | + return _llm_parameter_mappings.get((provider, model_name)) |
| 65 | + |
| 66 | + |
| 67 | +def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]: |
| 68 | + """Infer provider name from the LLM's module path. |
| 69 | +
|
| 70 | + This function extracts the provider name from LangChain package naming conventions: |
| 71 | + - langchain_openai -> openai |
| 72 | + - langchain_anthropic -> anthropic |
| 73 | + - langchain_google_genai -> google_genai |
| 74 | + - langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints |
| 75 | + - langchain_community.chat_models.ollama -> ollama |
| 76 | +
|
| 77 | + Args: |
| 78 | + llm: The LLM instance |
| 79 | +
|
| 80 | + Returns: |
| 81 | + The inferred provider name, or None if it cannot be determined |
| 82 | + """ |
| 83 | + module = type(llm).__module__ |
| 84 | + |
| 85 | + if module.startswith("langchain_"): |
| 86 | + package = module.split(".")[0] |
| 87 | + provider = package.replace("langchain_", "") |
| 88 | + |
| 89 | + if provider == "community": |
| 90 | + parts = module.split(".") |
| 91 | + if len(parts) >= 3: |
| 92 | + provider = parts[-1] |
| 93 | + log.debug( |
| 94 | + "Inferred provider '%s' from community module %s", provider, module |
| 95 | + ) |
| 96 | + return provider |
| 97 | + else: |
| 98 | + log.debug("Inferred provider '%s' from module %s", provider, module) |
| 99 | + return provider |
| 100 | + |
| 101 | + log.debug("Could not infer provider from module %s", module) |
| 102 | + return None |
| 103 | + |
| 104 | + |
| 105 | +def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]: |
| 106 | + """Get the provider name for an LLM instance by inferring from module path. |
| 107 | +
|
| 108 | + This function extracts the provider name from LangChain package naming conventions. |
| 109 | + See _infer_provider_from_module for details on the inference logic. |
| 110 | +
|
| 111 | + Args: |
| 112 | + llm: The LLM instance |
| 113 | +
|
| 114 | + Returns: |
| 115 | + The provider name if it can be inferred, None otherwise |
| 116 | + """ |
| 117 | + return _infer_provider_from_module(llm) |
| 118 | + |
| 119 | + |
| 120 | +def transform_llm_params( |
| 121 | + llm_params: Dict[str, Any], |
| 122 | + provider: Optional[str] = None, |
| 123 | + model_name: Optional[str] = None, |
| 124 | + parameter_mapping: Optional[Dict[str, Optional[str]]] = None, |
| 125 | +) -> Dict[str, Any]: |
| 126 | + """Transform LLM parameters using provider-specific or custom mappings. |
| 127 | +
|
| 128 | + Args: |
| 129 | + llm_params: The original parameters dictionary |
| 130 | + provider: Optional provider name |
| 131 | + model_name: Optional model name |
| 132 | + parameter_mapping: Custom mapping dictionary. If None, uses built-in provider mappings. |
| 133 | + Key is the internal parameter name, value is the provider parameter name. |
| 134 | + If value is None, the parameter is dropped. |
| 135 | +
|
| 136 | + Returns: |
| 137 | + Transformed parameters dictionary |
| 138 | + """ |
| 139 | + if not llm_params: |
| 140 | + return llm_params |
| 141 | + |
| 142 | + if parameter_mapping is not None: |
| 143 | + return _apply_mapping(llm_params, parameter_mapping) |
| 144 | + |
| 145 | + has_instance_mapping = (provider, model_name) in _llm_parameter_mappings |
| 146 | + has_builtin_mapping = provider in PROVIDER_PARAMETER_MAPPINGS |
| 147 | + |
| 148 | + if not has_instance_mapping and not has_builtin_mapping: |
| 149 | + return llm_params |
| 150 | + |
| 151 | + mapping = None |
| 152 | + if has_instance_mapping: |
| 153 | + mapping = _llm_parameter_mappings.get((provider, model_name)) |
| 154 | + log.debug("Using registered parameter mapping for %s/%s", provider, model_name) |
| 155 | + if not mapping and has_builtin_mapping: |
| 156 | + mapping = PROVIDER_PARAMETER_MAPPINGS[provider] |
| 157 | + log.debug("Using built-in parameter mapping for provider: %s", provider) |
| 158 | + |
| 159 | + return _apply_mapping(llm_params, mapping) if mapping else llm_params |
| 160 | + |
| 161 | + |
| 162 | +def _apply_mapping( |
| 163 | + llm_params: Dict[str, Any], mapping: Dict[str, Optional[str]] |
| 164 | +) -> Dict[str, Any]: |
| 165 | + """Apply parameter mapping transformation. |
| 166 | +
|
| 167 | + Args: |
| 168 | + llm_params: The original parameters dictionary |
| 169 | + mapping: The parameter mapping dictionary |
| 170 | +
|
| 171 | + Returns: |
| 172 | + Transformed parameters dictionary |
| 173 | + """ |
| 174 | + transformed_params = {} |
| 175 | + |
| 176 | + for param_name, param_value in llm_params.items(): |
| 177 | + if param_name in mapping: |
| 178 | + mapped_name = mapping[param_name] |
| 179 | + if mapped_name is not None: |
| 180 | + transformed_params[mapped_name] = param_value |
| 181 | + log.debug("Mapped parameter %s -> %s", param_name, mapped_name) |
| 182 | + else: |
| 183 | + log.debug("Dropped parameter %s", param_name) |
| 184 | + else: |
| 185 | + transformed_params[param_name] = param_value |
| 186 | + |
| 187 | + return transformed_params |
0 commit comments