Skip to content

[Feat] Support InternVL sft training in xtuner lite #1011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from

Conversation

C1rN09
Copy link

@C1rN09 C1rN09 commented Mar 21, 2025

[Feat] Enable SFT Training for InternVL Models in XTuner Lite

This PR introduces foundational support for Supervised Fine-Tuning (SFT) of InternVL models in XTuner Lite, implementing a basic parallel training strategy while maintaining compatibility with the latest XTuner development branch. The implementation prioritizes correctness and simplicity, establishing a baseline for future optimizations (e.g., advanced parallel strategies, PyTorch 2.0 compilation).

Implementation Highlights

  • Parallel Strategy Framework:
    • Language Model (LLaMA/Qwen): Utilizes existing FSDP2 + TP + SP + torch.compile paradigms
    • Vision Transformer (ViT): Implements pure FSDP2 with input chunking and output all-gather
  • System Integration:
    • Updated InternVL2 dataset adaption for XTuner development branch

Integration Example

internvl = AutoPatch.from_causal_lm(
    internvl,
    fsdp_config=FSDPConfig(
        tp_size=args.tp_size,
        sp_size=1  # SP parallelism currently constrained
    ),
)
internvl.fully_shard()

Compatibility Considerations

  1. API Modifications:

    • Added optional module2name and checkpoint_loader parameters to fully_shard
    • Maintains backward compatibility through default None values
  2. Embedding Layer Handling:

    • Automatic resizing to world_size multiples during training
    • Original dimensions restored pre-checkpointing
    • Note: Potential accuracy impact when original vocab size ≠ world_size multiple
  3. Embedding Layer Sharding:

    • Implemented separate FSDPParamGroup for language model embeddings
    • Observed accuracy degradation on specific settings (PyTorch 2.7 and inter-node comm, i.e. >= 2 nodes), but PyTorch 2.5.1 works fine
    • Further investigation required


vision_model = self.patched_model.vision_model
# compiled_layers: List[nn.Module] = []
for layer_idx, layer in enumerate(vision_model.encoder.layers):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉 vit 和 llm 可以有不同的切分规则和参数,但是现在只能用同一套fsdp_config,不是很友好

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了接口一致性,暂时使用了这种方式。但是 vit 的部分暂时是纯 FSDP 实现的

vision_model.embeddings.apply(param_init_fn)
self.patched_model.mlp1.apply(param_init_fn)

fully_shard(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vit 暂时是不支持 tp 是吧?

Copy link
Author

@C1rN09 C1rN09 Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不支持。有 3 个主要的阻碍点:

  1. PyTorch Convolution 层的 DTensor 实现有些问题,至少在 pytorch 2.5.1 还未修复
  2. InternVL attention 中的 qkv Liner 输出的布局是 (3, num_head, head_dim),直接使用 ColwiseParallel 对 qkv 做 Shard(1) 操作会导致输出不符合预期。这里需要对 attn 计算部分进行 patch 修改
  3. InternVL 的 attention 输入输出的 seq_len 是 patch_size^2 + 1,是一个奇数,不能被 tp_size 整除。如果不手动 padding,由于 PyTorch 的 bug会导致运行报错;如果手动 padding,需要添加很多的 patch function,把 cu_seq_lens 等参数一层一层地传进去

考虑到目前的 ViT 部分普遍比较小,并行策略带来的速度影响不大;同时为了保持这个 PR 的正确性和最简实现,因此 ViT 的部分只使用纯 FSDP 来实现。tp/sp 等特性可以后续再添加

return_dict: Optional[bool] = None,
tp_mesh: Optional[DeviceMesh] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持 sp

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不打算在此 PR 中支持。目前会在 forward 阶段 AssertionError 提醒用户

@C1rN09 C1rN09 force-pushed the xtuner-support-vl branch 4 times, most recently from 0862097 to 7fa885d Compare March 31, 2025 08:56
C1rN09 and others added 21 commits April 16, 2025 16:30
…hard` arguments

In some cases, model checkpoints are stored in an unofficial manner. For
example, VLM use language model as part of the backbone, but the
checkpoint load path and state key mapping has changed.
In order to reuse language model dispatch codes, we have to allow custom
checkpoint loading strategy
In many scenarios(e.g. VLM), language model's embedding layer is used
outside of its `forward` call. Better shard it in a seperate fsdp unit
When training internvl with fixed image size, position embedding
interpolation is usually unnecessary and should be bypassed.
Moreover, `F.interpolate` doesn't support DTensor ops as of torch 2.5.1
@C1rN09 C1rN09 force-pushed the xtuner-support-vl branch from 80014cb to 877a8d1 Compare April 16, 2025 08:40
@hhaAndroid hhaAndroid closed this Apr 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants