Machine Learning with JAX - From Hero to HeroPro+ | Tutorial #2

❤️ Become The AI Epiphany Patreon ❤️
/ theaiepiphany
👨‍👩‍👧‍👦 Join our Discord community 👨‍👩‍👧‍👦
/ discord
This is the second video in the JAX series of tutorials.
JAX is a powerful and increasingly more popular ML library built by the Google Research team. The 2 most popular deep learning frameworks built on top of JAX are Haiku (DeepMInd) and Flax (Google Research).
In this video, we continue on and learn additional components needed to train complex ML models (such as NNs) on multiple machines!
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
✅ My GitHub repo: github.com/gordicaleksa/get-s...
✅ JAX GitHub: github.com/google/jax
✅ JAX docs: jax.readthedocs.io/
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
⌚️ Timetable:
00:00:00 My get started with JAX repo
00:01:25 Stateful to stateless conversion
00:11:00 PyTrees in depth
00:17:45 Training an MLP in pure JAX
00:27:30 Custom PyTrees
00:32:55 Parallelism in JAX (TPUs example)
00:40:05 Communication between devices
00:46:05 value_and_grad and has_aux
00:48:45 Training an ML model on multiple machines
00:58:50 stop grad, per example grads
01:06:45 Implementing MAML in 3 lines
01:08:35 Outro
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💰 BECOME A PATREON OF THE AI EPIPHANY ❤️
If these videos, GitHub projects, and blogs help you,
consider helping me out by supporting me on Patreon!
The AI Epiphany - / theaiepiphany
One-time donation - www.paypal.com/paypalme/theai...
Huge thank you to these AI Epiphany patreons:
Eli Mahler
Petar Veličković
Bartłomiej Danek
Zvonimir Sabljic
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
💼 LinkedIn - / aleksagordic
🐦 Twitter - / gordic_aleksa
👨‍👩‍👧‍👦 Discord - / discord
📺 KZread - / theaiepiphany
📚 Medium - / gordicaleksa
💻 GitHub - github.com/gordicaleksa
📢 AI Newsletter - aiepiphany.substack.com/
▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬
#jax #machinelearning #framework

