Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set-up JuliaFormatter #67

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
style = "blue"

ignore = ["src/Wrapper.jl"]
pipe_to_function_call = false
whitespace_in_kwargs = true
whitespace_typedefs = true
8 changes: 8 additions & 0 deletions .github/workflows/Format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Format suggestions
on:
pull_request:
jobs:
code-style:
runs-on: ubuntu-latest
steps:
- uses: julia-actions/julia-format@v3
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Torch.jl

[![Build Status](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml/badge.svg?branch=master)](https://github.com/FluxML/Torch.jl/actions/workflows/CI.yaml?query=branch%3Amaster)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/JuliaDiff/BlueStyle)

Sensible extensions for exposing torch in Julia.

This package is aimed at providing the `Tensor` type, which offloads all computations over to [ATen](https://pytorch.org/cppdocs/), the foundational tensor library for PyTorch, written in C++.
Expand Down
4 changes: 2 additions & 2 deletions deps/julia_wrapper_generator/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ function rewrite!(e::Expr)
end

function rewrite!(e::Expr, ::Val{:function})
rewrite!(e.args[2], Val(e.args[2].head))
return rewrite!(e.args[2], Val(e.args[2].head))
end

function rewrite!(e::Expr, ::Val{:block})
e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
return e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
end

function rewrite!(dag::ExprDAG)
Expand Down
2 changes: 1 addition & 1 deletion deps/julia_wrapper_generator/generator.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[general]
library_name = "libtorch_c_api"
output_file_path = "../../src/wrapper.jl"
output_file_path = "../../src/Wrapper.jl"
prologue_file_path = "./prologue.jl"
module_name = "Wrapper"
jll_pkg_name = "TorchCAPI_jll"
Expand Down
20 changes: 10 additions & 10 deletions deps/julia_wrapper_generator/prologue.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
function get_error()
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
unsafe_string(err)
err = cglobal((:myerr, libtorch_c_api), Cstring) |> unsafe_load
return unsafe_string(err)
end

macro runtime_error_check(ex)
quote
x = $ex
if x == 1
cs = get_error()
flush_error()
throw(cs)
end
end |> esc
return quote
x = $ex
if x == 1
cs = get_error()
flush_error()
throw(cs)
end
end |> esc
end
30 changes: 16 additions & 14 deletions src/Torch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

TURN_ON_LOGGING = false

include("wrapper.jl")
include("Wrapper.jl")

using .Wrapper

Expand All @@ -32,23 +32,25 @@
include("grads.jl")
include("utils.jl")

@init @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
using .Flux
@init @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
using .Flux

function (tbn::Flux.BatchNorm)(x::Tensor)
tbn.λ.(Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1))
end
function (tbn::Flux.BatchNorm)(x::Tensor)
return tbn.λ.(
Torch.batchnorm(x, tbn.γ, tbn.β, tbn.μ, tbn.σ², 0, tbn.momentum, tbn.ϵ, 1)
)
end

function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T,N}) where {T,N}
ptr = Ref(Ptr{Cvoid}())
function Flux.Zygote.accum(t1::Tensor, t2::Tensor{T, N}) where {T, N}
ptr = Ref(Ptr{Cvoid}())

Check warning on line 45 in src/Torch.jl

View check run for this annotation

Codecov / codecov/patch

src/Torch.jl#L44-L45

Added lines #L44 - L45 were not covered by tests

Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
Tensor{T,N}(ptr[], Torch.on(t1))
end
Torch.Wrapper.atg_add_(ptr, t1.ptr, t2.ptr)
return Tensor{T, N}(ptr[], Torch.on(t1))

Check warning on line 48 in src/Torch.jl

View check run for this annotation

Codecov / codecov/patch

src/Torch.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end

eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
torch(x) = Flux.fmap(to_tensor, x)
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_copy_data))
eval(:(Flux.Zygote.@nograd Torch.Wrapper.at_dim))
torch(x) = Flux.fmap(to_tensor, x)

Check warning on line 53 in src/Torch.jl

View check run for this annotation

Codecov / codecov/patch

src/Torch.jl#L53

Added line #L53 was not covered by tests
end

end # module
File renamed without changes.
39 changes: 19 additions & 20 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
# Base.BroadcastStyle(::Type{Tensor}) = TensorStyle()

for op in (:+, :-, :/)
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
$op(t1, t2)
end
@eval function broadcasted(::typeof($op), t1::Tensor, t2::Tensor)
return $op(t1, t2)

Check warning on line 13 in src/broadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/broadcast.jl#L12-L13

Added lines #L12 - L13 were not covered by tests
end
end

for op in (:+, :-)
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
t_ = reshape(t2, -1, 1)
$op(t1, t_)
end
@eval function broadcasted(::typeof($op), t1::Tensor, t2::TensorVector)
t_ = reshape(t2, -1, 1)
return $op(t1, t_)
end
end

function broadcasted(::typeof(*), t1::Tensor{T,N}, t2::Tensor{T,M}) where {T,N,M}
ptr = Ref(Ptr{Cvoid}())
function broadcasted(::typeof(*), t1::Tensor{T, N}, t2::Tensor{T, M}) where {T, N, M}
ptr = Ref(Ptr{Cvoid}())

atg_mul(ptr, t1.ptr, t2.ptr)
Tensor{T,max(N,M)}(ptr[], on(t1))
atg_mul(ptr, t1.ptr, t2.ptr)
return Tensor{T, max(N, M)}(ptr[], on(t1))
end

broadcasted(::typeof(NNlib.relu), t::Tensor) = NNlib.relu(t)
Expand All @@ -34,22 +34,21 @@
broadcasted(::typeof(NNlib.sigmoid), t::Tensor) = NNlib.sigmoid(t)

for op in (:+, :-, :*, :/)
@eval function broadcasted(::typeof($op), t::Tensor, args...)
$op(t, args...)
end
@eval function broadcasted(::typeof($op), t::Tensor, args...)
return $op(t, args...)
end
end

broadcasted(::typeof(sqrt), t::Tensor) = sqrt(t)

function broadcasted(::typeof(copy), t::Tensor{T,N}) where {T,N}
t
function broadcasted(::typeof(copy), t::Tensor{T, N}) where {T, N}
return t

Check warning on line 45 in src/broadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/broadcast.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

@adjoint function broadcast(::typeof(NNlib.sigmoid), t::Tensor)

NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)
return NNlib.sigmoid(t), Δ -> (∇sigmoid(Δ, t),)

Check warning on line 49 in src/broadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/broadcast.jl#L49

Added line #L49 was not covered by tests
end

@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where T
relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)),)
@adjoint function broadcasted(::typeof(NNlib.relu), t::Tensor{T}) where {T}
return relu(t), Δ -> (nothing, ∇leaky_relu(Δ, t, zero(T)))

Check warning on line 53 in src/broadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/broadcast.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
end
Loading
Loading