From 17298d58c2dc8e3ce14fede34ae802a7326848c4 Mon Sep 17 00:00:00 2001 From: Tiago Davi Date: Fri, 4 Feb 2022 18:41:09 -0300 Subject: [PATCH 1/4] improve variance and standard deviation --- nx/lib/nx.ex | 70 ++++++++++++++++++++++++++++++++++++++++----- nx/test/nx_test.exs | 33 +++++++++++++++++++++ 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 2b53358093..a0f5a144fc 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -5417,7 +5417,7 @@ defmodule Nx do @doc """ Returns the mean for the tensor. - If the `:axis` option is given, it aggregates over + If the `:axes` option is given, it aggregates over that dimension, effectively removing it. `axes: [0]` implies aggregating over the highest order dimension and so forth. If the axis is negative, then counts @@ -9278,21 +9278,53 @@ defmodule Nx do f32 1.6666666269302368 > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0]) + #Nx.Tensor< + f32[2] + [1.0, 1.0] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1]) + #Nx.Tensor< + f32[2] + [0.25, 0.25] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1) + #Nx.Tensor< + f32[2] + [2.0, 2.0] + > + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1) + #Nx.Tensor< + f32[2] + [0.5, 0.5] + > """ @doc type: :aggregation @spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() def variance(tensor, opts \\ []) do - %T{shape: shape} = tensor = to_tensor(tensor) - - opts = keyword!(opts, ddof: 0) - total = size(shape) + %T{shape: shape, names: names} = tensor = to_tensor(tensor) + opts = keyword!(opts, [:axes, ddof: 0]) + axes = opts[:axes] ddof = Keyword.fetch!(opts, :ddof) - mean = mean(tensor) + opts = Keyword.delete(opts, :ddof) + + total = + if axes do + mean_den(shape, Nx.Shape.normalize_axes(shape, axes, names)) + else + size(shape) + end + + mean = mean(tensor, Keyword.put(opts, :keep_axes, true)) tensor |> subtract(mean) |> power(2) - |> sum() + |> sum(opts) |> divide(total - ddof) end @@ -9316,6 +9348,30 @@ defmodule Nx do f32 1.29099440574646 > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0]) + #Nx.Tensor< + f32[2] + [1.0, 1.0] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1]) + #Nx.Tensor< + f32[2] + [0.5, 0.5] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [0], ddof: 1) + #Nx.Tensor< + f32[2] + [1.4142135381698608, 1.4142135381698608] + > + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), axes: [1], ddof: 1) + #Nx.Tensor< + f32[2] + [0.7071067690849304, 0.7071067690849304] + > """ @doc type: :aggregation @spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 53879e2f68..12e5b7d0b1 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1931,6 +1931,25 @@ defmodule NxTest do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t, ddof: 1) == Nx.tensor(3.5) end + + test "should use the optional axes on x" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) + + assert Nx.variance(t, axes: [:x]) == + Nx.tensor([1.5555557012557983, 4.222222328186035], names: [:y]) + end + + test "should use the optional axes on y" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) + assert Nx.variance(t, axes: [:y]) == Nx.tensor([0.25, 0.25, 0.25], names: [:x]) + end + + test "should use the optional axes and ddof" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) + + assert Nx.variance(t, axes: [0], ddof: 1) == + Nx.tensor([2.3333334922790527, 6.333333492279053], names: [:y]) + end end describe "standard_deviation/1" do @@ -1943,5 +1962,19 @@ defmodule NxTest do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t, ddof: 1) == Nx.tensor(1.8708287477493286) end + + test "should use the optional axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + + assert Nx.standard_deviation(t, axes: [0]) == + Nx.tensor([1.247219204902649, 2.054804801940918]) + end + + test "should use the optional axes and ddof" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + + assert Nx.standard_deviation(t, axes: [1], ddof: 1) == + Nx.tensor([0.7071067690849304, 0.7071067690849304, 0.7071067690849304]) + end end end From 06f0a45b04f5af5a64002cb379a3daf628286163 Mon Sep 17 00:00:00 2001 From: Tiago Davi Date: Mon, 7 Feb 2022 12:57:39 -0300 Subject: [PATCH 2/4] improve based on guidelines --- nx/lib/nx.ex | 26 +++++++++++++++++++++++--- nx/test/nx_test.exs | 10 ++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index a0f5a144fc..4d2f58cf5b 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -9302,15 +9302,25 @@ defmodule Nx do f32[2] [0.5, 0.5] > + + ### Keeping axes + + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true) + #Nx.Tensor< + f32[2][1] + [ + [0.25], + [0.25] + ] + > """ @doc type: :aggregation @spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() def variance(tensor, opts \\ []) do %T{shape: shape, names: names} = tensor = to_tensor(tensor) - opts = keyword!(opts, [:axes, ddof: 0]) + opts = keyword!(opts, [:axes, ddof: 0, keep_axes: false]) axes = opts[:axes] - ddof = Keyword.fetch!(opts, :ddof) - opts = Keyword.delete(opts, :ddof) + {ddof, opts} = Keyword.pop!(opts, :ddof) total = if axes do @@ -9372,6 +9382,16 @@ defmodule Nx do f32[2] [0.7071067690849304, 0.7071067690849304] > + + ### Keeping axes + + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true) + #Nx.Tensor< + f32[1][1] + [ + [1.1180340051651] + ] + > """ @doc type: :aggregation @spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 12e5b7d0b1..41fcfbdd33 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1950,6 +1950,11 @@ defmodule NxTest do assert Nx.variance(t, axes: [0], ddof: 1) == Nx.tensor([2.3333334922790527, 6.333333492279053], names: [:y]) end + + test "should keep axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + assert Nx.variance(t, keep_axes: true) == Nx.tensor([[2.9166667461395264]]) + end end describe "standard_deviation/1" do @@ -1976,5 +1981,10 @@ defmodule NxTest do assert Nx.standard_deviation(t, axes: [1], ddof: 1) == Nx.tensor([0.7071067690849304, 0.7071067690849304, 0.7071067690849304]) end + + test "should keep axes" do + t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) + assert Nx.standard_deviation(t, keep_axes: true) == Nx.tensor([[1.7078251838684082]]) + end end end From 0ca60fa733cab3e5d02697d8462e512a92ec244b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 7 Feb 2022 17:56:27 +0100 Subject: [PATCH 3/4] Apply suggestions from code review --- nx/lib/nx.ex | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 4d2f58cf5b..72ad7860b6 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -9303,16 +9303,16 @@ defmodule Nx do [0.5, 0.5] > - ### Keeping axes + ### Keeping axes - iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true) - #Nx.Tensor< - f32[2][1] - [ - [0.25], - [0.25] - ] - > + iex> Nx.variance(Nx.tensor([[1, 2], [3, 4]]), axes: [1], keep_axes: true) + #Nx.Tensor< + f32[2][1] + [ + [0.25], + [0.25] + ] + > """ @doc type: :aggregation @spec variance(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() @@ -9383,15 +9383,15 @@ defmodule Nx do [0.7071067690849304, 0.7071067690849304] > - ### Keeping axes + ### Keeping axes - iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true) - #Nx.Tensor< - f32[1][1] - [ - [1.1180340051651] - ] - > + iex> Nx.standard_deviation(Nx.tensor([[1, 2], [3, 4]]), keep_axes: true) + #Nx.Tensor< + f32[1][1] + [ + [1.1180340051651] + ] + > """ @doc type: :aggregation @spec standard_deviation(tensor :: Nx.Tensor.t(), opts :: Keyword.t()) :: Nx.Tensor.t() From af6b2d79a77a7259ec5ebb247f9ef89dad4dabce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 7 Feb 2022 17:57:50 +0100 Subject: [PATCH 4/4] Update nx_test.exs --- nx/test/nx_test.exs | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 41fcfbdd33..caf9de05f9 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1922,67 +1922,51 @@ defmodule NxTest do end describe "variance/1" do - test "should calculate the variance of a tensor" do + test "calculates variance of a tensor" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t) == Nx.tensor(2.9166667461395264) end - test "should use the optional ddof" do + test "uses optional ddof" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t, ddof: 1) == Nx.tensor(3.5) end - test "should use the optional axes on x" do + test "uses optional axes" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) assert Nx.variance(t, axes: [:x]) == Nx.tensor([1.5555557012557983, 4.222222328186035], names: [:y]) - end - test "should use the optional axes on y" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) assert Nx.variance(t, axes: [:y]) == Nx.tensor([0.25, 0.25, 0.25], names: [:x]) end - test "should use the optional axes and ddof" do - t = Nx.tensor([[4, 5], [2, 3], [1, 0]], names: [:x, :y]) - - assert Nx.variance(t, axes: [0], ddof: 1) == - Nx.tensor([2.3333334922790527, 6.333333492279053], names: [:y]) - end - - test "should keep axes" do + test "uses optional keep axes" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.variance(t, keep_axes: true) == Nx.tensor([[2.9166667461395264]]) end end describe "standard_deviation/1" do - test "should calculate the standard deviation of a tensor" do + test "calculates the standard deviation of a tensor" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t) == Nx.tensor(1.707825127659933) end - test "should use the optional ddof" do + test "uses optional ddof" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t, ddof: 1) == Nx.tensor(1.8708287477493286) end - test "should use the optional axes" do + test "uses optional axes" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t, axes: [0]) == Nx.tensor([1.247219204902649, 2.054804801940918]) end - test "should use the optional axes and ddof" do - t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) - - assert Nx.standard_deviation(t, axes: [1], ddof: 1) == - Nx.tensor([0.7071067690849304, 0.7071067690849304, 0.7071067690849304]) - end - - test "should keep axes" do + test "uses optional keep axes" do t = Nx.tensor([[4, 5], [2, 3], [1, 0]]) assert Nx.standard_deviation(t, keep_axes: true) == Nx.tensor([[1.7078251838684082]]) end