Skip to content

Support for Jina Code model #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Examples of supported models:
| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |

You can explore the list of best performing text embeddings
models [here](https://huggingface.co/spaces/mteb/leaderboard).
Expand Down
67 changes: 47 additions & 20 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
NomicConfig, PositionEmbeddingType,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel,
};
use candle::{DType, Device};
use candle_nn::VarBuilder;
Expand All @@ -36,6 +36,10 @@ enum Config {
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "jina_bert"))]
JinaBert(JinaConfig),
#[serde(rename(deserialize = "jina_code_bert"))]
JinaCodeBert(JinaCodeConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
Expand Down Expand Up @@ -117,13 +121,16 @@ impl CandleBackend {
"`cuda` feature is not enabled".to_string(),
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
(
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
Expand Down Expand Up @@ -154,23 +161,43 @@ impl CandleBackend {
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
}
} else {
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(
Expand Down
10 changes: 9 additions & 1 deletion backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod flash_bert;
#[cfg(feature = "cuda")]
mod flash_jina;

#[cfg(feature = "cuda")]
mod flash_jina_code;

#[cfg(feature = "cuda")]
mod flash_nomic;

Expand All @@ -24,7 +27,8 @@ mod flash_distilbert;
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
pub use jina::JinaBertModel;
pub use jina::{JinaConfig, JinaBertModel};
pub use jina_code::{JinaCodeConfig, JinaCodeBertModel};
pub use nomic::{NomicBertModel, NomicConfig};
use text_embeddings_backend_core::Batch;

Expand All @@ -34,6 +38,10 @@ pub use flash_bert::FlashBertModel;
#[cfg(feature = "cuda")]
pub use flash_jina::FlashJinaBertModel;

#[cfg(feature = "cuda")]
pub use flash_jina_code::FlashJinaCodeBertModel;


#[cfg(feature = "cuda")]
pub use flash_nomic::FlashNomicBertModel;

Expand Down
3 changes: 2 additions & 1 deletion backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::{BertConfig, PositionEmbeddingType};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::{JinaConfig, BertEmbeddings};
use crate::models::jina::BertEmbeddings;
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
Expand Down
Loading