Skip to content

Commit a7cd188

Browse files
committed
Add initial implementation of tfplot.contrib (#11)
Add three functions: - tfplot.contrib.probmap - tfplot.contrib.probmap_simple - tfplot.contrib.batch
1 parent 14648a5 commit a7cd188

File tree

5 files changed

+99
-1
lines changed

5 files changed

+99
-1
lines changed

docs/api/contrib.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
.. _api_tfplot_summary:
2+
3+
:mod:`tfplot.summary`
4+
---------------------
5+
6+
.. contents::
7+
:local:
8+
9+
.. automodule:: tfplot.contrib
10+
:members:

docs/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ API References
88

99
tfplot
1010
figure
11+
contrib
1112
summary

tfplot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import print_function
44

55
import matplotlib
6-
matplotlib.use('Agg')
6+
if not matplotlib.rcParams.get('backend', None):
7+
matplotlib.use('Agg')
78

89
from .ops import plot, plot_many
910
from .wrapper import wrap, wrap_axesplot

tfplot/contrib.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
'''Some predefined plot functions.'''
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
from .wrapper import autowrap
8+
9+
10+
__all__ = (
11+
'probmap',
12+
'probmap_simple',
13+
'batch',
14+
)
15+
16+
17+
@autowrap
18+
def probmap(x, cmap='jet', colorbar=True,
19+
vmin=None, vmax=None, axis=True, ax=None):
20+
'''
21+
Display a heatmap in color. The resulting op will be a RGBA image Tensor.
22+
23+
Args:
24+
x: A 2-D image-like tensor to draw.
25+
cmap: Matplotlib colormap. Defaults 'jet'
26+
axis: If True (default), x-axis and y-axis will appear.
27+
colorbar: If True (default), a colorbar will be placed on the right.
28+
vmin: A scalar. Minimum value of the range. See ``matplotlib.axes.Axes.imshow``.
29+
vmax: A scalar. Maximum value of the range. See ``matplotlib.axes.Axes.imshow``.
30+
31+
Returns:
32+
A `uint8` `Tensor` of shape ``(?, ?, 4)`` containing the resulting plot.
33+
'''
34+
assert ax is not None, "autowrap did not set ax"
35+
36+
axim = ax.imshow(x, cmap=cmap, vmin=vmin, vmax=vmax)
37+
if colorbar:
38+
ax.figure.colorbar(axim)
39+
if not axis:
40+
ax.axis('off')
41+
42+
if not axis and not colorbar:
43+
ax.figure.subplots_adjust(0, 0, 1, 1)
44+
else:
45+
ax.figure.tight_layout()
46+
47+
48+
def probmap_simple(x, **kwargs):
49+
'''
50+
Display a heatmap in color, but only displays the image content.
51+
The resulting op will be a RGBA image Tensor.
52+
53+
It reduces to ``probmap`` having `colorbar` and `axis` off.
54+
See the documentation of ``probmap`` for available arguments.
55+
'''
56+
# pylint: disable=unexpected-keyword-arg
57+
return probmap(x,
58+
colorbar=kwargs.pop('colorbar', False),
59+
axis=kwargs.pop('axis', False),
60+
figsize=kwargs.pop('figsize', (3, 3)),
61+
**kwargs)
62+
# pylint: enable=unexpected-keyword-arg
63+
64+
65+
def batch(func):
66+
'''
67+
Make an autowrapped plot function (... -> RGBA tf.Tensor) work in a batch
68+
manner.
69+
70+
Example:
71+
72+
>>> p
73+
Tensor("p:0", shape=(batch_size, 16, 16, 4), dtype=uint8)
74+
>>> tfplot.contrib.batch(tfplot.contrib.probmap)(p)
75+
Tensor("probmap/PlotImages:0", shape=(batch_size, ?, ?, 4), dtype=uint8)
76+
'''
77+
if not hasattr(func, '__unwrapped__'):
78+
raise ValueError("The given function is not wrapped with tfplot.autowrap()!")
79+
80+
func = func.__unwrapped__
81+
return autowrap(func, batch=True)

tfplot/wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,16 @@ def _wrapped_plot_fn(*args, **kwargs_call):
307307

308308
return ret
309309

310+
# return the wrapper (a factory of Tensor)
310311
_wrapped_fn = wrap(_wrapped_plot_fn, batch=batch, name=name) # TODO kwargs
311312

312313
_wrapped_fn.__name__ = 'autowrap[%s]' % plot_func.__name__
313314
if hasattr(plot_func, '__qualname__'):
314315
_wrapped_fn.__qualname__ = 'autowrap[%s.%s]' % (plot_func.__module__, plot_func.__qualname__)
316+
317+
# expose the unwrapped python function as well
318+
_wrapped_fn.__unwrapped__ = plot_func
319+
315320
return _wrapped_fn
316321

317322

0 commit comments

Comments
 (0)