โครงการ PyTorch (ที่ปัจจุบัน Meta ยกให้ Linux Foundation ไปดูแลต่อแล้ว) เปิดตัว PyTorch 2.0 เวอร์ชันอัพเกรดครั้งใหญ่ที่รอคอยกันมานาน และทดสอบแบบพรีวิวมาสักระยะหนึ่งแล้ว
ฟีเจอร์ใหม่ที่สำคัญที่สุดคือ torch.compile ที่เป็น API หลักตัวใหม่ของ PyTorch ที่ช่วยเพิ่มประสิทธิภาพของโมเดล (เฉลี่ย 21% บนทศนิยม Float32 และ 51% บนทศนิยม AMP/Float16) และจะเป็นแกนหลักของ PyTorch ซีรีส์ 2.x ต่อไปในอนาคต ตอนนี้ torch.compile ยังเป็น "ตัวเลือก" (optional) เลือกใช้ได้ตามต้องการ และเข้ากันได้กับโค้ดเก่า 100%
เบื้องหลัง torch.compile ประกอบด้วยเทคโนโลยีหลายอย่าง ได้แก่ TorchDynamo, AOTAutograd, PrimTorch, TorchInductor ที่เขียนขึ้นมาใหม่ด้วย Python แทนชิ้นส่วนเดิมที่เป็น C++ (PyTorch บอกว่าจะพยายามมุ่งไปในทิศที่เป็น Python มากขึ้น) ชิ้นส่วนหนึ่งที่น่าสนใจคือ TorchInductor เป็น deep learning compiler ที่สร้างโค้ดขึ้นมาให้ทำงานได้บนสถาปัตยกรรมหลายแบบ โดยการทำงานบนจีพียู NVIDIA/AMD จะใช้ผ่าน OpenAI Triton อีกที
ของใหม่อีกอย่างคือ Accelerated Transformers หรือชื่อเดิม Better Transformer เป็นการเขียน Transformer API เวอร์ชันใหม่ให้มีประสิทธิภาพดีขึ้นกว่าเดิม รองรับรูปแบบการใช้งานใหม่ๆ และออกแบบสถาปัตยกรรมให้รองรับเคอร์เนล (ของ PyTorch) แบบคัสตอมด้วย
รายละเอียดของฟีเจอร์ใหม่ทั้งหมดอ่านได้จาก PyTorch 2.0
ที่มา - PyTorch
Comments
Flash Attention integrate เข้ามาใน stable ซักที รอมานาน