Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor punting of array storage to Tensor.device #13

Open
hameerabbasi opened this issue Feb 21, 2024 · 0 comments · May be fixed by #19
Open

Refactor punting of array storage to Tensor.device #13

hameerabbasi opened this issue Feb 21, 2024 · 0 comments · May be fixed by #19

Comments

@hameerabbasi
Copy link
Collaborator

Motivation

Users may need a way to:

  1. Query the format that a Tensor is currently in
  2. Create arrays with a specific format
  3. Convert to another format
  4. Materialize an array in a given format
  5. Do the above also with a given memory, i.e. CPU or GPU memory

For the purposes of being able to make decisions at the code level and also to gain back some control. In a meeting with @willow-ahrens and @mtsokol we decided the following interface for the long-term.

Proposed interface

I propose a number of new classes (stubs and descriptions below):

class LeafLevel(abc.ABC):
    @abc.abstractmethod
    def _construct(self, *, dtype, fill_value) -> jl.LeafLevel:
        ...
    

class Level(abc.ABC):
    @abc.abstractmethod
    def _construct(self, *, inner_level: "Level" | LeafLevel) -> jl.Level:
        ...
        
# Example impl of `Level`
class SparseList(Level):
    def __init__(self, index_type=dtypes.intp, pointer_type=dtypes.intp):
        self.index_type = index_type
        self.pointer_type = pointer_type
    def _construct(self, *, level) -> SparseList:
        return jl.SparseList[self.index_type, self.pos_type](level)

class Format:
    levels: tuple[Level, ...]
    order: tuple[int, ...]
    leaf: LeafLevel
    
    def __init__(self, *, levels: tuple[Level, ...], order: tuple[int, ...] | None, leaf: LeafLevel) -> None:
        if order is None:
            order = tuple(range(len(levels)))
        
        if len(order) != len(levels):
            raise ValueError(f"len(order) != len(levels), {order=}, {levels=}")
        
        if sorted(order) != range(len(order)):
            raise ValueError(f"sorted(order) != range(len(order)), {order=}")
        
        self.order = order
        self.levels = levels
        self.leaf = leaf
    
    def _construct(self, *, fill_value, dtype) -> jl.Swizzle:
        out_level = self.leaf._construct(dtype=dtype, fill_value=fill_value)
        for level in reversed(self.levels):
            out_level = level._construct(out_level)
        
        return jl.Swizzle(out_level, *reversed(self.order))
    
class Device:
    """The memory the `Tensor` will live on; as well as the execution context. Mixing devices will err."""

class Tensor:
    device: Device
    format: Format
    fill_value: Any
    dtype: dtypes.DType
    
    ...
    
    def to_device(self, device, /) -> "Tensor":
        ...
    
    def to_format(self, format, /) -> "Tensor":
        ...
        
        

def asarray(x: Any, /, *, dtype=None, device: Device | None = None, format: Format | None = None, fill_value: Any):
    # Massage dtype/device/format into acceptable form iff `None`
    return Tensor(jl_data=format._construct(fill_value=fill_value, dtype=dtype))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant