IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    Performance of Flash Attention and torch.compile()

    RobinDong发表于 2024-03-01 05:49:47
    love 0

    I am trying to build a small repo about multi-modal models (CLIP, ALBEF, BLIP etc). The GPT code is mainly from nanoGPT. Then I became inquisitive about the performance of “Flash Attention” and “torch.compile()”.

    The metrics with my original code (w/o Flash Attention, w/o torch.compile()):

    [100] loss: 4.0315 time 23.7708
    [200] loss: 4.0020 time 23.9010
    [300] loss: 3.8115 time 23.9407
    [400] loss: 3.7021 time 23.9785
    [500] loss: 3.6626 time 24.0076
    [600] loss: 3.7109 time 24.0060

    The metrics after adding Flash Attention:

    [100] loss: 4.1204 time 23.0655
    [200] loss: 3.8950 time 23.2243
    [300] loss: 3.9116 time 23.2714
    [400] loss: 3.7837 time 23.2864
    [500] loss: 3.8313 time 23.2993
    [600] loss: 3.9138 time 23.3255

    The metrics after adding Flash Attention and torch.compile()

    [100] loss: 3.9969 time 14.8842                                                                                               
    [200] loss: 3.8506 time 15.0004                                                                                               
    [300] loss: 3.8702 time 15.0050                               
    [400] loss: 3.7977 time 15.0061                                                                                               
    [500] loss: 3.7374 time 15.0492       
    [600] loss: 3.6589 time 15.0661 

    Seems “torch.compile()” is much more powerful than “Flash Attention”



沪ICP备19023445号-2号
友情链接