Skip to content

Commit c7a9e93

Browse files
authored
[-] resolve race conditions in YAML file operations, fixes #1198 (#1209)
Fixes race conditions in YAML-based metrics and sources I/O by adding an embedded `sync.Mutex` to serialize read–modify–write cycles so concurrent updates are not lost. Public methods now lock and call unlocked internal helpers (`getMetrics/getSources`, `writeMetrics/writeSources`), and new concurrent tests confirm the fix.
1 parent a8e7e00 commit c7a9e93

4 files changed

Lines changed: 214 additions & 28 deletions

File tree

internal/metrics/yaml.go

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"path/filepath"
1010
"strings"
11+
"sync"
1112

1213
"gopkg.in/yaml.v3"
1314
)
@@ -25,17 +26,34 @@ func NewYAMLMetricReaderWriter(ctx context.Context, path string) (ReaderWriter,
2526
type fileMetricReader struct {
2627
ctx context.Context
2728
path string
29+
sync.Mutex
2830
}
2931

32+
// WriteMetrics writes metrics to file with locking
3033
func (fmr *fileMetricReader) WriteMetrics(metricDefs *Metrics) error {
34+
fmr.Lock()
35+
defer fmr.Unlock()
36+
return fmr.writeMetrics(metricDefs)
37+
}
38+
39+
// writeMetrics writes metrics to file without locking (internal use only)
40+
func (fmr *fileMetricReader) writeMetrics(metricDefs *Metrics) error {
3141
yamlData, _ := yaml.Marshal(metricDefs)
3242
return os.WriteFile(fmr.path, yamlData, 0644)
3343
}
3444

3545
//go:embed metrics.yaml
3646
var defaultMetricsYAML []byte
3747

48+
// GetMetrics reads metrics from file or returns default metrics if path is empty with locking
3849
func (fmr *fileMetricReader) GetMetrics() (metrics *Metrics, err error) {
50+
fmr.Lock()
51+
defer fmr.Unlock()
52+
return fmr.getMetrics()
53+
}
54+
55+
// getMetrics reads metrics from file or returns default metrics if path is empty without locking (internal use only)
56+
func (fmr *fileMetricReader) getMetrics() (metrics *Metrics, err error) {
3957
metrics = &Metrics{MetricDefs{}, PresetDefs{}}
4058
if fmr.path == "" {
4159
err = yaml.Unmarshal(defaultMetricsYAML, metrics)
@@ -57,19 +75,20 @@ func (fmr *fileMetricReader) GetMetrics() (metrics *Metrics, err error) {
5775
return nil
5876
}
5977
var m *Metrics
60-
if m, err = fmr.getMetrics(path); err == nil {
78+
if m, err = fmr.loadMetricsFromFile(path); err == nil {
6179
maps.Copy(metrics.PresetDefs, m.PresetDefs)
6280
maps.Copy(metrics.MetricDefs, m.MetricDefs)
6381
}
6482
return err
6583
})
6684
case mode.IsRegular():
67-
metrics, err = fmr.getMetrics(fmr.path)
85+
metrics, err = fmr.loadMetricsFromFile(fmr.path)
6886
}
6987
return
7088
}
7189

