Skip to content

Commit 444b38b

Browse files
QuentinFuxaclaude
andcommitted
Fix self-attention for decoder-only models and add set_alignment_heads API
Self-attention was not returning attention weights for decoder-only models (Generator) because the attention pointer was always nullptr in TransformerDecoderLayer. Now passes the attention pointer to self-attention when there is no encoder-attention (decoder-only case). Also adds set_alignment_heads() to Generator Python API, allowing users to select specific (layer, head) pairs instead of the default (last layer, head 0). The attention from selected heads is concatenated in the output and can be reshaped to (num_heads, context_length). Fixed multi-head attention handling in decoding.cc to support variable-rank attention tensors (rank 3 for multi-head vs rank 2 for averaged). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 948f8f2 commit 444b38b

8 files changed

Lines changed: 76 additions & 11 deletions

File tree

include/ctranslate2/layers/decoder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ namespace ctranslate2 {
2020
public:
2121
Decoder(Device device);
2222

23+
// Configure which attention heads to collect when return_attention is enabled.
24+
virtual void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
25+
(void)alignment_heads;
26+
}
27+
2328
virtual DecoderState initial_state(bool iterative_decoding = true) const = 0;
2429

2530
// Forwards one step.

include/ctranslate2/layers/transformer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ namespace ctranslate2 {
185185
StorageView* attention = nullptr) override;
186186

187187
void set_alignment_heads(const dim_t layer, const dim_t num_heads_to_average);
188-
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads);
188+
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) override;
189189

190190
std::unique_ptr<StorageView>
191191
get_layer_alignment_heads(const dim_t layer, const dim_t batch_size) const;

include/ctranslate2/models/language_model.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ namespace ctranslate2 {
5858
const StorageView& lengths,
5959
const bool return_log_probs);
6060

