Skip to content

fix(candle): fix FlashJinaCodeModel #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
integration-tests:
cargo test --release
cargo test

cuda-integration-tests:
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --release
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --profile release-debug

integration-tests-review:
cargo insta test --review --release
cargo insta test --review

cuda-integration-tests-review:
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --release
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --profile release-debug
File renamed without changes.
127 changes: 69 additions & 58 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel,
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::BertConfig;
use nohash_hasher::BuildNoHashHasher;
use serde::Deserialize;
use std::collections::HashMap;
Expand All @@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};

/// This enum is needed to be able to differentiate between jina models that also use
/// the `bert` model type and valid Bert models.
/// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
/// run but is still better than the other options...
#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "_name_or_path")]
pub enum BertConfigWrapper {
#[serde(rename = "jinaai/jina-bert-implementation")]
JinaBert(BertConfig),
#[serde(rename = "jinaai/jina-bert-v2-qk-post-norm")]
JinaCodeBert(BertConfig),
#[serde(untagged)]
Bert(BertConfig),
}

#[derive(Deserialize)]
#[serde(tag = "model_type", rename_all = "kebab-case")]
enum Config {
Bert(BertConfig),
Bert(BertConfigWrapper),
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "jina_bert"))]
JinaBert(JinaConfig),
#[serde(rename(deserialize = "jina_code_bert"))]
JinaCodeBert(JinaCodeConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
Expand Down Expand Up @@ -76,7 +87,7 @@ impl CandleBackend {
"Runtime compute cap {} is not compatible with compile time compute cap {}",
get_runtime_compute_cap().unwrap(),
get_compile_compute_cap().unwrap()
)))
)));
}
Err(err) => {
tracing::warn!("Could not find a compatible CUDA device on host: {err:?}");
Expand Down Expand Up @@ -123,18 +134,22 @@ impl CandleBackend {
(_, Device::Cuda(_)) => Err(BackendError::Start(
"`cuda` feature is not enabled".to_string(),
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting JinaCodeBert model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
},
(
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
Device::Cpu | Device::Metal(_),
Expand All @@ -158,48 +173,45 @@ impl CandleBackend {
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting FlashJinaBert model on {:?}", device);
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting FlashJinaCodeBert model on {:?}", device);
Ok(Box::new(
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
}
}
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
tracing::info!("Starting JinaCodeBert model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
BertConfigWrapper::Bert(config) => {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
}
}
}
#[cfg(feature = "cuda")]
Expand All @@ -209,7 +221,6 @@ impl CandleBackend {
) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
Expand Down
35 changes: 17 additions & 18 deletions backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::{JinaConfig, BertEmbeddings};
use crate::models::jina::BertEmbeddings;
use crate::models::Model;
use crate::models::jina::JinaEmbeddings;
use crate::models::{BertConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct AlibiBertAttention {
struct JinaAttention {
qkv_linear: Linear,
dense: Linear,
layer_norm: LayerNorm,
Expand All @@ -23,7 +22,7 @@ struct AlibiBertAttention {
span: tracing::Span,
}

impl AlibiBertAttention {
impl JinaAttention {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi_slopes: Option<Tensor>) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
Expand Down Expand Up @@ -117,7 +116,7 @@ impl AlibiBertAttention {
}

struct JinaBertLayer {
attention: AlibiBertAttention,
attention: JinaAttention,
gated: Linear,
output: Linear,
layer_norm: LayerNorm,
Expand All @@ -130,7 +129,7 @@ struct JinaBertLayer {

impl JinaBertLayer {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
let attention = AlibiBertAttention::load(vb.pp("attention"), config, alibi)?;
let attention = JinaAttention::load(vb.pp("attention"), config, alibi)?;

let gated_weight = vb
.pp("mlp")
Expand Down Expand Up @@ -174,14 +173,14 @@ impl JinaBertLayer {
let residual = hidden_states.clone();

let hidden_states = self.gated.forward(&hidden_states)?;
let gated = hidden_states.i((.., 0..self.intermediate_size))?;
let gated = hidden_states.narrow(1, 0, self.intermediate_size)?;
let gated = match self.act {
HiddenAct::Gelu => gated.gelu(),
HiddenAct::Relu => gated.relu(),
HiddenAct::Swiglu => gated.silu(),
}?;

let non_gated = hidden_states.i((.., self.intermediate_size..))?;
let non_gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?;
let hidden_states = (gated * non_gated)?;

let hidden_states = self.output.forward(&hidden_states)?;
Expand All @@ -191,12 +190,12 @@ impl JinaBertLayer {
}
}

struct BertEncoder {
struct JinaBertEncoder {
layers: Vec<JinaBertLayer>,
span: tracing::Span,
}

impl BertEncoder {
impl JinaBertEncoder {
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| {
Expand All @@ -205,7 +204,7 @@ impl BertEncoder {
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");

Ok(BertEncoder { layers, span })
Ok(JinaBertEncoder { layers, span })
}

fn forward(&self, hidden_states: &Tensor, cu_seqlens: &Tensor, max_s: usize) -> Result<Tensor> {
Expand All @@ -223,8 +222,8 @@ impl BertEncoder {
}

pub struct FlashJinaBertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
embeddings: JinaEmbeddings,
encoder: JinaBertEncoder,
pool: Pool,
pub device: Device,

Expand Down Expand Up @@ -266,14 +265,14 @@ impl FlashJinaBertModel {
};

let (embeddings, encoder) = match (
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
JinaEmbeddings::load(vb.pp("embeddings"), config),
JinaBertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp("bert.embeddings"), config),
BertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
JinaEmbeddings::load(vb.pp("bert.embeddings"), config),
JinaBertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
) {
(embeddings, encoder)
} else {
Expand Down
Loading
Loading