-
Notifications
You must be signed in to change notification settings - Fork 479
Expand file tree
/
Copy pathmodule.cc
More file actions
93 lines (72 loc) · 3.06 KB
/
module.cc
File metadata and controls
93 lines (72 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <pybind11/pybind11.h>
#include <ctranslate2/devices.h>
#include <ctranslate2/models/model.h>
#include <ctranslate2/random.h>
#include <ctranslate2/types.h>
#include <ctranslate2/utils.h>
#include "module.h"
#include "utils.h"
static std::unordered_set<std::string>
get_supported_compute_types(const std::string& device_str, const int device_index) {
const auto device = ctranslate2::str_to_device(device_str);
const bool support_bfloat16 = ctranslate2::mayiuse_bfloat16(device, device_index);
const bool support_float16 = ctranslate2::mayiuse_float16(device, device_index);
const bool support_int16 = ctranslate2::mayiuse_int16(device, device_index);
const bool support_int8 = ctranslate2::mayiuse_int8(device, device_index);
std::unordered_set<std::string> compute_types;
compute_types.emplace("float32");
if (support_bfloat16)
compute_types.emplace("bfloat16");
if (support_float16)
compute_types.emplace("float16");
if (support_int16)
compute_types.emplace("int16");
if (support_int8) {
compute_types.emplace("int8");
compute_types.emplace("int8_float32");
if (support_float16)
compute_types.emplace("int8_float16");
if (support_bfloat16)
compute_types.emplace("int8_bfloat16");
}
return compute_types;
}
PYBIND11_MODULE(_ext, m)
{
py::options options;
options.disable_enum_members_docstring();
m.def("contains_model", &ctranslate2::models::contains_model, py::arg("path"),
"Helper function to check if a directory seems to contain a CTranslate2 model.");
m.def("get_cuda_device_count", &ctranslate2::get_gpu_count,
"Returns the number of visible GPU devices.");
m.def("get_supported_compute_types", &get_supported_compute_types,
py::arg("device"),
py::arg("device_index")=0,
R"pbdoc(
Returns the set of supported compute types on a device.
Arguments:
device: Device name (cpu or cuda).
device_index: Device index.
Example:
>>> ctranslate2.get_supported_compute_types("cpu")
{'int16', 'float32', 'int8', 'int8_float32'}
>>> ctranslate2.get_supported_compute_types("cuda")
{'float32', 'int8_float16', 'float16', 'int8', 'int8_float32'}
)pbdoc");
m.def("set_random_seed", &ctranslate2::set_random_seed, py::arg("seed"),
"Sets the seed of random generators.");
ctranslate2::python::register_logging(m);
ctranslate2::python::register_storage_view(m);
ctranslate2::python::register_translation_stats(m);
ctranslate2::python::register_translation_result(m);
ctranslate2::python::register_scoring_result(m);
ctranslate2::python::register_generation_result(m);
ctranslate2::python::register_translator(m);
ctranslate2::python::register_generator(m);
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wavlm(m);
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
}