Skip to content

Introduce decorator: wrap and autowrap #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ install:
- pip install -r requirements.txt

script:
pytest
- PYTEST_ADDOPTS="-s" python setup.py test
- pip install -e .
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
six>=1.10.0
numpy>=1.10.4
matplotlib>=2.0.0
tensorflow>=1.0.0
-e .

# test/dev environment
tensorflow>=1.4.0

# vim: set ft=conf:
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,17 @@ def run(self):
install_requires=[
'six',
'numpy',
'biwrap==0.1.6',
'matplotlib>=2.0.0',
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest',
'pytest-pudb',
'imgcat',
'termcolor',
],
cmdclass={
'deploy': DeployCommand,
Expand Down
5 changes: 4 additions & 1 deletion tfplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import matplotlib
matplotlib.use('Agg')

from .ops import plot, plot_many, wrap, wrap_axesplot
from .ops import plot, plot_many
from .wrapper import wrap, wrap_axesplot
from .wrapper import autowrap

from .figure import subplots
from . import summary

Expand Down
168 changes: 0 additions & 168 deletions tfplot/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,175 +135,7 @@ def plot_many(plot_func, in_tensors, name='PlotMany',



def wrap(plot_func, _sentinel=None,
batch=False, name=None,
**kwargs):
'''
Wrap a plot function as a TensorFlow operation. It will return a python
function that creates a TensorFlow plot operation applying the arguments
as input.

For example, if ``plot_func`` is a python function that takes two
arrays as input, and draw a plot by returning a matplotlib Figure,
we can wrap this function as a `Tensor` factory, such as:

>>> tf_plot = tfplot.wrap(plot_func, name="MyPlot", batch=True)
>>> # x, y = get_batch_inputs(batch_size=4, ...)

>>> plot_x = tf_plot(x)
Tensor("MyPlot:0", shape=(4, ?, ?, 4), dtype=uint8)
>>> plot_y = tf_plot(y)
Tensor("MyPlot_1:0", shape=(4, ?, ?, 4), dtype=uint8)

Args:
plot_func: A python function or callable to wrap. See the documentation
of :func:`tfplot.plot()` for details.
batch: If True, all the tensors passed as argument will be
assumed to be batched. Default value is False.
name: A default name for the operation (optional). If not given, the
name of ``plot_func`` will be used.
kwargs: An optional kwargs that will be passed by default to
``plot_func``.

Returns:
A python function that will create a TensorFlow plot operation,
passing the provided arguments.
'''

if not hasattr(plot_func, '__call__'):
raise TypeError("plot_func should be callable")
if _sentinel is not None:
raise RuntimeError("Invalid call: it can have only one unnamed argument, " +
"please pass named arguments for batch, name, etc.")

if name is None:
name = _clean_name(plot_func.__name__)

def _wrapped_fn(*args, **kwargs_call):
_plot = plot_many if batch else plot
_name = kwargs_call.pop('name', name)
return _plot(plot_func, list(args), name=_name,
**merge_kwargs(kwargs, kwargs_call))

_wrapped_fn.__name__ = 'wrapped_fn[%s]' % plot_func
return _wrapped_fn


def wrap_axesplot(axesplot_func, _sentinel=None,
batch=False, name=None,
figsize=None, tight_layout=False, **kwargs):
'''
Wrap an axesplot function as a TensorFlow operation. It will return a
python function that creates a TensorFlow plot operation applying the
arguments as input.

An axesplot function ``axesplot_func`` can be either:

- an unbounded method of matplotlib `Axes` (or `AxesSubplot`) class,
such as ``Axes.scatter()`` and ``Axes.text()``, etc, or
- a simple python function that takes the named argument ``ax``,
of type `Axes` or `AxesSubplot`, on which the plot will be drawn.
Some good examples of this family includes ``seaborn.heatmap(ax=...)``.

The resulting function can be used as a Tensor factory. When the created
tensorflow plot op is being executed, a new matplotlib figure which
consists of a single `AxesSubplot` will be created, and the axes plot
will be used as an argument for ``axesplot_func``. For example,

>>> import seaborn.apionly as sns
>>> tf_heatmap = tfplot.wrap_axesplot(sns.heatmap, name="HeatmapPlot", figsize=(4, 4), cmap='jet')

>>> plot_op = tf_heatmap(attention_map, cmap)
Tensor(HeatmapPlot:0", shape=(?, ?, 4), dtype=uint8)

Args:
axesplot_func: An unbounded method of matplotlib `Axes` or `AxesSubplot`,
or a python function or callable which has the `ax` parameter for
specifying the axis to draw on.
batch: If True, all the tensors passed as argument will be
assumed to be batched. Default value is False.
name: A default name for the operation (optional). If not given, the
name of ``axesplot_func`` will be used.
figsize: The figure size for the figure to be created.
tight_layout: If True, the resulting figure will have no margins for
axis. Equivalent to calling ``fig.subplots_adjust(0, 0, 1, 1)``.
kwargs: An optional kwargs that will be passed by default to
``axesplot_func``.

Returns:
A python function that will create a TensorFlow plot operation,
passing the provied arguments and a new instance of `AxesSubplot` into
``axesplot_func``.
'''

if not hasattr(axesplot_func, '__call__'):
raise TypeError("axesplot_func should be callable")
if _sentinel is not None:
raise RuntimeError("Invalid call: it can have only one unnamed argument, " +
"please pass named arguments for batch, name, etc.")

def _create_subplots():
if figsize is not None:
fig, ax = figure.subplots(figsize=figsize)
else:
fig, ax = figure.subplots()

if tight_layout:
fig.subplots_adjust(0, 0, 1, 1)
return fig, ax

# (1) instance method of Axes -- ax.xyz()
def _fig_axesplot_method(*args, **kwargs_call):
fig, ax = _create_subplots()
axesplot_func.__get__(ax)(*args, **merge_kwargs(kwargs, kwargs_call))
return fig

# (2) xyz(ax=...) style
def _fig_axesplot_fn(*args, **kwargs_call):
fig, ax = _create_subplots()
axesplot_func(*args, ax=ax, **merge_kwargs(kwargs, kwargs_call))
return fig

method_class = util.get_class_defining_method(axesplot_func)
if method_class is not None and issubclass(method_class, Axes):
# (1) Axes.xyz()
if hasattr(axesplot_func, '__self__') and axesplot_func.__self__:
raise ValueError("axesplot_func should be a unbound method of " +
"Axes or AxesSubplot, but given a bound method " +
str(axesplot_func))
fig_axesplot_func = _fig_axesplot_method
else:
# (2) xyz(ax=...)
if 'ax' not in util.getargspec(axesplot_func).args:
raise TypeError("axesplot_func must take 'ax' parameter to specify Axes")
fig_axesplot_func = _fig_axesplot_fn

if name is None:
name = _clean_name(axesplot_func.__name__)

def _wrapped_factory_fn(*args, **kwargs_call):
_plot = plot_many if batch else plot
_name = kwargs_call.pop('name', name)
return _plot(fig_axesplot_func, list(args), name=_name,
**kwargs_call)

_wrapped_factory_fn.__name__ = 'wrapped_axesplot_fn[%s]' % axesplot_func
return _wrapped_factory_fn


def _clean_name(s):
"""
Convert a string to a valid variable, function, or scope name.
"""
return re.sub('[^0-9a-zA-Z_]', '', s)





__all__ = (
'plot',
'plot_many',
'wrap',
'wrap_axesplot',
)
92 changes: 0 additions & 92 deletions tfplot/ops_test.py

This file was deleted.

10 changes: 9 additions & 1 deletion tfplot/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ def get_class_defining_method(m):
# getargspec(fn)
if six.PY2:
getargspec = inspect.getargspec
else:

def getargspec_allargs(func):
argspec = getargspec(func)
return argspec.args

else: # Python 3
getargspec = inspect.getfullargspec

def getargspec_allargs(func):
argspec = getargspec(func)
return argspec.args + argspec.kwonlyargs


def merge_kwargs(kwargs, kwargs_new):
Expand Down
Loading