diff --git a/misc/goutil.go b/misc/goutil.go index 29b71d2..f7d68fc 100644 --- a/misc/goutil.go +++ b/misc/goutil.go @@ -2,6 +2,7 @@ package misc import ( "bufio" + "bytes" "errors" "flag" "fmt" @@ -313,30 +314,33 @@ func Notify(head string, body string) error { return nil } -func DownloadAll(url string) ([]byte, error) { - resp, err := http.Get(url) - if err != nil { - return nil, errors.New("Download: " + err.Error()) - } - defer resp.Body.Close() - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, errors.New("Download: " + err.Error()) - } - return data, err +func DownloadAll(url string, limit int64) ([]byte, error) { + var buf bytes.Buffer + err := download(url, &buf, limit) + return buf.Bytes(), err } -func DownloadToFile(url string, dest string) error { - resp, err := http.Get(url) - if err != nil { - return errors.New("DownloadToFile: " + err.Error()) - } - defer resp.Body.Close() +func DownloadToFile(url string, dest string, limit int64) error { file, err := os.Create(dest) if err != nil { return errors.New("DownloadToFile: " + err.Error()) } defer file.Close() - _, err = io.Copy(file, resp.Body) + return download(url, file, limit) +} + +func download(url string, dest io.Writer, limit int64) error { + resp, err := http.Get(url) + if err != nil { + return errors.New("download: " + err.Error()) + } + defer resp.Body.Close() + var reader io.Reader + if limit > 0 { + reader = io.LimitReader(resp.Body, limit) + } else { + reader = resp.Body + } + _, err = io.Copy(dest, reader) return err }