Skip to content

Commit 0ec4dc6

Browse files
mudlersiddimore
authored andcommitted
feat(multimodal): allow to template placeholders (mudler#3728)
feat(multimodal): allow to template image placeholders Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 16dfee9 commit 0ec4dc6

File tree

5 files changed

+66
-4
lines changed

5 files changed

+66
-4
lines changed

core/config/backend_config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ type TemplateConfig struct {
196196
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
197197
// It defaults to \n
198198
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`
199+
200+
Video string `yaml:"video"`
201+
Image string `yaml:"image"`
202+
Audio string `yaml:"audio"`
199203
}
200204

201205
func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {

core/http/endpoints/openai/request.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/mudler/LocalAI/core/schema"
1313
"github.com/mudler/LocalAI/pkg/functions"
1414
"github.com/mudler/LocalAI/pkg/model"
15+
"github.com/mudler/LocalAI/pkg/templates"
1516
"github.com/mudler/LocalAI/pkg/utils"
1617
"github.com/rs/zerolog/log"
1718
)
@@ -168,8 +169,13 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
168169
continue CONTENT
169170
}
170171
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
172+
173+
t := "[vid-{{.ID}}]{{.Text}}"
174+
if config.TemplateConfig.Video != "" {
175+
t = config.TemplateConfig.Video
176+
}
171177
// set a placeholder for each image
172-
input.Messages[i].StringContent = fmt.Sprintf("[vid-%d]", vidIndex) + input.Messages[i].StringContent
178+
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
173179
vidIndex++
174180
case "audio_url", "audio":
175181
// Decode content as base64 either if it's an URL or base64 text
@@ -180,7 +186,11 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
180186
}
181187
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
182188
// set a placeholder for each image
183-
input.Messages[i].StringContent = fmt.Sprintf("[audio-%d]", audioIndex) + input.Messages[i].StringContent
189+
t := "[audio-{{.ID}}]{{.Text}}"
190+
if config.TemplateConfig.Audio != "" {
191+
t = config.TemplateConfig.Audio
192+
}
193+
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
184194
audioIndex++
185195
case "image_url", "image":
186196
// Decode content as base64 either if it's an URL or base64 text
@@ -189,9 +199,14 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
189199
log.Error().Msgf("Failed encoding image: %s", err)
190200
continue CONTENT
191201
}
202+
203+
t := "[img-{{.ID}}]{{.Text}}"
204+
if config.TemplateConfig.Image != "" {
205+
t = config.TemplateConfig.Image
206+
}
192207
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
193208
// set a placeholder for each image
194-
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", imgIndex) + input.Messages[i].StringContent
209+
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)
195210
imgIndex++
196211
}
197212
}

pkg/model/initializers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
314314

315315
client = NewModel(modelID, serverAddress, process)
316316
} else {
317-
log.Debug().Msg("external backend is uri")
317+
log.Debug().Msg("external backend is a uri")
318318
// address
319319
client = NewModel(modelID, uri, nil)
320320
}

pkg/templates/multimodal.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package templates
2+
3+
import (
4+
"bytes"
5+
"text/template"
6+
)
7+
8+
func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
9+
// compile the template
10+
tmpl, err := template.New("template").Parse(templateString)
11+
if err != nil {
12+
return "", err
13+
}
14+
result := bytes.NewBuffer(nil)
15+
// execute the template
16+
err = tmpl.Execute(result, struct {
17+
ID int
18+
Text string
19+
}{
20+
ID: templateID,
21+
Text: text,
22+
})
23+
return result.String(), err
24+
}

pkg/templates/multimodal_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package templates_test
2+
3+
import (
4+
. "github.com/mudler/LocalAI/pkg/templates" // Update with your module path
5+
6+
// Update with your module path
7+
. "github.com/onsi/ginkgo/v2"
8+
. "github.com/onsi/gomega"
9+
)
10+
11+
var _ = Describe("EvaluateTemplate", func() {
12+
Context("templating simple strings for multimodal chat", func() {
13+
It("should template messages correctly", func() {
14+
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
15+
Expect(err).NotTo(HaveOccurred())
16+
Expect(result).To(Equal("[img-1]bar"))
17+
})
18+
})
19+
})

0 commit comments

Comments
 (0)