61+
// Configure which attention heads to collect when return_attention is enabled.
62+
// Each pair is (layer_index, head_index).
63+
virtual void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
64+
(void)alignment_heads;
65+
}
66+
6167
protected:
6268
virtual bool skip_scoring(const std::vector<std::string>& tokens,
6369
const ScoringOptions& options,
@@ -89,6 +95,8 @@ namespace ctranslate2 {
8995
DecoderReplica(const std::shared_ptr<const LanguageModel>& model,
9096
std::unique_ptr<layers::Decoder> decoder);
9197

98+
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) override;
99+
92100
protected:
93101
bool skip_scoring(const std::vector<std::string>& tokens,
94102
const ScoringOptions& options,

include/ctranslate2/replica_pool.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,15 @@ namespace ctranslate2 {
152152
return worker.replica();
153153
}
154154

155+
// Apply a function to each replica. Not thread-safe.
156+
template <typename Func>
157+
void for_each_replica(Func func) {
158+
for (size_t i = 0; i < num_replicas(); ++i) {
159+
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(i));
160+
func(worker.replica());
161+
}
162+
}
163+
155164
protected:
156165
template <typename Result, typename Func>
157166
std::vector<std::future<Result>>

python/cpp/generator.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ namespace ctranslate2 {
1111
public:
1212
using ReplicaPoolHelper::ReplicaPoolHelper;
1313

14+
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
15+
_pool->for_each_replica([&](models::SequenceGeneratorReplica& replica) {
16+
replica.set_alignment_heads(alignment_heads);
17+
});
18+
}
19+
1420
std::variant<std::vector<GenerationResult>,
1521
std::vector<AsyncResult<GenerationResult>>>
1622
generate_batch(const BatchTokens& tokens,
@@ -185,6 +191,23 @@ namespace ctranslate2 {
185191
.def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches,
186192
"Number of batches waiting to be processed or currently processed.")
187193

194+
.def("set_alignment_heads", &GeneratorWrapper::set_alignment_heads,
195+
py::arg("alignment_heads"),
196+
R"pbdoc(
197+
Configure which attention heads to collect when ``return_attention=True``.
198+
199+
By default, only head 0 of the last layer is returned (averaged).
200+
Use this method to select specific (layer, head) pairs. The attention
201+
from the selected heads will be concatenated in the output.
202+
203+
Arguments:
204+
alignment_heads: List of (layer_index, head_index) pairs to collect.
205+
206+
Example:
207+
208+
>>> generator.set_alignment_heads([(31, 0), (31, 3), (33, 7)])
209+
)pbdoc")
210+
188211
.def("generate_batch", &GeneratorWrapper::generate_batch,
189212
py::arg("start_tokens"),
190213
py::kw_only(),

src/decoding.cc

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,23 @@ namespace ctranslate2 {
146146
if (!history)
147147
return {};
148148

149-
const auto source_length = history.dim(-1);
149+
// For averaged attention: history is (batch, beam, steps, ctx)
150+
// For per-head attention: history is (batch, beam, steps, heads, ctx)
151+
// Compute total floats per time step (ctx or heads*ctx).
152+
dim_t step_size = 1;
153+
for (dim_t d = 3; d < history.rank(); ++d)
154+
step_size *= history.dim(d);
150155

151156
std::vector<std::vector<float>> attention;
152157
attention.reserve(end - start);
158+
// Compute stride for the steps dimension: step_size floats per step.
159+
// Base offset for (batch, beam) = batch * (beam_stride) + beam * (steps * step_size).
160+
const dim_t steps = history.dim(2);
161+
const dim_t beam_stride = steps * step_size;
162+
const float* base = history.data<float>() + batch * history.dim(1) * beam_stride + beam * beam_stride;
153163
for (dim_t t = start; t < end; ++t) {
154-
const auto* vector = history.index<float>({batch, beam, t, 0});
155-
attention.emplace_back(vector, vector + source_length);
164+
const float* vector = base + t * step_size;
165+
attention.emplace_back(vector, vector + step_size);
156166
}
157167
return attention;
158168
}
@@ -911,8 +921,11 @@ namespace ctranslate2 {
911921
&& (return_prefix || step >= prefix_length)) {
912922
results[batch_id].hypotheses[0].push_back(word_id);
913923
if (attention_step) {
914-
const auto* attn = attention_step.index<float>({i, 0});
915-
results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1));
924+
// For averaged attention: shape (batch, ctx) -> take ctx floats
925+
// For per-head attention: shape (batch, heads, ctx) -> take heads*ctx floats
926+
const dim_t attn_size = attention_step.size() / attention_step.dim(0);
927+
const auto* attn = attention_step.data<float>() + i * attn_size;
928+
results[batch_id].attention[0].emplace_back(attn, attn + attn_size);
916929
}
917930
}
918931

@@ -1166,9 +1179,11 @@ namespace ctranslate2 {
11661179
if (options.return_attention) {
11671180
if (attention.device() != Device::CPU)
11681181
attention = attention.to_float32().to(Device::CPU);
1182+
// Compute floats per time step (ctx or heads*ctx for multi-head).
1183+
const dim_t step_size = attention.size() / (attention.dim(0) * attention.dim(1));
11691184
for (dim_t t = 0; t < prefix_length; ++t) {
1170-
const float* vector = attention.index<float>({0, t, 0});
1171-
result.attention[i].emplace_back(vector, vector + attention.dim(-1));
1185+
const float* vector = attention.data<float>() + t * step_size;
1186+
result.attention[i].emplace_back(vector, vector + step_size);
11721187
}
11731188
}
11741189
}

src/layers/transformer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ namespace ctranslate2 {
222222
context,
223223
cached_self_attn_keys,
224224
cached_self_attn_values,
225-
nullptr,
225+
_encoder_attention ? nullptr : attention,
226226
input_padder,
227227
input_padder,
228228
true,
@@ -291,7 +291,7 @@ namespace ctranslate2 {
291291
attn,
292292
cached_self_attn_keys,
293293
cached_self_attn_values,
294-
nullptr,
294+
_encoder_attention ? nullptr : attention,
295295
input_padder,
296296
input_padder,
297297
true,
@@ -315,7 +315,7 @@ namespace ctranslate2 {
315315
output,
316316
cached_self_attn_keys,
317317
cached_self_attn_values,
318-
nullptr,
318+
_encoder_attention ? nullptr : attention,
319319
input_padder,
320320
input_padder,
321321
true,

src/models/language_model.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ namespace ctranslate2 {
110110
{
111111
}
112112

113+
void DecoderReplica::set_alignment_heads(
114+
const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
115+
_decoder->set_alignment_heads(alignment_heads);
116+
}
117+
113118
std::vector<ScoringResult>
114119
DecoderReplica::run_scoring(const std::vector<std::vector<std::string>>& tokens,
115120
const ScoringOptions& options) {

0 commit comments

Comments
 (0)