Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 31 additions & 82 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"hash/fnv"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/cookiejar"
Expand Down Expand Up @@ -562,40 +561,28 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
if err != nil {
return err
}
if err := c.requestCheck(u, parsedURL, method, requestData, depth, checkRevisit); err != nil {
return err
}

if hdr == nil {
hdr = http.Header{}
}
if _, ok := hdr["User-Agent"]; !ok {
hdr.Set("User-Agent", c.UserAgent)
}
rc, ok := requestData.(io.ReadCloser)
if !ok && requestData != nil {
rc = ioutil.NopCloser(requestData)
req, err := http.NewRequest(method, parsedURL.String(), requestData)
if err != nil {
return err
}
req.Header = hdr
// The Go HTTP API ignores "Host" in the headers, preferring the client
// to use the Host field on Request.
host := parsedURL.Host
if hostHeader := hdr.Get("Host"); hostHeader != "" {
host = hostHeader
}
req := &http.Request{
Method: method,
URL: parsedURL,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: hdr,
Body: rc,
Host: host,
req.Host = hostHeader
}
// note: once 1.13 is minimum supported Go version,
// replace this with http.NewRequestWithContext
req = req.WithContext(c.Context)
setRequestBody(req, requestData)
if err := c.requestCheck(u, parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably I'm missing something, but I can't find where the req.GetBody function is implemented. Pls help. =]

Copy link
Collaborator Author

@WGH- WGH- Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetBody is set by NewRequest: https://cs.opensource.google/go/go/+/refs/tags/go1.17.8:src/net/http/request.go;l=889-930;drc=refs%2Ftags%2Fgo1.17.8

This code fragment used to be copy-pasted in Colly (setRequestBody) and was removed in my previous commit.

return err
}
u = parsedURL.String()
c.wg.Add(1)
if c.Async {
Expand All @@ -605,38 +592,6 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
return c.fetch(u, method, depth, requestData, ctx, hdr, req)
}

func setRequestBody(req *http.Request, body io.Reader) {
if body != nil {
switch v := body.(type) {
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
buf := v.Bytes()
req.GetBody = func() (io.ReadCloser, error) {
r := bytes.NewReader(buf)
return ioutil.NopCloser(r), nil
}
case *bytes.Reader:
req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return ioutil.NopCloser(&r), nil
}
case *strings.Reader:
req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return ioutil.NopCloser(&r), nil
}
}
if req.GetBody != nil && req.ContentLength == 0 {
req.Body = http.NoBody
req.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
}
}
}

func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ctx *Context, hdr http.Header, req *http.Request) error {
defer c.wg.Done()
if ctx == nil {
Expand Down Expand Up @@ -715,7 +670,7 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
return err
}

func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, requestData io.Reader, depth int, checkRevisit bool) error {
func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
if u == "" {
return ErrMissingURL
}
Expand All @@ -731,19 +686,23 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, re
}
}
if checkRevisit && !c.AllowURLRevisit {
h := fnv.New64a()
h.Write([]byte(u))

var uHash uint64
if method == "GET" {
uHash = h.Sum64()
} else if requestData != nil {
h.Write(streamToByte(requestData))
uHash = h.Sum64()
} else {
// TODO weird behaviour, it allows CheckHead to work correctly,
// but it should probably better be solved with
// "check-but-not-save" flag or something
if method != "GET" && getBody == nil {
return nil
}

var body io.ReadCloser
if getBody != nil {
var err error
body, err = getBody()
if err != nil {
return err
}
defer body.Close()
}
uHash := requestHash(u, body)
visited, err := c.store.IsVisited(uHash)
if err != nil {
return err
Expand Down Expand Up @@ -1343,14 +1302,8 @@ func (c *Collector) parseSettingsFromEnv() {
}

func (c *Collector) checkHasVisited(URL string, requestData map[string]string) (bool, error) {
h := fnv.New64a()
h.Write([]byte(URL))

if requestData != nil {
h.Write(streamToByte(createFormReader(requestData)))
}

return c.store.IsVisited(h.Sum64())
hash := requestHash(URL, createFormReader(requestData))
return c.store.IsVisited(hash)
}

// SanitizeFileName replaces dangerous characters in a string
Expand Down Expand Up @@ -1462,15 +1415,11 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
return false
}

func streamToByte(r io.Reader) []byte {
buf := new(bytes.Buffer)
buf.ReadFrom(r)

if strReader, k := r.(*strings.Reader); k {
strReader.Seek(0, 0)
} else if bReader, kb := r.(*bytes.Reader); kb {
bReader.Seek(0, 0)
func requestHash(url string, body io.Reader) uint64 {
h := fnv.New64a()
h.Write([]byte(url))
if body != nil {
io.Copy(h, body)
}

return buf.Bytes()
return h.Sum64()
}