Skip to content

Commit b87761f

Browse files
authored
Add pretrain processing into dygraph bert model, test=release/1.8 (#4718)
1 parent e4ad047 commit b87761f

File tree

4 files changed

+424
-8
lines changed

4 files changed

+424
-8
lines changed

dygraph/bert/model/bert.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,37 +230,37 @@ def forward(self, src_ids, position_ids, sentence_ids, input_mask,
230230

231231
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
232232
sentence_ids, input_mask)
233+
233234
reshaped_emb_out = fluid.layers.reshape(
234235
x=enc_output, shape=[-1, self._emb_size])
235236

236237
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
237-
238238
mask_trans_feat = self.pooled_fc(mask_feat)
239-
mask_trans_feat = self.pre_process_layer(None, mask_trans_feat, "n",
240-
self._prepostprocess_dropout)
239+
mask_trans_feat = self.pre_process_layer(mask_trans_feat)
241240

242241
if self._weight_sharing:
243242
fc_out = fluid.layers.matmul(
244243
x=mask_trans_feat,
245-
y=self.bert_layer._src_emb._w,
244+
y=self.bert_layer._src_emb.weight,
246245
transpose_y=True)
247246
fc_out += self.fc_create_params
248247
else:
249248
fc_out = self.out_fc(mask_trans_feat)
250249

251-
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
252-
logits=fc_out, label=mask_label)
250+
mask_lm_loss, mask_lm_softmax = fluid.layers.softmax_with_cross_entropy(
251+
logits=fc_out, label=mask_label, return_softmax=True)
253252
mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
254253

255254
next_sent_fc_out = self.next_sent_fc(next_sent_feat)
256255

257256
next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy(
258257
logits=next_sent_fc_out, label=labels, return_softmax=True)
259258

259+
lm_acc = fluid.layers.accuracy(input=mask_lm_softmax, label=mask_label)
260+
260261
next_sent_acc = fluid.layers.accuracy(
261262
input=next_sent_softmax, label=labels)
262-
263263
mean_next_sent_loss = fluid.layers.mean(next_sent_loss)
264264

265265
loss = mean_next_sent_loss + mean_mask_lm_loss
266-
return next_sent_acc, mean_mask_lm_loss, loss
266+
return lm_acc, next_sent_acc, mean_mask_lm_loss, loss

dygraph/bert/run_train_multi_gpu.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/bin/bash
2+
3+
# pretrain config
4+
SAVE_STEPS=10000
5+
BATCH_SIZE=4096
6+
LR_RATE=1e-4
7+
WEIGHT_DECAY=0.01
8+
MAX_LEN=512
9+
TRAIN_DATA_DIR=data/train
10+
VALIDATION_DATA_DIR=data/validation
11+
CONFIG_PATH=data/demo_config/bert_config.json
12+
VOCAB_PATH=data/demo_config/vocab.txt
13+
# Change your train arguments:
14+
GPU_TO_USE=0,1
15+
# start pretrain
16+
python -m paddle.distributed.launch --selected_gpus=$GPU_TO_USE --log_dir ./pretrain_log ./train.py ${is_distributed}\
17+
--use_cuda true\
18+
--use_data_parallel true\
19+
--weight_sharing true\
20+
--batch_size ${BATCH_SIZE} \
21+
--data_dir ${TRAIN_DATA_DIR} \
22+
--validation_set_dir ${VALIDATION_DATA_DIR} \
23+
--bert_config_path ${CONFIG_PATH} \
24+
--vocab_path ${VOCAB_PATH} \
25+
--generate_neg_sample true\
26+
--checkpoints ./output \
27+
--save_steps ${SAVE_STEPS} \
28+
--learning_rate ${LR_RATE} \
29+
--weight_decay ${WEIGHT_DECAY:-0} \
30+
--max_seq_len ${MAX_LEN} \
31+
--skip_steps 20 \
32+
--validation_steps 1000 \
33+
--num_iteration_per_drop_scope 10 \
34+
--use_fp16 false \
35+
--verbose true

dygraph/bert/run_train_single_gpu.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
3+
# pretrain config
4+
SAVE_STEPS=100
5+
BATCH_SIZE=4096
6+
LR_RATE=1e-4
7+
WEIGHT_DECAY=0.01
8+
MAX_LEN=512
9+
TRAIN_DATA_DIR=data/train
10+
VALIDATION_DATA_DIR=data/validation
11+
CONFIG_PATH=data/demo_config/bert_config.json
12+
VOCAB_PATH=data/demo_config/vocab.txt
13+
# Change your train arguments:
14+
# start pretrain
15+
python -u ./train.py --use_cuda true\
16+
--use_data_parallel false\
17+
--weight_sharing true\
18+
--batch_size ${BATCH_SIZE} \
19+
--data_dir ${TRAIN_DATA_DIR} \
20+
--validation_set_dir ${VALIDATION_DATA_DIR} \
21+
--bert_config_path ${CONFIG_PATH} \
22+
--vocab_path ${VOCAB_PATH} \
23+
--generate_neg_sample true\
24+
--checkpoints ./output \
25+
--save_steps ${SAVE_STEPS} \
26+
--learning_rate ${LR_RATE} \
27+
--weight_decay ${WEIGHT_DECAY:-0} \
28+
--max_seq_len ${MAX_LEN} \
29+
--skip_steps 20 \
30+
--validation_steps 100 \
31+
--num_iteration_per_drop_scope 10 \
32+
--use_fp16 false \
33+
--verbose true

0 commit comments

Comments
 (0)