72-
func (fmr *fileMetricReader) getMetrics(metricsFilePath string) (metrics *Metrics, err error) {
90+
// loadMetricsFromFile reads metrics from a single YAML file
91+
func (fmr *fileMetricReader) loadMetricsFromFile(metricsFilePath string) (metrics *Metrics, err error) {
7392
var yamlFile []byte
7493
if yamlFile, err = os.ReadFile(metricsFilePath); err != nil {
7594
return
@@ -79,26 +98,35 @@ func (fmr *fileMetricReader) getMetrics(metricsFilePath string) (metrics *Metric
7998
return
8099
}
81100

101+
// DeleteMetric deletes a metric by name and writes the updated metrics back to file
82102
func (fmr *fileMetricReader) DeleteMetric(metricName string) error {
83-
metrics, err := fmr.GetMetrics()
103+
fmr.Lock()
104+
defer fmr.Unlock()
105+
metrics, err := fmr.getMetrics()
84106
if err != nil {
85107
return err
86108
}
87109
delete(metrics.MetricDefs, metricName)
88-
return fmr.WriteMetrics(metrics)
110+
return fmr.writeMetrics(metrics)
89111
}
90112

113+
// UpdateMetric updates an existing metric or creates it if it doesn't exist, then writes the updated metrics back to file
91114
func (fmr *fileMetricReader) UpdateMetric(metricName string, metric Metric) error {
92-
metrics, err := fmr.GetMetrics()
115+
fmr.Lock()
116+
defer fmr.Unlock()
117+
metrics, err := fmr.getMetrics()
93118
if err != nil {
94119
return err
95120
}
96121
metrics.MetricDefs[metricName] = metric
97-
return fmr.WriteMetrics(metrics)
122+
return fmr.writeMetrics(metrics)
98123
}
99124

125+
// CreateMetric creates a new metric if it doesn't already exist, then writes the updated metrics back to file
100126
func (fmr *fileMetricReader) CreateMetric(metricName string, metric Metric) error {
101-
metrics, err := fmr.GetMetrics()
127+
fmr.Lock()
128+
defer fmr.Unlock()
129+
metrics, err := fmr.getMetrics()
102130
if err != nil {
103131
return err
104132
}
@@ -107,29 +135,38 @@ func (fmr *fileMetricReader) CreateMetric(metricName string, metric Metric) erro
107135
return ErrMetricExists
108136
}
109137
metrics.MetricDefs[metricName] = metric
110-
return fmr.WriteMetrics(metrics)
138+
return fmr.writeMetrics(metrics)
111139
}
112140

141+
// DeletePreset deletes a preset by name and writes the updated metrics back to file
113142
func (fmr *fileMetricReader) DeletePreset(presetName string) error {
114-
metrics, err := fmr.GetMetrics()
143+
fmr.Lock()
144+
defer fmr.Unlock()
145+
metrics, err := fmr.getMetrics()
115146
if err != nil {
116147
return err
117148
}
118149
delete(metrics.PresetDefs, presetName)
119-
return fmr.WriteMetrics(metrics)
150+
return fmr.writeMetrics(metrics)
120151
}
121152

153+
// UpdatePreset updates an existing preset or creates it if it doesn't exist, then writes the updated metrics back to file
122154
func (fmr *fileMetricReader) UpdatePreset(presetName string, preset Preset) error {
123-
metrics, err := fmr.GetMetrics()
155+
fmr.Lock()
156+
defer fmr.Unlock()
157+
metrics, err := fmr.getMetrics()
124158
if err != nil {
125159
return err
126160
}
127161
metrics.PresetDefs[presetName] = preset
128-
return fmr.WriteMetrics(metrics)
162+
return fmr.writeMetrics(metrics)
129163
}
130164

165+
// CreatePreset creates a new preset if it doesn't already exist, then writes the updated metrics back to file
131166
func (fmr *fileMetricReader) CreatePreset(presetName string, preset Preset) error {
132-
metrics, err := fmr.GetMetrics()
167+
fmr.Lock()
168+
defer fmr.Unlock()
169+
metrics, err := fmr.getMetrics()
133170
if err != nil {
134171
return err
135172
}
@@ -138,5 +175,5 @@ func (fmr *fileMetricReader) CreatePreset(presetName string, preset Preset) erro
138175
return ErrPresetExists
139176
}
140177
metrics.PresetDefs[presetName] = preset
141-
return fmr.WriteMetrics(metrics)
178+
return fmr.writeMetrics(metrics)
142179
}

internal/metrics/yaml_test.go

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package metrics_test
22

33
import (
4+
"fmt"
45
"os"
56
"path/filepath"
7+
"sync"
68
"testing"
9+
"time"
710

811
"github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
912
"github.com/stretchr/testify/assert"
@@ -390,4 +393,84 @@ func TestMetricsDir(t *testing.T) {
390393
a.Equal("preset1 description", ms.PresetDefs["preset1"].Description)
391394
a.Equal("metric2 description", ms.MetricDefs["metric2"].Description)
392395
a.Equal("preset2 description", ms.PresetDefs["preset2"].Description)
393-
}
396+
}
397+
398+
func TestConcurrentMetricUpdates(t *testing.T) {
399+
a := assert.New(t)
400+
tempDir := t.TempDir()
401+
tempFile := filepath.Join(tempDir, "metrics.yaml")
402+
403+
yamlrw, err := metrics.NewYAMLMetricReaderWriter(ctx, tempFile)
404+
a.NoError(err)
405+
406+
// Create initial empty metrics file
407+
initialMetrics := &metrics.Metrics{
408+
MetricDefs: make(map[string]metrics.Metric),
409+
PresetDefs: make(map[string]metrics.Preset),
410+
}
411+
err = yamlrw.WriteMetrics(initialMetrics)
412+
a.NoError(err)
413+
414+
numGoroutines := 10
415+
var wg sync.WaitGroup
416+
417+
// Each goroutine will add a unique metric
418+
for id := range numGoroutines {
419+
wg.Go(func() {
420+
metricName := fmt.Sprintf("metric_%d", id)
421+
testMetric := metrics.Metric{
422+
Description: fmt.Sprintf("Test metric %d", id),
423+
SQLs: map[int]string{
424+
1: fmt.Sprintf("SELECT %d", id),
425+
},
426+
}
427+
time.Sleep(time.Millisecond * time.Duration(id%3))
428+
err := yamlrw.UpdateMetric(metricName, testMetric)
429+
a.NoError(err, "Error during concurrent update")
430+
})
431+
}
432+
wg.Wait()
433+
434+
finalMetrics, err := yamlrw.GetMetrics()
435+
a.NoError(err)
436+
a.Equal(numGoroutines, len(finalMetrics.MetricDefs), "Some updates were lost due to race condition!")
437+
}
438+
439+
func TestConcurrentPresetUpdates(t *testing.T) {
440+
a := assert.New(t)
441+
tempDir := t.TempDir()
442+
tempFile := filepath.Join(tempDir, "metrics.yaml")
443+
444+
yamlrw, err := metrics.NewYAMLMetricReaderWriter(ctx, tempFile)
445+
a.NoError(err)
446+
447+
// Create initial empty metrics file
448+
initialMetrics := &metrics.Metrics{
449+
MetricDefs: make(map[string]metrics.Metric),
450+
PresetDefs: make(map[string]metrics.Preset),
451+
}
452+
err = yamlrw.WriteMetrics(initialMetrics)
453+
a.NoError(err)
454+
455+
numGoroutines := 10
456+
var wg sync.WaitGroup
457+
458+
for id := range numGoroutines {
459+
wg.Go(func() {
460+
presetName := fmt.Sprintf("preset_%d", id)
461+
testPreset := metrics.Preset{
462+
Description: fmt.Sprintf("Test preset %d", id),
463+
Metrics: map[string]float64{fmt.Sprintf("metric_%d", id): 60},
464+
}
465+
time.Sleep(time.Millisecond * time.Duration(id%3))
466+
err := yamlrw.UpdatePreset(presetName, testPreset)
467+
a.NoError(err, "Error during concurrent update")
468+
})
469+
}
470+
wg.Wait()
471+
472+
// ensure all presets were saved
473+
finalMetrics, err := yamlrw.GetMetrics()
474+
a.NoError(err)
475+
a.Equal(numGoroutines, len(finalMetrics.PresetDefs), "Some updates were lost due to race condition!")
476+
}

internal/sources/yaml.go

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"path/filepath"
1010
"slices"
1111
"strings"
12+
"sync"
1213

1314
"gopkg.in/yaml.v3"
1415
)
@@ -23,30 +24,45 @@ func NewYAMLSourcesReaderWriter(ctx context.Context, path string) (ReaderWriter,
2324
type fileSourcesReaderWriter struct {
2425
ctx context.Context
2526
path string
27+
sync.Mutex
2628
}
2729

30+
// WriteSources writes sources to file with locking
2831
func (fcr *fileSourcesReaderWriter) WriteSources(mds Sources) error {
32+
fcr.Lock()
33+
defer fcr.Unlock()
34+
return fcr.writeSources(mds)
35+
}
36+
37+
// writeSources writes sources to file without locking (internal use only)
38+
func (fcr *fileSourcesReaderWriter) writeSources(mds Sources) error {
2939
yamlData, _ := yaml.Marshal(mds)
3040
return os.WriteFile(fcr.path, yamlData, 0644)
3141
}
3242

43+
// UpdateSource updates an existing source or creates it if it doesn't exist, then writes the updated sources back to file
3344
func (fcr *fileSourcesReaderWriter) UpdateSource(md Source) error {
34-
dbs, err := fcr.GetSources()
45+
fcr.Lock()
46+
defer fcr.Unlock()
47+
dbs, err := fcr.getSources()
3548
if err != nil {
3649
return err
3750
}
3851
for i, db := range dbs {
3952
if db.Name == md.Name {
4053
dbs[i] = md
41-
return fcr.WriteSources(dbs)
54+
return fcr.writeSources(dbs)
4255
}
4356
}
4457
dbs = append(dbs, md)
45-
return fcr.WriteSources(dbs)
58+
return fcr.writeSources(dbs)
4659
}
4760

61+
// CreateSource creates a new source if it doesn't already exist, then writes the updated sources back to file
4862
func (fcr *fileSourcesReaderWriter) CreateSource(md Source) error {
49-
dbs, err := fcr.GetSources()
63+
fcr.Lock()
64+
defer fcr.Unlock()
65+
dbs, err := fcr.getSources()
5066
if err != nil {
5167
return err
5268
}
@@ -57,19 +73,30 @@ func (fcr *fileSourcesReaderWriter) CreateSource(md Source) error {
5773
}
5874
}
5975
dbs = append(dbs, md)
60-
return fcr.WriteSources(dbs)
76+
return fcr.writeSources(dbs)
6177
}
6278

