Skip to content

Support for Jina Code model #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 21, 2024
Merged

Conversation

patricebechard
Copy link
Contributor

What does this PR do?

We implement the modified architecture for Jina models found here

This allows us to run inference for jinaai/jina-embeddings-v2-base-code, a long context code embeddings model.

IMPORTANT CAVEAT
The way the models are currently loaded in the repo uses the model_type key found in the config.json for a given model. The way the Jina models are currently loaded is a bit hackish as we use Alibi position encoding to route to the JinaBertModel architecture rather than the BertModel architecture.

In the case where different models have the same model_type key (e.g. "model_type": "bert") and have "position_embedding_type": "alibi", they get routed to the JinaBertModel architecture at startup.

To get around this, I locally changed the model_type of the normal Jina models (e.g. jina-embeddings-v2-base-en) to jina_bert and the model_type of the Jina code models to jina_code_bert.

This enables me to properly load the models locally, but it is obviously not a clean solution that would enable someone to plug in the embedding models without changing the config files.

A more robust way of loading the different model architectures would be necessarily in order to properly support this new implementation

Fixes #270

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene

Copy link
Contributor

@OlivierDehaene OlivierDehaene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's ultra cool thanks for the hard work! :)
I only have very small comments.

Regarding how to correctly select the model type, I'm not sure yet. Very annoying that they didn't create a new one AGAIN.

Comment on lines +106 to +112
let query_layer = self.query_linear.forward(hidden_states)?;
let query_layer = self.layer_norm_q.forward(&query_layer, None)?;

let key_layer = self.key_linear.forward(hidden_states)?;
let key_layer = self.layer_norm_k.forward(&key_layer, None)?;

let value_layer = self.value_linear.forward(hidden_states)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should still merge qkv and do one fused linear.
It's faster to do one big linear instead of 3 smaller ones.
You can then split the result and apply the layer norm for q and k.

Comment on lines +60 to +62
let query_linear = Linear::new(query_weight, Some(query_bias), None);
let key_linear = Linear::new(key_weight, Some(key_bias), None);
let value_linear = Linear::new(value_weight, Some(value_bias), None);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other comment on merging

Comment on lines +119 to +121
let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there transpose here?

Comment on lines +214 to +215
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?;
let gated = hidden_states.i((.., .., self.intermediate_size..))?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use narrow instead:

Suggested change
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?;
let gated = hidden_states.i((.., .., self.intermediate_size..))?;
let non_gated = hidden_states.narrow(2, 0, self.intermediate_size)?;
let gated = hidden_states.narrow(2, self.intermediate_size, self.intermediate_size)?;

I'm a bit confused, why is hidden_states dim3 here? Shouldn't it be dim2 and narrow on 1?

Comment on lines +191 to +197
let query_layer = self.query_linear.forward(hidden_states)?;
let query_layer = self.layer_norm_q.forward(&query_layer, None)?;

let key_layer = self.key_linear.forward(hidden_states)?;
let key_layer = self.layer_norm_k.forward(&key_layer, None)?;

let value_layer = self.value_linear.forward(hidden_states)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines +357 to +358
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?;
let gated = hidden_states.i((.., .., self.intermediate_size..))?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

@OlivierDehaene
Copy link
Contributor

I'm not sure why we are dim3 in the flash implementation though. We should always be dim2 (concatenanted ids, hidden_dim). Can you check if there is an added dim somewhere for some reason?

@OlivierDehaene
Copy link
Contributor

I will merge this branch in a dev branch and modify the few nitpicks above to make this PR part of today's release.
I hope you don't mind. Of course you will be credited for the PR!
Thanks again for the hard work!

@OlivierDehaene OlivierDehaene merged commit b8f6c78 into huggingface:main Jun 21, 2024
@patricebechard
Copy link
Contributor Author

I will merge this branch in a dev branch and modify the few nitpicks above to make this PR part of today's release.

I hope you don't mind. Of course you will be credited for the PR!

Thanks again for the hard work!

Not a problem at all. Thanks for helping!

MasakiMu319 pushed a commit to MasakiMu319/text-embeddings-inference that referenced this pull request Nov 27, 2024
aagnone3 pushed a commit to StratisLLC/hf-text-embeddings-inference that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for jinaai/jina-embeddings-v2-base-code
2 participants