diff --git a/src/app/(dashboard)/dashboard/endpoint/EndpointPageClient.js b/src/app/(dashboard)/dashboard/endpoint/EndpointPageClient.js index 836a0ecf..f0a69129 100644 --- a/src/app/(dashboard)/dashboard/endpoint/EndpointPageClient.js +++ b/src/app/(dashboard)/dashboard/endpoint/EndpointPageClient.js @@ -2,7 +2,7 @@ import { useState, useEffect } from "react"; import PropTypes from "prop-types"; -import { Card, Button, Input, Modal, CardSkeleton, Toggle } from "@/shared/components"; +import { Card, Button, Input, Modal, CardSkeleton, Toggle, AllowedModelsInput } from "@/shared/components"; import { useCopyToClipboard } from "@/shared/hooks/useCopyToClipboard"; /* ========== CLOUD CODE — COMMENTED OUT (replaced by Tunnel) ========== @@ -24,7 +24,14 @@ export default function APIPageClient({ machineId }) { const [loading, setLoading] = useState(true); const [showAddModal, setShowAddModal] = useState(false); const [newKeyName, setNewKeyName] = useState(""); + const [newKeyAllowedModels, setNewKeyAllowedModels] = useState([]); const [createdKey, setCreatedKey] = useState(null); + const [editingKey, setEditingKey] = useState(null); + const [showEditModal, setShowEditModal] = useState(false); + const [editKeyName, setEditKeyName] = useState(""); + const [editKeyAllowedModels, setEditKeyAllowedModels] = useState([]); + const [editKeyIsActive, setEditKeyIsActive] = useState(true); + const [apiError, setApiError] = useState(null); /* ========== CLOUD STATE — COMMENTED OUT (replaced by Tunnel) ========== const [cloudEnabled, setCloudEnabled] = useState(false); @@ -328,11 +335,16 @@ export default function APIPageClient({ machineId }) { const handleCreateKey = async () => { if (!newKeyName.trim()) return; + setApiError(null); + try { const res = await fetch("/api/keys", { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ name: newKeyName }), + body: JSON.stringify({ + name: newKeyName, + allowedModels: newKeyAllowedModels.length > 0 ? newKeyAllowedModels : [], + }), }); const data = await res.json(); @@ -340,10 +352,14 @@ export default function APIPageClient({ machineId }) { setCreatedKey(data.key); await fetchData(); setNewKeyName(""); + setNewKeyAllowedModels([]); setShowAddModal(false); + } else { + setApiError(data.error || "Failed to create key"); } } catch (error) { console.log("Error creating key:", error); + setApiError("Failed to create key"); } }; @@ -381,6 +397,45 @@ export default function APIPageClient({ machineId }) { } }; + const handleEditKey = (key) => { + setEditingKey(key); + setEditKeyName(key.name); + setEditKeyAllowedModels(key.allowedModels || []); + setEditKeyIsActive(key.isActive ?? true); + setShowEditModal(true); + setApiError(null); + }; + + const handleUpdateKey = async () => { + if (!editKeyName.trim() || !editingKey) return; + + setApiError(null); + + try { + const res = await fetch(`/api/keys/${editingKey.id}`, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + name: editKeyName, + isActive: editKeyIsActive, + allowedModels: editKeyAllowedModels, + }), + }); + const data = await res.json(); + + if (res.ok) { + await fetchData(); + setShowEditModal(false); + setEditingKey(null); + } else { + setApiError(data.error || "Failed to update key"); + } + } catch (error) { + console.log("Error updating key:", error); + setApiError("Failed to update key"); + } + }; + const maskKey = (fullKey) => { if (!fullKey) return ""; return fullKey.length > 8 ? fullKey.slice(0, 8) + "..." : fullKey; @@ -395,6 +450,33 @@ export default function APIPageClient({ machineId }) { }); }; + const getRestrictionDisplay = (allowedModels) => { + if (!allowedModels || allowedModels.length === 0) { + return ( +
+ public + All Models +
+ ); + } + + // Show first 2 patterns + count + const display = allowedModels.slice(0, 2).join(", "); + const remaining = allowedModels.length - 2; + + return ( +
+ lock + + {display} + + {remaining > 0 && ( + +{remaining} more + )} +
+ ); + }; + const [baseUrl, setBaseUrl] = useState("/v1"); // Hydration fix: Only access window on client side @@ -549,6 +631,9 @@ export default function APIPageClient({ machineId }) { +
+ {getRestrictionDisplay(key.allowedModels)} +

Created {new Date(key.createdAt).toLocaleDateString()}

@@ -571,9 +656,17 @@ export default function APIPageClient({ machineId }) { }} title={key.isActive ? "Pause key" : "Resume key"} /> + @@ -595,7 +688,10 @@ export default function APIPageClient({ machineId }) { onClose={() => { setShowAddModal(false); setNewKeyName(""); + setNewKeyAllowedModels([]); + setApiError(null); }} + size="lg" >
setNewKeyName(e.target.value)} placeholder="Production Key" + required + /> + + + + {apiError && ( +
+ error + {apiError} +
+ )} +
+
+
+ + + {/* Edit Key Modal */} + { + setShowEditModal(false); + setEditingKey(null); + setApiError(null); + }} + size="lg" + > +
+ setEditKeyName(e.target.value)} + placeholder="Production Key" + required + /> + +
+ +
+

