Machine Learning with JAX - From Zero to Hero | Tutorial #1
❤️ Become The AI Epiphany Patreon ❤️
/ theaiepiphany
👨👩👧👦 Join our Discord community 👨👩👧👦
/ discord
With this video I'm kicking off a series of tutorials on JAX!
JAX is a powerful and increasingly more popular ML library built by the Google Research team. The 2 most popular deep learning frameworks built on top of JAX are Haiku (DeepMInd) and Flax (Google Research).
In this video I cover the basics as well as the nitty-gritty details of jit, grad, vmap, and various other idiosyncrasies of JAX.
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
✅ JAX GitHub: github.com/google/jax
✅ JAX docs: jax.readthedocs.io/
✅ My notebook: github.com/gordicaleksa/get-s...
✅ Useful video on autodiff: • What is Automatic Diff...
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
⌚️ Timetable:
00:00:00 What is JAX? JAX ecosystem
00:03:35 JAX basics
00:10:05 JAX is accelerator agnostic
00:15:00 jit explained
00:17:45 grad explained
00:27:25 The power of JAX autodiff (Hessians and beyond)
00:31:00 vmap explained
00:36:50 JAX API (NumPy, lax, XLA)
00:39:40 The nitty-gritty details of jit
00:46:55 Static arguments
00:50:05 Gotcha 1: Pure functions
00:56:00 Gotcha 2: In-Place Updates
00:57:35 Gotcha 3: Out-of-Bounds Indexing
00:59:55 Gotcha 4: Non-Array Inputs
01:01:50 Gotcha 5: Random Numbers
01:09:40 Gotcha 6: Control Flow
01:13:45 Gotcha 7: NaNs and float32
02:15:25 Quick summary
02:16:00 Conclusion: who should be using JAX?
02:17:10 Outro
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💰 BECOME A PATREON OF THE AI EPIPHANY ❤️
If these videos, GitHub projects, and blogs help you,
consider helping me out by supporting me on Patreon!
The AI Epiphany - / theaiepiphany
One-time donation - www.paypal.com/paypalme/theai...
Huge thank you to these AI Epiphany patreons:
Eli Mahler
Petar Veličković
Bartłomiej Danek
Zvonimir Sabljic
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💼 LinkedIn - / aleksagordic
🐦 Twitter - / gordic_aleksa
👨👩👧👦 Discord - / discord
📺 KZread - / theaiepiphany
📚 Medium - / gordicaleksa
💻 GitHub - github.com/gordicaleksa
📢 AI Newsletter - aiepiphany.substack.com/
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
#jax #machinelearning #framework
Пікірлер: 90
This channel is going on the GOAT level status
@TheAIEpiphany
2 жыл бұрын
Hahaha thank you! Step by step
Well done for reading out the jax quick start documentation
When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.
@TheAIEpiphany
2 жыл бұрын
Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂
@matthewkhoo1153
2 жыл бұрын
@@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.
@DanielSuo
11 ай бұрын
Believe it stands for "Just Another XLA" compiler.
this video is such a great service to the community. really great examples to help better understand Jax at a nuanced level.
Great material and great efforts , excited to see FLAX and Haiku ,thank you
Awesome work!! JAX is a fantastic library. This series is the reason I finally subscribed to your channel. Thanks for your work!
@TheAIEpiphany
2 жыл бұрын
Thank you! 😄
Finally some jax tutorial.. Keep them coming
@TheAIEpiphany
2 жыл бұрын
Yup it was about time I started learning it. Will do sir
Thanks for this video!! that was really interesting for a new user of JAX like me
Great video thanks, kindly complete the tutorial series on Flax as well.
Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.
@TheAIEpiphany
2 жыл бұрын
Thank you! It is confusing. It'd be cool to understand why exactly is it difficult to handle this "correctly" (throwing an exception).
Great video!🔥 Would need Paper implementations too💯
Great stuff, looking forward to the next jax tutorials
@TheAIEpiphany
2 жыл бұрын
Thanks!
looking forward for such grt videos
Thank you for the tutorial! By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃
Thanks for the tutorial! Love it!
@TheAIEpiphany
2 жыл бұрын
Glad to hear that ^^
Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.
Wonderful explanation about vmap function
Great job! Keep it up!😀
Great job. Very nice tutorial
congrats to your deepmind job man (read your post), nice channel, keep going!
@TheAIEpiphany
2 жыл бұрын
Thank you!! 🙏😄
I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).
great video and content, this channel needs more recognition.
@TheAIEpiphany
2 жыл бұрын
Thank you! 🥳
50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.
oh, quite impressive series with the perfect explanation
@TheAIEpiphany
2 жыл бұрын
Thank you man! 🙏
Very Cool, Thanks
i love jax ... thank you for your work!
@TheAIEpiphany
2 жыл бұрын
You're welcome!
It's really good Tutorial!! thx :)
Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)
Thank you for the amazing content. Greetings from Spain
@TheAIEpiphany
2 жыл бұрын
Gracias y saludos desde Belgrado! 🙏 Tengo muchos amigos en España.
Thank you so much !
JAX = Just After eXecution (related to the tracing behaviour) JIT = Just In Time (related to the compilation behaviour)
@TheAIEpiphany
2 жыл бұрын
Thank you! A couple more people also pointed it out. Makes sense
ty ty ty ty ty for this video
JAX is Just After eXecution, represent the paradigm of tracing and transform (grad, vmap, jit,..) after the first execution.
@TheAIEpiphany
2 жыл бұрын
Hmmm, source?
@PhucLe-qs7nx
2 жыл бұрын
@@TheAIEpiphany Sorry I can't remember it now. But it somewhere in the documentation or a Jax's github issue/discussion,
@TheAIEpiphany
2 жыл бұрын
@@PhucLe-qs7nx Thanks in any case! One of the other comments mentioned it simply stands for Jit Autograd XLA. 😄 That sounds reasonable as well.
great video!
Saw your video retweeted by someone, watched it and subbed, because your content is great :) How often will you be uploading the following Jax vids?
@TheAIEpiphany
2 жыл бұрын
Thanks! Next one tomorrow or Thursday.
Thank you for a great content!
@TheAIEpiphany
Жыл бұрын
Thanks!
I am glad that I found this channel
@TheAIEpiphany
2 жыл бұрын
Welcome 🚀
Finally the Jax !!!!!!!
Thanks alot. Keep up the good work. Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.
@TheAIEpiphany
2 жыл бұрын
Of course what did I do? 😂
@adels1388
2 жыл бұрын
@@TheAIEpiphany you wrote x1*2+x2*2+x3*2. I replaced + with comma :)
@TheAIEpiphany
2 жыл бұрын
@@adels1388 My bad 😅 Thanks for noticing, the printed result was correct...
Thank you so much ..
At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x? Code to reproduce: def switch(x): print("traced") if x > 10.: return 1. else: return 0. jit_switch = jit(switch, static_argnums=(0,)) x = 5 jit_switch(x) x = 16 jit_switch(x) ''' Output: traced DeviceArray(0., dtype=float32, weak_type=True) traced DeviceArray(1., dtype=float32, weak_type=True) '''
Can we say that if we made all arguments static, then it will be as good as normal code without jax? Thank you for these videos btw
Would it be better to use Julia and not have to worry about the gotchas? And still get the performance.
Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.
@TheAIEpiphany
2 жыл бұрын
I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum. f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f. Hope that makes sense. You can always consult the docs and experiment yourself.
@Gannicus99
2 жыл бұрын
Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.
Thanks for the video. I have a question, how do i run tensorflow jax on browser? (Not in an online notebook)
amazing content
@TheAIEpiphany
2 жыл бұрын
Thank you!
Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters. Closure variables are hacky, but they are valid python code. Just not in JAX.
What font are you using for the Colab notebook?
Hi, are there any resources on how to freeze certain layers of the network for transfer learning?
@TheAIEpiphany
2 жыл бұрын
jax.lax.stop_gradient
Hi ! Thank you for your video. Is not that very similar tu Numba ?
bru youre great
Is it possible to use jax with python statsmodel?
Noice
great content, horribile font ;)
@TheAIEpiphany
2 жыл бұрын
Hahaha thank you and thank you!
Why not julia lang?
@TheAIEpiphany
2 жыл бұрын
Why would this video imply that you shouldn't give Julia a shot? I may make a video on Julia in the future. I personally wanted to learn JAX since I'll be using it in DeepMind.
at kzread.info/dash/bejne/haeo19iMXde8k5M.html, what I noticed was that when I tried, print(grad(f_jit)(2.)), even with the static_argnums.
Can you make a video "how to install jax in anaconda, python and other python Ide
JAX = Just Autograd and Xla
one hour of nothing lol
49:00 jnp.reshape(x, (np.prod(x.shape),)) works.