Skip to content

Commit db4ae20

Browse files
authored
Merge pull request #752 from k-ivey/consistent-word-swap
Consistent word swap
2 parents 189d11f + bebf70f commit db4ae20

File tree

3 files changed

+187
-13
lines changed

3 files changed

+187
-13
lines changed

tests/test_transformations.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,34 @@ def test_word_swap_change_location():
3333
assert entity_original == entity_augmented
3434

3535

36+
def test_word_swap_change_location_consistent():
37+
from flair.data import Sentence
38+
from flair.models import SequenceTagger
39+
40+
from textattack.augmentation import Augmenter
41+
from textattack.transformations.word_swaps import WordSwapChangeLocation
42+
43+
augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True))
44+
s = "I am in New York. I love living in New York."
45+
s_augmented = augmenter.augment(s)
46+
augmented_text = Sentence(s_augmented[0])
47+
tagger = SequenceTagger.load("flair/ner-english")
48+
original_text = Sentence(s)
49+
tagger.predict(original_text)
50+
tagger.predict(augmented_text)
51+
52+
entity_original = []
53+
entity_augmented = []
54+
55+
for entity in original_text.get_spans("ner"):
56+
entity_original.append(entity.tag)
57+
for entity in augmented_text.get_spans("ner"):
58+
entity_augmented.append(entity.tag)
59+
60+
assert entity_original == entity_augmented
61+
assert s_augmented[0].count("New York") == 0
62+
63+
3664
def test_word_swap_change_name():
3765
from flair.data import Sentence
3866
from flair.models import SequenceTagger
@@ -59,6 +87,34 @@ def test_word_swap_change_name():
5987
assert entity_original == entity_augmented
6088

6189

