9
9
"errors"
10
10
"fmt"
11
11
"io"
12
+ "io/fs"
12
13
"log/slog"
13
14
"math"
14
15
"net"
@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
120
121
return
121
122
}
122
123
123
- model , err := GetModel (req .Model )
124
+ name := model .ParseName (req .Model )
125
+ if ! name .IsValid () {
126
+ // Ideally this is "invalid model name" but we're keeping with
127
+ // what the API currently returns until we can change it.
128
+ c .JSON (http .StatusNotFound , gin.H {"error" : fmt .Sprintf ("model '%s' not found" , req .Model )})
129
+ return
130
+ }
131
+
132
+ // We cannot currently consolidate this into GetModel because all we'll
133
+ // induce infinite recursion given the current code structure.
134
+ name , err := getExistingName (name )
135
+ if err != nil {
136
+ c .JSON (http .StatusNotFound , gin.H {"error" : fmt .Sprintf ("model '%s' not found" , req .Model )})
137
+ return
138
+ }
139
+
140
+ model , err := GetModel (name .String ())
124
141
if err != nil {
125
142
switch {
126
- case os . IsNotExist (err ):
143
+ case errors . Is (err , fs . ErrNotExist ):
127
144
c .JSON (http .StatusNotFound , gin.H {"error" : fmt .Sprintf ("model '%s' not found" , req .Model )})
128
145
case err .Error () == "invalid model name" :
129
146
c .JSON (http .StatusBadRequest , gin.H {"error" : err .Error ()})
@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
157
174
caps = append (caps , CapabilityInsert )
158
175
}
159
176
160
- r , m , opts , err := s .scheduleRunner (c .Request .Context (), req . Model , caps , req .Options , req .KeepAlive )
177
+ r , m , opts , err := s .scheduleRunner (c .Request .Context (), name . String () , caps , req .Options , req .KeepAlive )
161
178
if errors .Is (err , errCapabilityCompletion ) {
162
179
c .JSON (http .StatusBadRequest , gin.H {"error" : fmt .Sprintf ("%q does not support generate" , req .Model )})
163
180
return
@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
386
403
}
387
404
}
388
405
389
- r , m , opts , err := s .scheduleRunner (c .Request .Context (), req .Model , []Capability {}, req .Options , req .KeepAlive )
406
+ name , err := getExistingName (model .ParseName (req .Model ))
407
+ if err != nil {
408
+ c .JSON (http .StatusNotFound , gin.H {"error" : fmt .Sprintf ("model '%s' not found" , req .Model )})
409
+ return
410
+ }
411
+
412
+ r , m , opts , err := s .scheduleRunner (c .Request .Context (), name .String (), []Capability {}, req .Options , req .KeepAlive )
390
413
if err != nil {
391
414
handleScheduleError (c , req .Model , err )
392
415
return
@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
489
512
return
490
513
}
491
514
492
- r , _ , _ , err := s .scheduleRunner (c .Request .Context (), req .Model , []Capability {}, req .Options , req .KeepAlive )
515
+ name := model .ParseName (req .Model )
516
+ if ! name .IsValid () {
517
+ c .JSON (http .StatusBadRequest , gin.H {"error" : "model is required" })
518
+ return
519
+ }
520
+
521
+ r , _ , _ , err := s .scheduleRunner (c .Request .Context (), name .String (), []Capability {}, req .Options , req .KeepAlive )
493
522
if err != nil {
494
523
handleScheduleError (c , req .Model , err )
495
524
return
@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
582
611
return
583
612
}
584
613
585
- var model string
614
+ var mname string
586
615
if req .Model != "" {
587
- model = req .Model
616
+ mname = req .Model
588
617
} else if req .Name != "" {
589
- model = req .Name
618
+ mname = req .Name
590
619
} else {
591
620
c .AbortWithStatusJSON (http .StatusBadRequest , gin.H {"error" : "model is required" })
592
621
return
@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
606
635
ctx , cancel := context .WithCancel (c .Request .Context ())
607
636
defer cancel ()
608
637
609
- if err := PushModel (ctx , model , regOpts , fn ); err != nil {
638
+ name , err := getExistingName (model .ParseName (mname ))
639
+ if err != nil {
640
+ ch <- gin.H {"error" : err .Error ()}
641
+ return
642
+ }
643
+
644
+ if err := PushModel (ctx , name .DisplayShortest (), regOpts , fn ); err != nil {
610
645
ch <- gin.H {"error" : err .Error ()}
611
646
}
612
647
}()
@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
619
654
streamResponse (c , ch )
620
655
}
621
656
622
- // getExistingName returns the original, on disk name if the input name is a
623
- // case-insensitive match, otherwise it returns the input name.
657
+ // getExistingName searches the models directory for the longest prefix match of
658
+ // the input name and returns the input name with all existing parts replaced
659
+ // with each part found. If no parts are found, the input name is returned as
660
+ // is.
624
661
func getExistingName (n model.Name ) (model.Name , error ) {
625
662
var zero model.Name
626
663
existing , err := Manifests (true )
627
664
if err != nil {
628
665
return zero , err
629
666
}
667
+ var set model.Name // tracks parts already canonicalized
630
668
for e := range existing {
631
- if n .EqualFold (e ) {
632
- return e , nil
669
+ if set .Host == "" && strings .EqualFold (e .Host , n .Host ) {
670
+ n .Host = e .Host
671
+ }
672
+ if set .Namespace == "" && strings .EqualFold (e .Namespace , n .Namespace ) {
673
+ n .Namespace = e .Namespace
674
+ }
675
+ if set .Model == "" && strings .EqualFold (e .Model , n .Model ) {
676
+ n .Model = e .Model
677
+ }
678
+ if set .Tag == "" && strings .EqualFold (e .Tag , n .Tag ) {
679
+ n .Tag = e .Tag
633
680
}
634
681
}
635
682
return n , nil
@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
658
705
}
659
706
660
707
if r .Path == "" && r .Modelfile == "" {
661
- c .AbortWithStatusJSON (http .StatusBadRequest , gin.H {"error" : "path or modelfile are required" })
708
+ c .AbortWithStatusJSON (http .StatusBadRequest , gin.H {"error" : "path or Modelfile are required" })
662
709
return
663
710
}
664
711
@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
722
769
return
723
770
}
724
771
772
+ n , err := getExistingName (n )
773
+ if err != nil {
774
+ c .JSON (http .StatusNotFound , gin.H {"error" : fmt .Sprintf ("model '%s' not found" , cmp .Or (r .Model , r .Name ))})
775
+ return
776
+ }
777
+
725
778
m , err := ParseNamedManifest (n )
726
779
if err != nil {
727
780
switch {
@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
782
835
}
783
836
784
837
func GetModelInfo (req api.ShowRequest ) (* api.ShowResponse , error ) {
785
- m , err := GetModel (req .Model )
838
+ name := model .ParseName (req .Model )
839
+ if ! name .IsValid () {
840
+ return nil , errModelPathInvalid
841
+ }
842
+ name , err := getExistingName (name )
843
+ if err != nil {
844
+ return nil , err
845
+ }
846
+
847
+ m , err := GetModel (name .String ())
786
848
if err != nil {
787
849
return nil , err
788
850
}
@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
805
867
msgs [i ] = api.Message {Role : msg .Role , Content : msg .Content }
806
868
}
807
869
808
- n := model .ParseName (req .Model )
809
- if ! n .IsValid () {
810
- return nil , errors .New ("invalid model name" )
811
- }
812
-
813
- manifest , err := ParseNamedManifest (n )
870
+ manifest , err := ParseNamedManifest (name )
814
871
if err != nil {
815
872
return nil , err
816
873
}
@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
1431
1488
caps = append (caps , CapabilityTools )
1432
1489
}
1433
1490
1434
- r , m , opts , err := s .scheduleRunner (c .Request .Context (), req .Model , caps , req .Options , req .KeepAlive )
1491
+ name := model .ParseName (req .Model )
1492
+ if ! name .IsValid () {
1493
+ c .JSON (http .StatusBadRequest , gin.H {"error" : "model is required" })
1494
+ return
1495
+ }
1496
+ name , err := getExistingName (name )
1497
+ if err != nil {
1498
+ c .JSON (http .StatusBadRequest , gin.H {"error" : "model is required" })
1499
+ return
1500
+ }
1501
+
1502
+ r , m , opts , err := s .scheduleRunner (c .Request .Context (), name .String (), caps , req .Options , req .KeepAlive )
1435
1503
if errors .Is (err , errCapabilityCompletion ) {
1436
1504
c .JSON (http .StatusBadRequest , gin.H {"error" : fmt .Sprintf ("%q does not support chat" , req .Model )})
1437
1505
return
0 commit comments