Skip to content

Commit 32854c0

Browse files
andreyvelichkannon92
authored andcommitted
KEP-2170: Create model and dataset initializers (kubeflow#2303)
* KEP-2170: Create model and dataset initializers Signed-off-by: Andrey Velichkevich <[email protected]> * Add abstract classes Signed-off-by: Andrey Velichkevich <[email protected]> * Add storage URI to config Signed-off-by: Andrey Velichkevich <[email protected]> * Update .gitignore Co-authored-by: Kevin Hannon <[email protected]> Signed-off-by: Andrey Velichkevich <[email protected]> * Fix the misspelling for initializer Signed-off-by: Andrey Velichkevich <[email protected]> * Add .pt and .pth to ignore_patterns Signed-off-by: Andrey Velichkevich <[email protected]> --------- Signed-off-by: Andrey Velichkevich <[email protected]> Co-authored-by: Kevin Hannon <[email protected]> Signed-off-by: sailesh duddupudi <[email protected]>
1 parent 82d0535 commit 32854c0

File tree

14 files changed

+246
-2
lines changed

14 files changed

+246
-2
lines changed

.github/workflows/publish-core-images.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ jobs:
3030
dockerfile: cmd/training-operator.v2alpha1/Dockerfile
3131
platforms: linux/amd64,linux/arm64,linux/ppc64le
3232
tag-prefix: v2alpha1
33+
- component-name: model-initializer-v2
34+
dockerfile: cmd/initializer_v2/model/Dockerfile
35+
platforms: linux/amd64,linux/arm64
36+
tag-prefix: v2
37+
- component-name: dataset-initializer-v2
38+
dockerfile: cmd/initializer_v2/dataset/Dockerfile
39+
platforms: linux/amd64,linux/arm64
40+
tag-prefix: v2
3341
- component-name: kubectl-delivery
3442
dockerfile: build/images/kubectl-delivery/Dockerfile
3543
platforms: linux/amd64,linux/arm64,linux/ppc64le

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ cover.out
1111
.vscode/
1212
__debug_bin
1313

14-
# Compiled python files.
15-
*.pyc
14+
# Python cache files
15+
__pycache__/
1616

1717
# Emacs temporary files
1818
*~

cmd/initializer_v2/dataset/Dockerfile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FROM python:3.11-alpine
2+
3+
WORKDIR /workspace
4+
5+
# Copy the required Python modules.
6+
COPY cmd/initializer_v2/dataset/requirements.txt .
7+
COPY sdk/python/kubeflow sdk/python/kubeflow
8+
COPY pkg/initializer_v2 pkg/initializer_v2
9+
10+
# Install the needed packages.
11+
RUN pip install -r requirements.txt
12+
13+
ENTRYPOINT ["python", "-m", "pkg.initializer_v2.dataset"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
huggingface_hub==0.23.4

cmd/initializer_v2/model/Dockerfile

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
FROM python:3.11-alpine
2+
3+
WORKDIR /workspace
4+
5+
# Copy the required Python modules.
6+
COPY cmd/initializer_v2/model/requirements.txt .
7+
COPY sdk/python/kubeflow sdk/python/kubeflow
8+
COPY pkg/initializer_v2 pkg/initializer_v2
9+
10+
# Install the needed packages.
11+
RUN pip install -r requirements.txt
12+
13+
ENTRYPOINT ["python", "-m", "pkg.initializer_v2.model"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
huggingface_hub==0.23.4
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
import os
3+
from urllib.parse import urlparse
4+
5+
import pkg.initializer_v2.utils.utils as utils
6+
from pkg.initializer_v2.dataset.huggingface import HuggingFace
7+
8+
logging.basicConfig(
9+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
10+
datefmt="%Y-%m-%dT%H:%M:%SZ",
11+
level=logging.INFO,
12+
)
13+
14+
if __name__ == "__main__":
15+
logging.info("Starting dataset initialization")
16+
17+
try:
18+
storage_uri = os.environ[utils.STORAGE_URI_ENV]
19+
except Exception as e:
20+
logging.error("STORAGE_URI env variable must be set.")
21+
raise e
22+
23+
match urlparse(storage_uri).scheme:
24+
# TODO (andreyvelich): Implement more dataset providers.
25+
case utils.HF_SCHEME:
26+
hf = HuggingFace()
27+
hf.load_config()
28+
hf.download_dataset()
29+
case _:
30+
logging.error("STORAGE_URI must have the valid dataset provider")
31+
raise Exception

pkg/initializer_v2/dataset/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
# TODO (andreyvelich): This should be moved under Training V2 SDK.
6+
@dataclass
7+
class HuggingFaceDatasetConfig:
8+
storage_uri: str
9+
access_token: Optional[str] = None
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import logging
2+
from urllib.parse import urlparse
3+
4+
import huggingface_hub
5+
6+
import pkg.initializer_v2.utils.utils as utils
7+
8+
# TODO (andreyvelich): This should be moved to SDK V2 constants.
9+
import sdk.python.kubeflow.storage_initializer.constants as constants
10+
from pkg.initializer_v2.dataset.config import HuggingFaceDatasetConfig
11+
12+
logging.basicConfig(
13+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
14+
datefmt="%Y-%m-%dT%H:%M:%SZ",
15+
level=logging.INFO,
16+
)
17+
18+
19+
class HuggingFace(utils.DatasetProvider):
20+
21+
def load_config(self):
22+
config_dict = utils.get_config_from_env(HuggingFaceDatasetConfig)
23+
logging.info(f"Config for HuggingFace dataset initializer: {config_dict}")
24+
self.config = HuggingFaceDatasetConfig(**config_dict)
25+
26+
def download_dataset(self):
27+
storage_uri_parsed = urlparse(self.config.storage_uri)
28+
dataset_uri = storage_uri_parsed.netloc + storage_uri_parsed.path
29+
30+
logging.info(f"Downloading dataset: {dataset_uri}")
31+
logging.info("-" * 40)
32+
33+
if self.config.access_token:
34+
huggingface_hub.login(self.config.access_token)
35+
36+
huggingface_hub.snapshot_download(
37+
repo_id=dataset_uri,
38+
repo_type="dataset",
39+
local_dir=constants.VOLUME_PATH_DATASET,
40+
)
41+
42+
logging.info("Dataset has been downloaded")

pkg/initializer_v2/model/__main__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import logging
2+
import os
3+
from urllib.parse import urlparse
4+
5+
import pkg.initializer_v2.utils.utils as utils
6+
from pkg.initializer_v2.model.huggingface import HuggingFace
7+
8+
logging.basicConfig(
9+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
10+
datefmt="%Y-%m-%dT%H:%M:%SZ",
11+
level=logging.INFO,
12+
)
13+
14+
if __name__ == "__main__":
15+
logging.info("Starting pre-trained model initialization")
16+
17+
try:
18+
storage_uri = os.environ[utils.STORAGE_URI_ENV]
19+
except Exception as e:
20+
logging.error("STORAGE_URI env variable must be set.")
21+
raise e
22+
23+
match urlparse(storage_uri).scheme:
24+
# TODO (andreyvelich): Implement more model providers.
25+
case utils.HF_SCHEME:
26+
hf = HuggingFace()
27+
hf.load_config()
28+
hf.download_model()
29+
case _:
30+
logging.error(
31+
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
32+
)
33+
raise Exception

pkg/initializer_v2/model/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
# TODO (andreyvelich): This should be moved under Training V2 SDK.
6+
@dataclass
7+
class HuggingFaceModelInputConfig:
8+
storage_uri: str
9+
access_token: Optional[str] = None
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
from urllib.parse import urlparse
3+
4+
import huggingface_hub
5+
6+
import pkg.initializer_v2.utils.utils as utils
7+
8+
# TODO (andreyvelich): This should be moved to SDK V2 constants.
9+
import sdk.python.kubeflow.storage_initializer.constants as constants
10+
from pkg.initializer_v2.model.config import HuggingFaceModelInputConfig
11+
12+
logging.basicConfig(
13+
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
14+
datefmt="%Y-%m-%dT%H:%M:%SZ",
15+
level=logging.INFO,
16+
)
17+
18+
19+
class HuggingFace(utils.ModelProvider):
20+
21+
def load_config(self):
22+
config_dict = utils.get_config_from_env(HuggingFaceModelInputConfig)
23+
logging.info(f"Config for HuggingFace model initializer: {config_dict}")
24+
self.config = HuggingFaceModelInputConfig(**config_dict)
25+
26+
def download_model(self):
27+
storage_uri_parsed = urlparse(self.config.storage_uri)
28+
model_uri = storage_uri_parsed.netloc + storage_uri_parsed.path
29+
30+
logging.info(f"Downloading model: {model_uri}")
31+
logging.info("-" * 40)
32+
33+
if self.config.access_token:
34+
huggingface_hub.login(self.config.access_token)
35+
36+
# TODO (andreyvelich): We should consider to follow vLLM approach with allow patterns.
37+
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815913663
38+
# TODO (andreyvelich): We should update patterns for Mistral model
39+
# Ref: https://github.com/kubeflow/training-operator/pull/2303#discussion_r1815914270
40+
huggingface_hub.snapshot_download(
41+
repo_id=model_uri,
42+
local_dir=constants.VOLUME_PATH_MODEL,
43+
allow_patterns=["*.json", "*.safetensors", "*.model"],
44+
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
45+
)
46+
47+
logging.info("Model has been downloaded")

pkg/initializer_v2/utils/__init__.py

Whitespace-only changes.

pkg/initializer_v2/utils/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
from dataclasses import fields
4+
from typing import Dict
5+
6+
STORAGE_URI_ENV = "STORAGE_URI"
7+
HF_SCHEME = "hf"
8+
9+
10+
class ModelProvider(ABC):
11+
@abstractmethod
12+
def load_config(self):
13+
raise NotImplementedError()
14+
15+
@abstractmethod
16+
def download_model(self):
17+
raise NotImplementedError()
18+
19+
20+
class DatasetProvider(ABC):
21+
@abstractmethod
22+
def load_config(self):
23+
raise NotImplementedError()
24+
25+
@abstractmethod
26+
def download_dataset(self):
27+
raise NotImplementedError()
28+
29+
30+
# Get DataClass config from the environment variables.
31+
# Env names must be equal to the DataClass parameters.
32+
def get_config_from_env(config) -> Dict[str, str]:
33+
config_from_env = {}
34+
for field in fields(config):
35+
config_from_env[field.name] = os.getenv(field.name.upper())
36+
37+
return config_from_env

0 commit comments

Comments
 (0)