diff --git a/loader/include.go b/loader/include.go index 823c2f7a..3e49b8d8 100644 --- a/loader/include.go +++ b/loader/include.go @@ -177,6 +177,9 @@ func importResources(source map[string]any, target map[string]any) error { if err := importResource(source, target, "configs"); err != nil { return err } + if err := importResource(source, target, "models"); err != nil { + return err + } return nil } diff --git a/override/merge.go b/override/merge.go index 6fae6e5f..26b36469 100644 --- a/override/merge.go +++ b/override/merge.go @@ -39,6 +39,7 @@ type merger func(any, any, tree.Path) (any, error) var mergeSpecials = map[tree.Path]merger{} func init() { + mergeSpecials["models.*.runtime_flags"] = override mergeSpecials["networks.*.ipam.config"] = mergeIPAMConfig mergeSpecials["networks.*.labels"] = mergeToSequence mergeSpecials["volumes.*.labels"] = mergeToSequence @@ -160,11 +161,53 @@ func mergeDependsOn(c any, o any, path tree.Path) (any, error) { } func mergeModels(c any, o any, path tree.Path) (any, error) { + // Check if both sides are string arrays for short syntax only + if rightArr, ok := c.([]any); ok { + if leftArr, ok := o.([]any); ok { + if isStringArray(rightArr) && isStringArray(leftArr) { + return mergeStringArrays(rightArr, leftArr), nil + } + } + } + + // Otherwise, use map merge for long syntax or mixed syntax right := convertIntoMapping(c, nil) left := convertIntoMapping(o, nil) return mergeMappings(right, left, path) } +func isStringArray(arr []any) bool { + for _, item := range arr { + if _, ok := item.(string); !ok { + return false + } + } + return true +} + +func mergeStringArrays(right, left []any) []any { + seen := make(map[string]bool) + var result []any + + for _, item := range right { + str := item.(string) + if !seen[str] { + result = append(result, str) + seen[str] = true + } + } + + for _, item := range left { + str := item.(string) + if !seen[str] { + result = append(result, str) + seen[str] = true + } + } + + return result +} + func mergeNetworks(c any, o any, path tree.Path) (any, error) { right := convertIntoMapping(c, nil) left := convertIntoMapping(o, nil) diff --git a/override/merge_models_test.go b/override/merge_models_test.go new file mode 100644 index 00000000..8c694a73 --- /dev/null +++ b/override/merge_models_test.go @@ -0,0 +1,337 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package override + +import ( + "testing" +) + +func Test_mergeYamlServiceModelsShortSyntax(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo + models: + - llm + - embedding-model +`, ` +services: + test: + models: + - vision-model +`, ` +services: + test: + image: foo + models: + - llm + - embedding-model + - vision-model +`) +} + +func Test_mergeYamlServiceModelsLongSyntax(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo + models: + llm: + endpoint_var: AI_MODEL_URL + model_var: AI_MODEL_NAME +`, ` +services: + test: + models: + embedding-model: + endpoint_var: EMBEDDING_URL + model_var: EMBEDDING_MODEL +`, ` +services: + test: + image: foo + models: + llm: + endpoint_var: AI_MODEL_URL + model_var: AI_MODEL_NAME + embedding-model: + endpoint_var: EMBEDDING_URL + model_var: EMBEDDING_MODEL +`) +} + +func Test_mergeYamlServiceModelsMixed(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo + models: + - llm + - embedding-model +`, ` +services: + test: + models: + vision-model: + endpoint_var: VISION_URL + model_var: VISION_MODEL +`, ` +services: + test: + image: foo + models: + llm: + embedding-model: + vision-model: + endpoint_var: VISION_URL + model_var: VISION_MODEL +`) +} + +func Test_mergeYamlServiceModelsOverride(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo + models: + llm: + endpoint_var: OLD_MODEL_URL + model_var: OLD_MODEL_NAME +`, ` +services: + test: + models: + llm: + endpoint_var: NEW_MODEL_URL + model_var: NEW_MODEL_NAME +`, ` +services: + test: + image: foo + models: + llm: + endpoint_var: NEW_MODEL_URL + model_var: NEW_MODEL_NAME +`) +} + +func Test_mergeYamlTopLevelModels(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + context_size: 2048 + runtime_flags: + - "--gpu" +`, ` +services: + test: + image: foo +models: + embedding-model: + model: ai/all-minilm + context_size: 512 + runtime_flags: + - "--cpu" +`, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + context_size: 2048 + runtime_flags: + - "--gpu" + embedding-model: + model: ai/all-minilm + context_size: 512 + runtime_flags: + - "--cpu" +`) +} + +func Test_mergeYamlTopLevelModelsOverride(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + context_size: 2048 + runtime_flags: + - "--gpu" +`, ` +services: + test: + image: foo +models: + llm: + model: ai/gpt-4 + context_size: 8192 + runtime_flags: + - "--gpu" + - "--fp16" +`, ` +services: + test: + image: foo +models: + llm: + model: ai/gpt-4 + context_size: 8192 + runtime_flags: + - "--gpu" + - "--fp16" +`) +} + +func Test_mergeYamlModelsCompleteScenario(t *testing.T) { + assertMergeYaml(t, ` +services: + app: + image: myapp + models: + - llm + worker: + image: worker + models: + embedding-model: + endpoint_var: EMBEDDING_URL +models: + llm: + model: ai/smollm2 + context_size: 2048 + embedding-model: + model: ai/all-minilm + context_size: 512 +`, ` +services: + app: + models: + - vision-model + worker: + models: + llm: + endpoint_var: LLM_URL + model_var: LLM_NAME +models: + vision-model: + model: ai/clip + context_size: 1024 + llm: + model: ai/gpt-4 + context_size: 8192 +`, ` +services: + app: + image: myapp + models: + - llm + - vision-model + worker: + image: worker + models: + embedding-model: + endpoint_var: EMBEDDING_URL + llm: + endpoint_var: LLM_URL + model_var: LLM_NAME +models: + llm: + model: ai/gpt-4 + context_size: 8192 + embedding-model: + model: ai/all-minilm + context_size: 512 + vision-model: + model: ai/clip + context_size: 1024 +`) +} + +func Test_mergeYamlModelsRuntimeFlagsMerge(t *testing.T) { + assertMergeYaml(t, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + runtime_flags: + - "--gpu" + - "--batch-size=32" +`, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + runtime_flags: + - "--fp16" + - "--batch-size=64" +`, ` +services: + test: + image: foo +models: + llm: + model: ai/smollm2 + runtime_flags: + - "--fp16" + - "--batch-size=64" +`) +} + +func Test_mergeYamlModelsMultipleServices(t *testing.T) { + assertMergeYaml(t, ` +services: + go-genai: + models: + - llm +models: + llm: + model: ai/smollm2 + context_size: 2048 +`, ` +services: + node-genai: + models: + - llm +models: + llm: + model: ai/smollm2 + context_size: 2048 +`, ` +services: + go-genai: + models: + - llm + node-genai: + models: + - llm +models: + llm: + model: ai/smollm2 + context_size: 2048 +`) +} diff --git a/transform/canonical.go b/transform/canonical.go index d0525f02..9fd1a338 100644 --- a/transform/canonical.go +++ b/transform/canonical.go @@ -38,8 +38,8 @@ func init() { transformers["services.*.label_file"] = transformStringOrList transformers["services.*.extends"] = transformExtends transformers["services.*.gpus"] = transformGpus - transformers["services.*.networks"] = transformStringSliceToMap - transformers["services.*.models"] = transformStringSliceToMap + transformers["services.*.networks"] = transformStringSliceToMap(nil) + transformers["services.*.models"] = transformStringSliceToMap(map[string]any{}) transformers["services.*.volumes.*"] = transformVolumeMount transformers["services.*.dns"] = transformStringOrList transformers["services.*.devices.*"] = transformDeviceMapping diff --git a/transform/services.go b/transform/services.go index d9df42c8..4cb643e6 100644 --- a/transform/services.go +++ b/transform/services.go @@ -29,13 +29,15 @@ func transformService(data any, p tree.Path, ignoreParseError bool) (any, error) } } -func transformStringSliceToMap(data any, _ tree.Path, _ bool) (any, error) { - if slice, ok := data.([]any); ok { - mapping := make(map[string]any, len(slice)) - for _, net := range slice { - mapping[net.(string)] = nil +func transformStringSliceToMap(defaultValue any) func(data any, _ tree.Path, _ bool) (any, error) { + return func(data any, _ tree.Path, _ bool) (any, error) { + if slice, ok := data.([]any); ok { + mapping := make(map[string]any, len(slice)) + for _, item := range slice { + mapping[item.(string)] = defaultValue + } + return mapping, nil } - return mapping, nil + return data, nil } - return data, nil }