@@ -14,6 +14,7 @@ option(WITH_OPENBLAS "Compile with OpenBLAS backend" OFF)
1414option (WITH_RUY "Compile with Ruy backend" OFF )
1515option (WITH_CUDA "Compile with CUDA backend" OFF )
1616option (WITH_CUDNN "Compile with cuDNN backend" OFF )
17+ option (WITH_HIP "Compile with HIP backend" OFF )
1718option (CUDA_DYNAMIC_LOADING "Dynamically load CUDA libraries at runtime" OFF )
1819option (ENABLE_CPU_DISPATCH "Compile CPU kernels for multiple ISA and dispatch at runtime" ON )
1920option (ENABLE_PROFILING "Compile with profiling support" OFF )
@@ -491,6 +492,9 @@ ELSEIF (ENABLE_ADDRESS_SANITIZER)
491492ENDIF ()
492493
493494if (WITH_CUDA)
495+ if (WITH_HIP)
496+ message (FATAL_ERROR "WITH_CUDA=ON incompatible with WITH_HIP=ON" )
497+ endif ()
494498 find_package (CUDA 11.0 REQUIRED )
495499 list (APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR} /cmake)
496500 if (WITH_TENSOR_PARALLEL)
@@ -679,6 +683,94 @@ if (WITH_CUDA)
679683 )
680684
681685
686+ elseif (WITH_HIP)
687+ if (WITH_TENSOR_PARALLEL)
688+ message (FATAL_ERROR "WITH_HIP=ON incompatible with WITH_TENSOR_PARALLEL=ON" )
689+ endif ()
690+ enable_language (HIP )
691+ set (CMAKE_CXX_STANDARD_REQUIRED ON )
692+ message (STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER} " )
693+ message (STATUS "CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES} " )
694+
695+ if (NOT DEFINED ENV{ROCM_PATH})
696+ set (ROCM_PATH /opt/rocm)
697+ else ()
698+ set (ROCM_PATH $ENV{ROCM_PATH} )
699+ endif ()
700+ list (APPEND CMAKE_PREFIX_PATH ${ROCM_PATH} )
701+
702+ find_package (hiprand REQUIRED )
703+ find_package (hipblas REQUIRED )
704+ find_package (rocprim REQUIRED )
705+ find_package (rocthrust REQUIRED )
706+ find_package (hipcub REQUIRED )
707+
708+ list (REMOVE_ITEM SOURCES
709+ src/ops/awq/dequantize.cc
710+ src/ops/awq/dequantize_cpu.cc
711+ src/ops/awq/gemm.cc
712+ src/ops/awq/gemm_cpu.cc
713+ src/ops/awq/gemv.cc
714+ src/ops/awq/gemv_cpu.cc
715+ )
716+ list (REMOVE_ITEM CUDA_SOURCES
717+ src/ops/awq/gemm_gpu.cu
718+ src/ops/awq/gemv_gpu.cu
719+ src/ops/awq/dequantize_gpu.cu
720+ )
721+ if (WITH_FLASH_ATTN)
722+ message (FATAL_ERROR "WITH_HIP=ON incompatible with WITH_FLASH_ATTN=ON" )
723+ endif ()
724+
725+ set_source_files_properties (${CUDA_SOURCES} PROPERTIES LANGUAGE HIP )
726+ set_source_files_properties (
727+ src/cpu/allocator.cc
728+ src/cpu/backend.cc
729+ src/cpu/cpu_info.cc
730+ src/cpu/cpu_isa.cc
731+ src/cpu/kernels.cc
732+ src/cpu/parallel.cc
733+ src/cpu/primitives.cc
734+ src/ops/alibi_add_cpu.cc
735+ src/ops/bias_add_cpu.cc
736+ src/ops/concat_split_slide_cpu.cc
737+ src/ops/conv1d_cpu.cc
738+ src/ops/dequantize_cpu.cc
739+ src/ops/gather_cpu.cc
740+ src/ops/gumbel_max_cpu.cc
741+ src/ops/layer_norm_cpu.cc
742+ src/ops/mean_cpu.cc
743+ src/ops/median_filter_cpu.cc
744+ src/ops/multinomial_cpu.cc
745+ src/ops/quantize_cpu.cc
746+ src/ops/rms_norm_cpu.cc
747+ src/ops/rotary_cpu.cc
748+ src/ops/softmax_cpu.cc
749+ src/ops/tile_cpu.cc
750+ src/ops/topk_cpu.cc
751+ src/ops/topp_mask_cpu.cc
752+ src/ops/nccl_ops_cpu.cc
753+ PROPERTIES LANGUAGE CXX
754+ )
755+ link_directories (${ROCM_PATH} /lib )
756+
757+ add_definitions (-DCT2_WITH_CUDA )
758+ add_definitions (-DCT2_USE_HIP )
759+
760+ add_library (${PROJECT_NAME}
761+ SHARED
762+ ${SOURCES}
763+ ${CUDA_SOURCES}
764+ )
765+
766+ add_compile_definitions (__HIP_PLATFORM_AMD__ )
767+ add_compile_definitions (__HIP_PLATFORM_HCC__ )
768+ target_include_directories (${PROJECT_NAME} PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR} /include ${ROCM_PATH} /include /include )
769+ target_link_libraries (${PROJECT_NAME} PRIVATE hiprand roc::hipblas roc::rocprim roc::rocthrust hip::hipcub )
770+
771+ set_target_properties (${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX )
772+
773+
682774elseif (WITH_CUDNN)
683775 message (FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON" )
684776else ()
0 commit comments