Пікірлер: 34

  • @akashraut3581
    @akashraut35812 жыл бұрын

    Next video we go from HeroPro+ to ultraHeroProUltimateMaster+

  • @TheAIEpiphany
    @TheAIEpiphany2 жыл бұрын

    Code is here: github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb Just click the colab button and you're ready to play with the code yourself.

  • @varunsai9736

    @varunsai9736

    2 жыл бұрын

    Thanks for your effort .. will be of a great help for every one in the community

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    @@varunsai9736 I surely hope so! Thank you. 🙏

  • @sprmsv
    @sprmsv4 ай бұрын

    These videos are really great and helpful. Right to the point, no wasting of time. Thanks!!

  • @mariolinovalencia7776
    @mariolinovalencia7776 Жыл бұрын

    Such a good job. Great video

  • @tauhidkhan4849
    @tauhidkhan48492 жыл бұрын

    Thanks for this JAX series ❤️ I am planning to Implement a CV research paper in JAX and FLAX. It will be of great help thanks 👍

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Awesome! 🙏 Flax video is coming soon as well! 😄

  • @jimmyweber3166
    @jimmyweber316611 ай бұрын

    [Combination of Gradient Loss Across Devices] Hello guys, Firstly, thank you so much for the amazing tutorials ​ @TheAIEpiphany ! Secondly, I'd like to clarify the mathematics behind the combination of gradients of loss across multiple devices @55:34... The question arises: Is it correct to compute gradient as the average of gradients from different devices? I mean, will it give the same gradient as if we were only doing it on one device ? The answer is YES it is correct, but only if the Loss is defined as a weighted sum across the samples. This is supported by the fact that the gradient of a weighted sum is equivalent to the weighted sum of gradients. Thus, in this context, the Loss is a mean across samples (or batches), making it a weighted sum. The same principle would also be applicable for the cross-entropy Loss. Additionally, the batches size across the devices should be the same. Otherwise it would not be a mean, but instead a weighted sum (with the weights of each device equal the normalised batch size allocated to this device). Hope my comment is clear and will demystify some questions that one would have wondered :) PS : For the one that would not have understood my comment, the conclusion is : "it is good to do as ​ @TheAIEpiphany is doing" (because we are dealing with MSE/Cross-Entropy and batch size across devices is the same)

  • @eriknorlander5814
    @eriknorlander58142 жыл бұрын

    you are an inspiration sir! Thanks for these videos!

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Thank you Erik!

  • @objectobjectobject4707
    @objectobjectobject47072 жыл бұрын

    Odlican video,vise sam na pocetku da uradim one taskove sto si preporucio za ML pocetnike za MNIST,ali jedan od omiljenih kanal za ML definitivno :)

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Haha hvala puno! 😄 Samo cepaj

  • @jonathansum9084
    @jonathansum90842 жыл бұрын

    You are very nice! Thank you for your video :D

  • @mathmo
    @mathmo Жыл бұрын

    Great videos, Aleksa! I found the name of x and y arguments in the MLP forward function confusing, since they are really batches of xs and ys. You could used vmap there instead of writing in already batched form, but I guess it's a good exercise for your viewers to rewrite it in unbatched form and apply vmap :)

  • @nickkonovalchuk9280
    @nickkonovalchuk92802 жыл бұрын

    Both useful knowledge and memes in a single video!

  • @juansensio8323
    @juansensio83232 жыл бұрын

    Great series of tutorials, congrats ! It would be nice to see a comparative in terms of performance between jax and pytorch for some real-world use case (gpu and tpu) :)

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    That video is in the pipeline as well ^^

  • @juansensio8323

    @juansensio8323

    2 жыл бұрын

    @@TheAIEpiphany Really looking forward to it !

  • @lidiias5976
    @lidiias59765 ай бұрын

    Also, would there be any tutorial on JAXOpt? It would be highly appreciated! thanks for your videos

  • @lidiias5976
    @lidiias59765 ай бұрын

    jax.tree_multimap was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16. What can we use to replace this function in "Training an MLP in pure JAX" part of the video?

  • @thomashirtz
    @thomashirtz2 жыл бұрын

    8:53 why do we want to return the state if it doesn't change ? (in general I mean) Is it just a good practice and so you don't need to think about the fact that it may or may not change and always return it ?

  • @arturtoshev3765
    @arturtoshev37652 жыл бұрын

    In the middle of the notebook, I saw the comment "# notice how we do jit only at the highest level - XLA will have plenty of space to optimize". Do you have a reference on when to jit only at the highest level and when to jit single nested functions and what the advantages/risks of each approach are? I used to jit every single function until now, so I'm curious what I can gain by a single high-level jit.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Simple: always jit on the highest level possible. If for some reason you have to avoid jitting some code - go a level lower. And repeat.

  • @jimmyweber3166

    @jimmyweber3166

    11 ай бұрын

    To precise : it is true, ONLY if your function has some part that could be parallelised. For example, if you define a train function that has a for loop across the number of epochs, you should not jit it ! You should rather jit one level lower (here the update function). Indeed, the epochs can not be parallelised (you need the epoch to be done before starting the next one... ). If you were to jit the train function, the compilation time would take way longer, and would have to compile for every epochs... But @TheAIEpiphany is explaining it way better than me "why you should not jit everything" in his first tutorial :)

  • @santiagogomezhernandez9
    @santiagogomezhernandez92 жыл бұрын

    Nice video. Do you think the research community will embrace JAX as they did with PyTorch?

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    It depends on many factors, I guess we'll have to wait and see. If Google starts pushing JAX more than TF then sure, I'd be much more confident. But even like this I see that JAX is getting more and more love from the research community.

  • @promethesured
    @promethesured2 ай бұрын

    So the parallelism you demo'd with pmap, that was data parallelism correct? replicating the whole model across all the devices, sending different batches to each device, and then collecting the mean model back on the host device after forward and backwards pass? am i understanding that correctly?

  • @cheese-power
    @cheese-power11 ай бұрын

    4:29 why it’s called key and subkey, but not subkey and subsubkey? Aren’t the two keys on the same level of the descendant tree?

  • @1potdish271
    @1potdish2712 жыл бұрын

    What problem JAX is solving as compare to PyTroch?

  • @jawadmansoor6064
    @jawadmansoor60642 жыл бұрын

    But how can you train model in Jax? If you set up everything from scratch then I think it is not very useful, I am sure pytorch/TF are not so much behind in terms of speed.

  • @TheAIEpiphany

    @TheAIEpiphany

    2 жыл бұрын

    Check out the 3rd video for building models from scratch in pure JAX. As for the frameworks there are Flax and Haiku I'll cover them next (they are equivalent to TF/PyTorch).

  • @tempdeltavalue
    @tempdeltavalue2 жыл бұрын

    pmap, vmap - 🤯

  • @thomashirtz
    @thomashirtz2 жыл бұрын

    I wished you have passed some more minutes on the last example 😔