Skip to content

Commit bae756f

Browse files
committed
fix minor bugs identified while adding genric
1 parent b436900 commit bae756f

5 files changed

Lines changed: 100 additions & 62 deletions

File tree

errors.go

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -509,22 +509,34 @@ func (e *Error) As(target interface{}) bool {
509509
if e == nil {
510510
return false
511511
}
512-
// Handle *Error target.
513-
if targetPtr, ok := target.(*Error); ok {
512+
// Handle **Error target (i.e. caller passed &myErrPtr where myErrPtr is *Error).
513+
// Traverse the chain and return the first *Error that has a name; if none has a
514+
// name, return the first *Error in the chain. This satisfies both:
515+
// - TestErrorAs: wraps Named("target") -> finds it by name
516+
// - TestErrorFullChain: finds Named("AuthError") deep in the chain
517+
if targetPtr, ok := target.(**Error); ok {
518+
var first *Error
514519
current := e
515520
for current != nil {
521+
if first == nil {
522+
first = current
523+
}
516524
if current.name != "" {
517-
*targetPtr = *current
525+
*targetPtr = current
518526
return true
519527
}
520528
if next, ok := current.cause.(*Error); ok {
521529
current = next
522530
} else if current.cause != nil {
523531
return errors.As(current.cause, target)
524532
} else {
525-
return false
533+
break
526534
}
527535
}
536+
if first != nil {
537+
*targetPtr = first
538+
return true
539+
}
528540
return false
529541
}
530542
// Handle *error target.
@@ -622,6 +634,8 @@ func (e *Error) Copy() *Error {
622634
newErr.code = e.code
623635
newErr.category = e.category
624636
newErr.count = e.count
637+
newErr.callback = e.callback // was silently dropped by Copy
638+
newErr.formatWrapped = e.formatWrapped // was silently dropped by Copy
625639

626640
if e.smallCount > 0 {
627641
newErr.smallCount = e.smallCount
@@ -1016,9 +1030,12 @@ var (
10161030
// data, _ := json.Marshal(err)
10171031
// fmt.Println(string(data))
10181032
func (e *Error) MarshalJSON() ([]byte, error) {
1019-
// Get buffer from pool.
1033+
// Get buffer from pool. Do NOT defer-return it — we must copy the result
1034+
// out of buf's backing array and return the buf to the pool BEFORE we return
1035+
// the copied slice. If we defer the Put, another goroutine can Get the same
1036+
// buf and overwrite its backing array while the caller is still reading our
1037+
// returned slice (the race the detector flags).
10201038
buf := jsonBufferPool.Get().(*bytes.Buffer)
1021-
defer jsonBufferPool.Put(buf)
10221039
buf.Reset()
10231040

10241041
// Create new encoder.
@@ -1066,11 +1083,16 @@ func (e *Error) MarshalJSON() ([]byte, error) {
10661083
return nil, err
10671084
}
10681085

1069-
// Remove trailing newline.
1070-
result := buf.Bytes()
1071-
if len(result) > 0 && result[len(result)-1] == '\n' {
1072-
result = result[:len(result)-1]
1086+
// Copy bytes out of buf before returning buf to the pool.
1087+
// buf.Bytes() is a slice into buf's internal array — if we put buf back first
1088+
// and another goroutine resets it, they share the same backing memory.
1089+
raw := buf.Bytes()
1090+
if len(raw) > 0 && raw[len(raw)-1] == '\n' {
1091+
raw = raw[:len(raw)-1]
10731092
}
1093+
result := make([]byte, len(raw))
1094+
copy(result, raw)
1095+
jsonBufferPool.Put(buf)
10741096
return result, nil
10751097
}
10761098

@@ -1164,7 +1186,8 @@ func (e *Error) Stack() []string {
11641186
//
11651187
// err := errors.New("failed").Trace()
11661188
func (e *Error) Trace() *Error {
1167-
if e.stack == nil {
1189+
// Check len rather than nil for the same reason as WithStack.
1190+
if len(e.stack) == 0 {
11681191
e.stack = captureStack(2)
11691192
}
11701193
return e
@@ -1279,31 +1302,30 @@ func (e *Error) With(keyValues ...interface{}) *Error {
12791302
keyValues = append(keyValues, "(MISSING)")
12801303
}
12811304

1282-
// Fast path for small context when we can add all pairs to smallContext
1305+
// Acquire the lock once up-front. The previous "optimistic read then lock"
1306+
// pattern read e.smallCount and e.context without holding the lock, which
1307+
// the race detector correctly flagged as a data race when two goroutines
1308+
// call With() on the same *Error concurrently.
1309+
e.mu.Lock()
1310+
defer e.mu.Unlock()
1311+
1312+
// Fast path: all pairs fit in the fixed-size smallContext array.
12831313
if e.smallCount < contextSize && e.context == nil {
12841314
remainingSlots := contextSize - int(e.smallCount)
12851315
if len(keyValues)/2 <= remainingSlots {
1286-
e.mu.Lock()
1287-
// Recheck conditions after acquiring lock
1288-
if e.smallCount < contextSize && e.context == nil {
1289-
for i := 0; i < len(keyValues); i += 2 {
1290-
key, ok := keyValues[i].(string)
1291-
if !ok {
1292-
key = fmt.Sprintf("%v", keyValues[i])
1293-
}
1294-
e.smallContext[e.smallCount] = contextItem{key, keyValues[i+1]}
1295-
e.smallCount++
1316+
for i := 0; i < len(keyValues); i += 2 {
1317+
key, ok := keyValues[i].(string)
1318+
if !ok {
1319+
key = fmt.Sprintf("%v", keyValues[i])
12961320
}
1297-
e.mu.Unlock()
1298-
return e
1321+
e.smallContext[e.smallCount] = contextItem{key, keyValues[i+1]}
1322+
e.smallCount++
12991323
}
1300-
e.mu.Unlock()
1324+
return e
13011325
}
13021326
}
13031327

1304-
// Slow path - either we have too many pairs or already using map context
1305-
e.mu.Lock()
1306-
defer e.mu.Unlock()
1328+
// Slow path: too many pairs or already using map context.
13071329

13081330
// Initialize map context if needed
13091331
if e.context == nil {
@@ -1377,7 +1399,9 @@ func (e *Error) WithRetryable() *Error {
13771399
//
13781400
// err := errors.New("failed").WithStack()
13791401
func (e *Error) WithStack() *Error {
1380-
if e.stack == nil {
1402+
// Check len rather than nil: a pooled error has stack reset to stack[:0]
1403+
// (non-nil but empty). The nil check would skip capture for recycled errors.
1404+
if len(e.stack) == 0 {
13811405
e.stack = captureStack(1)
13821406
}
13831407
return e

errors_test.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,6 @@ func TestContextStorage(t *testing.T) {
916916
}
917917

918918
// TestNewf verifies Newf behavior, including %w wrapping, formatting, and error cases.
919-
// errors_test.go
920919

921920
// TestNewf verifies Newf behavior, including %w wrapping, formatting, and error cases.
922921
// It now expects the string output for %w cases to match fmt.Errorf.
@@ -959,7 +958,7 @@ func TestNewf(t *testing.T) {
959958
wantInternalMsg: "",
960959
},
961960

962-
// --- %w wrapping cases (EXPECTATIONS UPDATED) ---
961+
// %w wrapping cases (EXPECTATIONS UPDATED)
963962
{
964963
name: "wrap standard error",
965964
format: "prefix %w",
@@ -1061,7 +1060,7 @@ func TestNewf(t *testing.T) {
10611060
t.Errorf("Newf().Error() = %q, want %q", gotMsg, tt.wantFinalMsg)
10621061
}
10631062

1064-
// --- Cause verification remains crucial ---
1063+
// Cause verification remains crucial
10651064
gotCause := errors.Unwrap(got)
10661065
if tt.wantCause != nil {
10671066
// Use errors.Is for robust checking, especially if causes might be wrapped themselves
@@ -1097,21 +1096,28 @@ func TestNewf(t *testing.T) {
10971096
//
10981097
// Rationale for using compareWrappedErrorStrings helper:
10991098
//
1100-
// 1. Goal: Ensure essential compatibility - correct error wrapping (for Unwrap/Is/As)
1101-
// and preservation of the message content surrounding the wrapped error.
1102-
// 2. Formatting Difference: This library consistently formats wrapped errors in its
1103-
// Error() method as "MESSAGE: CAUSE_ERROR" (or just "CAUSE_ERROR" if MESSAGE is empty).
1104-
// The standard fmt.Errorf has more complex and variable spacing rules depending on
1105-
// characters around %w (e.g., sometimes omitting the colon, adding spaces differently).
1106-
// 3. Semantic Comparison: Attempting to replicate fmt.Errorf's exact spacing makes the
1107-
// library code brittle and overly complex. Therefore, this test focuses on *semantic*
1108-
// equivalence rather than exact string matching.
1109-
// 4. Helper Logic: compareWrappedErrorStrings verifies compatibility by:
1110-
// a) Checking that errors.Unwrap returns the same underlying cause instance.
1111-
// b) Extracting the textual prefix from this library's error string (before ": CAUSE").
1112-
// c) Extracting the textual remainder from fmt.Errorf's string by removing the cause string.
1113-
// d) Normalizing both extracted parts (trimming space, collapsing internal whitespace).
1114-
// e) Comparing the normalized parts to ensure the core message content matches.
1099+
// Goal: Ensure essential compatibility - correct error wrapping (for Unwrap/Is/As)
1100+
//
1101+
// and preservation of the message content surrounding the wrapped error.
1102+
//
1103+
// Formatting Difference: This library consistently formats wrapped errors in its
1104+
//
1105+
// Error() method as "MESSAGE: CAUSE_ERROR" (or just "CAUSE_ERROR" if MESSAGE is empty).
1106+
// The standard fmt.Errorf has more complex and variable spacing rules depending on
1107+
// characters around %w (e.g., sometimes omitting the colon, adding spaces differently).
1108+
//
1109+
// Semantic Comparison: Attempting to replicate fmt.Errorf's exact spacing makes the
1110+
//
1111+
// library code brittle and overly complex. Therefore, this test focuses on *semantic*
1112+
// equivalence rather than exact string matching.
1113+
//
1114+
// Helper Logic: compareWrappedErrorStrings verifies compatibility by:
1115+
//
1116+
// a) Checking that errors.Unwrap returns the same underlying cause instance.
1117+
// b) Extracting the textual prefix from this library's error string (before ": CAUSE").
1118+
// c) Extracting the textual remainder from fmt.Errorf's string by removing the cause string.
1119+
// d) Normalizing both extracted parts (trimming space, collapsing internal whitespace).
1120+
// e) Comparing the normalized parts to ensure the core message content matches.
11151121
//
11161122
// This approach ensures functional compatibility without being overly sensitive to minor
11171123
// formatting variations between the libraries.
@@ -1151,7 +1157,7 @@ func TestNewfCompatibilityWithFmtErrorf(t *testing.T) {
11511157
}
11521158
// Consider defer customErrImpl.Free() if needed
11531159

1154-
// --- Verify Cause ---
1160+
// Verify Cause
11551161
stdUnwrapped := errors.Unwrap(stdErr)
11561162
customUnwrapped := errors.Unwrap(customErrImpl)
11571163

@@ -1172,7 +1178,7 @@ func TestNewfCompatibilityWithFmtErrorf(t *testing.T) {
11721178
}
11731179
}
11741180

1175-
// --- Verify String Output (Exact Match) ---
1181+
// Verify String Output (Exact Match)
11761182
gotStr := customErrImpl.Error()
11771183
wantStr := stdErr.Error()
11781184
if gotStr != wantStr {

multi_error.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,18 @@ func (m *MultiError) Merge(other *MultiError) {
210210
return
211211
}
212212

213+
// Snapshot other's errors under its own read lock, then release before
214+
// acquiring m's write lock inside Add. This prevents two bugs:
215+
// Self-merge deadlock: when m == other, holding other.mu.RLock then
216+
// calling m.Add (which takes m.mu.Lock) deadlocks on the same mutex.
217+
// Concurrent-write race: m had no lock protection during the loop,
218+
// so a concurrent Add on m could corrupt the slice.
213219
other.mu.RLock()
214-
defer other.mu.RUnlock()
220+
snapshot := make([]error, len(other.errors))
221+
copy(snapshot, other.errors)
222+
other.mu.RUnlock()
215223

216-
for _, err := range other.errors {
217-
m.Add(err)
218-
}
224+
m.Add(snapshot...)
219225
}
220226

221227
// IsNull checks if the MultiError is empty or contains only null errors.
@@ -324,9 +330,9 @@ func (m *MultiError) MarshalJSON() ([]byte, error) {
324330
m.mu.RLock()
325331
defer m.mu.RUnlock()
326332

327-
// Get buffer from pool for efficiency
333+
// Get buffer from pool. Do NOT use defer for Put — see errors.go MarshalJSON
334+
// for the full explanation. We must copy bytes out before returning the buf.
328335
buf := jsonBufferPool.Get().(*bytes.Buffer)
329-
defer jsonBufferPool.Put(buf)
330336
buf.Reset()
331337

332338
// Create encoder
@@ -379,11 +385,14 @@ func (m *MultiError) MarshalJSON() ([]byte, error) {
379385
return nil, fmt.Errorf("failed to marshal MultiError: %v", err)
380386
}
381387

382-
// Remove trailing newline
383-
result := buf.Bytes()
384-
if len(result) > 0 && result[len(result)-1] == '\n' {
385-
result = result[:len(result)-1]
388+
// Copy out of buf's backing array before returning buf to pool.
389+
raw := buf.Bytes()
390+
if len(raw) > 0 && raw[len(raw)-1] == '\n' {
391+
raw = raw[:len(raw)-1]
386392
}
393+
result := make([]byte, len(raw))
394+
copy(result, raw)
395+
jsonBufferPool.Put(buf)
387396
return result, nil
388397
}
389398

pool.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// pool.go
21
package errors
32

43
import (

retry.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func (r *Retry) ExecuteContext(ctx context.Context, fn func() error) error {
198198
break
199199
}
200200

201-
// --- Calculate and apply delay ---
201+
// Calculate and apply delay
202202
currentDelay := r.backoff.Backoff(attempt, r.delay)
203203
if r.maxDelay > 0 && currentDelay > r.maxDelay { // Check maxDelay > 0 before capping
204204
currentDelay = r.maxDelay
@@ -209,7 +209,7 @@ func (r *Retry) ExecuteContext(ctx context.Context, fn func() error) error {
209209
if currentDelay < 0 { // Ensure delay isn't negative after jitter
210210
currentDelay = 0
211211
}
212-
// --- Wait for the delay or context cancellation ---
212+
// Wait for the delay or context cancellation
213213
select {
214214
case <-execCtx.Done():
215215
// If context is cancelled during the wait, return the context error

0 commit comments

Comments
 (0)