79+
// DeleteSource deletes a source by name and writes the updated sources back to file
6380
func (fcr *fileSourcesReaderWriter) DeleteSource(name string) error {
64-
dbs, err := fcr.GetSources()
81+
fcr.Lock()
82+
defer fcr.Unlock()
83+
dbs, err := fcr.getSources()
6584
if err != nil {
6685
return err
6786
}
6887
dbs = slices.DeleteFunc(dbs, func(md Source) bool { return md.Name == name })
69-
return fcr.WriteSources(dbs)
88+
return fcr.writeSources(dbs)
7089
}
7190

91+
// GetSources reads sources from file with locking
7292
func (fcr *fileSourcesReaderWriter) GetSources() (dbs Sources, err error) {
93+
fcr.Lock()
94+
defer fcr.Unlock()
95+
return fcr.getSources()
96+
}
97+
98+
// getSources reads sources from file without locking (internal use only)
99+
func (fcr *fileSourcesReaderWriter) getSources() (dbs Sources, err error) {
73100
var fi fs.FileInfo
74101
if fi, err = os.Stat(fcr.path); err != nil {
75102
return
@@ -85,21 +112,22 @@ func (fcr *fileSourcesReaderWriter) GetSources() (dbs Sources, err error) {
85112
return nil
86113
}
87114
var mdbs Sources
88-
if mdbs, err = fcr.getSources(path); err == nil {
115+
if mdbs, err = fcr.loadSourcesFromFile(path); err == nil {
89116
dbs = append(dbs, mdbs...)
90117
}
91118
return err
92119
})
93120
case mode.IsRegular():
94-
dbs, err = fcr.getSources(fcr.path)
121+
dbs, err = fcr.loadSourcesFromFile(fcr.path)
95122
}
96123
if err != nil {
97124
return nil, err
98125
}
99126
return dbs.Validate()
100127
}
101128

102-
func (fcr *fileSourcesReaderWriter) getSources(configFilePath string) (dbs Sources, err error) {
129+
// loadSourcesFromFile reads sources from a single YAML file, expands environment variables, and returns them
130+
func (fcr *fileSourcesReaderWriter) loadSourcesFromFile(configFilePath string) (dbs Sources, err error) {
103131
var yamlFile []byte
104132
if yamlFile, err = os.ReadFile(configFilePath); err != nil {
105133
return

0 commit comments

Comments
 (0)