Machine Learning with JAX - From Zero to Hero | Tutorial #1

❤️ Become The AI Epiphany Patreon ❤️
/ theaiepiphany
👨‍👩‍👧‍👦 Join our Discord community 👨‍👩‍👧‍👦
/ discord
With this video I'm kicking off a series of tutorials on JAX!
JAX is a powerful and increasingly more popular ML library built by the Google Research team. The 2 most popular deep learning frameworks built on top of JAX are Haiku (DeepMInd) and Flax (Google Research).
In this video I cover the basics as well as the nitty-gritty details of jit, grad, vmap, and various other idiosyncrasies of JAX.
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
✅ JAX GitHub: github.com/google/jax
✅ JAX docs: jax.readthedocs.io/
✅ My notebook: github.com/gordicaleksa/get-s...
✅ Useful video on autodiff: • What is Automatic Diff...
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
⌚️ Timetable:
00:00:00 What is JAX? JAX ecosystem
00:03:35 JAX basics
00:10:05 JAX is accelerator agnostic
00:15:00 jit explained
00:17:45 grad explained
00:27:25 The power of JAX autodiff (Hessians and beyond)
00:31:00 vmap explained
00:36:50 JAX API (NumPy, lax, XLA)
00:39:40 The nitty-gritty details of jit
00:46:55 Static arguments
00:50:05 Gotcha 1: Pure functions
00:56:00 Gotcha 2: In-Place Updates
00:57:35 Gotcha 3: Out-of-Bounds Indexing
00:59:55 Gotcha 4: Non-Array Inputs
01:01:50 Gotcha 5: Random Numbers
01:09:40 Gotcha 6: Control Flow
01:13:45 Gotcha 7: NaNs and float32
02:15:25 Quick summary
02:16:00 Conclusion: who should be using JAX?
02:17:10 Outro
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💰 BECOME A PATREON OF THE AI EPIPHANY ❤️
If these videos, GitHub projects, and blogs help you,
consider helping me out by supporting me on Patreon!
The AI Epiphany - / theaiepiphany
One-time donation - www.paypal.com/paypalme/theai...
Huge thank you to these AI Epiphany patreons:
Eli Mahler
Petar Veličković
Bartłomiej Danek
Zvonimir Sabljic
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💼 LinkedIn - / aleksagordic
🐦 Twitter - / gordic_aleksa
👨‍👩‍👧‍👦 Discord - / discord
📺 KZread - / theaiepiphany
📚 Medium - / gordicaleksa
💻 GitHub - github.com/gordicaleksa
📢 AI Newsletter - aiepiphany.substack.com/
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
#jax #machinelearning #framework

