Skip to content

Commit 5866e41

Browse files
committed
Add FLASH_ATTN_HDIMS option to limit kernel compilation
In many applications model head dimensions are known in advance and it's possible to opt-out of compiling ones that will never be used, even regardless of model choice. Signed-off-by: Tin Švagelj <tin.svagelj@live.com>
1 parent 226c95d commit 5866e41

2 files changed

Lines changed: 113 additions & 83 deletions

File tree

CMakeLists.txt

Lines changed: 30 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ option(BUILD_TESTS "Compile the tests" OFF)
2323
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
2424
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
2525
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
26+
set(FLASH_ATTN_HDIMS "" CACHE STRING "Head dimensions to compile for flash attention (e.g. '32;64'). Empty means all.")
2627
option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF)
2728

2829
MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}")
@@ -606,74 +607,36 @@ if (WITH_CUDA)
606607
endif()
607608
if (WITH_FLASH_ATTN)
608609
add_definitions(-DCT2_WITH_FLASH_ATTN)
609-
list(APPEND SOURCES
610-
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
611-
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
612-
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
613-
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
614-
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
615-
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
616-
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
617-
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
618-
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
619-
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
620-
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
621-
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
622-
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
623-
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
624-
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
625-
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
626-
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
627-
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
628-
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
629-
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
630-
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
631-
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
632-
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
633-
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
634-
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
635-
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
636-
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
637-
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
638-
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
639-
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
640-
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
641-
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
642-
)
643610

644-
set_source_files_properties(
645-
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
646-
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
647-
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
648-
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
649-
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
650-
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
651-
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
652-
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
653-
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
654-
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
655-
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
656-
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
657-
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
658-
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
659-
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
660-
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
661-
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
662-
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
663-
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
664-
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
665-
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
666-
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
667-
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
668-
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
669-
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
670-
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
671-
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
672-
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
673-
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
674-
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
675-
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
676-
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
611+
set(_ALL_FLASH_HDIMS 32 64 96 128 160 192 224 256)
612+
if(FLASH_ATTN_HDIMS)
613+
set(_FLASH_HDIMS ${FLASH_ATTN_HDIMS})
614+
else()
615+
set(_FLASH_HDIMS ${_ALL_FLASH_HDIMS})
616+
endif()
617+
618+
message(STATUS "Flash attention head dimensions: ${_FLASH_HDIMS}")
619+
620+
# Define which hdims are compiled so HEADDIM_SWITCH can limit instantiation
621+
foreach(_hdim ${_FLASH_HDIMS})
622+
add_definitions(-DCT2_FLASH_ATTN_HDIM_${_hdim})
623+
endforeach()
624+
if(FLASH_ATTN_HDIMS)
625+
add_definitions(-DCT2_FLASH_ATTN_HDIMS_RESTRICTED)
626+
endif()
627+
628+
set(_FLASH_ATTN_SOURCES "")
629+
foreach(_hdim ${_FLASH_HDIMS})
630+
list(APPEND _FLASH_ATTN_SOURCES
631+
src/ops/flash-attention/flash_fwd_hdim${_hdim}_bf16_sm80.cu
632+
src/ops/flash-attention/flash_fwd_hdim${_hdim}_fp16_sm80.cu
633+
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_bf16_sm80.cu
634+
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_fp16_sm80.cu
635+
)
636+
endforeach()
637+
638+
list(APPEND SOURCES ${_FLASH_ATTN_SOURCES})
639+
set_source_files_properties(${_FLASH_ATTN_SOURCES}
677640
PROPERTIES COMPILE_FLAGS "--use_fast_math")
678641
endif()
679642
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)

include/ctranslate2/ops/flash-attention/static_switch.h

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,31 +78,98 @@
7878
} \
7979
}()
8080

