summaryrefslogtreecommitdiff
path: root/bs4
diff options
context:
space:
mode:
Diffstat (limited to 'bs4')
-rw-r--r--bs4/element.py350
-rw-r--r--bs4/tests/test_tree.py46
2 files changed, 228 insertions, 168 deletions
diff --git a/bs4/element.py b/bs4/element.py
index da9afdf..197722d 100644
--- a/bs4/element.py
+++ b/bs4/element.py
@@ -1203,192 +1203,206 @@ class Tag(PageElement):
_select_debug = False
def select(self, selector, _candidate_generator=None):
"""Perform a CSS selection operation on the current element."""
- tokens = selector.split()
+
+ # Remove whitespace directly after the grouping operator ','
+ # then split into tokens.
+ tokens = re.sub(',[\s]*',',', selector).split()
current_context = [self]
if tokens[-1] in self._selector_combinators:
raise ValueError(
'Final combinator "%s" is missing an argument.' % tokens[-1])
+
if self._select_debug:
print 'Running CSS selector "%s"' % selector
- for index, token in enumerate(tokens):
- if self._select_debug:
- print ' Considering token "%s"' % token
- recursive_candidate_generator = None
- tag_name = None
+
+ for index, token_group in enumerate(tokens):
+ new_context = []
+ new_context_ids = set([])
+
+ # Grouping selectors, ie: p,a
+ grouped_tokens = token_group.split(',')
+ if '' in grouped_tokens:
+ raise ValueError('Invalid group selection syntax: %s' % token_group)
+
if tokens[index-1] in self._selector_combinators:
# This token was consumed by the previous combinator. Skip it.
if self._select_debug:
print ' Token was consumed by the previous combinator.'
continue
- # Each operation corresponds to a checker function, a rule
- # for determining whether a candidate matches the
- # selector. Candidates are generated by the active
- # iterator.
- checker = None
-
- m = self.attribselect_re.match(token)
- if m is not None:
- # Attribute selector
- tag_name, attribute, operator, value = m.groups()
- checker = self._attribute_checker(operator, attribute, value)
-
- elif '#' in token:
- # ID selector
- tag_name, tag_id = token.split('#', 1)
- def id_matches(tag):
- return tag.get('id', None) == tag_id
- checker = id_matches
-
- elif '.' in token:
- # Class selector
- tag_name, klass = token.split('.', 1)
- classes = set(klass.split('.'))
- def classes_match(candidate):
- return classes.issubset(candidate.get('class', []))
- checker = classes_match
-
- elif ':' in token:
- # Pseudo-class
- tag_name, pseudo = token.split(':', 1)
- if tag_name == '':
- raise ValueError(
- "A pseudo-class must be prefixed with a tag name.")
- pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
- found = []
- if pseudo_attributes is not None:
- pseudo_type, pseudo_value = pseudo_attributes.groups()
- if pseudo_type == 'nth-of-type':
- try:
- pseudo_value = int(pseudo_value)
- except:
- raise NotImplementedError(
- 'Only numeric values are currently supported for the nth-of-type pseudo-class.')
- if pseudo_value < 1:
- raise ValueError(
- 'nth-of-type pseudo-class value must be at least 1.')
- class Counter(object):
- def __init__(self, destination):
- self.count = 0
- self.destination = destination
-
- def nth_child_of_type(self, tag):
- self.count += 1
- if self.count == self.destination:
- return True
- if self.count > self.destination:
- # Stop the generator that's sending us
- # these things.
- raise StopIteration()
- return False
- checker = Counter(pseudo_value).nth_child_of_type
- else:
- raise NotImplementedError(
- 'Only the following pseudo-classes are implemented: nth-of-type.')
-
- elif token == '*':
- # Star selector -- matches everything
- pass
- elif token == '>':
- # Run the next token as a CSS selector against the
- # direct children of each tag in the current context.
- recursive_candidate_generator = lambda tag: tag.children
- elif token == '~':
- # Run the next token as a CSS selector against the
- # siblings of each tag in the current context.
- recursive_candidate_generator = lambda tag: tag.next_siblings
- elif token == '+':
- # For each tag in the current context, run the next
- # token as a CSS selector against the tag's next
- # sibling that's a tag.
- def next_tag_sibling(tag):
- yield tag.find_next_sibling(True)
- recursive_candidate_generator = next_tag_sibling
-
- elif self.tag_name_re.match(token):
- # Just a tag name.
- tag_name = token
- else:
- raise ValueError(
- 'Unsupported or invalid CSS selector: "%s"' % token)
- if recursive_candidate_generator:
- # This happens when the selector looks like "> foo".
- #
- # The generator calls select() recursively on every
- # member of the current context, passing in a different
- # candidate generator and a different selector.
- #
- # In the case of "> foo", the candidate generator is
- # one that yields a tag's direct children (">"), and
- # the selector is "foo".
- next_token = tokens[index+1]
- def recursive_select(tag):
- if self._select_debug:
- print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
- print '-' * 40
- for i in tag.select(next_token, recursive_candidate_generator):
+ for token in grouped_tokens:
+ if self._select_debug:
+ print ' Considering token "%s"' % token
+ recursive_candidate_generator = None
+ tag_name = None
+
+ # Each operation corresponds to a checker function, a rule
+ # for determining whether a candidate matches the
+ # selector. Candidates are generated by the active
+ # iterator.
+ checker = None
+
+ m = self.attribselect_re.match(token)
+ if m is not None:
+ # Attribute selector
+ tag_name, attribute, operator, value = m.groups()
+ checker = self._attribute_checker(operator, attribute, value)
+
+ elif '#' in token:
+ # ID selector
+ tag_name, tag_id = token.split('#', 1)
+ def id_matches(tag):
+ return tag.get('id', None) == tag_id
+ checker = id_matches
+
+ elif '.' in token:
+ # Class selector
+ tag_name, klass = token.split('.', 1)
+ classes = set(klass.split('.'))
+ def classes_match(candidate):
+ return classes.issubset(candidate.get('class', []))
+ checker = classes_match
+
+ elif ':' in token:
+ # Pseudo-class
+ tag_name, pseudo = token.split(':', 1)
+ if tag_name == '':
+ raise ValueError(
+ "A pseudo-class must be prefixed with a tag name.")
+ pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
+ found = []
+ if pseudo_attributes is not None:
+ pseudo_type, pseudo_value = pseudo_attributes.groups()
+ if pseudo_type == 'nth-of-type':
+ try:
+ pseudo_value = int(pseudo_value)
+ except:
+ raise NotImplementedError(
+ 'Only numeric values are currently supported for the nth-of-type pseudo-class.')
+ if pseudo_value < 1:
+ raise ValueError(
+ 'nth-of-type pseudo-class value must be at least 1.')
+ class Counter(object):
+ def __init__(self, destination):
+ self.count = 0
+ self.destination = destination
+
+ def nth_child_of_type(self, tag):
+ self.count += 1
+ if self.count == self.destination:
+ return True
+ if self.count > self.destination:
+ # Stop the generator that's sending us
+ # these things.
+ raise StopIteration()
+ return False
+ checker = Counter(pseudo_value).nth_child_of_type
+ else:
+ raise NotImplementedError(
+ 'Only the following pseudo-classes are implemented: nth-of-type.')
+
+ elif token == '*':
+ # Star selector -- matches everything
+ pass
+ elif token == '>':
+ # Run the next token as a CSS selector against the
+ # direct children of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.children
+ elif token == '~':
+ # Run the next token as a CSS selector against the
+ # siblings of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.next_siblings
+ elif token == '+':
+ # For each tag in the current context, run the next
+ # token as a CSS selector against the tag's next
+ # sibling that's a tag.
+ def next_tag_sibling(tag):
+ yield tag.find_next_sibling(True)
+ recursive_candidate_generator = next_tag_sibling
+
+ elif self.tag_name_re.match(token):
+ # Just a tag name.
+ tag_name = token
+ else:
+ raise ValueError(
+ 'Unsupported or invalid CSS selector: "%s"' % token)
+ if recursive_candidate_generator:
+ # This happens when the selector looks like "> foo".
+ #
+ # The generator calls select() recursively on every
+ # member of the current context, passing in a different
+ # candidate generator and a different selector.
+ #
+ # In the case of "> foo", the candidate generator is
+ # one that yields a tag's direct children (">"), and
+ # the selector is "foo".
+ next_token = tokens[index+1]
+ def recursive_select(tag):
+ if self._select_debug:
+ print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
+ print '-' * 40
+ for i in tag.select(next_token, recursive_candidate_generator):
+ if self._select_debug:
+ print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
+ yield i
if self._select_debug:
- print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
- yield i
+ print '-' * 40
+ _use_candidate_generator = recursive_select
+ elif _candidate_generator is None:
+ # By default, a tag's candidates are all of its
+ # children. If tag_name is defined, only yield tags
+ # with that name.
if self._select_debug:
- print '-' * 40
- _use_candidate_generator = recursive_select
- elif _candidate_generator is None:
- # By default, a tag's candidates are all of its
- # children. If tag_name is defined, only yield tags
- # with that name.
- if self._select_debug:
- if tag_name:
- check = "[any]"
+ if tag_name:
+ check = "[any]"
+ else:
+ check = tag_name
+ print ' Default candidate generator, tag name="%s"' % check
+ if self._select_debug:
+ # This is redundant with later code, but it stops
+ # a bunch of bogus tags from cluttering up the
+ # debug log.
+ def default_candidate_generator(tag):
+ for child in tag.descendants:
+ if not isinstance(child, Tag):
+ continue
+ if tag_name and not child.name == tag_name:
+ continue
+ yield child
+ _use_candidate_generator = default_candidate_generator
else:
- check = tag_name
- print ' Default candidate generator, tag name="%s"' % check
- if self._select_debug:
- # This is redundant with later code, but it stops
- # a bunch of bogus tags from cluttering up the
- # debug log.
- def default_candidate_generator(tag):
- for child in tag.descendants:
- if not isinstance(child, Tag):
- continue
- if tag_name and not child.name == tag_name:
- continue
- yield child
- _use_candidate_generator = default_candidate_generator
+ _use_candidate_generator = lambda tag: tag.descendants
else:
- _use_candidate_generator = lambda tag: tag.descendants
- else:
- _use_candidate_generator = _candidate_generator
+ _use_candidate_generator = _candidate_generator
+
+ for tag in current_context:
+ if self._select_debug:
+ print " Running candidate generator on %s %s" % (
+ tag.name, repr(tag.attrs))
+ for candidate in _use_candidate_generator(tag):
+ if not isinstance(candidate, Tag):
+ continue
+ if tag_name and candidate.name != tag_name:
+ continue
+ if checker is not None:
+ try:
+ result = checker(candidate)
+ except StopIteration:
+ # The checker has decided we should no longer
+ # run the generator.
+ break
+ if checker is None or result:
+ if self._select_debug:
+ print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
+ if id(candidate) not in new_context_ids:
+ # If a tag matches a selector more than once,
+ # don't include it in the context more than once.
+ new_context.append(candidate)
+ new_context_ids.add(id(candidate))
+ elif self._select_debug:
+ print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
- new_context = []
- new_context_ids = set([])
- for tag in current_context:
- if self._select_debug:
- print " Running candidate generator on %s %s" % (
- tag.name, repr(tag.attrs))
- for candidate in _use_candidate_generator(tag):
- if not isinstance(candidate, Tag):
- continue
- if tag_name and candidate.name != tag_name:
- continue
- if checker is not None:
- try:
- result = checker(candidate)
- except StopIteration:
- # The checker has decided we should no longer
- # run the generator.
- break
- if checker is None or result:
- if self._select_debug:
- print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
- if id(candidate) not in new_context_ids:
- # If a tag matches a selector more than once,
- # don't include it in the context more than once.
- new_context.append(candidate)
- new_context_ids.add(id(candidate))
- elif self._select_debug:
- print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
current_context = new_context
diff --git a/bs4/tests/test_tree.py b/bs4/tests/test_tree.py
index de9543d..8f629d9 100644
--- a/bs4/tests/test_tree.py
+++ b/bs4/tests/test_tree.py
@@ -1554,6 +1554,14 @@ class TestSoupSelector(TreeTest):
<span class="span3"></span>
</span>
</div>
+<x id="xid">
+<z id="zida"/>
+<z id="zidab"/>
+<z id="zidac"/>
+</x>
+<y id="yid">
+<z id="zidb"/>
+</y>
<p lang="en" id="lang-en">English</p>
<p lang="en-gb" id="lang-en-gb">English UK</p>
<p lang="en-us" id="lang-en-us">English US</p>
@@ -1827,3 +1835,41 @@ class TestSoupSelector(TreeTest):
def test_sibling_combinator_wont_select_same_tag_twice(self):
self.assertSelects('p[lang] ~ p', ['lang-en-gb', 'lang-en-us', 'lang-fr'])
+
+ # Test the selector grouping operator (the comma)
+ def test_multiple_select(self):
+ self.assertSelects('x, y',['xid','yid'])
+
+ def test_multiple_select_with_no_space(self):
+ self.assertSelects('x,y',['xid','yid'])
+
+ def test_multiple_select_with_more_space(self):
+ self.assertSelects('x, y',['xid', 'yid'])
+
+ def test_multiple_select_sibling(self):
+ self.assertSelects('x, y ~ p[lang=fr]',['lang-fr'])
+
+ def test_multiple_select(self):
+ self.assertSelects('x, y > z', ['zida', 'zidb', 'zidab', 'zidac'])
+
+ def test_multiple_select_direct_descendant(self):
+ self.assertSelects('div > x, y, z', ['xid', 'yid'])
+
+ def test_multiple_select_indirect_descendant(self):
+ self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
+
+ def test_invalid_multiple_select(self):
+ self.assertRaises(ValueError, self.soup.select, ',x, y')
+ self.assertRaises(ValueError, self.soup.select, 'x,,y')
+
+ def test_multiple_select(self):
+ self.assertSelects('p[lang=en], p[lang=en-gb]',['lang-en','lang-en-gb'])
+
+ def test_multiple_select_ids(self):
+ self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['zida', 'zidb','zidab'])
+
+ def test_multiple_select_nested(self):
+ self.assertSelects('body > div > x, y > z', ['zida', 'zidb', 'zidab', 'zidac'])
+
+
+