90+
def test_word_swap_change_name_consistent():
91+
from flair.data import Sentence
92+
from flair.models import SequenceTagger
93+
94+
from textattack.augmentation import Augmenter
95+
from textattack.transformations.word_swaps import WordSwapChangeName
96+
97+
augmenter = Augmenter(transformation=WordSwapChangeName(consistent=True))
98+
s = "My name is Anthony Davis. Anthony Davis plays basketball."
99+
s_augmented = augmenter.augment(s)
100+
augmented_text = Sentence(s_augmented[0])
101+
tagger = SequenceTagger.load("flair/ner-english")
102+
original_text = Sentence(s)
103+
tagger.predict(original_text)
104+
tagger.predict(augmented_text)
105+
106+
entity_original = []
107+
entity_augmented = []
108+
109+
for entity in original_text.get_spans("ner"):
110+
entity_original.append(entity.tag)
111+
for entity in augmented_text.get_spans("ner"):
112+
entity_augmented.append(entity.tag)
113+
114+
assert entity_original == entity_augmented
115+
assert s_augmented[0].count("Anthony") == 0 or s_augmented[0].count("Davis") == 0
116+
117+
62118
def test_chinese_morphonym_character_swap():
63119
from textattack.augmentation import Augmenter
64120
from textattack.transformations.word_swaps.chn_transformations import (

textattack/transformations/word_swaps/word_swap_change_location.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Word Swap by Changing Location
33
-------------------------------
44
"""
5+
from collections import defaultdict
6+
57
import more_itertools as mit
68
import numpy as np
79

@@ -25,12 +27,15 @@ def idx_to_words(ls, words):
2527

2628

2729
class WordSwapChangeLocation(WordSwap):
28-
def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
30+
def __init__(
31+
self, n=3, confidence_score=0.7, language="en", consistent=False, **kwargs
32+
):
2933
"""Transformation that changes recognized locations of a sentence to
3034
another location that is given in the location map.
3135
3236
:param n: Number of new locations to generate
3337
:param confidence_score: Location will only be changed if it's above the confidence score
38+
:param consistent: Whether to change all instances of the same location to the same new location
3439
3540
>>> from textattack.transformations import WordSwapChangeLocation
3641
>>> from textattack.augmentation import Augmenter
@@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
4449
self.n = n
4550
self.confidence_score = confidence_score
4651
self.language = language
52+
self.consistent = consistent
4753

4854
def _get_transformations(self, current_text, indices_to_modify):
4955
words = current_text.words
@@ -64,26 +70,55 @@ def _get_transformations(self, current_text, indices_to_modify):
6470
location_idx = [list(group) for group in mit.consecutive_groups(location_idx)]
6571
location_words = idx_to_words(location_idx, words)
6672

73+
if self.consistent:
74+
location_to_indices = self._build_location_to_indicies_map(
75+
location_words, current_text
76+
)
77+
6778
transformed_texts = []
6879
for location in location_words:
6980
idx = location[0]
70-
word = location[1].capitalize()
81+
word = self._capitalize(location[1])
82+
83+
# If doing consistent replacements, only replace the
84+
# word if it hasn't been replaced in a previous iteration
85+
if self.consistent and word not in location_to_indices:
86+
continue
87+
7188
replacement_words = self._get_new_location(word)
7289
for r in replacement_words:
7390
if r == word:
7491
continue
75-
text = current_text
76-
77-
# if original location is more than a single word, remain only the starting word
78-
if len(idx) > 1:
79-
index = idx[1]
80-
for i in idx[1:]:
81-
text = text.delete_word_at_index(index)
8292

83-
# replace the starting word with new location
84-
text = text.replace_word_at_index(idx[0], r)
93+
if self.consistent:
94+
indices_to_delete = []
95+
if len(idx) > 1:
96+
for i in location_to_indices[word]:
97+
for j in range(1, len(idx)):
98+
indices_to_delete.append(i + j)
99+
100+
transformed_texts.append(
101+
current_text.replace_words_at_indices(
102+
location_to_indices[word] + indices_to_delete,
103+
([r] * len(location_to_indices[word]))
104+
+ ([""] * len(indices_to_delete)),
105+
)
106+
)
107+
else:
108+
# If the original location is more than a single word, keep only the starting word
109+
# and replace the starting word with the new word
110+
indices_to_delete = idx[1:]
111+
transformed_texts.append(
112+
current_text.replace_words_at_indices(
113+
[idx[0]] + indices_to_delete,
114+
[r] + [""] * len(indices_to_delete),
115+
)
116+
)
117+
118+
if self.consistent:
119+
# Delete this word to mark it as replaced
120+
del location_to_indices[word]
85121

86-
transformed_texts.append(text)
87122
return transformed_texts
88123

89124
def _get_new_location(self, word):
@@ -101,3 +136,57 @@ def _get_new_location(self, word):
101136
elif word in NAMED_ENTITIES["city"]:
102137
return np.random.choice(NAMED_ENTITIES["city"], self.n)
103138
return []
139+
140+
def _capitalize(self, string):
141+
"""Capitalizes all words in the string."""
142+
return " ".join(word.capitalize() for word in string.split())
143+
144+
def _build_location_to_indicies_map(self, location_words, text):
145+
"""Returns a map of each location and the starting indicies of all
146+
appearances of that location in the text."""
147+
148+
location_to_indices = defaultdict(list)
149+
if len(location_words) == 0:
150+
return location_to_indices
151+
152+
location_words.sort(
153+
# Sort by the number of words in the location
154+
key=lambda index_location_pair: index_location_pair[0][-1]
155+
- index_location_pair[0][0]
156+
+ 1,
157+
reverse=True,
158+
)
159+
max_length = location_words[0][0][-1] - location_words[0][0][0] + 1
160+
161+
for idx, location in location_words:
162+
words_in_location = idx[-1] - idx[0] + 1
163+
found = False
164+
location_start = idx[0]
165+
166+
# Check each window of n words containing the original tagged location
167+
# for n from the max_length down to the original location length.
168+
# This prevents cases where the NER tagger misses a word in a location
169+
# (e.g. it does not tag "New" in "New York")
170+
for length in range(max_length, words_in_location, -1):
171+
for start in range(
172+
location_start - length + words_in_location,
173+
location_start + 1,
174+
):
175+
if start + length > len(text.words):
176+
break
177+
178+
expanded_location = self._capitalize(
179+
" ".join(text.words[start : start + length])
180+
)
181+
if expanded_location in location_to_indices:
182+
location_to_indices[expanded_location].append(start)
183+
found = True
184+
break
185+
186+
if found:
187+
break
188+
189+
if not found:
190+
location_to_indices[self._capitalize(location)].append(idx[0])
191+
192+
return location_to_indices

textattack/transformations/word_swaps/word_swap_change_name.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
-------------------------------
44
"""
55

6+
from collections import defaultdict
7+
68
import numpy as np
79

810
from textattack.shared.data import PERSON_NAMES
@@ -18,6 +20,7 @@ def __init__(
1820
last_only=False,
1921
confidence_score=0.7,
2022
language="en",
23+
consistent=False,
2124
**kwargs
2225
):
2326
"""Transforms an input by replacing names of recognized name entity.
@@ -26,6 +29,7 @@ def __init__(
2629
:param first_only: Whether to change first name only
2730
:param last_only: Whether to change last name only
2831
:param confidence_score: Name will only be changed when it's above confidence score
32+
:param consistent: Whether to change all instances of the same name to the same new name
2933
>>> from textattack.transformations import WordSwapChangeName
3034
>>> from textattack.augmentation import Augmenter
3135
@@ -42,6 +46,7 @@ def __init__(
4246
self.last_only = last_only
4347
self.confidence_score = confidence_score
4448
self.language = language
49+
self.consistent = consistent
4550

4651
def _get_transformations(self, current_text, indices_to_modify):
4752
transformed_texts = []
@@ -52,14 +57,38 @@ def _get_transformations(self, current_text, indices_to_modify):
5257
else:
5358
model_name = "flair/ner-multi-fast"
5459

60+
if self.consistent:
61+
word_to_indices = defaultdict(list)
62+
for i in indices_to_modify:
63+
word_to_replace = current_text.words[i].capitalize()
64+
word_to_indices[word_to_replace].append(i)
65+
5566
for i in indices_to_modify:
5667
word_to_replace = current_text.words[i].capitalize()
68+
# If we're doing consistent replacements, only replace the word
69+
# if it hasn't already been replaced in a previous iteration
70+
if self.consistent and word_to_replace not in word_to_indices:
71+
continue
5772
word_to_replace_ner = current_text.ner_of_word_index(i, model_name)
73+
5874
replacement_words = self._get_replacement_words(
5975
word_to_replace, word_to_replace_ner
6076
)
77+
6178
for r in replacement_words:
62-
transformed_texts.append(current_text.replace_word_at_index(i, r))
79+
if self.consistent:
80+
transformed_texts.append(
81+
current_text.replace_words_at_indices(
82+
word_to_indices[word_to_replace],
83+
[r] * len(word_to_indices[word_to_replace]),
84+
)
85+
)
86+
else:
87+
transformed_texts.append(current_text.replace_word_at_index(i, r))
88+
89+
# Delete this word to mark it as replaced
90+
if self.consistent and len(replacement_words) != 0:
91+
del word_to_indices[word_to_replace]
6392

6493
return transformed_texts
6594

0 commit comments

Comments
 (0)