Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support to skip tls when download file #348

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions cmd/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func newGetCmd(ctx context.Context) (cmd *cobra.Command) {
"Same with option --accept-preRelease")
flags.BoolVarP(&opt.Force, "force", "f", false, "Overwrite the exist file if this is true")
flags.IntVarP(&opt.Mod, "mod", "", -1, "The file permission, -1 means using the system default")
flags.BoolVarP(&opt.SkipTLS, "skip-tls", "k", false, "Skip the TLS")

flags.IntVarP(&opt.Timeout, "time", "", 10,
`The default timeout in seconds with the HTTP request`)
Expand Down Expand Up @@ -106,6 +107,7 @@ type downloadOption struct {
Magnet bool
Force bool
Mod int
SkipTLS bool

ContinueAt int64

Expand Down Expand Up @@ -297,24 +299,44 @@ func (o *downloadOption) runE(cmd *cobra.Command, args []string) (err error) {
targetURL = strings.Replace(targetURL, "raw.githubusercontent.com", fmt.Sprintf("%s/https://raw.githubusercontent.com", o.ProxyGitHub), 1)
}
logger.Printf("start to download from %s\n", targetURL)
var suggestedFilenameAware net.SuggestedFilenameAware
if o.Thread <= 1 {
downloader := &net.ContinueDownloader{}
suggestedFilenameAware = downloader
downloader.WithoutProxy(o.NoProxy).
WithRoundTripper(o.RoundTripper)
WithRoundTripper(o.RoundTripper).
WithInsecureSkipVerify(o.SkipTLS)
err = downloader.DownloadWithContinue(targetURL, o.Output, o.ContinueAt, -1, 0, o.ShowProgress)
} else {
downloader := &net.MultiThreadDownloader{}
suggestedFilenameAware = downloader
downloader.WithKeepParts(o.KeepPart).
WithShowProgress(o.ShowProgress).
WithoutProxy(o.NoProxy).
WithRoundTripper(o.RoundTripper)
WithRoundTripper(o.RoundTripper).
WithInsecureSkipVerify(o.SkipTLS)
err = downloader.Download(targetURL, o.Output, o.Thread)
}

// set file permission
if o.Mod != -1 {
err = sysos.Chmod(o.Output, fs.FileMode(o.Mod))
}

if err == nil {
logger.Printf("downloaded: %s\n", o.Output)
}

if suggested := suggestedFilenameAware.GetSuggestedFilename(); suggested != "" {
confirm := &survey.Confirm{
Message: fmt.Sprintf("Do you want to rename filename from '%s' to '%s'?", o.Output, suggested),
}
var yes bool
if confirmErr := survey.AskOne(confirm, &yes); confirmErr == nil && yes {
fmt.Println("rename")
err = sysos.Rename(o.Output, suggested)
}
}
return
}

Expand Down
187 changes: 47 additions & 140 deletions pkg/net/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"os"
"path"
"strconv"
"sync"
"strings"
"time"

"github.com/linuxsuren/http-downloader/pkg/common"
Expand Down Expand Up @@ -50,6 +50,7 @@ type HTTPDownloader struct {
Debug bool
RoundTripper http.RoundTripper
progressIndicator *ProgressIndicator
suggestedFilename string
}

// SetProxy set the proxy for a http
Expand Down Expand Up @@ -150,6 +151,14 @@ func (h *HTTPDownloader) DownloadFile() error {
}
}

if disposition, ok := resp.Header["Content-Disposition"]; ok && len(disposition) >= 1 {
h.suggestedFilename = strings.TrimPrefix(disposition[0], `filename="`)
h.suggestedFilename = strings.TrimSuffix(h.suggestedFilename, `"`)
if h.suggestedFilename == filepath {
h.suggestedFilename = ""
}
}

