Skip to content

Commit 04014dd

Browse files
authored
Merge pull request #7 from wookayin/decorator
Introduce decorator: wrap and autowrap Merging as functionality becomes stable.
2 parents 5a63dcb + 4020c13 commit 04014dd

File tree

9 files changed

+545
-267
lines changed

9 files changed

+545
-267
lines changed

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ install:
99
- pip install -r requirements.txt
1010

1111
script:
12-
pytest
12+
- PYTEST_ADDOPTS="-s" python setup.py test
13+
- pip install -e .

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
six>=1.10.0
2-
numpy>=1.10.4
3-
matplotlib>=2.0.0
4-
tensorflow>=1.0.0
1+
-e .
2+
3+
# test/dev environment
4+
tensorflow>=1.4.0
55

66
# vim: set ft=conf:

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,17 @@ def run(self):
9393
install_requires=[
9494
'six',
9595
'numpy',
96+
'biwrap==0.1.6',
9697
'matplotlib>=2.0.0',
9798
],
9899
setup_requires=[
99100
'pytest-runner',
100101
],
101102
tests_require=[
102103
'pytest',
104+
'pytest-pudb',
105+
'imgcat',
106+
'termcolor',
103107
],
104108
cmdclass={
105109
'deploy': DeployCommand,

tfplot/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import matplotlib
66
matplotlib.use('Agg')
77

8-
from .ops import plot, plot_many, wrap, wrap_axesplot
8+
from .ops import plot, plot_many
9+
from .wrapper import wrap, wrap_axesplot
10+
from .wrapper import autowrap
11+
912
from .figure import subplots
1013
from . import summary
1114

tfplot/ops.py

Lines changed: 0 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -135,175 +135,7 @@ def plot_many(plot_func, in_tensors, name='PlotMany',
135135

136136

137137

138-
def wrap(plot_func, _sentinel=None,
139-
batch=False, name=None,
140-
**kwargs):
141-
'''
142-
Wrap a plot function as a TensorFlow operation. It will return a python
143-
function that creates a TensorFlow plot operation applying the arguments
144-
as input.
145-
146-
For example, if ``plot_func`` is a python function that takes two
147-
arrays as input, and draw a plot by returning a matplotlib Figure,
148-
we can wrap this function as a `Tensor` factory, such as:
149-
150-
>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
151-
>>> # x, y = get_batch_inputs(batch_size=4, ...)
152-
153-
>>> plot_x = tf_plot(x)
154-
Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
155-
>>> plot_y = tf_plot(y)
156-
Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)
157-
158-
Args:
159-
plot_func: A python function or callable to wrap. See the documentation
160-
of :func:`tfplot.plot()` for details.
161-
batch: If True, all the tensors passed as argument will be
162-
assumed to be batched. Default value is False.
163-
name: A default name for the operation (optional). If not given, the
164-
name of ``plot_func`` will be used.
165-
kwargs: An optional kwargs that will be passed by default to
166-
``plot_func``.
167-
168-
Returns:
169-
A python function that will create a TensorFlow plot operation,
170-
passing the provided arguments.
171-
'''
172-
173-
if not hasattr(plot_func, '__call__'):
174-
raise TypeError("plot_func should be callable")
175-
if _sentinel is not None:
176-
raise RuntimeError("Invalid call: it can have only one unnamed argument, " +
177-
"please pass named arguments for batch, name, etc.")
178-
179-
if name is None:
180-
name = _clean_name(plot_func.__name__)
181-
182-
def _wrapped_fn(*args, **kwargs_call):
183-
_plot = plot_many if batch else plot
184-
_name = kwargs_call.pop('name', name)
185-
return _plot(plot_func, list(args), name=_name,
186-
**merge_kwargs(kwargs, kwargs_call))
187-
188-
_wrapped_fn.__name__ = 'wrapped_fn[%s]' % plot_func
189-
return _wrapped_fn
190-
191-
192-
def wrap_axesplot(axesplot_func, _sentinel=None,
193-
batch=False, name=None,
194-
figsize=None, tight_layout=False, **kwargs):
195-
'''
196-
Wrap an axesplot function as a TensorFlow operation. It will return a
197-
python function that creates a TensorFlow plot operation applying the
198-
arguments as input.
199-
200-
An axesplot function ``axesplot_func`` can be either:
201-
202-
- an unbounded method of matplotlib `Axes` (or `AxesSubplot`) class,
203-
such as ``Axes.scatter()`` and ``Axes.text()``, etc, or
204-
- a simple python function that takes the named argument ``ax``,
205-
of type `Axes` or `AxesSubplot`, on which the plot will be drawn.
206-
Some good examples of this family includes ``seaborn.heatmap(ax=...)``.
207-
208-
The resulting function can be used as a Tensor factory. When the created
209-
tensorflow plot op is being executed, a new matplotlib figure which
210-
consists of a single `AxesSubplot` will be created, and the axes plot
211-
will be used as an argument for ``axesplot_func``. For example,
212-
213-
>>> import seaborn.apionly as sns
214-
>>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')
215-
216-
>>> plot_op = tf_heatmap(attention_map, cmap)
217-
Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)
218-
219-
Args:
220-
axesplot_func: An unbounded method of matplotlib `Axes` or `AxesSubplot`,
221-
or a python function or callable which has the `ax` parameter for
222-
specifying the axis to draw on.
223-
batch: If True, all the tensors passed as argument will be
224-
assumed to be batched. Default value is False.
225-
name: A default name for the operation (optional). If not given, the
226-
name of ``axesplot_func`` will be used.
227-
figsize: The figure size for the figure to be created.
228-
tight_layout: If True, the resulting figure will have no margins for
229-
axis. Equivalent to calling ``fig.subplots_adjust(0, 0, 1, 1)``.
230-
kwargs: An optional kwargs that will be passed by default to
231-
``axesplot_func``.
232-
233-
Returns:
234-
A python function that will create a TensorFlow plot operation,
235-
passing the provied arguments and a new instance of `AxesSubplot` into
236-
``axesplot_func``.
237-
'''
238-
239-
if not hasattr(axesplot_func, '__call__'):
240-
raise TypeError("axesplot_func should be callable")
241-
if _sentinel is not None:
242-
raise RuntimeError("Invalid call: it can have only one unnamed argument, " +
243-
"please pass named arguments for batch, name, etc.")
244-
245-
def _create_subplots():
246-
if figsize is not None:
247-
fig, ax = figure.subplots(figsize=figsize)
248-
else:
249-
fig, ax = figure.subplots()
250-
251-
if tight_layout:
252-
fig.subplots_adjust(0, 0, 1, 1)
253-
return fig, ax
254-
255-
# (1) instance method of Axes -- ax.xyz()
256-
def _fig_axesplot_method(*args, **kwargs_call):
257-
fig, ax = _create_subplots()
258-
axesplot_func.__get__(ax)(*args, **merge_kwargs(kwargs, kwargs_call))
259-
return fig
260-
261-
# (2) xyz(ax=...) style
262-
def _fig_axesplot_fn(*args, **kwargs_call):
263-
fig, ax = _create_subplots()
264-
axesplot_func(*args, ax=ax, **merge_kwargs(kwargs, kwargs_call))
265-
return fig
266-
267-
method_class = util.get_class_defining_method(axesplot_func)
268-
if method_class is not None and issubclass(method_class, Axes):
269-
# (1) Axes.xyz()
270-
if hasattr(axesplot_func, '__self__') and axesplot_func.__self__:
271-
raise ValueError("axesplot_func should be a unbound method of " +
272-
"Axes or AxesSubplot, but given a bound method " +
273-
str(axesplot_func))
274-
fig_axesplot_func = _fig_axesplot_method
275-
else:
276-
# (2) xyz(ax=...)
277-
if 'ax' not in util.getargspec(axesplot_func).args:
278-
raise TypeError("axesplot_func must take 'ax' parameter to specify Axes")
279-
fig_axesplot_func = _fig_axesplot_fn
280-
281-
if name is None:
282-
name = _clean_name(axesplot_func.__name__)
283-
284-
def _wrapped_factory_fn(*args, **kwargs_call):
285-
_plot = plot_many if batch else plot
286-
_name = kwargs_call.pop('name', name)
287-
return _plot(fig_axesplot_func, list(args), name=_name,
288-
**kwargs_call)
289-
290-
_wrapped_factory_fn.__name__ = 'wrapped_axesplot_fn[%s]' % axesplot_func
291-
return _wrapped_factory_fn
292-
293-
294-
def _clean_name(s):
295-
"""
296-
Convert a string to a valid variable, function, or scope name.
297-
"""
298-
return re.sub('[^0-9a-zA-Z_]', '', s)
299-
300-
301-
302-
303-
304138
__all__ = (
305139
'plot',
306140
'plot_many',
307-
'wrap',
308-
'wrap_axesplot',
309141
)

tfplot/ops_test.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

tfplot/util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,17 @@ def get_class_defining_method(m):
3838
# getargspec(fn)
3939
if six.PY2:
4040
getargspec = inspect.getargspec
41-
else:
41+
42+
def getargspec_allargs(func):
43+
argspec = getargspec(func)
44+
return argspec.args
45+
46+
else: # Python 3
4247
getargspec = inspect.getfullargspec
4348

49+
def getargspec_allargs(func):
50+
argspec = getargspec(func)
51+
return argspec.args + argspec.kwonlyargs
4452

4553

4654
def merge_kwargs(kwargs, kwargs_new):

0 commit comments

Comments
 (0)