Skip to content

ENH: Native support for dims in tensors and tensor operations #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
williambdean opened this issue Jul 23, 2024 · 7 comments
Open

ENH: Native support for dims in tensors and tensor operations #954

williambdean opened this issue Jul 23, 2024 · 7 comments

Comments

@williambdean
Copy link
Contributor

williambdean commented Jul 23, 2024

Before

import pytensor.tensor as pt

# Need to 
a = pt.vector("a", shape=(2, ))
b = pt.vector("b", shape=(3, ))

# a + b fails due to broadcasting
# Transpose required
result = a + b[:, None]

After

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")

result = a + b
# result.type TensorType(float64, dims=("channel", "geo")) # xarray-like ordering of dims
# + operation handles the transpose based on dims but would work for other element wise operations

Context for the issue:

Use of the Prior class in PyMC-Marketing and the potential usefulness of it else where and in PyMC directly

dist = Prior(
    "Normal", 
    # Variables are automatically transposed before passing to PyMC distributions
    mu=Prior("Normal", dims="geo"), 
    sigma=Prior("HalfNormal", dims="geo"), 
    dims=("geo", "channel"), 
)

References:
PyMC-Marketing auto-broadcasting handling: https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/prior.py#L131-L168
PyMC Discussion: pymc-devs/pymc#7416

@williambdean
Copy link
Contributor Author

williambdean commented Jul 23, 2024

not sure how eval would work in the case where shapes are not provided.

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")
result = a + b
# raise since shapes are unknown?
result.eval({"a": np.array([1, 2, 3]),  "b": np.array([1, 2])})
# shape might be required for dims?
a = pt.vector("a", dims="channel", shape=(3, ))
b = pt.vector("b", dims="geo", shape=(2, ))

@ricardoV94
Copy link
Member

not sure how eval would work in the case where shapes are not provided.

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")
result = a + b
# raise since shapes are unknown?
result.eval({"a": np.array([1, 2, 3]),  "b": np.array([1, 2])})
# shape might be required for dims?
a = pt.vector("a", dims="channel", shape=(3, ))
b = pt.vector("b", dims="geo", shape=(2, ))

It should be optional when it's enough to know at runtime. When doing addition the shape isn't needed

@ricardoV94
Copy link
Member

There's a draft PR that started on this idea: #407

@ricardoV94
Copy link
Member

Also we probably want to use different types for dimmed and regular variables since they have completely different semantics.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 23, 2024

There's also the question of what should be the output, xarrays? Because if dims order is arbitrary users don't know what they're getting, but building xarray for the output (if not for intermediate operations, as our backends don't support that obviously) could be costly. Unless numpy arrays can be wrapped in xarray datarrays without copy costs.

Maybe a simpler object in between xarray and np arrays?

@williambdean
Copy link
Contributor Author

There's also the question of what should be the output, xarrays? Because if dims order is arbitrary users don't know what they're getting, but building xarray for the output (if not for intermediate operations, as our backends don't support that obviously) could be costly. Unless numpy arrays can be wrapped in xarray datarrays without copy costs.

Maybe a simpler object in between xarray and np arrays?

I'd think that it wouldn't be xarray. That seems like a pretty large dependency to add.

I would think the dims would be according to order of operations

# 407 syntax?
a = px.as_xtensor_variable("a", dims=("channel", ))
b = px.as_xtensor_variable("b", dims=("geo", ))
result1 = a + b # (channel, geo)
result2 = b + a # (geo, channel)

Couldn't this be constructed in a way where result1.owner is just Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) by some logic off the dims in the operation?

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 23, 2024

I would think the dims would be according to order of operations

In my draft PR I sorted dims alphabetically. I don't yet know what will work out better tbh. We definitely don't want to compute a + b and b + a in a final graph, since they are identical modulo transposition. But our regular backend should be able to merge those operations so we may not need to worry. We definitely need to have at least predictable function outputs, even if everything in the middle can be done in whatever order we want.

Btw, order can have an impact on performance as indicated by our incredibly slow sum along axis 0 xD: #935

But we definitely need to worry about that yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants