Skip to content

Commit b1fd7fe

Browse files
authored
server: more support for mixed-case model names (ollama#8017)
Fixes ollama#7944
1 parent 36d111e commit b1fd7fe

8 files changed

+123
-38
lines changed

cmd/cmd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
601601
var data [][]string
602602

603603
for _, m := range models.Models {
604-
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
604+
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
605605
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
606606
}
607607
}

server/images.go

+4
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
376376
switch command {
377377
case "model", "adapter":
378378
if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
379+
name, err := getExistingName(name)
380+
if err != nil {
381+
return err
382+
}
379383
baseLayers, err = parseFromModel(ctx, name, fn)
380384
if err != nil {
381385
return err

server/modelpath.go

+11-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ package server
33
import (
44
"errors"
55
"fmt"
6+
"io/fs"
67
"net/url"
78
"os"
89
"path/filepath"
910
"regexp"
1011
"strings"
1112

1213
"github.com/ollama/ollama/envconfig"
14+
"github.com/ollama/ollama/types/model"
1315
)
1416

1517
type ModelPath struct {
@@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
9395

9496
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
9597
func (mp ModelPath) GetManifestPath() (string, error) {
96-
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
97-
return filepath.Join(envconfig.Models(), "manifests", p), nil
98+
name := model.Name{
99+
Host: mp.Registry,
100+
Namespace: mp.Namespace,
101+
Model: mp.Repository,
102+
Tag: mp.Tag,
98103
}
99-
100-
return "", errModelPathInvalid
104+
if !name.IsValid() {
105+
return "", fs.ErrNotExist
106+
}
107+
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
101108
}
102109

103110
func (mp ModelPath) BaseURL() *url.URL {

server/modelpath_test.go

-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package server
22

33
import (
4-
"errors"
54
"os"
65
"path/filepath"
76
"testing"
@@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
155154
})
156155
}
157156
}
158-
159-
func TestInsecureModelpath(t *testing.T) {
160-
mp := ParseModelPath("../../..:something")
161-
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
162-
t.Errorf("expected error: %v", err)
163-
}
164-
}

server/routes.go

+90-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"errors"
1010
"fmt"
1111
"io"
12+
"io/fs"
1213
"log/slog"
1314
"math"
1415
"net"
@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
120121
return
121122
}
122123

123-
model, err := GetModel(req.Model)
124+
name := model.ParseName(req.Model)
125+
if !name.IsValid() {
126+
// Ideally this is "invalid model name" but we're keeping with
127+
// what the API currently returns until we can change it.
128+
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
129+
return
130+
}
131+
132+
// We cannot currently consolidate this into GetModel because all we'll
133+
// induce infinite recursion given the current code structure.
134+
name, err := getExistingName(name)
135+
if err != nil {
136+
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
137+
return
138+
}
139+
140+
model, err := GetModel(name.String())
124141
if err != nil {
125142
switch {
126-
case os.IsNotExist(err):
143+
case errors.Is(err, fs.ErrNotExist):
127144
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
128145
case err.Error() == "invalid model name":
129146
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
157174
caps = append(caps, CapabilityInsert)
158175
}
159176

160-
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
177+
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
161178
if errors.Is(err, errCapabilityCompletion) {
162179
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
163180
return
@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
386403
}
387404
}
388405

389-
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
406+
name, err := getExistingName(model.ParseName(req.Model))
407+
if err != nil {
408+
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
409+
return
410+
}
411+
412+
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
390413
if err != nil {
391414
handleScheduleError(c, req.Model, err)
392415
return
@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
489512
return
490513
}
491514

492-
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
515+
name := model.ParseName(req.Model)
516+
if !name.IsValid() {
517+
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
518+
return
519+
}
520+
521+
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
493522
if err != nil {
494523
handleScheduleError(c, req.Model, err)
495524
return
@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
582611
return
583612
}
584613

585-
var model string
614+
var mname string
586615
if req.Model != "" {
587-
model = req.Model
616+
mname = req.Model
588617
} else if req.Name != "" {
589-
model = req.Name
618+
mname = req.Name
590619
} else {
591620
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
592621
return
@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
606635
ctx, cancel := context.WithCancel(c.Request.Context())
607636
defer cancel()
608637

609-
if err := PushModel(ctx, model, regOpts, fn); err != nil {
638+
name, err := getExistingName(model.ParseName(mname))
639+
if err != nil {
640+
ch <- gin.H{"error": err.Error()}
641+
return
642+
}
643+
644+
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
610645
ch <- gin.H{"error": err.Error()}
611646
}
612647
}()
@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
619654
streamResponse(c, ch)
620655
}
621656

