Skip to content

Commit 7e34851

Browse files
authored
Merge branch 'master' into master
2 parents 6d20feb + 89ccbf6 commit 7e34851

File tree

4 files changed

+313
-1
lines changed

4 files changed

+313
-1
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# This script is adopted from the run-glue example of Nvidia-FasterTransformer,https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py
2+
3+
import argparse
4+
import logging
5+
import os
6+
import random
7+
import timeit
8+
9+
import numpy as np
10+
import torch
11+
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
12+
from tqdm import tqdm, trange
13+
14+
from transformers import (
15+
BertConfig,
16+
BertTokenizer,
17+
)
18+
from utils.modeling_bert import BertForSequenceClassification, BertForQuestionAnswering
19+
from transformers import glue_compute_metrics as compute_metrics
20+
from transformers import glue_convert_examples_to_features as convert_examples_to_features
21+
from transformers import glue_output_modes as output_modes
22+
from transformers import glue_processors as processors
23+
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def set_seed(args):
29+
random.seed(args.seed)
30+
np.random.seed(args.seed)
31+
torch.manual_seed(args.seed)
32+
33+
def main():
34+
parser = argparse.ArgumentParser()
35+
36+
parser.add_argument(
37+
"--model_name_or_path",
38+
default=None,
39+
type=str,
40+
required=True,
41+
help="Path to pre-trained model or shortcut name",
42+
)
43+
44+
parser.add_argument(
45+
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
46+
)
47+
parser.add_argument(
48+
"--tokenizer_name",
49+
default="",
50+
type=str,
51+
help="Pretrained tokenizer name or path if not the same as model_name",
52+
)
53+
parser.add_argument(
54+
"--cache_dir",
55+
default="",
56+
type=str,
57+
help="Where do you want to store the pre-trained models downloaded from s3",
58+
)
59+
parser.add_argument(
60+
"--max_seq_length",
61+
default=128,
62+
type=int,
63+
help="The maximum total input sequence length after tokenization. Sequences longer "
64+
"than this will be truncated, sequences shorter will be padded.",
65+
)
66+
parser.add_argument("--mode", default= "sequence_classification", help=" Set the model for sequence classification or question answering")
67+
parser.add_argument(
68+
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
69+
)
70+
71+
parser.add_argument(
72+
"--batch_size", default=8, type=int, help="Batch size for tracing.",
73+
)
74+
75+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
76+
# parser.add_arument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
77+
78+
parser.add_argument("--model_type", type=str, help="ori, ths, thsext")
79+
parser.add_argument("--data_type", type=str, help="fp32, fp16")
80+
parser.add_argument('--ths_path', type=str, default='./lib/libpyt_fastertransformer.so',
81+
help='path of the pyt_fastertransformer dynamic lib file')
82+
parser.add_argument('--remove_padding', action='store_false',
83+
help='Remove the padding of sentences of encoder.')
84+
parser.add_argument('--allow_gemm_test', action='store_false',
85+
help='per-channel quantization.')
86+
87+
args = parser.parse_args()
88+
89+
if torch.cuda.is_available():
90+
device = torch.device("cuda")
91+
args.device = device
92+
93+
# Setup logging
94+
logging.basicConfig(
95+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
96+
datefmt="%m/%d/%Y %H:%M:%S",
97+
level=logging.INFO if args.device else logging.WARN,
98+
)
99+
100+
# Set seed
101+
set_seed(args)
102+
103+
tokenizer = BertTokenizer.from_pretrained(
104+
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
105+
do_lower_case=args.do_lower_case,
106+
cache_dir=args.cache_dir if args.cache_dir else None,
107+
)
108+
109+
logger.info("Parameters %s", args)
110+
111+
checkpoints = [args.model_name_or_path]
112+
for checkpoint in checkpoints:
113+
use_ths = args.model_type.startswith('ths')
114+
if args.mode == "sequence_classification":
115+
model = BertForSequenceClassification.from_pretrained(checkpoint, torchscript=use_ths)
116+
elif args.mode == "question_answering":
117+
model = BertForQuestionAnswering.from_pretrained(checkpoint, torchscript=use_ths)
118+
model.to(args.device)
119+
120+
if args.data_type == 'fp16':
121+
logger.info("Use fp16")
122+
model.half()
123+
if args.model_type == 'thsext':
124+
logger.info("Use custom BERT encoder for TorchScript")
125+
from utils.encoder import EncoderWeights, CustomEncoder
126+
weights = EncoderWeights(
127+
model.config.num_hidden_layers, model.config.hidden_size,
128+
torch.load(os.path.join(checkpoint, 'pytorch_model.bin'), map_location='cpu'))
129+
weights.to_cuda()
130+
if args.data_type == 'fp16':
131+
weights.to_half()
132+
enc = CustomEncoder(model.config.num_hidden_layers,
133+
model.config.num_attention_heads,
134+
model.config.hidden_size//model.config.num_attention_heads,
135+
weights,
136+
remove_padding=args.remove_padding,
137+
allow_gemm_test=(args.allow_gemm_test),
138+
path=os.path.abspath(args.ths_path))
139+
enc_ = torch.jit.script(enc)
140+
model.replace_encoder(enc_)
141+
if use_ths:
142+
logger.info("Use TorchScript mode")
143+
fake_input_id = torch.LongTensor(args.batch_size, args.max_seq_length)
144+
fake_input_id.fill_(1)
145+
fake_input_id = fake_input_id.to(args.device)
146+
fake_mask = torch.ones(args.batch_size, args.max_seq_length).to(args.device)
147+
fake_type_id = fake_input_id.clone().detach()
148+
if args.data_type == 'fp16':
149+
fake_mask = fake_mask.half()
150+
model.eval()
151+
with torch.no_grad():
152+
print("********** input id and mask sizes ******",fake_input_id.size(),fake_mask.size() )
153+
model_ = torch.jit.trace(model, (fake_input_id, fake_mask))
154+
model = model_
155+
torch.jit.save(model,"traced_model.pt")
156+
157+
if __name__ == "__main__":
158+
main()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
## Faster Transformer
2+
3+
Batch inferencing with Transformers faces two challenges
4+
5+
- Large batch sizes suffer from higher latency and small or medium-sized batches this will become kernel latency launch bound.
6+
- Padding wastes a lot of compute, (batchsize, seq_length) requires to pad the sequence to (batchsize, max_length) where difference between avg_length and max_length results in a considerable waste of computation, increasing the batch size worsen this situation.
7+
8+
[Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) (FT) from Nvidia along with [Efficient Transformers](https://github.com/bytedance/effective_transformer) (EFFT) that is built on top of FT address the above two challenges, by fusing the CUDA kernels and dynamically removing padding during computations. The current implementation from [Faster Transformers](https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/run_glue.py) support BERT like encoder and decoder layers. In this example, we show how to get a Torchsctipted (traced) EFFT variant of Bert models from HuggingFace (HF) for sequence classification and question answering and serve it.
9+
10+
11+
### How to get a Torchsctipted (Traced) EFFT of HF Bert model and serving it
12+
13+
**Requirements**
14+
15+
Running Faster Transformer at this point is recommended through [NVIDIA docker and NGC container](https://github.com/NVIDIA/FasterTransformer#requirements), also it requires [Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) or [Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) based GPU. For this example we have used a **g4dn.2xlarge** EC2 instance that has a T4 GPU.
16+
17+
**Setup the a GPU machine that meets the requirements and connect to it**.
18+
19+
```bash
20+
### Sign up for NGC https://ngc.nvidia.com and get API key###
21+
docker login nvcr.io
22+
Username: $oauthtoken
23+
Password: API key
24+
25+
docker pull nvcr.io/nvidia/pytorch:20.12-py3
26+
27+
nvidia-docker run -ti --gpus all --rm nvcr.io/nvidia/pytorch:20.12-py3 bash
28+
29+
git clone https://github.com/NVIDIA/FasterTransformer.git
30+
31+
cd FasterTransformer
32+
33+
mkdir -p build
34+
35+
cd build
36+
37+
cmake -DSM=75 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON .. # -DSM = 70 for V100 gpu ------- 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100),
38+
39+
make
40+
41+
pip install transformers==2.5.1
42+
43+
cd /workspace
44+
45+
# clone Torchserve to access examples
46+
git clone https://github.com/pytorch/serve.git
47+
48+
# install torchserve
49+
cd serve
50+
51+
python ts_scripts/install_dependencies.py --cuda=cu102
52+
53+
pip install torchserve torch-model-archiver torch-workflow-archiver
54+
55+
cp /examples/FasterTransformer_HuggingFace_Bert/Bert_FT_trace.py /workspace/FasterTransformer/build/pytorch
56+
57+
58+
```
59+
60+
Now we are ready to make the Torchscripted file, as mentioned at the beginning two models are supported Bert for sequence classification and question answering. To do this step we need the download the model weights. We do this the same way we do in [HuggingFace example](https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers).
61+
62+
#### Sequence classification EFFT Traced model and serving
63+
64+
```bash
65+
# Sequence classification
66+
python ../Huggingface_Transformers/Download_Transformer_models.py
67+
68+
# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory
69+
70+
cd /workspace/FasterTransformer/build/
71+
72+
# This will generate the Traced model "traced_model.pt"
73+
# --data_type can be fp16 or fp32
74+
python pytorch/Bert_FT_trace.py --mode sequence_classification --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext
75+
76+
cd -
77+
78+
# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript" and "FasterTransformer":true
79+
80+
# change the ../Huggingface_Transformers/setup_config.json
81+
{
82+
"model_name":"bert-base-uncased",
83+
"mode":"question_answering",
84+
"do_lower_case":true,
85+
"num_labels":"0",
86+
"save_mode":"pretrained",
87+
"max_length":"128",
88+
"captum_explanation":false,
89+
"embedding_name": "bert",
90+
"FasterTransformer":true
91+
}
92+
93+
torch-model-archiver --model-name BERTSeqClassification --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,../Huggingface_Transformers/Seq_classification_artifacts/index_to_name.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so"
94+
95+
mkdir model_store
96+
97+
mv BERTSeqClassification.mar model_store/
98+
99+
torchserve --start --model-store model_store --models my_tc=BERTSeqClassification.mar --ncs
100+
101+
curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt
102+
103+
```
104+
105+
#### Question answering EFFT Traced model and serving
106+
107+
```bash
108+
# Question answering
109+
110+
# change the ../Huggingface_Transformers/setup_config.json
111+
{
112+
"model_name":"bert-base-uncased",
113+
"mode":"question_answering",
114+
"do_lower_case":true,
115+
"num_labels":"0",
116+
"save_mode":"pretrained",
117+
"max_length":"128",
118+
"captum_explanation":false,
119+
"embedding_name": "bert",
120+
"FasterTransformer":true
121+
}
122+
python ../Huggingface_Transformers/Download_Transformer_models.py
123+
124+
# This will downlaod the model weights in ../Huggingface_Transformers/Transfomer_model directory
125+
126+
cd /workspace/FasterTransformer/build/
127+
128+
# This will generate the Traced model "traced_model.pt"
129+
# --data_type can be fp16 or fp32
130+
python pytorch/Bert_FT_trace.py --mode question_answering --model_name_or_path "/workspace//serve/examples/Huggingface_Transformers/Transformer_model" --tokenizer_name "bert-base-uncased" --batch_size 1 --data_type fp16 --model_type thsext
131+
132+
cd -
133+
134+
# make sure to change the ../Huggingface_Transformers/setup_config.json "save_mode":"torchscript"
135+
136+
torch-model-archiver --model-name BERTQA --version 1.0 --serialized-file /workspace/FasterTransformer/build/traced_model.pt --handler ../Huggingface_Transformers/Transformer_handler_generalized.py --extra-files "../Huggingface_Transformers/setup_config.json,/workspace/FasterTransformer/build/lib/libpyt_fastertransformer.so"
137+
138+
mkdir model_store
139+
140+
mv BERTQA.mar model_store/
141+
142+
torchserve --start --model-store model_store --models my_tc=BERTQA.mar --ncs
143+
144+
curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ../Huggingface_Transformers/QA_artifacts/sample_text_captum_input.txt
145+
146+
```
147+
148+
####

examples/Huggingface_Transformers/Transformer_handler_generalized.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def initialize(self, ctx):
5252
else:
5353
logger.warning("Missing the setup_config.json file.")
5454

55+
# Loading the shared object of compiled Faster Transformer Library if Faster Transformer is set
56+
if self.setup_config["FasterTransformer"]:
57+
faster_transformer_complied_path = os.path.join(model_dir, "libpyt_fastertransformer.so")
58+
torch.classes.load_library(faster_transformer_complied_path)
59+
5560
# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
5661
# further setup config can be added.
5762
if self.setup_config["save_mode"] == "torchscript":

examples/Huggingface_Transformers/setup_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
"save_mode":"pretrained",
77
"max_length":"150",
88
"captum_explanation":true,
9-
"embedding_name": "bert"
9+
"embedding_name": "bert",
10+
"FasterTransformer":false
1011
}

0 commit comments

Comments
 (0)