81+
// When FLASH_ATTN_HDIMS is restricted via cmake, only instantiate selected
82+
// head dimensions. Others throw at runtime instead of generating link-time
83+
// symbol references. CT2_FLASH_ATTN_HDIM_N is defined per compiled hdim.
84+
#define _HEADDIM_DISPATCH(DIM, ...) \
85+
constexpr static int kHeadDim = DIM; \
86+
return __VA_ARGS__();
87+
88+
#define _HEADDIM_UNSUPPORTED(DIM) \
89+
throw std::runtime_error( \
90+
"Flash attention head dim " #DIM " not compiled. " \
91+
"Rebuild CTranslate2 with FLASH_ATTN_HDIMS including " #DIM);
92+
93+
#ifndef CT2_FLASH_ATTN_HDIM_32
94+
#define _HEADDIM_CASE_32(...) _HEADDIM_UNSUPPORTED(32)
95+
#else
96+
#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__)
97+
#endif
98+
#ifndef CT2_FLASH_ATTN_HDIM_64
99+
#define _HEADDIM_CASE_64(...) _HEADDIM_UNSUPPORTED(64)
100+
#else
101+
#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__)
102+
#endif
103+
#ifndef CT2_FLASH_ATTN_HDIM_96
104+
#define _HEADDIM_CASE_96(...) _HEADDIM_UNSUPPORTED(96)
105+
#else
106+
#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__)
107+
#endif
108+
#ifndef CT2_FLASH_ATTN_HDIM_128
109+
#define _HEADDIM_CASE_128(...) _HEADDIM_UNSUPPORTED(128)
110+
#else
111+
#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__)
112+
#endif
113+
#ifndef CT2_FLASH_ATTN_HDIM_160
114+
#define _HEADDIM_CASE_160(...) _HEADDIM_UNSUPPORTED(160)
115+
#else
116+
#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__)
117+
#endif
118+
#ifndef CT2_FLASH_ATTN_HDIM_192
119+
#define _HEADDIM_CASE_192(...) _HEADDIM_UNSUPPORTED(192)
120+
#else
121+
#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__)
122+
#endif
123+
#ifndef CT2_FLASH_ATTN_HDIM_224
124+
#define _HEADDIM_CASE_224(...) _HEADDIM_UNSUPPORTED(224)
125+
#else
126+
#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__)
127+
#endif
128+
#ifndef CT2_FLASH_ATTN_HDIM_256
129+
#define _HEADDIM_CASE_256(...) _HEADDIM_UNSUPPORTED(256)
130+
#else
131+
#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__)
132+
#endif
133+
134+
// When all hdims are compiled (no FLASH_ATTN_HDIMS set), all CT2_FLASH_ATTN_HDIM_*
135+
// macros are undefined and _HEADDIM_CASE_* defaults to _HEADDIM_UNSUPPORTED.
136+
// Fix: when not restricted, define all as dispatching.
137+
#ifndef CT2_FLASH_ATTN_HDIMS_RESTRICTED
138+
#undef _HEADDIM_CASE_32
139+
#undef _HEADDIM_CASE_64
140+
#undef _HEADDIM_CASE_96
141+
#undef _HEADDIM_CASE_128
142+
#undef _HEADDIM_CASE_160
143+
#undef _HEADDIM_CASE_192
144+
#undef _HEADDIM_CASE_224
145+
#undef _HEADDIM_CASE_256
146+
#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__)
147+
#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__)
148+
#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__)
149+
#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__)
150+
#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__)
151+
#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__)
152+
#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__)
153+
#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__)
154+
#endif
155+
81156
#define HEADDIM_SWITCH(HEADDIM, ...) \
82157
[&] { \
83158
if (HEADDIM <= 32) { \
84-
constexpr static int kHeadDim = 32; \
85-
return __VA_ARGS__(); \
159+
_HEADDIM_CASE_32(__VA_ARGS__) \
86160
} else if (HEADDIM <= 64) { \
87-
constexpr static int kHeadDim = 64; \
88-
return __VA_ARGS__(); \
161+
_HEADDIM_CASE_64(__VA_ARGS__) \
89162
} else if (HEADDIM <= 96) { \
90-
constexpr static int kHeadDim = 96; \
91-
return __VA_ARGS__(); \
163+
_HEADDIM_CASE_96(__VA_ARGS__) \
92164
} else if (HEADDIM <= 128) { \
93-
constexpr static int kHeadDim = 128; \
94-
return __VA_ARGS__(); \
165+
_HEADDIM_CASE_128(__VA_ARGS__) \
95166
} else if (HEADDIM <= 160) { \
96-
constexpr static int kHeadDim = 160; \
97-
return __VA_ARGS__(); \
167+
_HEADDIM_CASE_160(__VA_ARGS__) \
98168
} else if (HEADDIM <= 192) { \
99-
constexpr static int kHeadDim = 192; \
100-
return __VA_ARGS__(); \
169+
_HEADDIM_CASE_192(__VA_ARGS__) \
101170
} else if (HEADDIM <= 224) { \
102-
constexpr static int kHeadDim = 224; \
103-
return __VA_ARGS__(); \
171+
_HEADDIM_CASE_224(__VA_ARGS__) \
104172
} else if (HEADDIM <= 256) { \
105-
constexpr static int kHeadDim = 256; \
106-
return __VA_ARGS__(); \
173+
_HEADDIM_CASE_256(__VA_ARGS__) \
107174
} \
108175
}()

0 commit comments

Comments
 (0)