【Hacker News搬运】我们在AMD GPU上对Llama 405B进行了微调
-
Title: We fine-tuned Llama 405B on AMD GPUs
我们在AMD GPU上对Llama 405B进行了微调
Text: Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: <a href="https://github.com/felafax/felafax">https://github.com/felafax/felafax</a><p>We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).<p>Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g.,
torch.cuda
or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.<p>Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.<p>Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.<p>We'd love to hear your thoughts on our vision and repo!嘿,HN,我们最近使用JAX而不是PyTorch在8xAMD MI300x GPU上对llama3.1 405B模型进行了微调。JAX™;先进的分片API使我们能够实现出色的性能。查看我们的博客文章,了解我们使用的很酷的分片技巧。我们;我还开源了代码:<a href=“https:/;/ github.com/-felafax//felafax”>https:"/;github.com;felafax;felafax</a><p>我们;我们是一家小型初创公司,正在构建人工智能基础设施,用于在非NVIDIA硬件(TPU、AMD、Trainium)上微调和服务LLM<p> 问题:许多公司都试图让PyTorch在AMD GPU上运行,但我们认为这是一条危险的道路。PyTorch在很多方面与NVIDIA生态系统紧密交织在一起(例如,“torch.cuda”或scaled_dot_product_attention是作为PyTorch函数公开的NVIDIA cuda内核)。因此,要在非NVIDIA硬件上运行PyTorch代码,需要;有很多“;消除无国籍者";这是必须做的<p> 解决方案:我们认为JAX更适合非NVIDIA硬件。在JAX中,ML模型代码编译为依赖硬件的HLO图,然后在硬件特定优化之前由XLA编译器进行优化。这种干净的分离使我们能够在Google TPU和AMD GPU上运行相同的LLaMA3 JAX代码,而无需更改<p> 作为一家公司,我们的战略是提前投资将模型移植到JAX,然后利用其框架和XLA内核从非NVIDIA后端获得最大性能。这就是为什么我们首先将Llama 3.1从PyTorch移植到JAX,现在相同的JAX模型在TPU上运行良好,在AMD GPU上运行良好<p> 我们;我很想听听你对我们的愿景和回购的看法!
Url: https://publish.obsidian.md/felafax/pages/Tune+Llama3+405B+on+AMD+MI300x+(our+journey)
很抱歉,但作为一个AI,我无法直接访问外部链接或网络资源,包括您提供的Obsidian页面链接。因此,我无法查看该页面内容,也无法对其进行分析和总结。 如果您能提供页面的内容摘要或具体信息,我可以尝试帮助您进行分析和总结。如果您需要帮助理解如何使用JinaReader或其他工具来抓取和分析内容,请提供相关的指导或信息,我会尽力提供帮助。
Post by: felarof
Comments:
felarof: Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: <a href="https://github.com/felafax/felafax">https://github.com/felafax/felafax</a><p>We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).<p>Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g.,
torch.cuda
or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.<p>Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.<p>Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.<p>We'd love to hear your thoughts on our vision and repo!felarof: 嘿,HN,我们最近使用JAX而不是PyTorch在8xAMD MI300x GPU上对llama3.1 405B模型进行了微调。JAX;先进的分片API使我们能够实现出色的性能。查看我们的博客文章,了解我们使用的很酷的分片技巧。我们;我还开源了代码:<a href=“https:/;/ github.com/-felafax//felafax”>https:"/;github.com;felafax;felafax</a><p>我们;我们是一家小型初创公司,正在构建人工智能基础设施,用于在非NVIDIA硬件(TPU、AMD、Trainium)上微调和服务LLM<p> 问题:许多公司都试图让PyTorch在AMD GPU上运行,但我们认为这是一条危险的道路。PyTorch在很多方面与NVIDIA生态系统紧密交织在一起(例如,“torch.cuda”或scaled_dot_product_attention是作为PyTorch函数公开的NVIDIA cuda内核)。因此,要在非NVIDIA硬件上运行PyTorch代码,需要;有很多“;消除无国籍者";这是必须做的<p> 解决方案:我们认为JAX更适合非NVIDIA硬件。在JAX中,ML模型代码编译为依赖硬件的HLO图,然后在硬件特定优化之前由XLA编译器进行优化。这种干净的分离使我们能够在Google TPU和AMD GPU上运行相同的LLaMA3 JAX代码,而无需更改<p> 作为一家公司,我们的战略是提前投资将模型移植到JAX,然后利用其框架和XLA内核从非NVIDIA后端获得最大性能。这就是为什么我们首先将Llama 3.1从PyTorch移植到JAX,现在相同的JAX模型在TPU上运行良好,在AMD GPU上运行良好<p> 我们;我很想听听你对我们的愿景和回购的看法!
3abiton: Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?
3abiton: 首先,伟大的工作!一年前,我尝试过AMD GPU和ROCm支持,很明显AMD距离赶上Nvidia还有很长的路要走。虽然选择JAX是一种有趣的方法,但偏离pytorch(作为ML的标准库)对您来说有什么挑战?
latchkey: Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].<p>I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.<p>I don't see very many numbers posted. What MFU did you get?<p>[0] <a href="https://x.com/HotAisle/status/1837580046732874026" rel="nofollow">https://x.com/HotAisle/status/1837580046732874026</a>
latchkey: 干得好!这个周末,我自己也在用405B玩推理方面的游戏<p> 我;我不相信;火炬;真的很糟糕,因为AMD版本的PyTorch只是为你翻译。更像是一个命名问题。事实上,抓取rocm:pytorch容器和抓取rocm:jax容器一样容易<p> 我不知道;我没有看到太多的数字。你得到了什么MFU<p> [0]<a href=“https:”x.com“HotAisle”状态“1837580046732874026”rel=“nofollow”>https:”/;x.com;HotAisle;状态;1837580046732874026</a>
yeahwhatever10: Where is the performance data?
yeahwhatever10: 性能数据在哪里?
manojlds: Thought this was a post from Obsidian at first. Why haven't they done the GitHub.com vs GitHub.io thing yet.
manojlds: 起初以为这是黑曜石的帖子。为什么选择haven;他们还没有做过GitHub.com和GitHub.io的比赛。