Пікірлер: 90

  • @billykotsos4642
    @billykotsos46422 жыл бұрын

    This channel is going on the GOAT level status

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Hahaha thank you! Step by step

  • @mk91-vz1oj
    @mk91-vz1oj13 күн бұрын

    Well done for reading out the jax quick start documentation

  • @matthewkhoo1153
    @matthewkhoo11532 жыл бұрын

    When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂

  • @matthewkhoo1153

    @matthewkhoo1153

    2 жыл бұрын

    @@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.

  • @DanielSuo

    @DanielSuo

    11 ай бұрын

    Believe it stands for "Just Another XLA" compiler.

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w2 жыл бұрын

    this video is such a great service to the community. really great examples to help better understand Jax at a nuanced level.

  • @sarahel-sherif3318
    @sarahel-sherif33182 жыл бұрын

    Great material and great efforts , excited to see FLAX and Haiku ,thank you

  • @mikesmith853
    @mikesmith8532 жыл бұрын

    Awesome work!! JAX is a fantastic library. This series is the reason I finally subscribed to your channel. Thanks for your work!

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you! 😄

  • @akashraut3581
    @akashraut35812 жыл бұрын

    Finally some jax tutorial.. Keep them coming

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Yup it was about time I started learning it. Will do sir

  • @mariuskombou6729
    @mariuskombou67293 ай бұрын

    Thanks for this video!! that was really interesting for a new user of JAX like me

  • @jawadmansoor6064
    @jawadmansoor60642 жыл бұрын

    Great video thanks, kindly complete the tutorial series on Flax as well.

  • @sacramentofwilderness6656
    @sacramentofwilderness66562 жыл бұрын

    Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you! It is confusing. It'd be cool to understand why exactly is it difficult to handle this "correctly" (throwing an exception).

  • @vijayanand7270
    @vijayanand72702 жыл бұрын

    Great video!🔥 Would need Paper implementations too💯

  • @Khushpich
    @Khushpich2 жыл бұрын

    Great stuff, looking forward to the next jax tutorials

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thanks!

  • @deoabhijit5935
    @deoabhijit59352 жыл бұрын

    looking forward for such grt videos

  • @RamithHettiarachchi
    @RamithHettiarachchi2 жыл бұрын

    Thank you for the tutorial! By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃

  • @kaneyxx
    @kaneyxx2 жыл бұрын

    Thanks for the tutorial! Love it!

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Glad to hear that ^^

  • @maikkschischo1790
    @maikkschischo1790 Жыл бұрын

    Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.

  • @DullPigeon2750
    @DullPigeon2750 Жыл бұрын

    Wonderful explanation about vmap function

  • @shyamalchandra6597
    @shyamalchandra65972 жыл бұрын

    Great job! Keep it up!😀

  • @broccoli322
    @broccoli322 Жыл бұрын

    Great job. Very nice tutorial

  • @johanngerberding5956
    @johanngerberding59562 жыл бұрын

    congrats to your deepmind job man (read your post), nice channel, keep going!

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you!! 🙏😄

  • @TheAIEpiphany
    @TheAIEpiphany2 жыл бұрын

    I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).

  • @TheParkitny
    @TheParkitny2 жыл бұрын

    great video and content, this channel needs more recognition.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you! 🥳

  • @not_a_human_being
    @not_a_human_being9 ай бұрын

    50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.

  • @alexanderchernyavskiy9538
    @alexanderchernyavskiy95382 жыл бұрын

    oh, quite impressive series with the perfect explanation

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you man! 🙏

  • @alinajafistudent6906
    @alinajafistudent6906 Жыл бұрын

    Very Cool, Thanks

  • @bionhoward3159
    @bionhoward31592 жыл бұрын

    i love jax ... thank you for your work!

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    You're welcome!

  • @user-vg4gv9zj3d
    @user-vg4gv9zj3d2 жыл бұрын

    It's really good Tutorial!! thx :)

  • @1potdish271
    @1potdish2712 жыл бұрын

    Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)

  • @sallanmega1
    @sallanmega12 жыл бұрын

    Thank you for the amazing content. Greetings from Spain

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Gracias y saludos desde Belgrado! 🙏 Tengo muchos amigos en España.

  • @gim8377
    @gim83774 ай бұрын

    Thank you so much !

  • @yulanliu3839
    @yulanliu38392 жыл бұрын

    JAX = Just After eXecution (related to the tracing behaviour) JIT = Just In Time (related to the compilation behaviour)

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you! A couple more people also pointed it out. Makes sense

  • @promethesured
    @promethesured2 ай бұрын

    ty ty ty ty ty for this video

  • @PhucLe-qs7nx
    @PhucLe-qs7nx2 жыл бұрын

    JAX is Just After eXecution, represent the paradigm of tracing and transform (grad, vmap, jit,..) after the first execution.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Hmmm, source?

  • @PhucLe-qs7nx

    @PhucLe-qs7nx

    2 жыл бұрын

    @@TheAIEpiphany Sorry I can't remember it now. But it somewhere in the documentation or a Jax's github issue/discussion,

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    @@PhucLe-qs7nx Thanks in any case! One of the other comments mentioned it simply stands for Jit Autograd XLA. 😄 That sounds reasonable as well.

  • @michaellaskin3407
    @michaellaskin34072 жыл бұрын

    great video!

  • @MikeOxmol_
    @MikeOxmol_2 жыл бұрын

    Saw your video retweeted by someone, watched it and subbed, because your content is great :) How often will you be uploading the following Jax vids?

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thanks! Next one tomorrow or Thursday.

  • @yagneshm.bhadiyadra4359
    @yagneshm.bhadiyadra4359 Жыл бұрын

    Thank you for a great content!

  • @TheAIEpiphany

    @TheAIEpiphany

    Жыл бұрын

    Thanks!

  • @mohammedelfatihsalahmohame7288
    @mohammedelfatihsalahmohame72882 жыл бұрын

    I am glad that I found this channel

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Welcome 🚀

  • @mrlovalovaa
    @mrlovalovaa2 жыл бұрын

    Finally the Jax !!!!!!!

  • @adels1388
    @adels13882 жыл бұрын

    Thanks alot. Keep up the good work. Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Of course what did I do? 😂

  • @adels1388

    @adels1388

    2 жыл бұрын

    @@TheAIEpiphany you wrote x1*2+x2*2+x3*2. I replaced + with comma :)

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    @@adels1388 My bad 😅 Thanks for noticing, the printed result was correct...

  • @vaishnav4035
    @vaishnav40357 ай бұрын

    Thank you so much ..

  • @adityakane5669
    @adityakane56692 жыл бұрын

    At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x? Code to reproduce: def switch(x): print("traced") if x > 10.: return 1. else: return 0. jit_switch = jit(switch, static_argnums=(0,)) x = 5 jit_switch(x) x = 16 jit_switch(x) ''' Output: traced DeviceArray(0., dtype=float32, weak_type=True) traced DeviceArray(1., dtype=float32, weak_type=True) '''

  • @yagneshm.bhadiyadra4359
    @yagneshm.bhadiyadra4359 Жыл бұрын

    Can we say that if we made all arguments static, then it will be as good as normal code without jax? Thank you for these videos btw

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w2 жыл бұрын

    Would it be better to use Julia and not have to worry about the gotchas? And still get the performance.

  • @kenbobcorn
    @kenbobcorn2 жыл бұрын

    Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum. f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f. Hope that makes sense. You can always consult the docs and experiment yourself.

  • @Gannicus99

    @Gannicus99

    2 жыл бұрын

    Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.

  • @nicecomment5411
    @nicecomment5411 Жыл бұрын

    Thanks for the video. I have a question, how do i run tensorflow jax on browser? (Not in an online notebook)

  • @teetanrobotics5363
    @teetanrobotics53632 жыл бұрын

    amazing content

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you!

  • @Gannicus99
    @Gannicus992 жыл бұрын

    Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters. Closure variables are hacky, but they are valid python code. Just not in JAX.

  • @peterkanini867
    @peterkanini867 Жыл бұрын

    What font are you using for the Colab notebook?

  • @alvinhew1872
    @alvinhew18722 жыл бұрын

    Hi, are there any resources on how to freeze certain layers of the network for transfer learning?

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    jax.lax.stop_gradient

  • @zumpitu
    @zumpitu11 ай бұрын

    Hi ! Thank you for your video. Is not that very similar tu Numba ?

  • @arshsharma8627
    @arshsharma8627Ай бұрын

    bru youre great

  • @tshegofatsotshego375
    @tshegofatsotshego3752 жыл бұрын

    Is it possible to use jax with python statsmodel?

  • @TheMazyProduction
    @TheMazyProduction2 жыл бұрын

    Noice

  • @adrianstaniec
    @adrianstaniec2 жыл бұрын

    great content, horribile font ;)

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Hahaha thank you and thank you!

  • @L4rsTrysToMakeTut
    @L4rsTrysToMakeTut2 жыл бұрын

    Why not julia lang?

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Why would this video imply that you shouldn't give Julia a shot? I may make a video on Julia in the future. I personally wanted to learn JAX since I'll be using it in DeepMind.

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w2 жыл бұрын

    at kzread.info/dash/bejne/haeo19iMXde8k5M.html, what I noticed was that when I tried, print(grad(f_jit)(2.)), even with the static_argnums.

  • @alokraj7120
    @alokraj71202 жыл бұрын

    Can you make a video "how to install jax in anaconda, python and other python Ide

  • @kirtipandya4618
    @kirtipandya46182 жыл бұрын

    JAX = Just Autograd and Xla

  • @samueltrif5472
    @samueltrif547211 ай бұрын

    one hour of nothing lol

  • @heyman620
    @heyman6202 жыл бұрын

    49:00 jnp.reshape(x, (np.prod(x.shape),)) works.