// pre-hook before get started to download file
if h.PreStart != nil && !h.PreStart(resp) {
return nil
Expand Down Expand Up @@ -192,127 +201,15 @@ func (h *HTTPDownloader) DownloadFile() error {
return err
}

// DownloadFileWithMultipleThread downloads the files with multiple threads
func DownloadFileWithMultipleThread(targetURL, targetFilePath string, thread int, showProgress bool) (err error) {
return DownloadFileWithMultipleThreadKeepParts(targetURL, targetFilePath, thread, false, showProgress)
// GetSuggestedFilename returns the suggested filename which comes from the HTTP response header.
// Returns empty string if the filename is same with the given name.
func (h *HTTPDownloader) GetSuggestedFilename() string {
return h.suggestedFilename
}

// MultiThreadDownloader is a download with multi-thread
type MultiThreadDownloader struct {
noProxy bool
keepParts, showProgress bool

roundTripper http.RoundTripper
}

// WithoutProxy indicates not use HTTP proxy
func (d *MultiThreadDownloader) WithoutProxy(noProxy bool) *MultiThreadDownloader {
d.noProxy = noProxy
return d
}

// WithShowProgress indicate if show the download progress
func (d *MultiThreadDownloader) WithShowProgress(showProgress bool) *MultiThreadDownloader {
d.showProgress = showProgress
return d
}

// WithKeepParts indicates if keeping the part files
func (d *MultiThreadDownloader) WithKeepParts(keepParts bool) *MultiThreadDownloader {
d.keepParts = keepParts
return d
}

// WithRoundTripper sets RoundTripper
func (d *MultiThreadDownloader) WithRoundTripper(roundTripper http.RoundTripper) *MultiThreadDownloader {
d.roundTripper = roundTripper
return d
}

// Download starts to download the target URL
func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) {
// get the total size of the target file
var total int64
var rangeSupport bool
if total, rangeSupport, err = DetectSizeWithRoundTripper(targetURL, targetFilePath, true, d.noProxy, d.roundTripper); err != nil {
return
}

if rangeSupport {
unit := total / int64(thread)
offset := total - unit*int64(thread)
var wg sync.WaitGroup
var partItems []string
var m sync.Mutex

defer func() {
// remove all partial files
for _, part := range partItems {
_ = os.RemoveAll(part)
}
}()

fmt.Printf("start to download with %d threads, size: %d, unit: %d\n", thread, total, unit)
for i := 0; i < thread; i++ {
wg.Add(1)
go func(index int, wg *sync.WaitGroup) {
defer wg.Done()
output := fmt.Sprintf("%s-%d", targetFilePath, index)

m.Lock()
partItems = append(partItems, output)
m.Unlock()

end := unit*int64(index+1) - 1
if index == thread-1 {
// this is the last part
end += offset
}
start := unit * int64(index)

downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy).
WithRoundTripper(d.roundTripper)
if downloadErr := downloader.DownloadWithContinue(targetURL, output,
int64(index), start, end, d.showProgress); downloadErr != nil {
fmt.Println(downloadErr)
}
}(i, &wg)
}

wg.Wait()
ProgressIndicator{}.Close()

// concat all these partial files
var f *os.File
if f, err = os.OpenFile(targetFilePath, os.O_CREATE|os.O_WRONLY, 0600); err == nil {
defer func() {
_ = f.Close()
}()

for i := 0; i < thread; i++ {
partFile := fmt.Sprintf("%s-%d", targetFilePath, i)
if data, ferr := os.ReadFile(partFile); ferr == nil {
if _, err = f.Write(data); err != nil {
err = fmt.Errorf("failed to write file: '%s'", partFile)
break
} else if !d.keepParts {
_ = os.RemoveAll(partFile)
}
} else {
err = fmt.Errorf("failed to read file: '%s'", partFile)
break
}
}
}
} else {
fmt.Println("cannot download it using multiple threads, failed to one")
downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy)
downloader.WithRoundTripper(d.roundTripper)
err = downloader.DownloadWithContinue(targetURL, targetFilePath, -1, 0, 0, true)
}
return
// SuggestedFilenameAware is the interface for getting suggested filename
type SuggestedFilenameAware interface {
GetSuggestedFilename() string
}

// DownloadFileWithMultipleThreadKeepParts downloads the files with multiple threads
Expand All @@ -326,8 +223,14 @@ func DownloadFileWithMultipleThreadKeepParts(targetURL, targetFilePath string, t
type ContinueDownloader struct {
downloader *HTTPDownloader

roundTripper http.RoundTripper
noProxy bool
roundTripper http.RoundTripper
noProxy bool
insecureSkipVerify bool
}

// GetSuggestedFilename returns the suggested filename
func (c *ContinueDownloader) GetSuggestedFilename() string {
return c.downloader.GetSuggestedFilename()
}

// WithRoundTripper set WithRoundTripper
Expand All @@ -342,14 +245,21 @@ func (c *ContinueDownloader) WithoutProxy(noProxy bool) *ContinueDownloader {
return c
}

// WithInsecureSkipVerify set if skip the insecure verify
func (c *ContinueDownloader) WithInsecureSkipVerify(insecureSkipVerify bool) *ContinueDownloader {
c.insecureSkipVerify = insecureSkipVerify
return c
}

// DownloadWithContinue downloads the files continuously
func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) {
c.downloader = &HTTPDownloader{
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
NoProxy: c.noProxy,
RoundTripper: c.roundTripper,
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
NoProxy: c.noProxy,
RoundTripper: c.roundTripper,
InsecureSkipVerify: c.insecureSkipVerify,
}
if index >= 0 {
c.downloader.Title = fmt.Sprintf("Downloading part %d", index)
Expand All @@ -371,21 +281,16 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde
return
}

// DetectSize returns the size of target resource
//
// Deprecated, use DetectSizeWithRoundTripper instead
func DetectSize(targetURL, output string, showProgress bool) (int64, bool, error) {
return DetectSizeWithRoundTripper(targetURL, output, showProgress, false, nil)
}

// DetectSizeWithRoundTripper returns the size of target resource
func DetectSizeWithRoundTripper(targetURL, output string, showProgress bool, noProxy bool, roundTripper http.RoundTripper) (total int64, rangeSupport bool, err error) {
func DetectSizeWithRoundTripper(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool,
roundTripper http.RoundTripper) (total int64, rangeSupport bool, err error) {
downloader := HTTPDownloader{
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
RoundTripper: roundTripper,
NoProxy: false, // below HTTP request does not need proxy
TargetFilePath: output,
URL: targetURL,
ShowProgress: showProgress,
RoundTripper: roundTripper,
NoProxy: false, // below HTTP request does not need proxy
InsecureSkipVerify: insecureSkipVerify,
}

var detectOffset int64
Expand All @@ -400,6 +305,8 @@ func DetectSizeWithRoundTripper(targetURL, output string, showProgress bool, noP
contentLen := resp.Header.Get("Content-Length")
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
total += detectOffset
} else {
rangeSupport = false
}
// always return false because we just want to get the header from response
return false
Expand Down
Loading