diff --git a/scrapy_autounit/middleware.py b/scrapy_autounit/middleware.py index 551b727..ef1f636 100644 --- a/scrapy_autounit/middleware.py +++ b/scrapy_autounit/middleware.py @@ -6,7 +6,6 @@ from scrapy.exceptions import NotConfigured from scrapy.commands.genspider import sanitize_module_name -from scrapy.spiders import CrawlSpider from .utils import ( add_sample, @@ -19,6 +18,7 @@ create_dir, parse_callback_result, clear_fixtures, + get_filter_attrs, ) logger = logging.getLogger(__name__) @@ -73,15 +73,12 @@ def from_crawler(cls, crawler): return cls(crawler) def process_spider_input(self, response, spider): - filter_args = {'crawler', 'settings', 'start_urls'} - if isinstance(spider, CrawlSpider): - filter_args |= {'rules', '_rules'} response.meta['_autounit'] = pickle.dumps({ 'request': parse_request(response.request, spider), 'response': response_to_dict(response), 'spider_args': { k: v for k, v in spider.__dict__.items() - if k not in filter_args + if k not in get_filter_attrs(spider) }, 'middlewares': get_middlewares(spider), }) @@ -98,7 +95,7 @@ def process_spider_output(self, response, result, spider): callback_name = request['callback'] spider_attr_out = { k: v for k, v in spider.__dict__.items() - if k not in ('crawler', 'settings', 'start_urls') + if k not in get_filter_attrs(spider) } data = { diff --git a/scrapy_autounit/utils.py b/scrapy_autounit/utils.py index 24f3077..be12882 100644 --- a/scrapy_autounit/utils.py +++ b/scrapy_autounit/utils.py @@ -10,6 +10,7 @@ import six from scrapy import signals from scrapy.crawler import Crawler +from scrapy.spiders import CrawlSpider from scrapy.exceptions import NotConfigured from scrapy.http import Request, Response from scrapy.item import Item @@ -100,6 +101,13 @@ def get_or_create_test_dir(base_path, spider_name, callback_name, extra=None): return test_dir, test_name +def get_filter_attrs(spider): + attrs = {'crawler', 'settings', 'start_urls'} + if isinstance(spider, CrawlSpider): + attrs |= {'rules', '_rules'} + return attrs + + def add_sample(index, test_dir, test_name, data): encoding = data['response']['encoding'] filename = 'fixture%s.bin' % str(index) @@ -182,6 +190,10 @@ def parse_request(request, spider): _request = request_to_dict(request, spider=spider) if not _request['callback']: _request['callback'] = 'parse' + elif isinstance(spider, CrawlSpider): + rule = request.meta.get('rule') + if rule is not None: + _request['callback'] = spider.rules[rule].callback clean_headers(_request['headers'], spider.settings) @@ -375,7 +387,7 @@ def test(self): ) result_attr_in = { k: v for k, v in spider.__dict__.items() - if k not in ('crawler', 'settings', 'start_urls') + if k not in get_filter_attrs(spider) } self.assertEqual(spider_args_in, result_attr_in, 'Input arguments not equal!') @@ -427,7 +439,7 @@ def test(self): # Spider attributes get updated after the yield result_attr_out = { k: v for k, v in spider.__dict__.items() - if k not in ('crawler', 'settings', 'start_urls') + if k not in get_filter_attrs(spider) } self.assertEqual(data['spider_args_out'], result_attr_out,