2
2
Word Swap by Changing Location
3
3
-------------------------------
4
4
"""
5
+ from collections import defaultdict
6
+
5
7
import more_itertools as mit
6
8
import numpy as np
7
9
@@ -25,12 +27,15 @@ def idx_to_words(ls, words):
25
27
26
28
27
29
class WordSwapChangeLocation (WordSwap ):
28
- def __init__ (self , n = 3 , confidence_score = 0.7 , language = "en" , ** kwargs ):
30
+ def __init__ (
31
+ self , n = 3 , confidence_score = 0.7 , language = "en" , consistent = False , ** kwargs
32
+ ):
29
33
"""Transformation that changes recognized locations of a sentence to
30
34
another location that is given in the location map.
31
35
32
36
:param n: Number of new locations to generate
33
37
:param confidence_score: Location will only be changed if it's above the confidence score
38
+ :param consistent: Whether to change all instances of the same location to the same new location
34
39
35
40
>>> from textattack.transformations import WordSwapChangeLocation
36
41
>>> from textattack.augmentation import Augmenter
@@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
44
49
self .n = n
45
50
self .confidence_score = confidence_score
46
51
self .language = language
52
+ self .consistent = consistent
47
53
48
54
def _get_transformations (self , current_text , indices_to_modify ):
49
55
words = current_text .words
@@ -64,26 +70,55 @@ def _get_transformations(self, current_text, indices_to_modify):
64
70
location_idx = [list (group ) for group in mit .consecutive_groups (location_idx )]
65
71
location_words = idx_to_words (location_idx , words )
66
72
73
+ if self .consistent :
74
+ location_to_indices = self ._build_location_to_indicies_map (
75
+ location_words , current_text
76
+ )
77
+
67
78
transformed_texts = []
68
79
for location in location_words :
69
80
idx = location [0 ]
70
- word = location [1 ].capitalize ()
81
+ word = self ._capitalize (location [1 ])
82
+
83
+ # If doing consistent replacements, only replace the
84
+ # word if it hasn't been replaced in a previous iteration
85
+ if self .consistent and word not in location_to_indices :
86
+ continue
87
+
71
88
replacement_words = self ._get_new_location (word )
72
89
for r in replacement_words :
73
90
if r == word :
74
91
continue
75
- text = current_text
76
-
77
- # if original location is more than a single word, remain only the starting word
78
- if len (idx ) > 1 :
79
- index = idx [1 ]
80
- for i in idx [1 :]:
81
- text = text .delete_word_at_index (index )
82
92
83
- # replace the starting word with new location
84
- text = text .replace_word_at_index (idx [0 ], r )
93
+ if self .consistent :
94
+ indices_to_delete = []
95
+ if len (idx ) > 1 :
96
+ for i in location_to_indices [word ]:
97
+ for j in range (1 , len (idx )):
98
+ indices_to_delete .append (i + j )
99
+
100
+ transformed_texts .append (
101
+ current_text .replace_words_at_indices (
102
+ location_to_indices [word ] + indices_to_delete ,
103
+ ([r ] * len (location_to_indices [word ]))
104
+ + (["" ] * len (indices_to_delete )),
105
+ )
106
+ )
107
+ else :
108
+ # If the original location is more than a single word, keep only the starting word
109
+ # and replace the starting word with the new word
110
+ indices_to_delete = idx [1 :]
111
+ transformed_texts .append (
112
+ current_text .replace_words_at_indices (
113
+ [idx [0 ]] + indices_to_delete ,
114
+ [r ] + ["" ] * len (indices_to_delete ),
115
+ )
116
+ )
117
+
118
+ if self .consistent :
119
+ # Delete this word to mark it as replaced
120
+ del location_to_indices [word ]
85
121
86
- transformed_texts .append (text )
87
122
return transformed_texts
88
123
89
124
def _get_new_location (self , word ):
@@ -101,3 +136,57 @@ def _get_new_location(self, word):
101
136
elif word in NAMED_ENTITIES ["city" ]:
102
137
return np .random .choice (NAMED_ENTITIES ["city" ], self .n )
103
138
return []
139
+
140
+ def _capitalize (self , string ):
141
+ """Capitalizes all words in the string."""
142
+ return " " .join (word .capitalize () for word in string .split ())
143
+
144
+ def _build_location_to_indicies_map (self , location_words , text ):
145
+ """Returns a map of each location and the starting indicies of all
146
+ appearances of that location in the text."""
147
+
148
+ location_to_indices = defaultdict (list )
149
+ if len (location_words ) == 0 :
150
+ return location_to_indices
151
+
152
+ location_words .sort (
153
+ # Sort by the number of words in the location
154
+ key = lambda index_location_pair : index_location_pair [0 ][- 1 ]
155
+ - index_location_pair [0 ][0 ]
156
+ + 1 ,
157
+ reverse = True ,
158
+ )
159
+ max_length = location_words [0 ][0 ][- 1 ] - location_words [0 ][0 ][0 ] + 1
160
+
161
+ for idx , location in location_words :
162
+ words_in_location = idx [- 1 ] - idx [0 ] + 1
163
+ found = False
164
+ location_start = idx [0 ]
165
+
166
+ # Check each window of n words containing the original tagged location
167
+ # for n from the max_length down to the original location length.
168
+ # This prevents cases where the NER tagger misses a word in a location
169
+ # (e.g. it does not tag "New" in "New York")
170
+ for length in range (max_length , words_in_location , - 1 ):
171
+ for start in range (
172
+ location_start - length + words_in_location ,
173
+ location_start + 1 ,
174
+ ):
175
+ if start + length > len (text .words ):
176
+ break
177
+
178
+ expanded_location = self ._capitalize (
179
+ " " .join (text .words [start : start + length ])
180
+ )
181
+ if expanded_location in location_to_indices :
182
+ location_to_indices [expanded_location ].append (start )
183
+ found = True
184
+ break
185
+
186
+ if found :
187
+ break
188
+
189
+ if not found :
190
+ location_to_indices [self ._capitalize (location )].append (idx [0 ])
191
+
192
+ return location_to_indices
0 commit comments