Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
- Use the uv created .venv when trying to run python code or python tests
- Use semantic versioning commit messages
- After you make non trival changes, run ruff linting, then ruff formating, then the tests
- Always use semantic commit messages.
- To build the frontend, always run from project root: `uv run python align_browser/build.py experiment-data/test-experiments --output-dir align_browser/static --build-only`

## Testing

Expand Down
55 changes: 53 additions & 2 deletions align_browser/csv_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def format_kdma_config(kdma_values: List[Dict[str, Any]]) -> str:
return ",".join(kdma_strings)


def extract_scene_id(item: InputOutputItem) -> str:
"""Extract the scene_id from an input/output item."""
if not item.input.full_state:
return ""

meta_info = item.input.full_state.get("meta_info", {})
return meta_info.get("scene_id", "")


def extract_choice_text(item: InputOutputItem) -> str:
"""Extract the human-readable choice text from an input/output item."""
if not item.output or "choice" not in item.output:
Expand Down Expand Up @@ -55,7 +64,16 @@ def extract_choice_kdma(item: InputOutputItem) -> str:
return ""

selected_choice = choices[choice_index]
return selected_choice.get("kdma_association", "")
kdma_association = selected_choice.get("kdma_association", "")

# Format the KDMA association dictionary as a string
if isinstance(kdma_association, dict) and kdma_association:
kdma_strings = []
for kdma_name, value in kdma_association.items():
kdma_strings.append(f"{kdma_name}:{value}")
return ",".join(kdma_strings)

return str(kdma_association) if kdma_association else ""


def extract_justification(item: InputOutputItem) -> str:
Expand All @@ -67,6 +85,35 @@ def extract_justification(item: InputOutputItem) -> str:
return action.get("justification", "")


def extract_choice_info(item: InputOutputItem) -> str:
"""Extract the choice_info as a JSON string, truncating ICL examples."""
if not item.choice_info:
return ""

import json

# Create a copy to avoid modifying the original
filtered_choice_info = {}

for key, value in item.choice_info.items():
if key == "icl_example_responses" and isinstance(value, dict):
# Keep only first example for each KDMA, truncate the rest
truncated_icl = {}
for kdma, examples in value.items():
if isinstance(examples, list) and len(examples) > 0:
# Keep first example, replace rest with "truncated"
truncated_icl[kdma] = [examples[0]]
if len(examples) > 1:
truncated_icl[kdma].append("truncated")
else:
truncated_icl[kdma] = examples
filtered_choice_info[key] = truncated_icl
else:
filtered_choice_info[key] = value

return json.dumps(filtered_choice_info, separators=(",", ":"))


