빅데이타 & 머신러닝/머신러닝
Google JAX 소개
Terry Cho
2023. 9. 8. 13:05
모두의 연구소 JAX님 영상 강의
요약
- Numpy의 대체제이지만 GPU,TPU 사용이 가능함. 구글과 허깅페이스가 강력하게 밀고 있는 프레임웍.
- Functional programming 모델로 Numpy와 프로그래밍 모델이 다소 다름. Immutable(변경 불가능 특징)을 가짐. 예를 들어
# In NumPy arrays are mutable
x = np.arange(size)
print(x)
x[index] = value
print(x)
JAX의 경우 직접 값을 변경하는것이 불가능하고 set function을 직접 사용해야 한다.
# Solution/workaround:
y = x.at[index].set(value)
print(x)
print(y)
- JIT 컴파일 방식을 사용함. @jit 데코레이터나 jax.jit() 명령을 사용해서 컴파일해야함.
- AutoGrad : 자동으로 Gradient를 구해줌 (자동 미분) ex) jax.grad(f) : f를 미분하는 함수.
JAX Ecosystem
- Haiku : Neural network
- Flax : Deep learning (Google Research의 많은 모델이 Flax를 이용해서 개발됨, 허깅페이스도 Flax를 많이 지원함)
- RLax : Reinforcement Learning
- Chex : 테스트 환경
- Graph : GNNs