diff --git a/core/backend/llm.go b/core/backend/llm.go index cf8a4f861439..05151383354e 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "fmt" - "os" "regexp" + "slices" "strings" "sync" "unicode/utf8" @@ -14,6 +14,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/grpc/proto" @@ -34,15 +35,19 @@ type TokenUsage struct { TimingTokenGeneration float64 } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, cl *config.BackendConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery if o.AutoloadGalleries { // experimental - if _, err := os.Stat(modelFile); os.IsNotExist(err) { + modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS) + if err != nil { + return nil, err + } + if !slices.Contains(modelNames, c.Name) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, modelFile, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) + err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, c.Name, loader.ModelPath, o.BackendsPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) if err != nil { log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile) //return nil, err diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 684e0c4199b5..1e33c977b42d 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } responses <- initialMessage - ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { + ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, @@ -68,7 +68,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { result := "" - _, tokenUsage, _ := ComputeChoices(req, prompt, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + _, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s // TODO: Change generated BNF grammar to be compliant with the schema so we can // stream the result token by token here. @@ -92,7 +92,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } responses <- initialMessage - result, err := handleQuestion(config, req, ml, startupOptions, functionResults, result, prompt) + result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -383,7 +383,8 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat // no streaming mode default: - result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]schema.Choice) { + + tokenCallback := func(s string, c *[]schema.Choice) { if !shouldUseFn { // no function is called, just reply and use stop as finish reason *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) @@ -403,7 +404,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat switch { case noActionsToRun: - result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput) + result, err := handleQuestion(config, cl, input, ml, startupOptions, results, s, predInput) if err != nil { log.Error().Err(err).Msg("error handling question") return @@ -458,7 +459,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } } - }, nil) + } + + result, tokenUsage, err := ComputeChoices( + input, + predInput, + config, + cl, + startupOptions, + ml, + tokenCallback, + nil, + ) if err != nil { return err } @@ -489,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat } } -func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { +func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { if len(funcResults) == 0 && result != "" { log.Debug().Msgf("nothing function results but we had a message from the LLM") @@ -538,7 +550,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m audios = append(audios, m.StringAudios...) } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil) + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 79da6cd63359..9620d2c12cb2 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -31,7 +31,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e created := int(time.Now().Unix()) process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) { - ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { + tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, @@ -58,7 +58,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e responses <- resp return true - }) + } + ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) close(responses) } @@ -168,7 +169,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e } r, tokenUsage, err := ComputeChoices( - input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) }, nil) if err != nil { diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 7409e2915d80..c4cb5e19ca62 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -56,7 +56,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat log.Debug().Msgf("Template found, input modified to: %s", i) } - r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { + r, tokenUsage, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil) if err != nil { diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index 4b84773a0a42..c0deeb02e9b6 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -12,6 +12,7 @@ func ComputeChoices( req *schema.OpenAIRequest, predInput string, config *config.BackendConfig, + bcl *config.BackendConfigLoader, o *config.ApplicationConfig, loader *model.ModelLoader, cb func(string, *[]schema.Choice), @@ -37,7 +38,7 @@ func ComputeChoices( } // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback) + predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback) if err != nil { return result, backend.TokenUsage{}, err }