Device management and pmap #1148
Replies: 1 comment
-
To get some perspective, I suggest looking at what we already have. Elemwise uses/can use OpenMP in the C backend. Fusion rewrite is important here. We don't currently fuse Blockwise. Scan could be implemented with prange in Numba. Not sure anything needs to be done with JAX. We used to have Ops to transfer data to and from GPU (check Theano). If thinking about users like PyMC, note that we work with multiple graphs and automatically translate from forward to logp graphs, and Ops like send to GPU don't necessarily make sense in the same places. This raises the question how much stuff can we do automatically so it doesn't require user input. Also, how do we avoid clobbering graph analysis with low level device management operations? Just some things to muse on |
Beta Was this translation helpful? Give feedback.
-
With the implementation of
Blockwise
andpt.vectorize
, one of the remaining major features pytensor doesn't have vis-a-vis the big boy tensor libraries is device management and pmap.I'm not suggesting we commit to developing and maintaining our own cuda ops (though it would be nice in an ideal world). But it would be nice to at least be able to support CPU parallelism out of the box. How far away are we from something like that? I don't even have a sense of how to think about it.
Even if we couldn't do it ourselves at first, it would still be nice be able to send information about which Ops should go to which devices when compiling to JAX or torch.
Tagging @Ch0ronomato because he got me thinking about this.
Beta Was this translation helpful? Give feedback.
All reactions