From b02f0ab72a2a6c808d644a9683b081683db408d3 Mon Sep 17 00:00:00 2001 From: WGH Date: Fri, 9 Oct 2020 21:06:52 +0300 Subject: [PATCH] Add context.Context support This allows one to pass context.Context to colly.Collector, which in turn will passed to stdlib's http.Request. This enables colly scraping to be cancelled cleanly at any time. --- colly.go | 19 ++++++++++++++- colly_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/colly.go b/colly.go index b5b9f4dc0..d8d6e1ee9 100644 --- a/colly.go +++ b/colly.go @@ -109,7 +109,11 @@ type Collector struct { CheckHead bool // TraceHTTP enables capturing and reporting request performance for crawler tuning. // When set to true, the Response.Trace will be filled in with an HTTPTrace object. - TraceHTTP bool + TraceHTTP bool + // Context is the context that will be used for HTTP requests. You can set this + // to support clean cancellation of scraping. + Context context.Context + store storage.Storage debugger debug.Debugger robotsMap map[string]*robotstxt.RobotsData @@ -357,6 +361,14 @@ func TraceHTTP() CollectorOption { } } +// StdlibContext sets the context that will be used for HTTP requests. +// You can set this to support clean cancellation of scraping. +func StdlibContext(ctx context.Context) CollectorOption { + return func(c *Collector) { + c.Context = ctx + } +} + // ID sets the unique identifier of the Collector. func ID(id uint32) CollectorOption { return func(c *Collector) { @@ -412,6 +424,7 @@ func (c *Collector) Init() { c.IgnoreRobotsTxt = true c.ID = atomic.AddUint32(&collectorCounter, 1) c.TraceHTTP = false + c.Context = context.Background() } // Appengine will replace the Collector's backend http.Client @@ -567,6 +580,9 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c Body: rc, Host: host, } + // note: once 1.13 is minimum supported Go version, + // replace this with http.NewRequestWithContext + req = req.WithContext(c.Context) setRequestBody(req, requestData) u = parsedURL.String() c.wg.Add(1) @@ -1239,6 +1255,7 @@ func (c *Collector) Clone() *Collector { ParseHTTPErrorResponse: c.ParseHTTPErrorResponse, UserAgent: c.UserAgent, TraceHTTP: c.TraceHTTP, + Context: c.Context, store: c.store, backend: c.backend, debugger: c.debugger, diff --git a/colly_test.go b/colly_test.go index e8f7ef8da..bb3f3305f 100644 --- a/colly_test.go +++ b/colly_test.go @@ -17,6 +17,7 @@ package colly import ( "bufio" "bytes" + "context" "fmt" "net/http" "net/http/httptest" @@ -26,6 +27,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/PuerkitoBio/goquery" @@ -166,6 +168,31 @@ func newTestServer() *httptest.Server { } }) + mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + i := 0 + + for { + select { + case <-r.Context().Done(): + return + case t := <-ticker.C: + fmt.Fprintf(w, "%s\n", t) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + i++ + if i == 10 { + return + } + } + } + }) + return httptest.NewServer(mux) } @@ -1128,6 +1155,43 @@ func TestCollectorDepth(t *testing.T) { } } +func TestCollectorContext(t *testing.T) { + // "/slow" takes 1 second to return the response. + // If context does abort the transfer after 0.5 seconds as it should, + // OnError will be called, and the test is passed. Otherwise, test is failed. + + ts := newTestServer() + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + c := NewCollector(StdlibContext(ctx)) + + onErrorCalled := false + + c.OnResponse(func(resp *Response) { + t.Error("OnResponse was called, expected OnError") + }) + + c.OnError(func(resp *Response, err error) { + onErrorCalled = true + if err != context.DeadlineExceeded { + t.Errorf("OnError got err=%#v, expected context.DeadlineExceeded", err) + } + }) + + err := c.Visit(ts.URL + "/slow") + if err != context.DeadlineExceeded { + t.Errorf("Visit return err=%#v, expected context.DeadlineExceeded", err) + } + + if !onErrorCalled { + t.Error("OnError was not called") + } + +} + func BenchmarkOnHTML(b *testing.B) { ts := newTestServer() defer ts.Close()