-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Add Differentiable Physics: Mass-Spring System example #1332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
7d3d2a1
39a0c8e
3d01b48
77c8abc
84072d5
8a7cf5e
f20fa20
e14251f
96e04f1
35b0afa
f1a806e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,154 @@ | ||||
import torch | ||||
import torch.nn as nn | ||||
import torch.optim as optim | ||||
import argparse | ||||
import matplotlib.pyplot as plt | ||||
import os | ||||
|
||||
|
||||
class MassSpringSystem(nn.Module): | ||||
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"): | ||||
super().__init__() | ||||
self.device = device | ||||
self.mass = mass | ||||
self.springs = springs | ||||
self.dt = dt | ||||
self.gravity = gravity | ||||
|
||||
# Particle 0 is fixed at the origin | ||||
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device) | ||||
|
||||
# Remaining particles are trainable | ||||
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device)) | ||||
|
||||
# Velocities | ||||
self.velocities = torch.zeros(num_particles, 2, device=device) | ||||
|
||||
def forward(self, steps): | ||||
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0) | ||||
velocities = self.velocities | ||||
|
||||
for _ in range(steps): | ||||
forces = torch.zeros_like(positions) | ||||
|
||||
# Compute spring forces | ||||
for (i, j, rest_length, stiffness) in self.springs: | ||||
xi, xj = positions[i], positions[j] | ||||
dir_vec = xj - xi | ||||
dist = dir_vec.norm() | ||||
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6) | ||||
forces[i] += force | ||||
forces[j] -= force | ||||
|
||||
# Apply gravity | ||||
forces[:, 1] -= self.gravity * self.mass | ||||
|
||||
# Semi-implicit Euler integration | ||||
acceleration = forces / self.mass | ||||
velocities = velocities + acceleration * self.dt | ||||
positions = positions + velocities * self.dt | ||||
|
||||
# Fix particle 0 at origin | ||||
positions[0] = self.initial_position_0 | ||||
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device) | ||||
|
||||
return positions | ||||
|
||||
|
||||
def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"): | ||||
plt.figure(figsize=(6, 4)) | ||||
plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x') | ||||
plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o') | ||||
plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*') | ||||
plt.title("Mass-Spring System Positions") | ||||
plt.xlabel("X") | ||||
plt.ylabel("Y") | ||||
plt.legend() | ||||
plt.grid(True) | ||||
plt.tight_layout() | ||||
plt.savefig(save_path) | ||||
print(f"Saved visualization to {os.path.abspath(save_path)}") | ||||
plt.close() | ||||
|
||||
|
||||
def train(args): | ||||
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") | ||||
print(f"Using device: {device}") | ||||
system = MassSpringSystem( | ||||
num_particles=args.num_particles, | ||||
springs=[(0, 1, 1.0, args.stiffness)], | ||||
mass=args.mass, | ||||
dt=args.dt, | ||||
gravity=args.gravity, | ||||
device=device, | ||||
) | ||||
|
||||
optimizer = optim.Adam(system.parameters(), lr=args.lr) | ||||
target_positions = torch.tensor( | ||||
[[0.0, 0.0], [1.0, 0.0]], device=device | ||||
) | ||||
|
||||
for epoch in range(args.epochs): | ||||
optimizer.zero_grad() | ||||
final_positions = system(args.steps) | ||||
loss = (final_positions - target_positions).pow(2).mean() | ||||
loss.backward() | ||||
optimizer.step() | ||||
|
||||
if (epoch + 1) % args.log_interval == 0: | ||||
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}") | ||||
|
||||
# Visualization | ||||
initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy() | ||||
visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy()) | ||||
|
||||
print("\nTraining completed.") | ||||
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}") | ||||
print(f"Target positions:\n{target_positions.cpu().numpy()}") | ||||
|
||||
|
||||
def evaluate(args): | ||||
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove:
Suggested change
|
||||
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu") | ||||
print(f"Using device: {device}") | ||||
system = MassSpringSystem( | ||||
num_particles=args.num_particles, | ||||
springs=[(0, 1, 1.0, args.stiffness)], | ||||
mass=args.mass, | ||||
dt=args.dt, | ||||
gravity=args.gravity, | ||||
device=device, | ||||
) | ||||
|
||||
with torch.no_grad(): | ||||
final_positions = system(args.steps) | ||||
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}") | ||||
|
||||
|
||||
def parse_args(): | ||||
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System") | ||||
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs") | ||||
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass") | ||||
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate") | ||||
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration") | ||||
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle") | ||||
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant") | ||||
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system") | ||||
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval") | ||||
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs") | ||||
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength") | ||||
return parser.parse_args() | ||||
|
||||
|
||||
def main(): | ||||
args = parse_args() | ||||
|
||||
if args.mode == "train": | ||||
train(args) | ||||
elif args.mode == "eval": | ||||
evaluate(args) | ||||
|
||||
|
||||
if __name__ == "__main__": | ||||
main() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,43 @@ | ||||||
# Differentiable Physics: Mass-Spring System | ||||||
|
||||||
This example demonstrates a simple differentiable mass-spring system using PyTorch. | ||||||
|
||||||
Particles are connected by springs and evolve under the forces exerted by the springs and gravity. | ||||||
The system is fully differentiable, allowing the optimization of particle positions to match a target configuration using gradient-based learning. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I afraid I don't understand the task you are trying to solve here. Can it please, be thoroughly described and link to the associated paper provided? Looking into provided image:
|
||||||
|
||||||
--- | ||||||
|
||||||
## Files | ||||||
|
||||||
- `mass_spring.py` — Implements the mass-spring simulation, training loop, and evaluation. | ||||||
- `README.md` — Usage instructions and description. | ||||||
|
||||||
|
||||||
--- | ||||||
|
||||||
## Requirements | ||||||
|
||||||
- Python 3.8+ | ||||||
- PyTorch | ||||||
- pip install -r requirements.txt | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. matplotlib is missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add
Suggested change
to be consistent with other examples |
||||||
No external dependencies are required apart from PyTorch. | ||||||
|
||||||
--- | ||||||
|
||||||
## Usage | ||||||
|
||||||
First, ensure PyTorch is installed. | ||||||
|
||||||
#### Train the system | ||||||
|
||||||
```bash | ||||||
python mass_spring.py --mode train | ||||||
Comment on lines
+34
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code block is not properly closed. |
||||||
|
||||||
|
||||||
##### Visualization | ||||||
|
||||||
After training, the system's final positions are compared to the target positions. The plot below illustrates this comparison: | ||||||
|
||||||
 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not rendered. Maybe because above code block was not closed. |
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,3 @@ | ||||||
torch | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @AbhiLegend,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to usage of
Suggested change
|
||||||
matplotlib | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,40 +1,27 @@ | ||||||||||
#!/bin/bash | ||||||||||
# | ||||||||||
# This script runs through the code in each of the python examples. | ||||||||||
# The purpose is just as an integration test, not to actually train models in any meaningful way. | ||||||||||
# The purpose is just as an integration test, not to actually train models in any meaningful way. | ||||||||||
# For that reason, most of these set epochs = 1 and --dry-run. | ||||||||||
# | ||||||||||
# Optionally specify a comma separated list of examples to run. Can be run as: | ||||||||||
# * To run all examples: | ||||||||||
# To run all examples: | ||||||||||
# ./run_python_examples.sh | ||||||||||
# * To run few specific examples: | ||||||||||
# ./run_python_examples.sh "dcgan,fast_neural_style" | ||||||||||
# | ||||||||||
# To test examples on CUDA accelerator, run as: | ||||||||||
# USE_CUDA=True ./run_python_examples.sh | ||||||||||
# To run specific examples: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these changes be better moved to separate PR. These are unrelated to the example being added and some are arguable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made it to as it was. |
||||||||||
# ./run_python_examples.sh "dcgan,fast_neural_style" | ||||||||||
# | ||||||||||
# To test examples on hardware accelerator (CUDA, MPS, XPU, etc.), run as: | ||||||||||
# USE_ACCEL=True ./run_python_examples.sh | ||||||||||
# NOTE: USE_ACCEL relies on torch.accelerator API and not all examples are converted | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why a note was dropped? it highlights the planned work which is not yet done and I believe is helpful for maintenance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed |
||||||||||
# to use it at the moment. Thus, expect failures using this flag on non-CUDA accelerators | ||||||||||
# and consider to run examples one by one. | ||||||||||
# USE_CUDA=True ./run_python_examples.sh → for CUDA | ||||||||||
# USE_ACCEL=True ./run_python_examples.sh → for any accelerator (CUDA/MPS/XPU) | ||||||||||
# | ||||||||||
# Script requires uv to be installed. When executed, script will install prerequisites from | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This highlights the convention that each example is expected to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restored to original |
||||||||||
# `requirements.txt` for each example. If ran within activated virtual environment (uv venv, | ||||||||||
# python -m venv, conda) this might reinstall some of the packages. To change pip installation | ||||||||||
# index or to pass additional pip install options, run as: | ||||||||||
# PIP_INSTALL_ARGS="--pre -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" \ | ||||||||||
# ./run_python_examples.sh | ||||||||||
# To use a custom pip install source: | ||||||||||
# PIP_INSTALL_ARGS="--pre -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" ./run_python_examples.sh | ||||||||||
# | ||||||||||
# To force script to create virtual environment for each example, run as: | ||||||||||
# To force venv per example: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||||||||||
# VIRTUAL_ENV=".venv" ./run_python_examples.sh | ||||||||||
# Script will remove environments it creates in a teardown step after execution of each example. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I don't see why this needs to be dropped. That's explicit clarification of the behavior. It's helpful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rectified. |
||||||||||
|
||||||||||
BASE_DIR="$(pwd)/$(dirname $0)" | ||||||||||
source $BASE_DIR/utils.sh | ||||||||||
|
||||||||||
# TODO: Leave only USE_ACCEL and drop USE_CUDA once all examples will be converted | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This work is not completed. Why drop the TODO? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the changes. |
||||||||||
# to torch.accelerator API. For now, just add USE_ACCEL as an alias for USE_CUDA. | ||||||||||
if [ -n "$USE_ACCEL" ]; then | ||||||||||
USE_CUDA=$USE_ACCEL | ||||||||||
fi | ||||||||||
|
@@ -53,7 +40,7 @@ case $USE_CUDA in | |||||||||
ACCEL_FLAG="" | ||||||||||
;; | ||||||||||
"") | ||||||||||
exit 1; | ||||||||||
exit 1 | ||||||||||
;; | ||||||||||
esac | ||||||||||
|
||||||||||
|
@@ -67,7 +54,6 @@ function fast_neural_style() { | |||||||||
uv run download_saved_models.py | ||||||||||
fi | ||||||||||
test -d "saved_models" || { error "saved models not found"; return; } | ||||||||||
|
||||||||||
echo "running fast neural style model" | ||||||||||
uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg $ACCEL_FLAG || error "neural_style.py failed" | ||||||||||
} | ||||||||||
|
@@ -92,10 +78,11 @@ function language_translation() { | |||||||||
function mnist() { | ||||||||||
uv run main.py --epochs 1 --dry-run || error "mnist example failed" | ||||||||||
} | ||||||||||
|
||||||||||
function mnist_forward_forward() { | ||||||||||
uv run main.py --epochs 1 --no_accel || error "mnist forward forward failed" | ||||||||||
|
||||||||||
} | ||||||||||
|
||||||||||
function mnist_hogwild() { | ||||||||||
uv run main.py --epochs 1 --dry-run $CUDA_FLAG || error "mnist hogwild failed" | ||||||||||
} | ||||||||||
|
@@ -119,13 +106,12 @@ function reinforcement_learning() { | |||||||||
|
||||||||||
function snli() { | ||||||||||
echo "installing 'en' model if not installed" | ||||||||||
uv run -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; } | ||||||||||
uv run -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; } | ||||||||||
echo "training..." | ||||||||||
uv run train.py --epochs 1 --dev_every 1 --no-bidirectional --dry-run || error "couldn't train snli" | ||||||||||
} | ||||||||||
|
||||||||||
function fx() { | ||||||||||
# uv run custom_tracer.py || error "fx custom tracer has failed" UnboundLocalError: local variable 'tabulate' referenced before assignment | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup, here and in few places below. But someone needs to clarify whether this can be dropped or not. In any case - better not to mix up such changes with the new example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||||||||||
uv run invert.py || error "fx invert has failed" | ||||||||||
uv run module_tracer.py || error "fx module tracer has failed" | ||||||||||
uv run primitive_library.py || error "fx primitive library has failed" | ||||||||||
|
@@ -140,7 +126,7 @@ function super_resolution() { | |||||||||
} | ||||||||||
|
||||||||||
function time_sequence_prediction() { | ||||||||||
uv run generate_sine_wave.py || { error "generate sine wave failed"; return; } | ||||||||||
uv run generate_sine_wave.py || { error "generate sine wave failed"; return; } | ||||||||||
uv run train.py --steps 2 || error "time sequence prediction training failed" | ||||||||||
} | ||||||||||
|
||||||||||
|
@@ -164,6 +150,12 @@ function gat() { | |||||||||
uv run main.py --epochs 1 --dry-run || error "graph attention network failed" | ||||||||||
} | ||||||||||
|
||||||||||
function differentiable_physics() { | ||||||||||
pushd differentiable_physics | ||||||||||
python -m uv run mass_spring.py --mode train --epochs 5 --steps 3 || error "differentiable_physics example failed" | ||||||||||
popd | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason to use different pattern for tests just for this example? One potential problem -
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rectified it. |
||||||||||
} | ||||||||||
|
||||||||||
eval "base_$(declare -f stop)" | ||||||||||
|
||||||||||
function stop() { | ||||||||||
|
@@ -196,12 +188,9 @@ function stop() { | |||||||||
} | ||||||||||
|
||||||||||
function run_all() { | ||||||||||
# cpp moved to `run_cpp_examples.sh``` | ||||||||||
run dcgan | ||||||||||
# distributed moved to `run_distributed_examples.sh` | ||||||||||
run fast_neural_style | ||||||||||
run imagenet | ||||||||||
# language_translation | ||||||||||
run mnist | ||||||||||
run mnist_forward_forward | ||||||||||
run mnist_hogwild | ||||||||||
|
@@ -212,14 +201,13 @@ function run_all() { | |||||||||
run super_resolution | ||||||||||
run time_sequence_prediction | ||||||||||
run vae | ||||||||||
# vision_transformer - example broken see https://github.com/pytorch/examples/issues/1184 and https://github.com/pytorch/examples/pull/1258 for more details | ||||||||||
run word_language_model | ||||||||||
run fx | ||||||||||
run gcn | ||||||||||
run gat | ||||||||||
run differentiable_physics | ||||||||||
} | ||||||||||
|
||||||||||
# by default, run all examples | ||||||||||
if [ "" == "$EXAMPLES" ]; then | ||||||||||
run_all | ||||||||||
else | ||||||||||
|
@@ -236,7 +224,5 @@ if [ "" == "$ERRORS" ]; then | |||||||||
else | ||||||||||
echo "Some python examples failed:" | ||||||||||
printf "$ERRORS\n" | ||||||||||
#Exit with error (0-255) in case of failure in one of the tests. | ||||||||||
exit 1 | ||||||||||
|
||||||||||
fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code as it's replaced by
torch.accelerator
: