Skip to content

Commit

Permalink
Remove CTask (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
rikhuijzer authored Feb 20, 2022
1 parent 4a85b47 commit 01c2727
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.6.10"
version = "0.7.0"

[deps]
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Heap allocated objects are shallow copied:
Expand All @@ -45,17 +45,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(t)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 4
@show consume(ctask) # 5
@show consume(ttask) # 4
@show consume(ttask) # 5
```

In constrast to standard arrays, which are only shallow copied during
Expand All @@ -74,17 +74,17 @@ function f()
end
end

ctask = CTask(f)
ttask = TapedTask(f)

@show consume(ctask) # 0
@show consume(ctask) # 1
@show consume(ttask) # 0
@show consume(ttask) # 1

a = copy(ctask)
a = copy(ttask)
@show consume(a) # 2
@show consume(a) # 3

@show consume(ctask) # 2
@show consume(ctask) # 3
@show consume(ttask) # 2
@show consume(ttask) # 3
```

Note: The [Turing](https://github.com/TuringLang/Turing.jl)
Expand Down
10 changes: 4 additions & 6 deletions perf/p0.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ]add Turing#hg/new-libtask2

using Libtask
using Turing, DynamicPPL, AdvancedPS
using BenchmarkTools
Expand All @@ -26,8 +24,8 @@ args = m.evaluator[2:end];
@btime f(args...)
# (2.0, VarInfo (2 variables (μ, σ), dimension 2; logp: -6.162))

@show "CTask construction..."
t = @btime Libtask.CTask(f, args...)
@show "TapedTask construction..."
t = @btime TapedTask(f, args...)
# schedule(t.task) # work fine!
# @show Libtask.result(t.tf)
@show "Run a tape..."
Expand All @@ -39,8 +37,8 @@ m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo());
@show "Directly call..."
@btime m.evaluator[1](m.evaluator[2:end]...)

@show "CTask construction..."
t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...);
@show "TapedTask construction..."
t = @btime TapedTask(m.evaluator[1], m.evaluator[2:end]...);
# schedule(t.task)
# @show Libtask.result(t.tf.tape)
@show "Run a tape..."
Expand Down
2 changes: 1 addition & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo())
f = m.evaluator[1]
args = m.evaluator[2:end]

t = Libtask.CTask(f, args...)
t = TapedTask(f, args...)

t.tf(args...)

Expand Down
6 changes: 1 addition & 5 deletions src/Libtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ using MacroTools

using LRUCache

export CTask, consume, produce
export TapedTask, consume, produce
export TArray, tzeros, tfill, TRef

export TapedTask

include("tapedfunction.jl")
include("tapedtask.jl")

include("tarray.jl")
include("tref.jl")

const CTask = TapedTask

end
16 changes: 10 additions & 6 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@ struct TapedTaskException
backtrace::Vector{Any}
end

struct TapedTask
struct TapedTask{F}
task::Task
tf::TapedFunction
tf::TapedFunction{F}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, pch, cch, Any[])
t::Task,
tf::TapedFunction{F},
produce_ch::Channel{Any},
consume_ch::Channel{Int}
) where {F}
new{F}(t, tf, produce_ch, consume_ch, Any[])
end
end

Expand Down Expand Up @@ -148,8 +152,8 @@ function Base.iterate(t::TapedTask, state=nothing)
nothing
end
end
Base.IteratorSize(::Type{TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TapedTask}) = Base.EltypeUnknown()
Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()


# copy the task
Expand Down
9 changes: 4 additions & 5 deletions test/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using BenchmarkTools
using Libtask


macro rep(cnt, exp)
blk =:(begin end)
for _ in 1:eval(cnt)
Expand Down Expand Up @@ -47,10 +46,10 @@ function f()
end

@btime begin
ctask = CTask(f)
consume(ctask)
consume(ctask)
a = copy(ctask)
ttask = TapedTask(f)
consume(ttask)
consume(ttask)
a = copy(ttask)
consume(a)
consume(a)
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ using Libtask
using Test

include("tf.jl")
include("ctask.jl")
include("tapedtask.jl")
include("tarray.jl")
include("tref.jl")

if get(ENV, "BENCHMARK", nothing) != nothing
if haskey(ENV, "BENCHMARK")
include("benchmarks.jl")
end
78 changes: 39 additions & 39 deletions test/ctask.jl → test/tapedtask.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "ctask" begin
@testset "tapedtask" begin
# Test case 1: stack allocated objects are deep copied.
@testset "stack allocated objects" begin
function f()
Expand All @@ -9,14 +9,14 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 2
@test consume(ctask) == 3
@test consume(ttask) == 2
@test consume(ttask) == 3

@inferred Libtask.TapedFunction(f)
end
Expand All @@ -31,16 +31,16 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 0
@test consume(ctask) == 1
a = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 0
@test consume(ttask) == 1
a = copy(ttask)
@test consume(a) == 2
@test consume(a) == 3
@test consume(ctask) == 4
@test consume(ctask) == 5
@test consume(ctask) == 6
@test consume(ctask) == 7
@test consume(ttask) == 4
@test consume(ttask) == 5
@test consume(ttask) == 6
@test consume(ttask) == 7
end

@testset "iteration" begin
Expand All @@ -52,20 +52,20 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)

next = iterate(ctask)
next = iterate(ttask)
@test next === (1, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (2, nothing)

val, state = next
next = iterate(ctask, state)
next = iterate(ttask, state)
@test next === (3, nothing)

a = collect(Iterators.take(ctask, 7))
a = collect(Iterators.take(ttask, 7))
@test eltype(a) === Int
@test a == 4:10
end
Expand All @@ -82,14 +82,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa MethodError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa MethodError
@test ttask.task.exception isa MethodError
end
end

Expand All @@ -103,14 +103,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa ErrorException
end
if VERSION >= v"1.5"
@test ctask.task.exception isa ErrorException
@test ttask.task.exception isa ErrorException
end
end

Expand All @@ -125,14 +125,14 @@
end
end

ctask = CTask(f)
ttask = TapedTask(f)
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -147,15 +147,15 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ttask = TapedTask(f)
@test consume(ttask) == 2
try
consume(ctask)
consume(ttask)
catch ex
@test ex isa BoundsError
end
if VERSION >= v"1.5"
@test ctask.task.exception isa BoundsError
@test ttask.task.exception isa BoundsError
end
end

Expand All @@ -170,17 +170,17 @@
end
end

ctask = CTask(f)
@test consume(ctask) == 2
ctask2 = copy(ctask)
ttask = TapedTask(f)
@test consume(ttask) == 2
ttask2 = copy(ttask)
try
consume(ctask2)
consume(ttask2)
catch ex
@test ex isa BoundsError
end
@test ctask.task.exception === nothing
@test ttask.task.exception === nothing
if VERSION >= v"1.5"
@test ctask2.task.exception isa BoundsError
@test ttask2.task.exception isa BoundsError
end
end
end
Expand Down
Loading

2 comments on commit 01c2727

@rikhuijzer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/55084

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 01c2727aa10c4959b75b848334a3416d767bb1be
git push origin v0.7.0

Please sign in to comment.