-
Notifications
You must be signed in to change notification settings - Fork 277
Support for Jina Code model #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's ultra cool thanks for the hard work! :)
I only have very small comments.
Regarding how to correctly select the model type, I'm not sure yet. Very annoying that they didn't create a new one AGAIN.
let query_layer = self.query_linear.forward(hidden_states)?; | ||
let query_layer = self.layer_norm_q.forward(&query_layer, None)?; | ||
|
||
let key_layer = self.key_linear.forward(hidden_states)?; | ||
let key_layer = self.layer_norm_k.forward(&key_layer, None)?; | ||
|
||
let value_layer = self.value_linear.forward(hidden_states)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should still merge qkv and do one fused linear.
It's faster to do one big linear instead of 3 smaller ones.
You can then split the result and apply the layer norm for q and k.
let query_linear = Linear::new(query_weight, Some(query_bias), None); | ||
let key_linear = Linear::new(key_weight, Some(key_bias), None); | ||
let value_linear = Linear::new(value_weight, Some(value_bias), None); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comment on merging
let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; | ||
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; | ||
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there transpose here?
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; | ||
let gated = hidden_states.i((.., .., self.intermediate_size..))?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use narrow instead:
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; | |
let gated = hidden_states.i((.., .., self.intermediate_size..))?; | |
let non_gated = hidden_states.narrow(2, 0, self.intermediate_size)?; | |
let gated = hidden_states.narrow(2, self.intermediate_size, self.intermediate_size)?; |
I'm a bit confused, why is hidden_states dim3 here? Shouldn't it be dim2 and narrow on 1?
let query_layer = self.query_linear.forward(hidden_states)?; | ||
let query_layer = self.layer_norm_q.forward(&query_layer, None)?; | ||
|
||
let key_layer = self.key_linear.forward(hidden_states)?; | ||
let key_layer = self.layer_norm_k.forward(&key_layer, None)?; | ||
|
||
let value_layer = self.value_linear.forward(hidden_states)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
let non_gated = hidden_states.i((.., .., 0..self.intermediate_size))?; | ||
let gated = hidden_states.i((.., .., self.intermediate_size..))?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
I'm not sure why we are dim3 in the flash implementation though. We should always be dim2 (concatenanted ids, hidden_dim). Can you check if there is an added dim somewhere for some reason? |
I will merge this branch in a dev branch and modify the few nitpicks above to make this PR part of today's release. |
Not a problem at all. Thanks for helping! |
What does this PR do?
We implement the modified architecture for Jina models found here
This allows us to run inference for
jinaai/jina-embeddings-v2-base-code
, a long context code embeddings model.IMPORTANT CAVEAT
The way the models are currently loaded in the repo uses the
model_type
key found in theconfig.json
for a given model. The way the Jina models are currently loaded is a bit hackish as we use Alibi position encoding to route to the JinaBertModel architecture rather than the BertModel architecture.In the case where different models have the same
model_type
key (e.g."model_type": "bert"
) and have"position_embedding_type": "alibi"
, they get routed to the JinaBertModel architecture at startup.To get around this, I locally changed the
model_type
of the normal Jina models (e.g.jina-embeddings-v2-base-en
) tojina_bert
and themodel_type
of the Jina code models tojina_code_bert
.This enables me to properly load the models locally, but it is obviously not a clean solution that would enable someone to plug in the embedding models without changing the config files.
A more robust way of loading the different model architectures would be necessarily in order to properly support this new implementation
Fixes #270
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@OlivierDehaene