Active

+

+ {editKeyIsActive ? "Key is active and working" : "Key is paused and won't work"} +

+
+
+ + + + {apiError && ( +
+ error + {apiError} +
+ )} + +
+ + + ))} +
+
+ + {/* Help Section */} +
+

+ Pattern Examples: +

+ +
+ + ); +} diff --git a/src/shared/components/index.js b/src/shared/components/index.js index d46db7d1..9bbc54e9 100644 --- a/src/shared/components/index.js +++ b/src/shared/components/index.js @@ -23,6 +23,7 @@ export { default as KiroOAuthWrapper } from "./KiroOAuthWrapper"; export { default as KiroSocialOAuthModal } from "./KiroSocialOAuthModal"; export { default as CursorAuthModal } from "./CursorAuthModal"; export { default as SegmentedControl } from "./SegmentedControl"; +export { default as AllowedModelsInput } from "./AllowedModelsInput"; // Layouts export * from "./layouts"; diff --git a/src/shared/utils/apiKey.js b/src/shared/utils/apiKey.js index 046abceb..02be09a3 100644 --- a/src/shared/utils/apiKey.js +++ b/src/shared/utils/apiKey.js @@ -96,3 +96,6 @@ export function isNewFormatKey(apiKey) { return parsed?.isNewFormat === true; } +// Re-export model pattern matcher utilities +export { isModelAllowed, validateAllowedModelsFormat } from './model-pattern-matcher.js'; + diff --git a/src/shared/utils/model-pattern-matcher.js b/src/shared/utils/model-pattern-matcher.js new file mode 100644 index 00000000..f983e39e --- /dev/null +++ b/src/shared/utils/model-pattern-matcher.js @@ -0,0 +1,142 @@ +/** + * Model Pattern Matcher + * Validates if a model matches any pattern in an allowlist + */ + +/** + * Normalize model string to provider/model format + * @param {string} model - Model string (may be alias or full format) + * @returns {string} Normalized format "provider/model" + */ +function normalizeModel(model) { + if (!model || typeof model !== 'string') { + return ''; + } + + // Already in provider/model format + if (model.includes('/')) { + return model.toLowerCase().trim(); + } + + // Handle edge cases (single word = assume it's a model name) + return model.toLowerCase().trim(); +} + +/** + * Check if model matches a single pattern + * @param {string} model - Normalized model (provider/model) + * @param {string} pattern - Pattern to match against + * @returns {boolean} True if matches + */ +function matchesPattern(model, pattern) { + if (!model || !pattern) return false; + + const normalizedPattern = pattern.toLowerCase().trim(); + const normalizedModel = model.toLowerCase().trim(); + + // Exact match + if (normalizedPattern === normalizedModel) { + return true; + } + + // Provider wildcard (e.g., "gh/*") + if (normalizedPattern.endsWith('/*')) { + const provider = normalizedPattern.slice(0, -2); + const modelProvider = normalizedModel.split('/')[0]; + return provider === modelProvider; + } + + // Global wildcard + if (normalizedPattern === '*') { + return true; + } + + return false; +} + +/** + * Check if model is allowed by any pattern in allowlist + * @param {string} model - Model to check + * @param {string[]} allowedModels - Array of allowed patterns + * @returns {{ allowed: boolean, reason?: string }} + */ +export function isModelAllowed(model, allowedModels) { + // Empty or missing allowlist = unrestricted + if (!allowedModels || !Array.isArray(allowedModels) || allowedModels.length === 0) { + return { allowed: true }; + } + + const normalizedModel = normalizeModel(model); + + if (!normalizedModel) { + return { + allowed: false, + reason: 'Invalid model format' + }; + } + + // Check each pattern + for (const pattern of allowedModels) { + if (matchesPattern(normalizedModel, pattern)) { + return { allowed: true }; + } + } + + // No matches found + return { + allowed: false, + reason: `Model '${model}' not allowed. Allowed patterns: ${allowedModels.join(', ')}` + }; +} + +/** + * Validate allowedModels array format + * @param {any} allowedModels - Value to validate + * @returns {{ valid: boolean, error?: string }} + */ +export function validateAllowedModelsFormat(allowedModels) { + // Empty/null is valid (unrestricted) + if (!allowedModels) { + return { valid: true }; + } + + // Must be array + if (!Array.isArray(allowedModels)) { + return { + valid: false, + error: 'allowedModels must be an array' + }; + } + + // Check each element is a string + for (let i = 0; i < allowedModels.length; i++) { + const pattern = allowedModels[i]; + + if (typeof pattern !== 'string') { + return { + valid: false, + error: `Pattern at index ${i} must be a string` + }; + } + + if (pattern.trim().length === 0) { + return { + valid: false, + error: `Pattern at index ${i} cannot be empty` + }; + } + + // Basic format validation (optional, can be extended) + const normalized = pattern.toLowerCase().trim(); + + // Check for invalid characters (allow alphanumeric, /, -, _, ., *) + if (!/^[\w\-\.\/\*]+$/.test(normalized)) { + return { + valid: false, + error: `Pattern '${pattern}' contains invalid characters` + }; + } + } + + return { valid: true }; +} diff --git a/src/sse/handlers/chat.js b/src/sse/handlers/chat.js index 5a1bd039..2e30ae63 100644 --- a/src/sse/handlers/chat.js +++ b/src/sse/handlers/chat.js @@ -68,10 +68,17 @@ export async function handleChat(request, clientRawRequest = null) { log.warn("AUTH", "Missing API key (requireApiKey=true)"); return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Missing API key"); } - const valid = await isValidApiKey(apiKey); - if (!valid) { - log.warn("AUTH", "Invalid API key (requireApiKey=true)"); - return errorResponse(HTTP_STATUS.UNAUTHORIZED, "Invalid API key"); + + // Validate API key and check model allowlist + const validation = await isValidApiKey(apiKey, modelStr); + + if (!validation || (typeof validation === 'object' && !validation.valid)) { + const error = typeof validation === 'object' ? validation.error : 'Invalid API key'; + log.warn("AUTH", `API key validation failed: ${error}`); + + // Return 403 for model restrictions, 401 for invalid keys + const status = error.includes('not allowed') ? HTTP_STATUS.FORBIDDEN : HTTP_STATUS.UNAUTHORIZED; + return errorResponse(status, error); } } diff --git a/src/sse/services/auth.js b/src/sse/services/auth.js index fa9b7105..1b3dd18c 100644 --- a/src/sse/services/auth.js +++ b/src/sse/services/auth.js @@ -242,8 +242,19 @@ export function extractApiKey(request) { /** * Validate API key (optional - for local use can skip) + * @param {string} apiKey - API key to validate + * @param {string} [model] - Optional model to check against allowlist + * @returns {Promise<{ valid: boolean, error?: string, key?: object } | boolean>} */ -export async function isValidApiKey(apiKey) { +export async function isValidApiKey(apiKey, model = null) { if (!apiKey) return false; - return await validateApiKey(apiKey); + + // If model is provided, return full validation result + if (model) { + return await validateApiKey(apiKey, model); + } + + // Backward compatibility: return boolean + const result = await validateApiKey(apiKey); + return typeof result === 'boolean' ? result : result.valid; }