Skip to content

feat: refactor main_ds.py (1/n) Model class #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

cdoern
Copy link
Contributor

@cdoern cdoern commented May 27, 2025

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization into classes. This commit introduces the Model class

NOTE: a follow up to this work will be to introduce classes/structure for the DataLoader, Sampler, etc. This was left out of this PR given the already large scope of change.

The Model class wraps the various AutoModel classes we support -- and aims to be a lightweight wrapper to help with usability of the library with different model types. setup_optimizer resides within the model class and returns one of the optimizer types we support

These classes are one of a few steps needed to "SDK-ify" the training library

Adding structure to code via classes can either be someone's favorite or least favorite thing. So I figured I'd explain myself before continuing. Here is my rationale:

Classes provide logical structuring to code, especially code meant to be a publicly consumable SDK and allows you to associate related objects and methods with one another.

Being able to group functionality under the Model, Accelerator, and Checkpointer classes inherently reduces code complexity and duplication. Being able to store things like , self.distributed_framework,self.lora_config, etc in a way such that within the class they are accessible within different methods allows the arguments per method to go down drastically, as well as complex return values. Simpler methods and argument/return values allows for simpler testing of code.

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization
into classes. This commit introduces the Model class

NOTE: a follow up to this work will be to introduce classes/structure for the DataLoader, Sampler, etc. This was left out of this PR given the already large scope of change.

The Model class wraps the various AutoModel classes we support -- and aims to be a lightweight wrapper to help with usability of the library with different model types.
setup_optimizer resides within the model class and returns one of the optimizer types we support

These classes are one of a few steps needed to "SDK-ify" the training library

Adding structure to code via classes can either be someone's favorite or least favorite thing. So I figured I'd explain myself before continuing. Here is my rationale:

Classes provide logical structuring to code, especially code meant to be a publicly consumable SDK and allows you to associate related objects and methods with one another.

Being able to group functionality under the Model, Accelerator, and Checkpointer classes inherently reduces code complexity and duplication. Being able to store things like , self.distributed_framework,self.lora_config, etc in a way such that within the class they are accessible within different methods allows the arguments per method to go down drastically, as well as complex return values. Simpler methods and argument/return values allows for simpler testing of code.

Signed-off-by: Charlie Doern <[email protected]>
@mergify mergify bot added testing Relates to testing ci-failure and removed ci-failure labels May 27, 2025
class ModelTypes(Enum):
LIGER = "Liger"
CAUSALLM = "Causallm"
DOLOMITE = "Dolomite"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've dropped dolomite, no need to include this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RobotSail Interesting! What does it mean exactly? If I grep through the code, I still see hits for dolomite, including the mandatory dependency on instructlab-dolomite. Was some decision made to drop it? Should we clean these remnants from the tree then?

Copy link
Contributor

mergify bot commented May 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. @cdoern please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@booxter booxter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed tests or Accelerator class in detail. I need to step off this PR. Posting questions and concerns I have collected so far.

parser.add_argument(
"--model-class",
type=str,
default=ModelTypes.CAUSALLM.value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use choice=[x.value for x in enum] to avoid listing them below

@@ -141,6 +141,19 @@ class FSDPOptions(BaseModel):
sharding_strategy: ShardingStrategies = ShardingStrategies.HYBRID_SHARD


class Optimizers(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required, Observation) I think it's more common to call enums as singular, not plural. But it's a matter of habit of course.

# public API
class ModelTypes(Enum):
LIGER = "Liger"
CAUSALLM = "Causallm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use "correct" case? CausalLM?

from deepspeed.ops.adam import DeepSpeedCPUAdam
except ImportError:
DeepSpeedCPUAdam = None
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required) I know it was done in main_ds so you are not introducing anything new here, but consider not running code / issuing warnings when importing the module. An import should not, generally, produce side effects of this sort, especially in a library. Consider warning later when the missing class is actually referred to / used.

output_dir: str,
distributed_framework: DistributedBackend,
model_type: ModelTypes,
noise_alpha: Optional[float],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use type | None instead of Optional

)
self.model.config.eos_token_id = self.tokenizer.eos_token_id

if "ForCausalLM" not in self.model.__class__.__name__:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fragile; can you think of a more robust way of checking it? if not, maybe the Model class could have a helper method to hide the check?

from .utils import add_noisy_embeddings, convert_loss_to_reduce_sum

self.model = convert_loss_to_reduce_sum(
self.model, use_dolomite=(self.model_type == "dolomite")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect enum == str check

)
self.model = add_noisy_embeddings(self.model, noise_alpha=self.noise_alpha)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these two functions from utils.py then?

"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action required) Could be:

if ...:
     return True
if ...:
     return True
if ...:
     return True
return False

@@ -692,15 +503,35 @@ def main(args):
extra={"hparams": True},
)

model, lr_scheduler, optimizer, accelerator = setup_model(
args, tokenizer, train_loader, grad_accum, flash_enabled
accelerator = setup_accelerator(args, m, grad_accum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you mean to use the Accelerator class here? I think this belongs to setup_accelerator module, which is replaced by Accelerator class in this PR. (btw please also remove the old code.)

@cdoern
Copy link
Contributor Author

cdoern commented May 30, 2025

@booxter thanks for the review. I actually meant to remove Accelerator in this PR which is why there is a confusing non-usage of that class. I am intending to introduce it in a 2/n PR just for clarity.

In regard to most other comments, a lot of them are inherited from the existing code or mis-steps by me when splitting out my mega PR (I forgot to take my changes from utils.py for example). Will take another pass here. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants