diff --git a/src/finch/compiled.py b/src/finch/compiled.py index f56411d..2b7eb6b 100644 --- a/src/finch/compiled.py +++ b/src/finch/compiled.py @@ -34,18 +34,18 @@ class GalleyScheduler(AbstractScheduler): def __init__(self, verbose=False): self.verbose=verbose + def get_julia_scheduler(self): + return jl.Finch.galley_scheduler(verbose=self.verbose) + class DefaultScheduler(AbstractScheduler): def __init__(self, verbose=False): self.verbose=verbose -def get_julia_scheduler(opt): - if isinstance(opt, DefaultScheduler): - return jl.Finch.default_scheduler(verbose=opt.verbose) - elif isinstance(opt, GalleyScheduler): - return jl.Finch.galley_scheduler(verbose=opt.verbose) + def get_julia_scheduler(self): + return jl.Finch.default_scheduler(verbose=self.verbose) def set_optimizer(opt): - jl.Finch.set_scheduler_b(get_julia_scheduler(opt)) + jl.Finch.set_scheduler_b(opt.get_julia_scheduler()) return def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1): @@ -53,5 +53,5 @@ def compute(tensor: Tensor, *, verbose: bool = False, opt=None, tag=-1): if opt == None: return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag)) else: - return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=get_julia_scheduler(opt))) + return Tensor(jl.Finch.compute(tensor._obj, verbose=verbose, tag=tag, ctx=opt.get_julia_scheduler())) return tensor