summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NEWS.txt36
-rw-r--r--bs4/element.py320
-rw-r--r--bs4/tests/test_tree.py20
3 files changed, 243 insertions, 133 deletions
diff --git a/NEWS.txt b/NEWS.txt
index 7abc700..c65058c 100644
--- a/NEWS.txt
+++ b/NEWS.txt
@@ -1,12 +1,16 @@
= 4.2.0 (Unreleased) =
-* In an HTML document, the contents of a <script> or <style> tag will
- no longer undergo entity substitution by default. XML documents work
- the same way they did before. [bug=1085953]
+* The Tag.select() method now supports a much wider variety of CSS
+ selectors.
-* Methods like get_text() and properties like .strings now only give
- you strings that are visible in the document--no comments or
- processing commands. [bug=1050164]
+ - Added support for the adjacent sibling combinator (+) and the
+ general sibling combinator (~). Tests by "liquider". [bug=1082144]
+
+ - The combinators (>, +, and ~) can now combine with any supported
+ selector, not just one that selects based on tag name.
+
+ - Added limited support for the "nth-of-type" pseudo-class. Code
+ by Sven Slootweg. [bug=1109952]
* The BeautifulSoup class is now aliased to "_s" and "_soup", making
it quicker to type the import statement in an interactive session:
@@ -21,26 +25,28 @@
* Added the 'diagnose' submodule, which includes several useful
functions for reporting problems and doing tech support.
- * diagnose(data) tries the given markup on every installed parser,
+ - diagnose(data) tries the given markup on every installed parser,
reporting exceptions and displaying successes. If a parser is not
installed, diagnose() mentions this fact.
- * lxml_trace(data, html=True) runs the given markup through lxml's
+ - lxml_trace(data, html=True) runs the given markup through lxml's
XML parser or HTML parser, and prints out the parser events as
they happen. This helps you quickly determine whether a given
problem occurs in lxml code or Beautiful Soup code.
- * htmlparser_trace(data) is the same thing, but for Python's
+ - htmlparser_trace(data) is the same thing, but for Python's
built-in HTMLParser class.
-* The prettify() method now leaves the contents of <pre> tags
- alone. [bug=1095654]
+* In an HTML document, the contents of a <script> or <style> tag will
+ no longer undergo entity substitution by default. XML documents work
+ the same way they did before. [bug=1085953]
-* Added support for the "nth-of-type" CSS selector. Code by Sven
- Slootweg. [bug=1109952]
+* Methods like get_text() and properties like .strings now only give
+ you strings that are visible in the document--no comments or
+ processing commands. [bug=1050164]
-* The CSS selector ">" can now find a tag by means other than the
- tag name. Code by Sven Slootweg. [bug=1109952]
+* The prettify() method now leaves the contents of <pre> tags
+ alone. [bug=1095654]
* Fix a bug in the html5lib treebuilder which sometimes created
disconnected trees. [bug=1039527]
diff --git a/bs4/element.py b/bs4/element.py
index f4d5c40..21e040a 100644
--- a/bs4/element.py
+++ b/bs4/element.py
@@ -567,6 +567,14 @@ class PageElement(object):
value =" ".join(value)
return value
+ def _tag_name_matches_and(self, function, tag_name):
+ if not tag_name:
+ return function
+ else:
+ def _match(tag):
+ return tag.name == tag_name and function(tag)
+ return _match
+
def _attribute_checker(self, operator, attribute, value=''):
"""Create a function that performs a CSS selector operation.
@@ -608,122 +616,6 @@ class PageElement(object):
else:
return lambda el: el.has_attr(attribute)
- def select(self, selector, recursive=True):
- """Perform a CSS selection operation on the current element."""
- tokens = selector.split()
- current_context = [self]
- for index, token in enumerate(tokens):
- if tokens[index - 1] == '>':
- # already found direct descendants in last step. skip this
- # step.
- continue
- m = self.attribselect_re.match(token)
- if m is not None:
- # Attribute selector
- tag, attribute, operator, value = m.groups()
- if not tag:
- tag = True
- checker = self._attribute_checker(operator, attribute, value)
- found = []
- for context in current_context:
- found.extend(
- [el for el in
- context.find_all(tag, recursive=recursive)
- if checker(el)])
- current_context = found
- continue
-
- if '#' in token:
- # ID selector
- tag, id = token.split('#', 1)
- if tag == "":
- tag = True
- if len(current_context) == 0:
- # No match.
- return []
- el = current_context[0].find(tag, {'id': id})
- if el is None:
- return [] # No match
- current_context = [el]
- continue
-
- if '.' in token:
- # Class selector
- tag_name, klass = token.split('.', 1)
- if not tag_name:
- tag_name = True
- classes = set(klass.split('.'))
- found = []
- def classes_match(tag):
- if tag_name is not True and tag.name != tag_name:
- return False
- if not tag.has_attr('class'):
- return False
- return classes.issubset(tag['class'])
- for context in current_context:
- found.extend(context.find_all(classes_match, recursive=recursive))
- current_context = found
- continue
-
- if ':' in token:
- # Pseudoselector
- tag_name, pseudo = token.split(':', 1)
- if not tag_name:
- raise ValueError(
- "A pseudoselector must be prefixed with a tag name.")
- pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
- found = []
- if pseudo_attributes is not None:
- pseudo_type, pseudo_value = pseudo_attributes.groups()
- if pseudo_type == 'nth-of-type':
- try:
- pseudo_value = int(pseudo_value)
- except:
- raise NotImplementedError(
- 'Only numeric values are supported for the nth-of-type pseudoselector for now.')
- if pseudo_value < 1:
- raise ValueError(
- 'nth-of-type pseudoselector value must be at least 1.')
- pseudo_value = pseudo_value - 1
- for context in current_context:
- all_nodes = context.find_all(tag_name, recursive=recursive)
- if pseudo_value < len(all_nodes):
- found.extend([all_nodes[pseudo_value]])
- current_context = found
- continue
- else:
- raise NotImplementedError(
- 'Only the nth-of-type pseudoselector is supported for now.')
-
- if token == '*':
- # Star selector
- found = []
- for context in current_context:
- found.extend(context.find_all(True, recursive=recursive))
- current_context = found
- continue
-
- if token == '>':
- # Child selector
- tag = tokens[index + 1]
- if not tag:
- tag = True
-
- found = []
- for context in current_context:
- found.extend(context.select(tag, recursive=False))
- current_context = found
- continue
-
- # Here we should just have a regular tag
- if not self.tag_name_re.match(token):
- return []
- found = []
- for context in current_context:
- found.extend(context.find_all(token, recursive=recursive))
- current_context = found
- return current_context
-
# Old non-property versions of the generators, for backwards
# compatibility with BS3.
def nextGenerator(self):
@@ -1292,6 +1184,202 @@ class Tag(PageElement):
yield current
current = current.next_element
+ # CSS selector code
+
+ _selector_combinators = ['>', '+', '~']
+ _select_debug = False
+ def select(self, selector, _candidate_generator=None):
+ """Perform a CSS selection operation on the current element."""
+ tokens = selector.split()
+ current_context = [self]
+
+ if tokens[-1] in self._selector_combinators:
+ raise ValueError(
+ 'Final combinator "%s" is missing an argument.' % tokens[-1])
+ if self._select_debug:
+ print 'Running CSS selector "%s"' % selector
+ for index, token in enumerate(tokens):
+ if self._select_debug:
+ print ' Considering token "%s"' % token
+ recursive_candidate_generator = None
+ tag_name = None
+ if tokens[index-1] in self._selector_combinators:
+ # This token was consumed by the previous combinator. Skip it.
+ if self._select_debug:
+ print ' Token was consumed by the previous combinator.'
+ continue
+ # Each operation corresponds to a checker function, a rule
+ # for determining whether a candidate matches the
+ # selector. Candidates are generated by the active
+ # iterator.
+ checker = None
+
+ m = self.attribselect_re.match(token)
+ if m is not None:
+ # Attribute selector
+ tag_name, attribute, operator, value = m.groups()
+ checker = self._attribute_checker(operator, attribute, value)
+
+ elif '#' in token:
+ # ID selector
+ tag_name, id = token.split('#', 1)
+ def id_matches(tag):
+ return tag.get('id', None) == id
+ checker = id_matches
+
+ elif '.' in token:
+ # Class selector
+ tag_name, klass = token.split('.', 1)
+ classes = set(klass.split('.'))
+ def classes_match(candidate):
+ return classes.issubset(candidate.get('class', []))
+ checker = classes_match
+
+ elif ':' in token:
+ # Pseudo-class
+ tag_name, pseudo = token.split(':', 1)
+ if tag_name == '':
+ raise ValueError(
+ "A pseudo-class must be prefixed with a tag name.")
+ pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
+ found = []
+ if pseudo_attributes is not None:
+ pseudo_type, pseudo_value = pseudo_attributes.groups()
+ if pseudo_type == 'nth-of-type':
+ try:
+ pseudo_value = int(pseudo_value)
+ except:
+ raise NotImplementedError(
+ 'Only numeric values are currently supported for the nth-of-type pseudo-class.')
+ if pseudo_value < 1:
+ raise ValueError(
+ 'nth-of-type pseudo-class value must be at least 1.')
+ class Counter(object):
+ def __init__(self, destination):
+ self.count = 0
+ self.destination = destination
+
+ def nth_child_of_type(self, tag):
+ self.count += 1
+ if self.count == self.destination:
+ return True
+ if self.count > self.destination:
+ # Stop the generator that's sending us
+ # these things.
+ raise StopIteration()
+ return False
+ checker = Counter(pseudo_value).nth_child_of_type
+ else:
+ raise NotImplementedError(
+ 'Only the following pseudo-classes are implemented: nth-of-type.')
+
+ elif token == '*':
+ # Star selector -- matches everything
+ pass
+ elif token == '>':
+ # Run the next token as a CSS selector against the
+ # direct children of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.children
+ elif token == '~':
+ # Run the next token as a CSS selector against the
+ # siblings of each tag in the current context.
+ recursive_candidate_generator = lambda tag: tag.next_siblings
+ elif token == '+':
+ # For each tag in the current context, run the next
+ # token as a CSS selector against the tag's next
+ # sibling that's a tag.
+ def next_tag_sibling(tag):
+ yield tag.find_next_sibling(True)
+ recursive_candidate_generator = next_tag_sibling
+
+ elif self.tag_name_re.match(token):
+ # Just a tag name.
+ tag_name = token
+ else:
+ raise ValueError(
+ 'Unsupported or invalid CSS selector: "%s"' % token)
+
+ if recursive_candidate_generator:
+ # This happens when the selector looks like "> foo".
+ #
+ # The generator calls select() recursively on every
+ # member of the current context, passing in a different
+ # candidate generator and a different selector.
+ #
+ # In the case of "> foo", the candidate generator is
+ # one that yields a tag's direct children (">"), and
+ # the selector is "foo".
+ next_token = tokens[index+1]
+ def recursive_select(tag):
+ if self._select_debug:
+ print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
+ print '-' * 40
+ for i in tag.select(next_token, recursive_candidate_generator):
+ if self._select_debug:
+ print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
+ yield i
+ if self._select_debug:
+ print '-' * 40
+ _use_candidate_generator = recursive_select
+ elif _candidate_generator is None:
+ # By default, a tag's candidates are all of its
+ # children. If tag_name is defined, only yield tags
+ # with that name.
+ if self._select_debug:
+ if tag_name:
+ check = "[any]"
+ else:
+ check = tag_name
+ print ' Default candidate generator, tag name="%s"' % check
+ if self._select_debug:
+ # This is redundant with later code, but it stops
+ # a bunch of bogus tags from cluttering up the
+ # debug log.
+ def default_candidate_generator(tag):
+ for child in tag.descendants:
+ if not isinstance(child, Tag):
+ continue
+ if tag_name and not child.name == tag_name:
+ continue
+ yield child
+ _use_candidate_generator = default_candidate_generator
+ else:
+ _use_candidate_generator = lambda tag: tag.descendants
+ else:
+ _use_candidate_generator = _candidate_generator
+
+ new_context = []
+ for tag in current_context:
+ if self._select_debug:
+ print " Running candidate generator on %s %s" % (
+ tag.name, repr(tag.attrs))
+ for candidate in _use_candidate_generator(tag):
+ if not isinstance(candidate, Tag):
+ continue
+ if tag_name and candidate.name != tag_name:
+ continue
+ if checker is not None:
+ try:
+ result = checker(candidate)
+ except StopIteration:
+ # The checker has decided we should no longer
+ # run the generator.
+ break
+ if checker is None or result:
+ if self._select_debug:
+ print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
+ new_context.append(candidate)
+ elif self._select_debug:
+ print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
+
+ current_context = new_context
+
+ if self._select_debug:
+ print "Final verdict:"
+ for i in current_context:
+ print " %s %s" % (i.name, i.attrs)
+ return current_context
+
# Old names for backwards compatibility
def childGenerator(self):
return self.children
diff --git a/bs4/tests/test_tree.py b/bs4/tests/test_tree.py
index ac60aa1..77d4199 100644
--- a/bs4/tests/test_tree.py
+++ b/bs4/tests/test_tree.py
@@ -1585,7 +1585,7 @@ class TestSoupSelector(TreeTest):
self.assertEqual(len(self.soup.select('del')), 0)
def test_invalid_tag(self):
- self.assertEqual(len(self.soup.select('tag%t')), 0)
+ self.assertRaises(ValueError, self.soup.select, 'tag%t')
def test_header_tags(self):
self.assertSelectMultiple(
@@ -1637,7 +1637,7 @@ class TestSoupSelector(TreeTest):
def test_child_selector(self):
self.assertSelects('.s1 > a', ['s1a1', 's1a2'])
self.assertSelects('.s1 > a span', ['s1a2s1'])
-
+
def test_child_selector_id(self):
self.assertSelects('.s1 > a#s1a2 span', ['s1a2s1'])
@@ -1786,3 +1786,19 @@ class TestSoupSelector(TreeTest):
def test_overspecified_child_id(self):
self.assertSelects(".fancy #inner", ['inner'])
self.assertSelects(".normal #inner", [])
+
+ def test_adjacent_sibling_selector(self):
+ self.assertSelects('#p1 + h2', ['header2'])
+ self.assertSelects('#p1 + h2 + p', ['pmulti'])
+ self.assertSelects('#p1 + #header2 + .class1', ['pmulti'])
+ self.assertEqual([], self.soup.select('#p1 + p'))
+
+ def test_general_sibling_selector(self):
+ self.assertSelects('#p1 ~ h2', ['header2', 'header3'])
+ self.assertSelects('#p1 ~ #header2', ['header2'])
+ self.assertSelects('#p1 ~ h2 + a', ['me'])
+ self.assertSelects('#p1 ~ h2 + [rel="me"]', ['me'])
+ self.assertEqual([], self.soup.select('#inner ~ h2'))
+
+ def test_dangling_combinator(self):
+ self.assertRaises(ValueError, self.soup.select, 'h1 >')