Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions codon/cir/llvm/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,16 +945,19 @@ void patchPTXVar(llvm::Module *M, llvm::GlobalValue *ptxVar,
}
} // namespace

void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) {
llvm::LLVMContext &context = M->getContext();
std::unique_ptr<llvm::Module> prepareGPUmodule(llvm::Module *M){
std::unique_ptr<llvm::Module> clone = llvm::CloneModule(*M);
clone->setTargetTriple(llvm::Triple::normalize(GPU_TRIPLE));
clone->setDataLayout(GPU_DL);
if (isFastMathOn()) {
clone->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "nvvm-reflect-ftz",
1);
}
return clone;
}

void applyGPUTransformations(llvm::Module *M, std::unique_ptr<llvm::Module> clone, const std::string &ptxFilename) {
llvm::LLVMContext &context = M->getContext();
llvm::NamedMDNode *nvvmAnno = clone->getOrInsertNamedMetadata("nvvm.annotations");
std::vector<llvm::Function *> kernelCandidates;
std::vector<llvm::GlobalValue *> kernels;
Expand Down
4 changes: 3 additions & 1 deletion codon/cir/llvm/gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ namespace ir {
/// annotation)
/// @param ptxFilename Filename for output PTX code; empty to use filename based on
/// module
void applyGPUTransformations(llvm::Module *module, const std::string &ptxFilename = "");

std::unique_ptr<llvm::Module> prepareGPUmodule(llvm::Module *module);
void applyGPUTransformations(llvm::Module *module, std::unique_ptr<llvm::Module> clone, const std::string &ptxFilename = "");

} // namespace ir
} // namespace codon
81 changes: 54 additions & 27 deletions codon/cir/llvm/optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -995,28 +995,9 @@ llvm::cl::opt<bool>
llvm::cl::desc("Disable architecture-specific optimizations"),
llvm::cl::init(false));

void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit,
PluginManager *plugins) {
applyDebugTransformations(module, debug, jit);
applyFastMathTransformations(module);

llvm::LoopAnalysisManager lam;
llvm::FunctionAnalysisManager fam;
llvm::CGSCCAnalysisManager cgam;
llvm::ModuleAnalysisManager mam;
auto machine = getTargetMachine(module, /*setFunctionAttributes=*/true);
llvm::PassBuilder pb(machine.get());

llvm::Triple moduleTriple(module->getTargetTriple());
llvm::TargetLibraryInfoImpl tlii(moduleTriple);
fam.registerPass([&] { return llvm::TargetLibraryAnalysis(tlii); });

pb.registerModuleAnalyses(mam);
pb.registerCGSCCAnalyses(cgam);
pb.registerFunctionAnalyses(fam);
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);

void registerCodonLLVMOptimizationPasses(llvm::PassBuilder &pb, bool debug,
PluginManager *plugins, bool includeNative,
bool includePlugins) {
pb.registerLateLoopOptimizationsEPCallback(
[&](llvm::LoopPassManager &pm, llvm::OptimizationLevel opt) {
if (opt.isOptimizingForSpeed())
Expand All @@ -1035,14 +1016,43 @@ void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit,
}
});

if (!DisableNative)
if (!DisableNative && includeNative)
addNativeLLVMPasses(&pb);

if (plugins) {
if (includePlugins && plugins) {
for (auto *plugin : *plugins) {
plugin->dsl->addLLVMPasses(&pb, debug);
}
}
}

void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit,
PluginManager *plugins, bool includeNative,
bool includePlugins) {
applyDebugTransformations(module, debug, jit);
applyFastMathTransformations(module);

llvm::LoopAnalysisManager lam;
llvm::FunctionAnalysisManager fam;
llvm::CGSCCAnalysisManager cgam;
llvm::ModuleAnalysisManager mam;
auto machine =
includeNative ? getTargetMachine(module, /*setFunctionAttributes=*/true)
: std::unique_ptr<llvm::TargetMachine>();
llvm::PassBuilder pb(machine.get());

llvm::Triple moduleTriple(module->getTargetTriple());
llvm::TargetLibraryInfoImpl tlii(moduleTriple);
fam.registerPass([&] { return llvm::TargetLibraryAnalysis(tlii); });

pb.registerModuleAnalyses(mam);
pb.registerCGSCCAnalyses(cgam);
pb.registerFunctionAnalyses(fam);
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);

registerCodonLLVMOptimizationPasses(pb, debug, plugins, includeNative,
includePlugins);

if (debug) {
llvm::ModulePassManager mpm =
Expand Down Expand Up @@ -1074,17 +1084,34 @@ void verify(llvm::Module *module) {

void optimize(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) {
verify(module);
std::unique_ptr<llvm::Module> GPUmodule;
{
TIME("preparing/gpu");
GPUmodule = prepareGPUmodule(module);
}
{
TIME("llvm/opt1");
runLLVMOptimizationPasses(module, debug, jit, plugins);
runLLVMOptimizationPasses(module, debug, jit, plugins, true, true);
}
if (!debug) {
TIME("llvm/opt2");
runLLVMOptimizationPasses(module, debug, jit, plugins);
runLLVMOptimizationPasses(module, debug, jit, plugins, true, true);
}
{
TIME("llvm/gpuopt1");
runLLVMOptimizationPasses(GPUmodule.get(), debug, jit, plugins,
/*includeNative=*/false,
/*includePlugins=*/false);
}
{
TIME("llvm/gpuopt2");
runLLVMOptimizationPasses(GPUmodule.get(), debug, jit, plugins,
/*includeNative=*/false,
/*includePlugins=*/false);
}
{
TIME("llvm/gpu");
applyGPUTransformations(module);
applyGPUTransformations(module, std::move(GPUmodule));
}
verify(module);
}
Expand Down
Loading