def get_decision_time(
timing_data: Optional[Dict[str, Any]], item_index: int
) -> Optional[float]:
Expand Down Expand Up @@ -151,11 +198,13 @@ def experiment_to_csv_rows(
"kdma_config": kdma_config,
"alignment_target_id": alignment_target_id,
"scenario_id": item.input.scenario_id,
"scene_id": extract_scene_id(item),
"state_description": item.input.state
if hasattr(item.input, "state")
else "",
"choice_text": extract_choice_text(item),
"choice_kdma_association": extract_choice_kdma(item),
"choice_info": extract_choice_info(item),
"justification": extract_justification(item),
"decision_time_s": get_decision_time(timing_data, idx),
"score": get_score(scores_data, idx),
Expand Down Expand Up @@ -185,9 +234,11 @@ def write_experiments_to_csv(
"kdma_config",
"alignment_target_id",
"scenario_id",
"scene_id",
"state_description",
"choice_text",
"choice_kdma_association",
"choice_text",
"choice_info",
"justification",
"decision_time_s",
"score",
Expand Down
2 changes: 2 additions & 0 deletions align_browser/experiment_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class InputOutputItem(BaseModel):

input: InputData
output: Optional[Dict[str, Any]] = None
choice_info: Optional[Dict[str, Any]] = None
label: Optional[List[Dict[str, Any]]] = None
original_index: int # Index in the original file


Expand Down
91 changes: 49 additions & 42 deletions align_browser/static/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import {
getParameterValueFromRun,
isParameterLinked,
isResultParameter,
PARAMETER_PRIORITY_ORDER,
API_TO_APP,
APP_TO_API
} from './state.js';

import {
Expand All @@ -24,23 +27,38 @@ import {
getValidKDMAsForRun
} from './table-formatter.js';

// Constants
const KDMA_SLIDER_DEBOUNCE_MS = 500;

// Generic function to preserve linked parameters after validation
function preserveLinkedParameters(validatedParams, originalParams, appState) {
const preserved = { ...validatedParams };
// Takes snake_case params from API and returns mixed camelCase/snake_case for internal use
function preserveLinkedParameters(validatedSnakeParams, originalCamelParams, appState, changedParamCamel) {
// Start with the validated params from API (in snake_case)
const preserved = { ...validatedSnakeParams };

// Iterate through all possible linked parameters
const linkableParams = ['scenario', 'scene', 'admType', 'llmBackbone', 'kdmaValues', 'runVariant'];
for (const paramName of linkableParams) {
if (isParameterLinked(paramName, appState)) {
// Preserve the original value for linked parameters
preserved[paramName] = originalParams[paramName];
// Convert changed param to snake_case for priority checking
const changedParamSnake = APP_TO_API[changedParamCamel] || changedParamCamel;
const changedIndex = PARAMETER_PRIORITY_ORDER.indexOf(changedParamSnake);

// Only preserve linked parameters that are HIGHER priority than the changed parameter
PARAMETER_PRIORITY_ORDER.forEach((snakeParam, paramIndex) => {
const camelParam = API_TO_APP[snakeParam];

if (camelParam && isParameterLinked(camelParam, appState)) {
// Only preserve if this parameter has higher priority (lower index) than changed parameter
// OR if it's the same parameter (to prevent it from being reset)
if (paramIndex <= changedIndex) {
// Get the original value using the camelCase key from originalCamelParams
preserved[snakeParam] = originalCamelParams[camelParam];
}
// Lower priority linked parameters will use the validated values
}
}
});

return preserved;
// Convert back to camelCase format for internal use
const result = {};
Object.entries(APP_TO_API).forEach(([camelKey, snakeKey]) => {
result[camelKey] = preserved[snakeKey];
});
return result;
}

// CSV Download functionality
Expand Down Expand Up @@ -175,18 +193,22 @@ document.addEventListener("DOMContentLoaded", () => {
const validParams = result.params;
const validOptions = result.options;

// For propagated updates of linked parameters, use raw values without validation
// For propagated updates of linked parameters, we need to validate and cascade lower priority params
if (isPropagatedUpdate && isParameterLinked(paramType, appState)) {
// Use the raw params values (don't validate for propagated linked parameters)
columnParameters.set(runId, createParameterStructure(params));
// Use the special preserveLinkedParameters that respects priority hierarchy
// Pass the API params directly (snake_case) and original params (camelCase)
const finalParams = preserveLinkedParameters(validParams, params, appState, paramType);

// Store the parameters with proper cascading
columnParameters.set(runId, createParameterStructure(finalParams));

// Update the actual run state with raw values
run.scenario = params.scenario;
run.scene = params.scene;
run.admType = params.admType;
run.llmBackbone = params.llmBackbone;
run.runVariant = params.runVariant;
run.kdmaValues = params.kdmaValues;
// Update the actual run state
run.scenario = finalParams.scenario;
run.scene = finalParams.scene;
run.admType = finalParams.admType;
run.llmBackbone = finalParams.llmBackbone;
run.runVariant = finalParams.runVariant;
run.kdmaValues = finalParams.kdmaValues;

// Store the updated available options for UI dropdowns
run.availableOptions = {
Expand All @@ -200,26 +222,16 @@ document.addEventListener("DOMContentLoaded", () => {
}
};

return params; // Return the raw params for propagated linked parameters
return finalParams; // Return the params with proper cascading
}

// For direct user updates (including on linked parameters), always validate for proper cascading
// This ensures the source column gets valid parameter combinations

// For unlinked parameters, use validated parameters
const kdmaValues = validParams.kdma_values || {};

const correctedParams = {
scenario: validParams.scenario,
scene: validParams.scene,
admType: validParams.adm,
llmBackbone: validParams.llm,
kdmaValues: kdmaValues,
runVariant: validParams.run_variant
};

// Preserve any linked parameters - they should not be changed by validation
const finalParams = preserveLinkedParameters(correctedParams, params, appState);
// Pass the changed parameter so we know which linked params to preserve vs cascade
// Pass the API params directly (snake_case) and original params (camelCase)
const finalParams = preserveLinkedParameters(validParams, params, appState, paramType);

// Store corrected parameters
columnParameters.set(runId, createParameterStructure(finalParams));
Expand Down Expand Up @@ -700,12 +712,7 @@ document.addEventListener("DOMContentLoaded", () => {
// Update link state visual indicators
const linkIcon = row.querySelector('.link-icon');
// Map snake_case parameter names to camelCase for link checking
const linkParamName = {
'adm_type': 'admType',
'llm_backbone': 'llmBackbone',
'run_variant': 'runVariant',
'kdma_values': 'kdmaValues'
}[paramName] || paramName;
const linkParamName = API_TO_APP[paramName] || paramName;

if (isParameterLinked(linkParamName, appState)) {
row.classList.add('linked');
Expand Down
76 changes: 34 additions & 42 deletions align_browser/static/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@

import { showWarning } from './notifications.js';

// Centralized parameter mappings
export const PARAMETER_MAPPINGS = {
// Maps validation API names back to internal names
API_TO_INTERNAL: {
'scenario': 'scenario',
'scene': 'scene',
'adm': 'admType',
'llm': 'llmBackbone',
'kdma_values': 'kdmaValues',
'run_variant': 'runVariant'
}
export const API_TO_APP = {
'scenario': 'scenario',
'scene': 'scene',
'adm': 'admType',
'llm': 'llmBackbone',
'kdma_values': 'kdmaValues',
'run_variant': 'runVariant'
};

export const APP_TO_API = {
'scenario': 'scenario',
'scene': 'scene',
'admType': 'adm',
'llmBackbone': 'llm',
'kdmaValues': 'kdma_values',
'runVariant': 'run_variant'
};

// Priority order for parameter cascading
export const PARAMETER_PRIORITY_ORDER = ['scenario', 'scene', 'kdma_values', 'adm', 'llm', 'run_variant'];

// Constants for KDMA processing
const KDMA_CONSTANTS = {
DECIMAL_PRECISION: 10, // For 1 decimal place normalization
Expand Down Expand Up @@ -230,15 +238,6 @@ export function decodeStateFromURL() {
return null;
}

// Configuration for parameter validation system
const PARAMETER_CONFIG = {
// Priority order for parameter cascading
PRIORITY_ORDER: ['scenario', 'scene', 'kdma_values', 'adm', 'llm', 'run_variant'],

// Parameters that require special handling
SPECIAL_COMPARISON_PARAMS: new Set(['kdma_values'])
};

// Parameter update system with priority-based cascading
const updateParametersBase = (priorityOrder) => (manifest) => (currentParams, changes) => {
const newParams = { ...currentParams, ...changes };
Expand All @@ -260,15 +259,13 @@ const updateParametersBase = (priorityOrder) => (manifest) => (currentParams, ch

// Only check constraint if the current selection has a non-null value for this parameter
if (currentSelection[param] !== null && currentSelection[param] !== undefined) {
// Special handling for parameters that need custom comparison
if (PARAMETER_CONFIG.SPECIAL_COMPARISON_PARAMS.has(param)) {
if (param === 'kdma_values') {
const manifestKdmas = manifestEntry[param];
const selectionKdmas = currentSelection[param];

if (!KDMAUtils.deepEqual(manifestKdmas, selectionKdmas)) {
return false;
}
// Special handling for kdma_values which needs deep comparison
if (param === 'kdma_values') {
const manifestKdmas = manifestEntry[param];
const selectionKdmas = currentSelection[param];

if (!KDMAUtils.deepEqual(manifestKdmas, selectionKdmas)) {
return false;
}
} else if (manifestEntry[param] !== currentSelection[param]) {
return false;
Expand Down Expand Up @@ -310,13 +307,11 @@ const updateParametersBase = (priorityOrder) => (manifest) => (currentParams, ch
// Only change if current value is invalid
let isValid = validOptions.includes(currentValue);

// For special parameters, use custom comparison logic
if (PARAMETER_CONFIG.SPECIAL_COMPARISON_PARAMS.has(param) && !isValid) {
if (param === 'kdma_values') {
isValid = validOptions.some(option => {
return KDMAUtils.deepEqual(option, currentValue);
});
}
// For kdma_values, use custom comparison logic
if (param === 'kdma_values' && !isValid) {
isValid = validOptions.some(option => {
return KDMAUtils.deepEqual(option, currentValue);
});
}

if (!isValid) {
Expand All @@ -338,7 +333,7 @@ const updateParametersBase = (priorityOrder) => (manifest) => (currentParams, ch
};

// Export updateParameters with priority order already curried
export const updateParameters = updateParametersBase(PARAMETER_CONFIG.PRIORITY_ORDER);
export const updateParameters = updateParametersBase(PARAMETER_PRIORITY_ORDER);

export function toggleParameterLink(paramName, appState, callbacks) {
if (appState.linkedParameters.has(paramName)) {
Expand All @@ -350,11 +345,8 @@ export function toggleParameterLink(paramName, appState, callbacks) {
appState.linkedParameters.add(paramName);
// When enabling link, propagate the leftmost column's value
const firstRun = Array.from(appState.pinnedRuns.values())[0];
let propagationResult = null;
if (firstRun) {
const currentValue = getParameterValueFromRun(firstRun, paramName);
propagationResult = propagateParameterToAllRuns(paramName, currentValue, firstRun.id, appState, callbacks);
}
const currentValue = getParameterValueFromRun(firstRun, paramName);
const propagationResult = propagateParameterToAllRuns(paramName, currentValue, firstRun.id, appState, callbacks);
callbacks.renderTable();
callbacks.updateURL();
return propagationResult;
Expand Down
Loading
Loading