-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
31 lines (25 loc) · 860 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def tensor_to_latex(tensor):
"""
Convert a 1D or 2D PyTorch tensor to a LaTeX bmatrix string
Example:
```python
import torch
from utils import tensor_to_latex
from IPython.display import display, Math
tensor = torch.tensor([1, 2, 3, 4, 5])
latex_tensor = tensor_to_latex(tensor)
display(Math(f"\\mathrm{{tensor}} = {latex_tensor}"))
```
Args:
tensor (torch.Tensor): A PyTorch tensor
Returns:
str: A LaTeX bmatrix string
"""
if tensor.dim() == 1:
# Convert 1D tensor to a 2D column vector for display
tensor = tensor.unsqueeze(1)
latex_str = "\\begin{bmatrix}\n"
for row in tensor:
latex_str += " & ".join([f"{val.item():.3f}" for val in row]) + " \\\\\n"
latex_str += "\\end{bmatrix}"
return latex_str