Skip to content

Commit a1e2fad

Browse files
authored
Merge pull request #1435 from containers/multimodal
Don't use jinja in the multimodal case
2 parents e041cc8 + 98e15e4 commit a1e2fad

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

ramalama/model.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,16 @@ def setup_mounts(self, model_path, args):
322322
if self.store is not None:
323323
_, tag, _ = self.extract_model_identifiers()
324324
ref_file = self.store.get_ref_file(tag)
325-
if ref_file is not None and ref_file.chat_template_name != "":
326-
chat_template_path = self.store.get_snapshot_file_path(ref_file.hash, ref_file.chat_template_name)
327-
self.engine.add([f"--mount=type=bind,src={chat_template_path},destination={MNT_CHAT_TEMPLATE_FILE},ro"])
328-
if ref_file is not None and ref_file.mmproj_name != "":
329-
mmproj_path = self.store.get_snapshot_file_path(ref_file.hash, ref_file.mmproj_name)
330-
self.engine.add([f"--mount=type=bind,src={mmproj_path},destination={MNT_MMPROJ_FILE},ro"])
325+
if ref_file is not None:
326+
if ref_file.chat_template_name != "":
327+
chat_template_path = self.store.get_snapshot_file_path(ref_file.hash, ref_file.chat_template_name)
328+
self.engine.add(
329+
[f"--mount=type=bind,src={chat_template_path},destination={MNT_CHAT_TEMPLATE_FILE},ro"]
330+
)
331+
332+
if ref_file.mmproj_name != "":
333+
mmproj_path = self.store.get_snapshot_file_path(ref_file.hash, ref_file.mmproj_name)
334+
self.engine.add([f"--mount=type=bind,src={mmproj_path},destination={MNT_MMPROJ_FILE},ro"])
331335

332336
def handle_rag_mode(self, args, cmd_args):
333337
# force accel_image to use -rag version. Drop TAG if it exists
@@ -523,19 +527,23 @@ def build_exec_args_serve(self, args, exec_model_path, chat_template_path="", mm
523527
exec_args += ["llama-server", "--port", args.port, "--model", exec_model_path]
524528
if mmproj_path:
525529
exec_args += ["--mmproj", mmproj_path]
530+
else:
531+
exec_args += ["--jinja"]
532+
526533
exec_args += [
527534
"--alias",
528535
self.model,
529536
"--ctx-size",
530537
f"{args.context}",
531538
"--temp",
532539
f"{args.temp}",
533-
"--jinja",
534540
"--cache-reuse",
535541
"256",
536542
] + args.runtime_args
543+
537544
if draft_model_path:
538545
exec_args += ['--model_draft', draft_model_path]
546+
539547
# Placeholder for clustering, it might be kept for override
540548
rpc_nodes = os.getenv("RAMALAMA_LLAMACPP_RPC_NODES")
541549
if rpc_nodes:
@@ -609,35 +617,34 @@ def execute_command(self, model_path, exec_args, args):
609617
def serve(self, args, quiet=False):
610618
self.validate_args(args)
611619
args.port = compute_serving_port(args.port, args.debug, quiet)
612-
613620
model_path = self.get_model_path(args)
614621
if is_split_file_model(model_path):
615622
mnt_file = MNT_DIR + '/' + self.mnt_path
616623
else:
617624
mnt_file = MNT_FILE
618625

619626
exec_model_path = mnt_file if args.container or args.generate else model_path
620-
621627
chat_template_path = ""
622628
mmproj_path = ""
623629
if self.store is not None:
624630
_, tag, _ = self.extract_model_identifiers()
625631
ref_file = self.store.get_ref_file(tag)
626-
if ref_file is not None and ref_file.chat_template_name != "":
627-
chat_template_path = (
628-
MNT_CHAT_TEMPLATE_FILE
629-
if args.container or args.generate
630-
else self.store.get_snapshot_file_path(ref_file.hash, ref_file.chat_template_name)
631-
)
632-
if ref_file is not None and ref_file.mmproj_name != "":
633-
mmproj_path = (
634-
MNT_MMPROJ_FILE
635-
if args.container or args.generate
636-
else self.store.get_snapshot_file_path(ref_file.hash, ref_file.mmproj_name)
637-
)
632+
if ref_file is not None:
633+
if ref_file.chat_template_name != "":
634+
chat_template_path = (
635+
MNT_CHAT_TEMPLATE_FILE
636+
if args.container or args.generate
637+
else self.store.get_snapshot_file_path(ref_file.hash, ref_file.chat_template_name)
638+
)
639+
640+
if ref_file.mmproj_name != "":
641+
mmproj_path = (
642+
MNT_MMPROJ_FILE
643+
if args.container or args.generate
644+
else self.store.get_snapshot_file_path(ref_file.hash, ref_file.mmproj_name)
645+
)
638646

639647
exec_args = self.build_exec_args_serve(args, exec_model_path, chat_template_path, mmproj_path)
640-
641648
exec_args = self.handle_runtime(args, exec_args, exec_model_path)
642649
if self.generate_container_config(model_path, chat_template_path, args, exec_args):
643650
return

0 commit comments

Comments
 (0)