@@ -95,6 +95,9 @@ def wrap_axesplot(axesplot_func, _sentinel=None,
95
95
batch = False , name = None ,
96
96
figsize = None , tight_layout = False , ** kwargs ):
97
97
'''
98
+ DEPRECATED: Use ``tfplot.autowrap()`` instead. Will be removed
99
+ in the next version.
100
+
98
101
Wrap an axesplot function as a TensorFlow operation. It will return a
99
102
python function that creates a TensorFlow plot operation applying the
100
103
arguments as input.
@@ -248,6 +251,17 @@ def autowrap(plot_func=REQUIRED, _sentinel=None,
248
251
if arg_name in util .getargspec_allargs (plot_func )
249
252
)
250
253
254
+ # check if func is an instance method of Axes, e.g. ax.scatter()
255
+ method_class = util .get_class_defining_method (plot_func )
256
+ is_axesplot_bind = False
257
+ if method_class is not None and issubclass (method_class , Axes ):
258
+ if hasattr (plot_func , '__self__' ) and plot_func .__self__ :
259
+ raise ValueError ("plot_func should be a unbound method of " +
260
+ "Axes or AxesSubplot, but given a bound method " +
261
+ str (plot_func ))
262
+ is_axesplot_bind = True
263
+
264
+
251
265
def _create_subplots (_kwargs ):
252
266
# recognize overriding parameters for creating subplots, e.g. figsize
253
267
_figsize = _kwargs .pop ('figsize' , figsize )
@@ -260,7 +274,7 @@ def _create_subplots(_kwargs):
260
274
@functools .wraps (plot_func )
261
275
def _wrapped_plot_fn (* args , ** kwargs_call ):
262
276
# (1) auto-inject fig, ax
263
- if fig_ax_mode :
277
+ if fig_ax_mode or is_axesplot_bind :
264
278
# auto-create rather than manually
265
279
fig , ax = _create_subplots (kwargs_call )
266
280
fig_ax_kwargs = dict (
@@ -269,13 +283,21 @@ def _wrapped_plot_fn(*args, **kwargs_call):
269
283
)
270
284
271
285
# (2) body
272
- ret = plot_func (* args , ** merge_kwargs (kwargs_call , fig_ax_kwargs )) # TODO conflict??
286
+ if is_axesplot_bind : # e.g. Axesplot.scatter -> bind 'ax' as self
287
+ ret = plot_func .__get__ (ax )(
288
+ * args , ** merge_kwargs (kwargs_call , fig_ax_kwargs ))
289
+ else :
290
+ ret = plot_func (* args , ** merge_kwargs (kwargs_call , fig_ax_kwargs )) # TODO conflict??
273
291
274
292
# (3) return value handling
275
293
if ret is None and fig_ax_mode :
276
294
# even if the function doesn't return anything,
277
295
# but we know that `fig` is what we just need to draw.
278
296
ret = fig
297
+ elif is_axesplot_bind :
298
+ # for Axesplot methods, ignore the return value
299
+ # and use the fig instance created before as target figure
300
+ ret = fig
279
301
elif isinstance (ret , Axes ):
280
302
ret = fig = ret .figure
281
303
elif isinstance (ret , Figure ):
@@ -304,4 +326,5 @@ def _clean_name(s):
304
326
__all__ = (
305
327
'wrap' ,
306
328
'wrap_axesplot' ,
329
+ 'autowrap' ,
307
330
)
0 commit comments