-
Notifications
You must be signed in to change notification settings - Fork 480
Expand file tree
/
Copy pathreplica_pool.h
More file actions
383 lines (318 loc) · 12.7 KB
/
replica_pool.h
File metadata and controls
383 lines (318 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
#pragma once
#include <chrono>
#include <future>
#include "batch_reader.h"
#include "models/model.h"
#include "thread_pool.h"
#include "utils.h"
namespace ctranslate2 {
struct ReplicaPoolConfig {
size_t num_threads_per_replica = 0;
long max_queued_batches = 0;
int cpu_core_offset = -1;
};
template <typename Replica>
class ReplicaWorker;
// Base class to implement a pool of model replicas that can run in parallel.
template <typename Replica>
class ReplicaPool {
public:
virtual ~ReplicaPool() = default;
ReplicaPool(const models::ModelLoader& model_loader,
const ReplicaPoolConfig& config = {}) {
initialize_pool(model_loader, config);
}
ReplicaPool(const std::string& model_path,
const Device device,
const ComputeType compute_type = ComputeType::DEFAULT,
const std::vector<int>& device_indices = {0},
const bool tensor_parallel = false,
const ReplicaPoolConfig& config = {}) {
models::ModelLoader model_loader(model_path);
model_loader.device = device;
model_loader.device_indices = device_indices;
model_loader.compute_type = compute_type;
model_loader.tensor_parallel = tensor_parallel;
initialize_pool(model_loader, config);
}
ReplicaPool(const std::shared_ptr<const models::Model>& model,
const ReplicaPoolConfig& config = {}) {
initialize_pool({model}, config);
}
ReplicaPool(const std::vector<std::shared_ptr<const models::Model>>& replicas,
const ReplicaPoolConfig& config = {}) {
initialize_pool(replicas, config);
}
// Posts a function and return its result as a future.
// The function will be run with the first available replica.
// The function must have the signature: Result(Replica&)
template <typename Result, typename Func>
std::future<Result> post(Func func) {
auto batched_func = [func = std::move(func)](Replica& replica) mutable {
std::vector<Result> results;
results.reserve(1);
results.emplace_back(func(replica));
return results;
};
auto futures = post_batch<Result>(std::move(batched_func), 1);
return std::move(futures[0]);
}
// Posts a function and return one future per result.
// The function will be run with the first available replica.
// The function must have the signature: std::vector<Result>(Replica&)
template <typename Result, typename Func>
std::vector<std::future<Result>> post_batch(Func func, size_t num_results) {
std::vector<std::promise<Result>> promises(num_results);
std::vector<std::future<Result>> futures;
futures.reserve(promises.size());
for (auto& promise : promises)
futures.emplace_back(promise.get_future());
post_batch(std::move(func), std::move(promises));
return futures;
}
// Same as above, but taking the list of promises directly.
template <typename Result, typename Func>
void post_batch(Func func, std::vector<std::promise<Result>> promises) {
auto wrapped_func = [func = std::move(func)]() mutable {
return func(get_thread_replica());
};
post_func(std::move(wrapped_func), std::move(promises));
}
// Number of batches in the work queue.
size_t num_queued_batches() const {
return _thread_pool->num_queued_jobs();
}
// Number of batches in the work queue or currently processed by a worker.
size_t num_active_batches() const {
return _thread_pool->num_active_jobs();
}
// Number of parallel replicas.
size_t num_replicas() const {
return _thread_pool->num_threads();
}
// Detaches the models used by each replica for unloading.
// This method is not thread-safe.
std::vector<std::shared_ptr<const models::Model>> detach_models() {
std::vector<std::shared_ptr<const models::Model>> models;
models.reserve(num_replicas());
for (size_t i = 0; i < num_replicas(); ++i) {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(i));
models.emplace_back(worker.detach_model());
}
return models;
}
// Assigns a model to each replica.
// This method is not thread-safe.
void set_models(const std::vector<std::shared_ptr<const models::Model>>& models) {
if (models.size() != num_replicas())
throw std::invalid_argument("The number of models does not match the number "
"of parallel replicas");
for (size_t i = 0; i < num_replicas(); ++i) {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(i));
worker.set_model(models[i]);
}
}
// Clears the cache of each worker.
// This method is not thread-safe.
void clear_cache() const {
for (size_t i = 0; i < num_replicas(); ++i) {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(i));
auto* allocator = worker.allocator();
if (allocator)
allocator->clear_cache();
}
}
const Replica& get_first_replica() const {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(0));
return worker.replica();
}
protected:
template <typename Result, typename Func>
std::vector<std::future<Result>>
post_examples(const std::vector<Example>& examples,
size_t max_batch_size,
BatchType batch_type,
const Func& func) {
std::vector<std::promise<Result>> promises(examples.size());
std::vector<std::future<Result>> futures;
futures.reserve(promises.size());
for (auto& promise : promises)
futures.emplace_back(promise.get_future());
post_examples(examples, max_batch_size, batch_type, std::move(promises), func);
return futures;
}
template <typename Result, typename Func>
void post_examples(const std::vector<Example>& examples,
size_t max_batch_size,
BatchType batch_type,
std::vector<std::promise<Result>> promises,
const Func& func) {
for (auto& batch : rebatch_input(examples, max_batch_size, batch_type)) {
std::vector<std::promise<Result>> batch_promises;
batch_promises.reserve(batch.num_examples());
for (const size_t index : batch.example_index)
batch_promises.emplace_back(std::move(promises[index]));
post_batch<Result>(
[batch = std::move(batch), func](Replica& replica) { return func(replica, batch); },
std::move(batch_promises));
}
}
template <typename Result, typename ResultWriter, typename Func>
void consume_batches(BatchReader& batch_reader,
ResultWriter& result_writer,
const Func& func,
size_t max_batch_size,
size_t read_batch_size,
BatchType batch_type) {
std::queue<std::future<Result>> results;
auto pop_results = [&results, &result_writer](bool blocking) {
constexpr std::chrono::seconds zero_sec(0);
while (!results.empty()
&& (blocking
|| results.front().wait_for(zero_sec) == std::future_status::ready)) {
result_writer(results.front().get());
results.pop();
}
};
if (read_batch_size == 0)
read_batch_size = (max_batch_size == 1 ? max_batch_size : max_batch_size * 16);
while (true) {
auto examples = batch_reader.get_next(read_batch_size, batch_type);
if (examples.empty())
break;
auto futures = post_examples<Result>(examples, max_batch_size, batch_type, func);
for (auto& future : futures)
results.emplace(std::move(future));
pop_results(/*blocking=*/false);
}
pop_results(/*blocking=*/true);
}
private:
std::unique_ptr<ThreadPool> _thread_pool;
static Replica& get_thread_replica() {
auto& worker = static_cast<ReplicaWorker<Replica>&>(ThreadPool::get_local_worker());
return worker.replica();
}
void initialize_pool(const models::ModelLoader& model_loader,
const ReplicaPoolConfig& config) {
// The same number of computation threads should be used for loading and running model.
set_num_threads(config.num_threads_per_replica);
initialize_pool(model_loader.load(), config);
}
void initialize_pool(const std::vector<std::shared_ptr<const models::Model>>& models,
const ReplicaPoolConfig& config) {
std::vector<std::unique_ptr<Worker>> workers;
workers.reserve(models.size());
for (const auto& model : models) {
workers.emplace_back(std::make_unique<ReplicaWorker<Replica>>(model, config.num_threads_per_replica));
}
size_t max_queue_size = std::numeric_limits<size_t>::max();
if (config.max_queued_batches == 0)
max_queue_size = 4 * workers.size();
else if (config.max_queued_batches > 0)
max_queue_size = config.max_queued_batches;
_thread_pool = std::make_unique<ThreadPool>(std::move(workers),
max_queue_size,
config.cpu_core_offset);
}
template <typename Result, typename Func>
void post_func(Func func, std::vector<std::promise<Result>> promises) {
_thread_pool->post(std::make_unique<BatchJob<Result, Func>>(std::move(promises),
std::move(func)));
}
template <typename Result, typename Func>
class BatchJob : public Job {
public:
BatchJob(std::vector<std::promise<Result>> promises, Func func)
: _promises(std::move(promises))
, _func(std::move(func))
{
}
void run() override {
std::vector<Result> results;
std::exception_ptr exception;
try {
results = _func();
} catch (...) {
exception = std::current_exception();
}
for (size_t i = 0; i < _promises.size(); ++i) {
if (exception)
_promises[i].set_exception(exception);
else
_promises[i].set_value(std::move(results[i]));
}
}
private:
std::vector<std::promise<Result>> _promises;
Func _func;
};
};
// Model replica worker.
template <typename Replica>
class ReplicaWorker : public Worker {
public:
ReplicaWorker(const std::shared_ptr<const models::Model>& model, size_t num_threads)
: _device(model->device())
, _device_index(model->device_index())
, _num_threads(num_threads)
, _allocator(nullptr)
, _shutting_down(false)
{
set_model(model);
}
Replica& replica() {
if (!_replica)
throw std::runtime_error("No model replica is available in this thread");
return *_replica;
}
void set_model(const std::shared_ptr<const models::Model>& model) {
_replica = Replica::create_from_model(*model);
}
std::shared_ptr<const models::Model> detach_model() {
if (!_replica)
return nullptr;
auto model = _replica->model();
_replica.reset();
return model;
}
Allocator* allocator() {
return _allocator;
}
protected:
void initialize() override {
set_device_index(_device, _device_index);
// Set the number of computation threads for the current thread.
set_num_threads(_num_threads);
// Register the memory allocator used in this thread.
_allocator = &get_allocator(_device);
}
void prepare_shutdown() override {
// Set shutdown flag BEFORE queue.close() so that idle() won't enter
// synchronize_stream() while the queue is being closed. This prevents
// the deadlock where idle() blocks on CUDA sync while close() has
// already signaled workers to stop.
_shutting_down.store(true, std::memory_order_release);
}
void idle() override {
// Check shutdown flag before synchronizing — synchronize_stream() can
// block indefinitely on CUDA teardown if called during shutdown.
if (_shutting_down.load(std::memory_order_acquire))
return;
// When no new jobs are immediately available, we synchronize the CUDA stream
// so that the CudaAsyncAllocator can release some memory.
synchronize_stream(_device);
}
void finalize() override {
_shutting_down.store(true, std::memory_order_release);
_replica.reset();
}
private:
const Device _device;
const int _device_index;
const size_t _num_threads;
Allocator* _allocator;
std::unique_ptr<Replica> _replica;
std::atomic<bool> _shutting_down;
};
}