Skip to content

Commit

Permalink
Move get_julia_scheduler into the class definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebd99 committed Dec 5, 2024
1 parent 1219c0e commit 4a973ab
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/finch/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ class GalleyScheduler(AbstractScheduler):
def __init__(self, verbose=False):
self.verbose=verbose

def get_julia_scheduler(self):
return jl.Finch.galley_scheduler(verbose=self.verbose)

class DefaultScheduler(AbstractScheduler):
def __init__(self, verbose=False):
self.verbose=verbose

def get_julia_scheduler(opt):
if isinstance(opt, DefaultScheduler):
return jl.Finch.default_scheduler(verbose=opt.verbose)
elif isinstance(opt, GalleyScheduler):
return jl.Finch.galley_scheduler(verbose=opt.verbose)
def get_julia_scheduler(self):
return jl.Finch.default_scheduler(verbose=self.verbose)

def set_optimizer(opt):
jl.Finch.set_scheduler_b(get_julia_scheduler(opt))
jl.Finch.set_scheduler_b(opt.get_julia_scheduler())
return

def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1):
if not tensor.is_computed():
if opt == None:
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag))
else:
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=get_julia_scheduler(opt)))
return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=opt.get_julia_scheduler()))
return tensor

0 comments on commit 4a973ab

Please sign in to comment.