Skip to content

Add "load pytorch tensor" section into the burn book #2316

Open
@med1844

Description

@med1844

Issue based on discussion #2315, @antimora

To my best knowledge, here's how to load a tensor:

  1. In python:
    Ensure you wrap the tensor with dict before save, e.g.

    torch.save({"some_key": tensor}, "path/to/tensor.pt")
  2. In rust:

    #[derive(Module, Debug)]
    struct FloatTensor<B: Backend, const D: usize> {
        some_key: Param<Tensor<B, D>>,
    }
    
    fn main() {
        type B = NdArray;
        let device = Default::default();
        let tensor: FloatTensorRecord<B, 3> =
            PyTorchFileRecorder::<FullPrecisionSettings>::new()
                .load("path/to/tensor.pt".into(), &device)
                .unwrap();
        let tensor = tensor.some_key.val();
    }

Metadata

Metadata

Assignees

Labels

documentationImprovements or additions to documentationgood first issueGood for newcomers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions