-
Notifications
You must be signed in to change notification settings - Fork 521
/
Copy pathconfig.py
48 lines (35 loc) · 1.29 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from jetson_containers import L4T_VERSION, CUDA_ARCHITECTURES
def jax(version, requires=None, alias=None, default=False):
"""
Install JAX from pip server or build the wheel from source.
"""
pkg = package.copy()
pkg['name'] = f'jax:{version}'
pkg['dockerfile'] = 'Dockerfile'
if len(version.split('.')) < 3:
build_version = version + '.0'
else:
build_version = version
pkg['build_args'] = {
'JAX_CUDA_ARCH_ARGS': ';'.join([f'{x/10:.1f}' for x in CUDA_ARCHITECTURES]),
'JAX_VERSION': version,
'JAX_BUILD_VERSION': build_version,
}
if L4T_VERSION.major >= 36:
pkg['build_args']['ENABLE_NCCL'] = 1
if requires:
pkg['requires'] = requires
builder = pkg.copy()
builder['name'] = builder['name'] + '-builder'
builder['build_args'] = {**builder['build_args'], 'FORCE_BUILD': 'on'}
if default:
pkg['alias'] = 'jax'
builder['alias'] = 'jax:builder'
return pkg, builder
package = [
jax('0.4.28', requires='>=35'),
jax('0.4.30', requires='>=35'),
jax('0.4.32', requires='>=35'),
jax('0.4.35', requires='>=35'),
jax('0.4.38', requires='>=35', default=True), # It works from jetpack 5 11.8 Cuda & 8.6 Cudnn
]