Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit bc4ef44

Browse files
NER example: fix metrics computation
1 parent 8a9dd72 commit bc4ef44

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

example/named_entity_recognition/src/metrics.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# -*- coding: utf-8 -*-
2121

22+
import logging
2223
import mxnet as mx
2324
import numpy as np
2425
import pickle
@@ -43,17 +44,17 @@ def classifer_metrics(label, pred):
4344
corr_pred = (prediction == label) == (pred_is_entity == True)
4445

4546
#how many entities are there?
46-
num_entities = np.sum(label_is_entity)
47-
entity_preds = np.sum(pred_is_entity)
48-
47+
# better to cast to float for safer further ratio computations
48+
num_entities = float(np.sum(label_is_entity))
49+
entity_preds = float(np.sum(pred_is_entity))
4950
#how many times did we correctly predict an entity?
50-
correct_entitites = np.sum(corr_pred[pred_is_entity])
51+
correct_entitites = float(np.sum(corr_pred[pred_is_entity]))
5152

5253
#precision: when we predict entity, how often are we right?
5354
if entity_preds == 0:
5455
precision = np.nan
5556
else:
56-
precision = correct_entitites/entity_preds
57+
precision = correct_entitites / entity_preds
5758

5859
#recall: of the things that were an entity, how many did we catch?
5960
recall = correct_entitites / num_entities
@@ -64,6 +65,8 @@ def classifer_metrics(label, pred):
6465
f1 = 0
6566
else:
6667
f1 = 2 * precision * recall / (precision + recall)
68+
69+
logging.debug("Metrics results: precision=%f recall=%f f1=%f", precision, recall, f1)
6770
return precision, recall, f1
6871

6972
def entity_precision(label, pred):

example/named_entity_recognition/src/ner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_vocab(nested_list):
9393
"""
9494
# Build vocabulary
9595
word_counts = Counter(itertools.chain(*nested_list))
96+
logging.getLogger().info("build_vocab: word_counts=%d" % (len(word_counts)))
9697

9798
# Mapping from index to label
9899
vocabulary_inv = [x[0] for x in word_counts.most_common()]
@@ -114,6 +115,7 @@ def build_iters(data_dir, max_records, train_fraction, batch_size, buckets=None)
114115
:param buckets: size of each bucket in the iterators
115116
:return: train_iter, val_iter, word_to_index, index_to_word, pos_to_index, index_to_pos
116117
"""
118+
117119
# Read in data as numpy array
118120
df = pd.read_pickle(os.path.join(data_dir, "ner_data.pkl"))[:max_records]
119121

@@ -135,12 +137,14 @@ def build_iters(data_dir, max_records, train_fraction, batch_size, buckets=None)
135137

136138
# Split into training and testing data
137139
idx=int(len(indexed_tokens)*train_fraction)
140+
logging.info("Preparing train/test datasets splitting at idx %d on total %d sentences using a batchsize of %d", idx, len(indexed_tokens), batch_size)
138141
X_token_train, X_char_train, Y_train = indexed_tokens[:idx], indexed_chars[:idx], indexed_entities[:idx]
139142
X_token_test, X_char_test, Y_test = indexed_tokens[idx:], indexed_chars[idx:], indexed_entities[idx:]
140143

141144
# build iterators to feed batches to network
142145
train_iter = iterators.BucketNerIter(sentences=X_token_train, characters=X_char_train, label=Y_train,
143146
max_token_chars=5, batch_size=batch_size, buckets=buckets)
147+
logging.info("Creating the val_iter using %d sentences", len(X_token_test))
144148
val_iter = iterators.BucketNerIter(sentences=X_token_test, characters=X_char_test, label=Y_test,
145149
max_token_chars=train_iter.max_token_chars, batch_size=batch_size, buckets=train_iter.buckets)
146150
return train_iter, val_iter, word_to_index, char_to_index, entity_to_index
@@ -205,6 +209,8 @@ def sym_gen(seq_len):
205209
def train(train_iter, val_iter):
206210
import metrics
207211
devs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]
212+
logging.info("train on device %s using optimizer %s at learningrate %f for %d epochs using %d records: lstm_state_size=%d ...",
213+
devs, args.optimizer, args.lr, args.num_epochs, args.max_records, args.lstm_state_size)
208214
module = mx.mod.BucketingModule(sym_gen, train_iter.default_bucket_key, context=devs)
209215
module.fit(train_data=train_iter,
210216
eval_data=val_iter,
@@ -225,6 +231,8 @@ def train(train_iter, val_iter):
225231
train_iter, val_iter, word_to_index, char_to_index, entity_to_index = build_iters(args.data_dir, args.max_records,
226232
args.train_fraction, args.batch_size, args.buckets)
227233

234+
logging.info("validation iterator: %s", val_iter)
235+
228236
# Define the recurrent layer
229237
bi_cell = mx.rnn.SequentialRNNCell()
230238
for layer_num in range(args.lstm_layers):

example/named_entity_recognition/src/preprocess.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# -*- coding: utf-8 -*-
2121

22+
import logging
2223
import pandas as pd
2324
import numpy as np
2425

@@ -45,6 +46,12 @@
4546

4647
#join the results on utterance id
4748
df = df1.merge(df2.merge(df3, how = "left", on = "utterance_id"), how = "left", on = "utterance_id")
49+
pd.option_context('display.max_colwidth', None)
50+
pd.option_context('display.max_rowwidth', None)
51+
52+
logging.info("preprocess: 1st sentence:")
53+
logging.info(df['token'].iloc[0].tolist())
54+
logging.info(df['BILOU_tag'].iloc[0].tolist())
4855

4956
#save the dataframe to a csv file
5057
df.to_pickle("../data/ner_data.pkl")

0 commit comments

Comments
 (0)