Skip to content

Commit 46428ca

Browse files
authored
fix(config): race condition in createWorkRoot() (#2338)
Creating a temporary directory based on a timestamp is inherently racy. Use the standard functions to create temporary directories, and relax the tests to check for what matters.
1 parent f2d1a10 commit 46428ca

File tree

2 files changed

+21
-60
lines changed

2 files changed

+21
-60
lines changed

internal/config/config.go

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"path/filepath"
2525
"regexp"
2626
"strings"
27-
"time"
2827
)
2928

3029
const (
@@ -60,7 +59,6 @@ const (
6059

6160
// are variables so it can be replaced during testing.
6261
var (
63-
now = time.Now
6462
tempDir = os.TempDir
6563
currentUser = user.Current
6664
)
@@ -272,19 +270,9 @@ func (c *Config) createWorkRoot() error {
272270
slog.Info("Using specified working directory", "dir", c.WorkRoot)
273271
return nil
274272
}
275-
t := now()
276-
path := filepath.Join(tempDir(), fmt.Sprintf("librarian-%s", formatTimestamp(t)))
277-
278-
_, err := os.Stat(path)
279-
switch {
280-
case os.IsNotExist(err):
281-
if err := os.Mkdir(path, 0755); err != nil {
282-
return fmt.Errorf("unable to create temporary working directory '%s': %w", path, err)
283-
}
284-
case err == nil:
285-
return fmt.Errorf("temporary working directory already exists: %s", path)
286-
default:
287-
return fmt.Errorf("unable to check directory '%s': %w", path, err)
273+
path, err := os.MkdirTemp(tempDir(), "librarian-*")
274+
if err != nil {
275+
return err
288276
}
289277

290278
slog.Info("Temporary working directory", "dir", path)
@@ -364,8 +352,3 @@ func validateHostMount(hostMount, defaultValue string) (bool, error) {
364352

365353
return true, nil
366354
}
367-
368-
func formatTimestamp(t time.Time) string {
369-
const yyyyMMddHHmmss = "20060102T150405Z" // Expected format by time library
370-
return t.Format(yyyyMMddHHmmss)
371-
}

internal/config/config_test.go

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@ package config
1616

1717
import (
1818
"errors"
19-
"fmt"
2019
"os"
2120
"os/user"
2221
"path/filepath"
2322
"strings"
2423
"testing"
25-
"time"
2624

2725
"github.com/google/go-cmp/cmp"
2826
)
@@ -253,23 +251,17 @@ func TestIsValid(t *testing.T) {
253251
}
254252

255253
func TestCreateWorkRoot(t *testing.T) {
256-
timestamp := time.Now()
257254
localTempDir := t.TempDir()
258-
now = func() time.Time {
259-
return timestamp
260-
}
261255
tempDir = func() string {
262256
return localTempDir
263257
}
264-
defer func() {
265-
now = time.Now
258+
t.Cleanup(func() {
266259
tempDir = os.TempDir
267-
}()
260+
})
268261
for _, test := range []struct {
269262
name string
270263
config *Config
271264
setup func(t *testing.T) (string, func())
272-
errMsg string
273265
}{
274266
{
275267
name: "configured root",
@@ -294,53 +286,44 @@ func TestCreateWorkRoot(t *testing.T) {
294286
name: "without override, new dir",
295287
config: &Config{},
296288
setup: func(t *testing.T) (string, func()) {
297-
expectedPath := filepath.Join(localTempDir, fmt.Sprintf("librarian-%s", formatTimestamp(timestamp)))
289+
expectedPath := filepath.Join(localTempDir, "librarian-")
298290
return expectedPath, func() {
299291
if err := os.RemoveAll(expectedPath); err != nil {
300292
t.Errorf("os.RemoveAll(%q) = %v; want nil", expectedPath, err)
301293
}
302294
}
303295
},
304296
},
305-
{
306-
name: "without override, dir exists",
307-
config: &Config{},
308-
setup: func(t *testing.T) (string, func()) {
309-
expectedPath := filepath.Join(localTempDir, fmt.Sprintf("librarian-%s", formatTimestamp(timestamp)))
310-
if err := os.Mkdir(expectedPath, 0755); err != nil {
311-
t.Fatalf("failed to create test dir: %v", err)
312-
}
313-
return expectedPath, func() {
314-
if err := os.RemoveAll(expectedPath); err != nil {
315-
t.Errorf("os.RemoveAll(%q) = %v; want nil", expectedPath, err)
316-
}
317-
}
318-
},
319-
errMsg: "working directory already exists",
320-
},
321297
} {
322298
t.Run(test.name, func(t *testing.T) {
323299
want, cleanup := test.setup(t)
324300
defer cleanup()
325301

326-
err := test.config.createWorkRoot()
327-
if test.errMsg != "" {
328-
if !strings.Contains(err.Error(), test.errMsg) {
329-
t.Errorf("createWorkRoot() = %q, want contains %q", err, test.errMsg)
330-
}
331-
return
332-
} else if err != nil {
302+
if err := test.config.createWorkRoot(); err != nil {
333303
t.Errorf("createWorkRoot() got unexpected error: %v", err)
334304
return
335305
}
336306

337-
if test.config.WorkRoot != want {
307+
if !strings.HasPrefix(test.config.WorkRoot, want) {
338308
t.Errorf("createWorkRoot() = %v, want %v", test.config.WorkRoot, want)
339309
}
340310
})
341311
}
342312
}
343313

314+
func TestCreateWorkRootError(t *testing.T) {
315+
tempDir = func() string {
316+
return filepath.Join("--invalid--", "--not-a-directory--")
317+
}
318+
t.Cleanup(func() {
319+
tempDir = os.TempDir
320+
})
321+
config := &Config{}
322+
if err := config.createWorkRoot(); err == nil {
323+
t.Errorf("createWorkRoot() expected an error got: %v", config.WorkRoot)
324+
}
325+
}
326+
344327
func TestDeriveRepo(t *testing.T) {
345328
for _, test := range []struct {
346329
name string
@@ -417,12 +400,7 @@ func TestSetDefaults(t *testing.T) {
417400
}, nil
418401
}
419402

420-
timestamp := time.Now()
421-
now = func() time.Time {
422-
return timestamp
423-
}
424403
t.Cleanup(func() {
425-
now = time.Now
426404
currentUser = user.Current
427405
})
428406
for _, test := range []struct {

0 commit comments

Comments
 (0)