Skip to content

Commit 7bc29ea

Browse files
committed
Make tfplot.autowrap() support Axes methods
As a result, `tfplot.autowrap()` can totally replace `tfplot.wrap_axesplot()` (which now shall be deprecated). Follow-up of #7.
1 parent 3cefadd commit 7bc29ea

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

tfplot/wrapper.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def wrap_axesplot(axesplot_func, _sentinel=None,
9595
batch=False, name=None,
9696
figsize=None, tight_layout=False, **kwargs):
9797
'''
98+
DEPRECATED: Use ``tfplot.autowrap()`` instead. Will be removed
99+
in the next version.
100+
98101
Wrap an axesplot function as a TensorFlow operation. It will return a
99102
python function that creates a TensorFlow plot operation applying the
100103
arguments as input.
@@ -248,6 +251,17 @@ def autowrap(plot_func=REQUIRED, _sentinel=None,
248251
if arg_name in util.getargspec_allargs(plot_func)
249252
)
250253

254+
# check if func is an instance method of Axes, e.g. ax.scatter()
255+
method_class = util.get_class_defining_method(plot_func)
256+
is_axesplot_bind = False
257+
if method_class is not None and issubclass(method_class, Axes):
258+
if hasattr(plot_func, '__self__') and plot_func.__self__:
259+
raise ValueError("plot_func should be a unbound method of " +
260+
"Axes or AxesSubplot, but given a bound method " +
261+
str(plot_func))
262+
is_axesplot_bind = True
263+
264+
251265
def _create_subplots(_kwargs):
252266
# recognize overriding parameters for creating subplots, e.g. figsize
253267
_figsize = _kwargs.pop('figsize', figsize)
@@ -260,7 +274,7 @@ def _create_subplots(_kwargs):
260274
@functools.wraps(plot_func)
261275
def _wrapped_plot_fn(*args, **kwargs_call):
262276
# (1) auto-inject fig, ax
263-
if fig_ax_mode:
277+
if fig_ax_mode or is_axesplot_bind:
264278
# auto-create rather than manually
265279
fig, ax = _create_subplots(kwargs_call)
266280
fig_ax_kwargs = dict(
@@ -269,13 +283,21 @@ def _wrapped_plot_fn(*args, **kwargs_call):
269283
)
270284

271285
# (2) body
272-
ret = plot_func(*args, **merge_kwargs(kwargs_call, fig_ax_kwargs)) # TODO conflict??
286+
if is_axesplot_bind: # e.g. Axesplot.scatter -> bind 'ax' as self
287+
ret = plot_func.__get__(ax)(
288+
*args, **merge_kwargs(kwargs_call, fig_ax_kwargs))
289+
else:
290+
ret = plot_func(*args, **merge_kwargs(kwargs_call, fig_ax_kwargs)) # TODO conflict??
273291

274292
# (3) return value handling
275293
if ret is None and fig_ax_mode:
276294
# even if the function doesn't return anything,
277295
# but we know that `fig` is what we just need to draw.
278296
ret = fig
297+
elif is_axesplot_bind:
298+
# for Axesplot methods, ignore the return value
299+
# and use the fig instance created before as target figure
300+
ret = fig
279301
elif isinstance(ret, Axes):
280302
ret = fig = ret.figure
281303
elif isinstance(ret, Figure):
@@ -304,4 +326,5 @@ def _clean_name(s):
304326
__all__ = (
305327
'wrap',
306328
'wrap_axesplot',
329+
'autowrap',
307330
)

tfplot/wrapper_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def _check_plot_op_shape(self, op):
3636
self.assertTrue(op.get_shape().is_compatible_with([None, None, 4])) # RGB-A
3737
self.assertEqual(op.dtype, tf.uint8)
3838

39-
4039
def test_wrap_simplefunction(self):
4140
'''Basic functionality test of tfplot.wrap() in successful cases.'''
4241

@@ -55,7 +54,6 @@ def _fn_to_wrap(message="str"):
5554
self._check_plot_op_shape(plot_op)
5655
self.assertEqual(plot_op.name, 'Wrapped:0')
5756

58-
5957
def test_wrap_axesplot_axes(self):
6058
'''Basic functionality test of tfplot.wrap_axesplot() in successful cases.'''
6159

@@ -65,9 +63,8 @@ def test_wrap_axesplot_axes(self):
6563
cprint("\n tf_scatter: %s" % tf_scatter, color='magenta')
6664

6765
plot_op = tf_scatter([1, 2, 3], [1, 4, 9])
68-
6966
self._check_plot_op_shape(plot_op)
70-
self.assertEqual(plot_op.name, 'scatter:0')
67+
self.assertRegex(plot_op.name, 'scatter(_\d)?:0')
7168

7269
def test_wrap_axesplot_kwarg(self):
7370
'''Basic functionality test of tfplot.wrap_axesplot() in successful cases.'''
@@ -145,6 +142,17 @@ def foo(values):
145142
op = foo(tf.convert_to_tensor([2, 2, 3, 3]))
146143
self._execute_plot_op(op)
147144

145+
def test_autowrap_axesplot(self):
146+
'''Does autowrap also work with Axes.xxxx methods?
147+
needs to handle binding (e.g. self) carefully! '''
148+
from matplotlib.axes import Axes
149+
tf_scatter = tfplot.autowrap(Axes.scatter, name='ScatterAutowrap')
150+
cprint("\n tf_scatter: %s" % tf_scatter, color='magenta')
151+
152+
op = tf_scatter([1, 2, 3], [1, 4, 9])
153+
self._execute_plot_op(op)
154+
155+
148156
def test_wrap_autoinject_figax(self):
149157
"""Tests whether @tfplot.autowrap work in many use cases"""
150158
@tfplot.autowrap

0 commit comments

Comments
 (0)