@@ -135,175 +135,7 @@ def plot_many(plot_func, in_tensors, name='PlotMany',
135
135
136
136
137
137
138
- def wrap (plot_func , _sentinel = None ,
139
- batch = False , name = None ,
140
- ** kwargs ):
141
- '''
142
- Wrap a plot function as a TensorFlow operation. It will return a python
143
- function that creates a TensorFlow plot operation applying the arguments
144
- as input.
145
-
146
- For example, if ``plot_func`` is a python function that takes two
147
- arrays as input, and draw a plot by returning a matplotlib Figure,
148
- we can wrap this function as a `Tensor` factory, such as:
149
-
150
- >>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
151
- >>> # x, y = get_batch_inputs(batch_size=4, ...)
152
-
153
- >>> plot_x = tf_plot(x)
154
- Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
155
- >>> plot_y = tf_plot(y)
156
- Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)
157
-
158
- Args:
159
- plot_func: A python function or callable to wrap. See the documentation
160
- of :func:`tfplot.plot()` for details.
161
- batch: If True, all the tensors passed as argument will be
162
- assumed to be batched. Default value is False.
163
- name: A default name for the operation (optional). If not given, the
164
- name of ``plot_func`` will be used.
165
- kwargs: An optional kwargs that will be passed by default to
166
- ``plot_func``.
167
-
168
- Returns:
169
- A python function that will create a TensorFlow plot operation,
170
- passing the provided arguments.
171
- '''
172
-
173
- if not hasattr (plot_func , '__call__' ):
174
- raise TypeError ("plot_func should be callable" )
175
- if _sentinel is not None :
176
- raise RuntimeError ("Invalid call: it can have only one unnamed argument, " +
177
- "please pass named arguments for batch, name, etc." )
178
-
179
- if name is None :
180
- name = _clean_name (plot_func .__name__ )
181
-
182
- def _wrapped_fn (* args , ** kwargs_call ):
183
- _plot = plot_many if batch else plot
184
- _name = kwargs_call .pop ('name' , name )
185
- return _plot (plot_func , list (args ), name = _name ,
186
- ** merge_kwargs (kwargs , kwargs_call ))
187
-
188
- _wrapped_fn .__name__ = 'wrapped_fn[%s]' % plot_func
189
- return _wrapped_fn
190
-
191
-
192
- def wrap_axesplot (axesplot_func , _sentinel = None ,
193
- batch = False , name = None ,
194
- figsize = None , tight_layout = False , ** kwargs ):
195
- '''
196
- Wrap an axesplot function as a TensorFlow operation. It will return a
197
- python function that creates a TensorFlow plot operation applying the
198
- arguments as input.
199
-
200
- An axesplot function ``axesplot_func`` can be either:
201
-
202
- - an unbounded method of matplotlib `Axes` (or `AxesSubplot`) class,
203
- such as ``Axes.scatter()`` and ``Axes.text()``, etc, or
204
- - a simple python function that takes the named argument ``ax``,
205
- of type `Axes` or `AxesSubplot`, on which the plot will be drawn.
206
- Some good examples of this family includes ``seaborn.heatmap(ax=...)``.
207
-
208
- The resulting function can be used as a Tensor factory. When the created
209
- tensorflow plot op is being executed, a new matplotlib figure which
210
- consists of a single `AxesSubplot` will be created, and the axes plot
211
- will be used as an argument for ``axesplot_func``. For example,
212
-
213
- >>> import seaborn.apionly as sns
214
- >>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')
215
-
216
- >>> plot_op = tf_heatmap(attention_map, cmap)
217
- Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)
218
-
219
- Args:
220
- axesplot_func: An unbounded method of matplotlib `Axes` or `AxesSubplot`,
221
- or a python function or callable which has the `ax` parameter for
222
- specifying the axis to draw on.
223
- batch: If True, all the tensors passed as argument will be
224
- assumed to be batched. Default value is False.
225
- name: A default name for the operation (optional). If not given, the
226
- name of ``axesplot_func`` will be used.
227
- figsize: The figure size for the figure to be created.
228
- tight_layout: If True, the resulting figure will have no margins for
229
- axis. Equivalent to calling ``fig.subplots_adjust(0, 0, 1, 1)``.
230
- kwargs: An optional kwargs that will be passed by default to
231
- ``axesplot_func``.
232
-
233
- Returns:
234
- A python function that will create a TensorFlow plot operation,
235
- passing the provied arguments and a new instance of `AxesSubplot` into
236
- ``axesplot_func``.
237
- '''
238
-
239
- if not hasattr (axesplot_func , '__call__' ):
240
- raise TypeError ("axesplot_func should be callable" )
241
- if _sentinel is not None :
242
- raise RuntimeError ("Invalid call: it can have only one unnamed argument, " +
243
- "please pass named arguments for batch, name, etc." )
244
-
245
- def _create_subplots ():
246
- if figsize is not None :
247
- fig , ax = figure .subplots (figsize = figsize )
248
- else :
249
- fig , ax = figure .subplots ()
250
-
251
- if tight_layout :
252
- fig .subplots_adjust (0 , 0 , 1 , 1 )
253
- return fig , ax
254
-
255
- # (1) instance method of Axes -- ax.xyz()
256
- def _fig_axesplot_method (* args , ** kwargs_call ):
257
- fig , ax = _create_subplots ()
258
- axesplot_func .__get__ (ax )(* args , ** merge_kwargs (kwargs , kwargs_call ))
259
- return fig
260
-
261
- # (2) xyz(ax=...) style
262
- def _fig_axesplot_fn (* args , ** kwargs_call ):
263
- fig , ax = _create_subplots ()
264
- axesplot_func (* args , ax = ax , ** merge_kwargs (kwargs , kwargs_call ))
265
- return fig
266
-
267
- method_class = util .get_class_defining_method (axesplot_func )
268
- if method_class is not None and issubclass (method_class , Axes ):
269
- # (1) Axes.xyz()
270
- if hasattr (axesplot_func , '__self__' ) and axesplot_func .__self__ :
271
- raise ValueError ("axesplot_func should be a unbound method of " +
272
- "Axes or AxesSubplot, but given a bound method " +
273
- str (axesplot_func ))
274
- fig_axesplot_func = _fig_axesplot_method
275
- else :
276
- # (2) xyz(ax=...)
277
- if 'ax' not in util .getargspec (axesplot_func ).args :
278
- raise TypeError ("axesplot_func must take 'ax' parameter to specify Axes" )
279
- fig_axesplot_func = _fig_axesplot_fn
280
-
281
- if name is None :
282
- name = _clean_name (axesplot_func .__name__ )
283
-
284
- def _wrapped_factory_fn (* args , ** kwargs_call ):
285
- _plot = plot_many if batch else plot
286
- _name = kwargs_call .pop ('name' , name )
287
- return _plot (fig_axesplot_func , list (args ), name = _name ,
288
- ** kwargs_call )
289
-
290
- _wrapped_factory_fn .__name__ = 'wrapped_axesplot_fn[%s]' % axesplot_func
291
- return _wrapped_factory_fn
292
-
293
-
294
- def _clean_name (s ):
295
- """
296
- Convert a string to a valid variable, function, or scope name.
297
- """
298
- return re .sub ('[^0-9a-zA-Z_]' , '' , s )
299
-
300
-
301
-
302
-
303
-
304
138
__all__ = (
305
139
'plot' ,
306
140
'plot_many' ,
307
- 'wrap' ,
308
- 'wrap_axesplot' ,
309
141
)
0 commit comments