diff --git a/cmd/torrent/download.go b/cmd/torrent/download.go index 0b5d4d58..123b3659 100644 --- a/cmd/torrent/download.go +++ b/cmd/torrent/download.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "expvar" "fmt" @@ -15,7 +16,6 @@ import ( "time" "github.com/anacrolix/log" - "github.com/anacrolix/missinggo/v2" "github.com/anacrolix/tagflag" "github.com/anacrolix/torrent" "github.com/anacrolix/torrent/iplist" @@ -89,7 +89,7 @@ func resolveTestPeers(addrs []string) (ret []torrent.PeerInfo) { return } -func addTorrents(client *torrent.Client, flags downloadFlags) error { +func addTorrents(ctx context.Context, client *torrent.Client, flags downloadFlags, wg *sync.WaitGroup) error { testPeers := resolveTestPeers(flags.TestPeer) for _, arg := range flags.Torrent { t, err := func() (*torrent.Torrent, error) { @@ -137,10 +137,30 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error { torrentBar(t, flags.PieceStates) } t.AddPeers(testPeers) + wg.Add(1) go func() { - <-t.GotInfo() + defer wg.Done() + select { + case <-ctx.Done(): + return + case <-t.GotInfo(): + } + if flags.SaveMetainfos { + path := fmt.Sprintf("%v.torrent", t.InfoHash().HexString()) + err := writeMetainfoToFile(t.Metainfo(), path) + if err == nil { + log.Printf("wrote %q", path) + } else { + log.Printf("error writing %q: %v", path, err) + } + } if len(flags.File) == 0 { t.DownloadAll() + wg.Add(1) + go func() { + defer wg.Done() + waitForPieces(ctx, t, 0, t.NumPieces()) + }() if flags.LinearDiscard { r := t.NewReader() io.Copy(io.Discard, r) @@ -150,6 +170,11 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error { for _, f := range t.Files() { for _, fileArg := range flags.File { if f.DisplayPath() == fileArg { + wg.Add(1) + go func() { + defer wg.Done() + waitForPieces(ctx, t, f.BeginPieceIndex(), f.EndPieceIndex()) + }() f.Download() if flags.LinearDiscard { r := f.NewReader() @@ -167,12 +192,52 @@ func addTorrents(client *torrent.Client, flags downloadFlags) error { return nil } +func waitForPieces(ctx context.Context, t *torrent.Torrent, beginIndex, endIndex int) { + sub := t.SubscribePieceStateChanges() + defer sub.Close() + pending := make(map[int]struct{}) + for i := beginIndex; i < endIndex; i++ { + pending[i] = struct{}{} + } + expected := storage.Completion{ + Complete: true, + Ok: true, + } + for { + select { + case ev := <-sub.Values: + if ev.Completion == expected { + delete(pending, ev.Index) + } + if len(pending) == 0 { + return + } + case <-ctx.Done(): + return + } + } +} + +func writeMetainfoToFile(mi metainfo.MetaInfo, path string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0640) + if err != nil { + return err + } + defer f.Close() + err = mi.Write(f) + if err != nil { + return err + } + return f.Close() +} + type downloadFlags struct { Debug bool DownloadCmd } type DownloadCmd struct { + SaveMetainfos bool Mmap bool `help:"memory-map torrent data"` Seed bool `help:"seed after download is complete"` Addr string `help:"network listen addr"` @@ -211,15 +276,6 @@ func statsEnabled(flags downloadFlags) bool { return flags.Stats } -func exitSignalHandlers(notify *missinggo.SynchronizedEvent) { - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) - for { - log.Printf("close signal received: %+v", <-c) - notify.Set() - } -} - func downloadErr(flags downloadFlags) error { clientConfig := torrent.NewDefaultClientConfig() clientConfig.DisableWebseeds = flags.DisableWebseeds @@ -269,35 +325,29 @@ func downloadErr(flags downloadFlags) error { } clientConfig.MaxUnverifiedBytes = flags.MaxUnverifiedBytes.Int64() - var stop missinggo.SynchronizedEvent - defer func() { - stop.Set() - }() + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() client, err := torrent.NewClient(clientConfig) if err != nil { return fmt.Errorf("creating client: %w", err) } - var clientClose sync.Once // In certain situations, close was being called more than once. - defer clientClose.Do(func() { client.Close() }) - go exitSignalHandlers(&stop) - go func() { - <-stop.C() - clientClose.Do(func() { client.Close() }) - }() + defer client.Close() // Write status on the root path on the default HTTP muxer. This will be bound to localhost // somewhere if GOPPROF is set, thanks to the envpprof import. http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { client.WriteStatus(w) }) - err = addTorrents(client, flags) - started := time.Now() + var wg sync.WaitGroup + err = addTorrents(ctx, client, flags, &wg) if err != nil { return fmt.Errorf("adding torrents: %w", err) } + started := time.Now() defer outputStats(client, flags) - if client.WaitAll() { + wg.Wait() + if ctx.Err() == nil { log.Print("downloaded ALL the torrents") } else { err = errors.New("y u no complete torrents?!") @@ -314,7 +364,7 @@ func downloadErr(flags downloadFlags) error { log.Print("no torrents to seed") } else { outputStats(client, flags) - <-stop.C() + <-ctx.Done() } } spew.Dump(expvar.Get("torrent").(*expvar.Map).Get("chunks received"))