622-
// getExistingName returns the original, on disk name if the input name is a
623-
// case-insensitive match, otherwise it returns the input name.
657+
// getExistingName searches the models directory for the longest prefix match of
658+
// the input name and returns the input name with all existing parts replaced
659+
// with each part found. If no parts are found, the input name is returned as
660+
// is.
624661
func getExistingName(n model.Name) (model.Name, error) {
625662
var zero model.Name
626663
existing, err := Manifests(true)
627664
if err != nil {
628665
return zero, err
629666
}
667+
var set model.Name // tracks parts already canonicalized
630668
for e := range existing {
631-
if n.EqualFold(e) {
632-
return e, nil
669+
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
670+
n.Host = e.Host
671+
}
672+
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
673+
n.Namespace = e.Namespace
674+
}
675+
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
676+
n.Model = e.Model
677+
}
678+
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
679+
n.Tag = e.Tag
633680
}
634681
}
635682
return n, nil
@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
658705
}
659706

660707
if r.Path == "" && r.Modelfile == "" {
661-
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
708+
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
662709
return
663710
}
664711

@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
722769
return
723770
}
724771

772+
n, err := getExistingName(n)
773+
if err != nil {
774+
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
775+
return
776+
}
777+
725778
m, err := ParseNamedManifest(n)
726779
if err != nil {
727780
switch {
@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
782835
}
783836

784837
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
785-
m, err := GetModel(req.Model)
838+
name := model.ParseName(req.Model)
839+
if !name.IsValid() {
840+
return nil, errModelPathInvalid
841+
}
842+
name, err := getExistingName(name)
843+
if err != nil {
844+
return nil, err
845+
}
846+
847+
m, err := GetModel(name.String())
786848
if err != nil {
787849
return nil, err
788850
}
@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
805867
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
806868
}
807869

808-
n := model.ParseName(req.Model)
809-
if !n.IsValid() {
810-
return nil, errors.New("invalid model name")
811-
}
812-
813-
manifest, err := ParseNamedManifest(n)
870+
manifest, err := ParseNamedManifest(name)
814871
if err != nil {
815872
return nil, err
816873
}
@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
14311488
caps = append(caps, CapabilityTools)
14321489
}
14331490

1434-
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
1491+
name := model.ParseName(req.Model)
1492+
if !name.IsValid() {
1493+
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
1494+
return
1495+
}
1496+
name, err := getExistingName(name)
1497+
if err != nil {
1498+
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
1499+
return
1500+
}
1501+
1502+
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
14351503
if errors.Is(err, errCapabilityCompletion) {
14361504
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
14371505
return

server/routes_generate_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
719719
t.Errorf("expected status 400, got %d", w.Code)
720720
}
721721

722-
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
722+
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
723723
t.Errorf("mismatch (-got +want):\n%s", diff)
724724
}
725725
})

server/routes_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
514514

515515
wantStableName := name()
516516

517+
t.Logf("stable name: %s", wantStableName)
518+
517519
// checkManifestList tests that there is strictly one manifest in the
518520
// models directory, and that the manifest is for the model under test.
519521
checkManifestList := func() {
@@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
601603
Destination: name(),
602604
}))
603605
checkManifestList()
606+
607+
t.Logf("pushing")
608+
rr := createRequest(t, s.PushHandler, api.PushRequest{
609+
Model: name(),
610+
Insecure: true,
611+
Username: "alice",
612+
Password: "x",
613+
})
614+
checkOK(rr)
615+
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
616+
t.Errorf("got = %q, want success", rr.Body.String())
617+
}
604618
}
605619

606620
func TestShow(t *testing.T) {

types/model/name.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,12 @@ func (n Name) String() string {
223223
func (n Name) DisplayShortest() string {
224224
var sb strings.Builder
225225

226-
if n.Host != defaultHost {
226+
if !strings.EqualFold(n.Host, defaultHost) {
227227
sb.WriteString(n.Host)
228228
sb.WriteByte('/')
229229
sb.WriteString(n.Namespace)
230230
sb.WriteByte('/')
231-
} else if n.Namespace != defaultNamespace {
231+
} else if !strings.EqualFold(n.Namespace, defaultNamespace) {
232232
sb.WriteString(n.Namespace)
233233
sb.WriteByte('/')
234234
}

0 commit comments

Comments
 (0)