fixes race conditions

This commit is contained in:
Marcin Czenko 2025-10-24 04:57:10 +02:00
parent f0aa462dc0
commit 05ab491bca
No known key found for this signature in database
GPG Key ID: A0449219BDBA98AE

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -15,6 +16,7 @@ type CodexIndexDownloader struct {
codexClient CodexClientInterface codexClient CodexClientInterface
indexCid string indexCid string
filePath string filePath string
mu sync.RWMutex // protects all fields below
datasetSize int64 // stores the dataset size from the manifest datasetSize int64 // stores the dataset size from the manifest
bytesCompleted int64 // tracks download progress bytesCompleted int64 // tracks download progress
downloadComplete bool // true when file is fully downloaded and renamed downloadComplete bool // true when file is fully downloaded and renamed
@ -42,8 +44,10 @@ func (d *CodexIndexDownloader) GotManifest() <-chan struct{} {
go func() { go func() {
// Reset datasetSize to 0 to indicate no successful fetch yet // Reset datasetSize to 0 to indicate no successful fetch yet
d.mu.Lock()
d.datasetSize = 0 d.datasetSize = 0
d.downloadError = nil d.downloadError = nil
d.mu.Unlock()
// Check for cancellation before starting // Check for cancellation before starting
select { select {
@ -68,7 +72,9 @@ func (d *CodexIndexDownloader) GotManifest() <-chan struct{} {
// Fetch manifest from Codex // Fetch manifest from Codex
manifest, err := d.codexClient.FetchManifestWithContext(ctx, d.indexCid) manifest, err := d.codexClient.FetchManifestWithContext(ctx, d.indexCid)
if err != nil { if err != nil {
d.mu.Lock()
d.downloadError = err d.downloadError = err
d.mu.Unlock()
d.logger.Debug("failed to fetch manifest", d.logger.Debug("failed to fetch manifest",
zap.String("indexCid", d.indexCid), zap.String("indexCid", d.indexCid),
zap.Error(err)) zap.Error(err))
@ -79,7 +85,9 @@ func (d *CodexIndexDownloader) GotManifest() <-chan struct{} {
// Verify that the CID matches our configured indexCid // Verify that the CID matches our configured indexCid
if manifest.CID != d.indexCid { if manifest.CID != d.indexCid {
d.mu.Lock()
d.downloadError = fmt.Errorf("manifest CID mismatch: expected %s, got %s", d.indexCid, manifest.CID) d.downloadError = fmt.Errorf("manifest CID mismatch: expected %s, got %s", d.indexCid, manifest.CID)
d.mu.Unlock()
d.logger.Debug("manifest CID mismatch", d.logger.Debug("manifest CID mismatch",
zap.String("expected", d.indexCid), zap.String("expected", d.indexCid),
zap.String("got", manifest.CID)) zap.String("got", manifest.CID))
@ -87,7 +95,9 @@ func (d *CodexIndexDownloader) GotManifest() <-chan struct{} {
} }
// Store the dataset size for later use - this indicates success // Store the dataset size for later use - this indicates success
d.mu.Lock()
d.datasetSize = manifest.Manifest.DatasetSize d.datasetSize = manifest.Manifest.DatasetSize
d.mu.Unlock()
// Success! Close the channel to signal completion // Success! Close the channel to signal completion
close(ch) close(ch)
@ -98,15 +108,19 @@ func (d *CodexIndexDownloader) GotManifest() <-chan struct{} {
// GetDatasetSize returns the dataset size from the last successfully fetched manifest // GetDatasetSize returns the dataset size from the last successfully fetched manifest
func (d *CodexIndexDownloader) GetDatasetSize() int64 { func (d *CodexIndexDownloader) GetDatasetSize() int64 {
d.mu.RLock()
defer d.mu.RUnlock()
return d.datasetSize return d.datasetSize
} }
// DownloadIndexFile starts downloading the index file from Codex and writes it to the configured file path // DownloadIndexFile starts downloading the index file from Codex and writes it to the configured file path
func (d *CodexIndexDownloader) DownloadIndexFile() { func (d *CodexIndexDownloader) DownloadIndexFile() {
// Reset progress counter and completion flag // Reset progress counter and completion flag
d.mu.Lock()
d.bytesCompleted = 0 d.bytesCompleted = 0
d.downloadComplete = false d.downloadComplete = false
d.downloadError = nil d.downloadError = nil
d.mu.Unlock()
// Create cancellable context // Create cancellable context
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -129,7 +143,9 @@ func (d *CodexIndexDownloader) DownloadIndexFile() {
// This ensures atomic rename works (same filesystem) // This ensures atomic rename works (same filesystem)
tmpFile, err := os.CreateTemp(filepath.Dir(d.filePath), ".codex-download-*.tmp") tmpFile, err := os.CreateTemp(filepath.Dir(d.filePath), ".codex-download-*.tmp")
if err != nil { if err != nil {
d.mu.Lock()
d.downloadError = fmt.Errorf("failed to create temporary file: %w", err) d.downloadError = fmt.Errorf("failed to create temporary file: %w", err)
d.mu.Unlock()
d.logger.Debug("failed to create temporary file", d.logger.Debug("failed to create temporary file",
zap.String("filePath", d.filePath), zap.String("filePath", d.filePath),
zap.Error(err)) zap.Error(err))
@ -146,12 +162,15 @@ func (d *CodexIndexDownloader) DownloadIndexFile() {
progressWriter := &progressWriter{ progressWriter := &progressWriter{
writer: tmpFile, writer: tmpFile,
completed: &d.bytesCompleted, completed: &d.bytesCompleted,
mu: &d.mu,
} }
// Use CodexClient to download and stream to temporary file with context for cancellation // Use CodexClient to download and stream to temporary file with context for cancellation
err = d.codexClient.DownloadWithContext(ctx, d.indexCid, progressWriter) err = d.codexClient.DownloadWithContext(ctx, d.indexCid, progressWriter)
if err != nil { if err != nil {
d.mu.Lock()
d.downloadError = fmt.Errorf("failed to download index file: %w", err) d.downloadError = fmt.Errorf("failed to download index file: %w", err)
d.mu.Unlock()
d.logger.Debug("failed to download index file", d.logger.Debug("failed to download index file",
zap.String("indexCid", d.indexCid), zap.String("indexCid", d.indexCid),
zap.String("filePath", d.filePath), zap.String("filePath", d.filePath),
@ -162,7 +181,9 @@ func (d *CodexIndexDownloader) DownloadIndexFile() {
// Close the temporary file before renaming // Close the temporary file before renaming
if err := tmpFile.Close(); err != nil { if err := tmpFile.Close(); err != nil {
d.mu.Lock()
d.downloadError = fmt.Errorf("failed to close temporary file: %w", err) d.downloadError = fmt.Errorf("failed to close temporary file: %w", err)
d.mu.Unlock()
d.logger.Debug("failed to close temporary file", d.logger.Debug("failed to close temporary file",
zap.String("tmpPath", tmpPath), zap.String("tmpPath", tmpPath),
zap.Error(err)) zap.Error(err))
@ -172,7 +193,9 @@ func (d *CodexIndexDownloader) DownloadIndexFile() {
// Atomically rename temporary file to final destination // Atomically rename temporary file to final destination
// This ensures we only have a complete file at filePath // This ensures we only have a complete file at filePath
if err := os.Rename(tmpPath, d.filePath); err != nil { if err := os.Rename(tmpPath, d.filePath); err != nil {
d.mu.Lock()
d.downloadError = fmt.Errorf("failed to rename temporary file to final destination: %w", err) d.downloadError = fmt.Errorf("failed to rename temporary file to final destination: %w", err)
d.mu.Unlock()
d.logger.Debug("failed to rename temporary file to final destination", d.logger.Debug("failed to rename temporary file to final destination",
zap.String("tmpPath", tmpPath), zap.String("tmpPath", tmpPath),
zap.String("filePath", d.filePath), zap.String("filePath", d.filePath),
@ -181,27 +204,37 @@ func (d *CodexIndexDownloader) DownloadIndexFile() {
} }
// Mark download as complete only after successful rename // Mark download as complete only after successful rename
d.mu.Lock()
d.downloadComplete = true d.downloadComplete = true
d.mu.Unlock()
}() }()
} }
// BytesCompleted returns the number of bytes downloaded so far // BytesCompleted returns the number of bytes downloaded so far
func (d *CodexIndexDownloader) BytesCompleted() int64 { func (d *CodexIndexDownloader) BytesCompleted() int64 {
d.mu.RLock()
defer d.mu.RUnlock()
return d.bytesCompleted return d.bytesCompleted
} }
// IsDownloadComplete returns true when the file has been fully downloaded and saved to disk // IsDownloadComplete returns true when the file has been fully downloaded and saved to disk
func (d *CodexIndexDownloader) IsDownloadComplete() bool { func (d *CodexIndexDownloader) IsDownloadComplete() bool {
d.mu.RLock()
defer d.mu.RUnlock()
return d.downloadComplete return d.downloadComplete
} }
// GetError returns the last error that occurred during manifest fetch or download, or nil if no error // GetError returns the last error that occurred during manifest fetch or download, or nil if no error
func (d *CodexIndexDownloader) GetError() error { func (d *CodexIndexDownloader) GetError() error {
d.mu.RLock()
defer d.mu.RUnlock()
return d.downloadError return d.downloadError
} }
// Length returns the total dataset size (equivalent to torrent file length) // Length returns the total dataset size (equivalent to torrent file length)
func (d *CodexIndexDownloader) Length() int64 { func (d *CodexIndexDownloader) Length() int64 {
d.mu.RLock()
defer d.mu.RUnlock()
return d.datasetSize return d.datasetSize
} }
@ -209,12 +242,15 @@ func (d *CodexIndexDownloader) Length() int64 {
type progressWriter struct { type progressWriter struct {
writer io.Writer writer io.Writer
completed *int64 completed *int64
mu *sync.RWMutex
} }
func (pw *progressWriter) Write(p []byte) (n int, err error) { func (pw *progressWriter) Write(p []byte) (n int, err error) {
n, err = pw.writer.Write(p) n, err = pw.writer.Write(p)
if n > 0 { if n > 0 {
pw.mu.Lock()
*pw.completed += int64(n) *pw.completed += int64(n)
pw.mu.Unlock()
} }
return n, err return n, err
} }