I constantly find myself having to print raw tensors (either by hand or dumping them to stdout), especially when reading pytorch / jax code, to understand the transformations, e.g. for something like,
`x = torch.randn(32, 3, 224, 224).unfold(2, 16, 16).unfold(3, 16, 16).reshape(32, 3, 196, 256).transpose(1, 2).reshape(32, 196, 768).view(32, 196, 12, 64).transpose(1, 2)`
How do folks visualize tensors to quickly understand data flow in complex NN’s?
Comments URL: https://news.ycombinator.com/item?id=45564494
Points: 2
# Comments: 0
Source: news.ycombinator.com