Terry Cho 2023. 9. 8. 13:05

모두의 연구소 JAX님 영상 강의

요약

  • Numpy의 대체제이지만 GPU,TPU 사용이 가능함. 구글과 허깅페이스가 강력하게 밀고 있는 프레임웍.
  • Functional programming 모델로 Numpy와 프로그래밍 모델이 다소 다름. Immutable(변경 불가능 특징)을 가짐. 예를 들어 
# In NumPy arrays are mutable
x = np.arange(size)
print(x)
x[index] = value
print(x)

JAX의 경우 직접 값을 변경하는것이 불가능하고 set function을 직접 사용해야 한다.

# Solution/workaround:
y = x.at[index].set(value)
print(x)
print(y)
  • JIT 컴파일 방식을 사용함. @jit 데코레이터나 jax.jit() 명령을 사용해서 컴파일해야함. 
  • AutoGrad : 자동으로 Gradient를 구해줌 (자동 미분) ex) jax.grad(f) : f를 미분하는 함수. 

 

 

JAX Ecosystem

  • Haiku : Neural network
  • Flax : Deep learning (Google Research의 많은 모델이 Flax를 이용해서 개발됨, 허깅페이스도 Flax를 많이 지원함)
  • RLax : Reinforcement Learning
  • Chex : 테스트 환경
  • Graph : GNNs