summaryrefslogtreecommitdiff
path: root/bs4
diff options
context:
space:
mode:
Diffstat (limited to 'bs4')
-rw-r--r--bs4/element.py364
-rw-r--r--bs4/tests/test_tree.py27
2 files changed, 201 insertions, 190 deletions
diff --git a/bs4/element.py b/bs4/element.py
index e84d70f..9ce4535 100644
--- a/bs4/element.py
+++ b/bs4/element.py
@@ -1288,9 +1288,23 @@ class Tag(PageElement):
def select(self, selector, _candidate_generator=None, limit=None):
"""Perform a CSS selection operation on the current element."""
- # Remove whitespace directly after the grouping operator ','
- # then split into tokens.
- tokens = re.sub(',[\s]*',',', selector).split()
+ # Handle grouping selectors if ',' exists, ie: p,a
+ if ',' in selector:
+ context = []
+ for partial_selector in selector.split(','):
+ partial_selector = partial_selector.strip()
+ if partial_selector == '':
+ raise ValueError('Invalid group selection syntax: %s' % selector)
+ candidates = self.select(partial_selector, limit=limit)
+ for candidate in candidates:
+ if candidate not in context:
+ context.append(candidate)
+
+ if limit and len(context) >= limit:
+ break
+ return context
+
+ tokens = selector.split()
current_context = [self]
if tokens[-1] in self._selector_combinators:
@@ -1300,198 +1314,192 @@ class Tag(PageElement):
if self._select_debug:
print 'Running CSS selector "%s"' % selector
- for index, token_group in enumerate(tokens):
+ for index, token in enumerate(tokens):
new_context = []
new_context_ids = set([])
- # Grouping selectors, ie: p,a
- grouped_tokens = token_group.split(',')
- if '' in grouped_tokens:
- raise ValueError('Invalid group selection syntax: %s' % token_group)
-
if tokens[index-1] in self._selector_combinators:
# This token was consumed by the previous combinator. Skip it.
if self._select_debug:
print ' Token was consumed by the previous combinator.'
continue
- for token in grouped_tokens:
- if self._select_debug:
- print ' Considering token "%s"' % token
- recursive_candidate_generator = None
- tag_name = None
-
- # Each operation corresponds to a checker function, a rule
- # for determining whether a candidate matches the
- # selector. Candidates are generated by the active
- # iterator.
- checker = None
-
- m = self.attribselect_re.match(token)
- if m is not None:
- # Attribute selector
- tag_name, attribute, operator, value = m.groups()
- checker = self._attribute_checker(operator, attribute, value)
-
- elif '#' in token:
- # ID selector
- tag_name, tag_id = token.split('#', 1)
- def id_matches(tag):
- return tag.get('id', None) == tag_id
- checker = id_matches
-
- elif '.' in token:
- # Class selector
- tag_name, klass = token.split('.', 1)
- classes = set(klass.split('.'))
- def classes_match(candidate):
- return classes.issubset(candidate.get('class', []))
- checker = classes_match
-
- elif ':' in token:
- # Pseudo-class
- tag_name, pseudo = token.split(':', 1)
- if tag_name == '':
- raise ValueError(
- "A pseudo-class must be prefixed with a tag name.")
- pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
- found = []
- if pseudo_attributes is None:
- pseudo_type = pseudo
- pseudo_value = None
- else:
- pseudo_type, pseudo_value = pseudo_attributes.groups()
- if pseudo_type == 'nth-of-type':
- try:
- pseudo_value = int(pseudo_value)
- except:
- raise NotImplementedError(
- 'Only numeric values are currently supported for the nth-of-type pseudo-class.')
- if pseudo_value < 1:
- raise ValueError(
- 'nth-of-type pseudo-class value must be at least 1.')
- class Counter(object):
- def __init__(self, destination):
- self.count = 0
- self.destination = destination
-
- def nth_child_of_type(self, tag):
- self.count += 1
- if self.count == self.destination:
- return True
- if self.count > self.destination:
- # Stop the generator that's sending us
- # these things.
- raise StopIteration()
- return False
- checker = Counter(pseudo_value).nth_child_of_type
- else:
+ if self._select_debug:
+ print ' Considering token "%s"' % token
+ recursive_candidate_generator = None
+ tag_name = None
+
+ # Each operation corresponds to a checker function, a rule
+ # for determining whether a candidate matches the
+ # selector. Candidates are generated by the active
+ # iterator.
+ checker = None
+
+ m = self.attribselect_re.match(token)
+ if m is not None:
+ # Attribute selector
+ tag_name, attribute, operator, value = m.groups()
+ checker = self._attribute_checker(operator, attribute, value)
+
+ elif '#' in token:
+ # ID selector
+ tag_name, tag_id = token.split('#', 1)
+ def id_matches(tag):
+ return tag.get('id', None) == tag_id
+ checker = id_matches
+
+ elif '.' in token:
+ # Class selector
+ tag_name, klass = token.split('.', 1)
+ classes = set(klass.split('.'))
+ def classes_match(candidate):
+ return classes.issubset(candidate.get('class', []))
+ checker = classes_match
+
+ elif ':' in token:
+ # Pseudo-class
+ tag_name, pseudo = token.split(':', 1)
+ if tag_name == '':
+ raise ValueError(
+ "A pseudo-class must be prefixed with a tag name.")
+ pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
+ found = []
+ if pseudo_attributes is None:
+ pseudo_type = pseudo
+ pseudo_value = None
+ else:
+ pseudo_type, pseudo_value = pseudo_attributes.groups()
+ if pseudo_type == 'nth-of-type':
+ try:
+ pseudo_value = int(pseudo_value)
+ except:
raise NotImplementedError(
- 'Only the following pseudo-classes are implemented: nth-of-type.')
-
- elif token == '*':
- # Star selector -- matches everything
- pass
- elif token == '>':
- # Run the next token as a CSS selector against the
- # direct children of each tag in the current context.
- recursive_candidate_generator = lambda tag: tag.children
- elif token == '~':
- # Run the next token as a CSS selector against the
- # siblings of each tag in the current context.
- recursive_candidate_generator = lambda tag: tag.next_siblings
- elif token == '+':
- # For each tag in the current context, run the next
- # token as a CSS selector against the tag's next
- # sibling that's a tag.
- def next_tag_sibling(tag):
- yield tag.find_next_sibling(True)
- recursive_candidate_generator = next_tag_sibling
-
- elif self.tag_name_re.match(token):
- # Just a tag name.
- tag_name = token
+ 'Only numeric values are currently supported for the nth-of-type pseudo-class.')
+ if pseudo_value < 1:
+ raise ValueError(
+ 'nth-of-type pseudo-class value must be at least 1.')
+ class Counter(object):
+ def __init__(self, destination):
+ self.count = 0
+ self.destination = destination
+
+ def nth_child_of_type(self, tag):
+ self.count += 1
+ if self.count == self.destination:
+ return True
+ if self.count > self.destination:
+ # Stop the generator that's sending us
+ # these things.
+ raise StopIteration()
+ return False
+ checker = Counter(pseudo_value).nth_child_of_type
else:
- raise ValueError(
- 'Unsupported or invalid CSS selector: "%s"' % token)
- if recursive_candidate_generator:
- # This happens when the selector looks like "> foo".
- #
- # The generator calls select() recursively on every
- # member of the current context, passing in a different
- # candidate generator and a different selector.
- #
- # In the case of "> foo", the candidate generator is
- # one that yields a tag's direct children (">"), and
- # the selector is "foo".
- next_token = tokens[index+1]
- def recursive_select(tag):
- if self._select_debug:
- print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
- print '-' * 40
- for i in tag.select(next_token, recursive_candidate_generator):
- if self._select_debug:
- print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
- yield i
- if self._select_debug:
- print '-' * 40
- _use_candidate_generator = recursive_select
- elif _candidate_generator is None:
- # By default, a tag's candidates are all of its
- # children. If tag_name is defined, only yield tags
- # with that name.
+ raise NotImplementedError(
+ 'Only the following pseudo-classes are implemented: nth-of-type.')
+
+ elif token == '*':
+ # Star selector -- matches everything
+ pass
+ elif token == '>':
+ # Run the next token as a CSS selector against the
+ # direct children of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.children
+ elif token == '~':
+ # Run the next token as a CSS selector against the
+ # siblings of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.next_siblings
+ elif token == '+':
+ # For each tag in the current context, run the next
+ # token as a CSS selector against the tag's next
+ # sibling that's a tag.
+ def next_tag_sibling(tag):
+ yield tag.find_next_sibling(True)
+ recursive_candidate_generator = next_tag_sibling
+
+ elif self.tag_name_re.match(token):
+ # Just a tag name.
+ tag_name = token
+ else:
+ raise ValueError(
+ 'Unsupported or invalid CSS selector: "%s"' % token)
+ if recursive_candidate_generator:
+ # This happens when the selector looks like "> foo".
+ #
+ # The generator calls select() recursively on every
+ # member of the current context, passing in a different
+ # candidate generator and a different selector.
+ #
+ # In the case of "> foo", the candidate generator is
+ # one that yields a tag's direct children (">"), and
+ # the selector is "foo".
+ next_token = tokens[index+1]
+ def recursive_select(tag):
if self._select_debug:
- if tag_name:
- check = "[any]"
- else:
- check = tag_name
- print ' Default candidate generator, tag name="%s"' % check
+ print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
+ print '-' * 40
+ for i in tag.select(next_token, recursive_candidate_generator):
+ if self._select_debug:
+ print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
+ yield i
if self._select_debug:
- # This is redundant with later code, but it stops
- # a bunch of bogus tags from cluttering up the
- # debug log.
- def default_candidate_generator(tag):
- for child in tag.descendants:
- if not isinstance(child, Tag):
- continue
- if tag_name and not child.name == tag_name:
- continue
- yield child
- _use_candidate_generator = default_candidate_generator
+ print '-' * 40
+ _use_candidate_generator = recursive_select
+ elif _candidate_generator is None:
+ # By default, a tag's candidates are all of its
+ # children. If tag_name is defined, only yield tags
+ # with that name.
+ if self._select_debug:
+ if tag_name:
+ check = "[any]"
else:
- _use_candidate_generator = lambda tag: tag.descendants
+ check = tag_name
+ print ' Default candidate generator, tag name="%s"' % check
+ if self._select_debug:
+ # This is redundant with later code, but it stops
+ # a bunch of bogus tags from cluttering up the
+ # debug log.
+ def default_candidate_generator(tag):
+ for child in tag.descendants:
+ if not isinstance(child, Tag):
+ continue
+ if tag_name and not child.name == tag_name:
+ continue
+ yield child
+ _use_candidate_generator = default_candidate_generator
else:
- _use_candidate_generator = _candidate_generator
+ _use_candidate_generator = lambda tag: tag.descendants
+ else:
+ _use_candidate_generator = _candidate_generator
- count = 0
- for tag in current_context:
- if self._select_debug:
- print " Running candidate generator on %s %s" % (
- tag.name, repr(tag.attrs))
- for candidate in _use_candidate_generator(tag):
- if not isinstance(candidate, Tag):
- continue
- if tag_name and candidate.name != tag_name:
- continue
- if checker is not None:
- try:
- result = checker(candidate)
- except StopIteration:
- # The checker has decided we should no longer
- # run the generator.
+ count = 0
+ for tag in current_context:
+ if self._select_debug:
+ print " Running candidate generator on %s %s" % (
+ tag.name, repr(tag.attrs))
+ for candidate in _use_candidate_generator(tag):
+ if not isinstance(candidate, Tag):
+ continue
+ if tag_name and candidate.name != tag_name:
+ continue
+ if checker is not None:
+ try:
+ result = checker(candidate)
+ except StopIteration:
+ # The checker has decided we should no longer
+ # run the generator.
+ break
+ if checker is None or result:
+ if self._select_debug:
+ print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
+ if id(candidate) not in new_context_ids:
+ # If a tag matches a selector more than once,
+ # don't include it in the context more than once.
+ new_context.append(candidate)
+ new_context_ids.add(id(candidate))
+ if limit and len(new_context) >= limit:
break
- if checker is None or result:
- if self._select_debug:
- print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
- if id(candidate) not in new_context_ids:
- # If a tag matches a selector more than once,
- # don't include it in the context more than once.
- new_context.append(candidate)
- new_context_ids.add(id(candidate))
- if limit and len(new_context) >= limit:
- break
- elif self._select_debug:
- print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
+ elif self._select_debug:
+ print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
current_context = new_context
diff --git a/bs4/tests/test_tree.py b/bs4/tests/test_tree.py
index a5cf9e9..6b2a123 100644
--- a/bs4/tests/test_tree.py
+++ b/bs4/tests/test_tree.py
@@ -1971,22 +1971,25 @@ class TestSoupSelector(TreeTest):
# Test the selector grouping operator (the comma)
def test_multiple_select(self):
- self.assertSelects('x, y',['xid','yid'])
+ self.assertSelects('x, y', ['xid', 'yid'])
def test_multiple_select_with_no_space(self):
- self.assertSelects('x,y',['xid','yid'])
+ self.assertSelects('x,y', ['xid', 'yid'])
def test_multiple_select_with_more_space(self):
- self.assertSelects('x, y',['xid', 'yid'])
+ self.assertSelects('x, y', ['xid', 'yid'])
+
+ def test_multiple_select_duplicated(self):
+ self.assertSelects('x, x', ['xid'])
def test_multiple_select_sibling(self):
- self.assertSelects('x, y ~ p[lang=fr]',['lang-fr'])
+ self.assertSelects('x, y ~ p[lang=fr]', ['xid', 'lang-fr'])
- def test_multiple_select(self):
- self.assertSelects('x, y > z', ['zida', 'zidb', 'zidab', 'zidac'])
+ def test_multiple_select_tag_and_direct_descendant(self):
+ self.assertSelects('x, y > z', ['xid', 'zidb'])
- def test_multiple_select_direct_descendant(self):
- self.assertSelects('div > x, y, z', ['xid', 'yid'])
+ def test_multiple_select_direct_descendant_and_tags(self):
+ self.assertSelects('div > x, y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
def test_multiple_select_indirect_descendant(self):
self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
@@ -1995,14 +1998,14 @@ class TestSoupSelector(TreeTest):
self.assertRaises(ValueError, self.soup.select, ',x, y')
self.assertRaises(ValueError, self.soup.select, 'x,,y')
- def test_multiple_select(self):
- self.assertSelects('p[lang=en], p[lang=en-gb]',['lang-en','lang-en-gb'])
+ def test_multiple_select_attrs(self):
+ self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb'])
def test_multiple_select_ids(self):
- self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['zida', 'zidb','zidab'])
+ self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['xid', 'zidb', 'zidab'])
def test_multiple_select_nested(self):
- self.assertSelects('body > div > x, y > z', ['zida', 'zidb', 'zidab', 'zidac'])
+ self.assertSelects('body > div > x, y > z', ['xid', 'zidb'])