[utils] Add convenience urljoin

This commit is contained in:
Sergey M․ 2016-12-13 02:23:49 +07:00
parent abf3494ac7
commit e34c33614d
No known key found for this signature in database
GPG Key ID: 2C393E0F18A9236D
2 changed files with 24 additions and 0 deletions

View File

@ -70,6 +70,7 @@ from youtube_dl.utils import (
lowercase_escape, lowercase_escape,
url_basename, url_basename,
base_url, base_url,
urljoin,
urlencode_postdata, urlencode_postdata,
urshift, urshift,
update_url_query, update_url_query,
@ -445,6 +446,19 @@ class TestUtil(unittest.TestCase):
self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/') self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/')
self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/') self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/')
def test_urljoin(self):
self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin(None, 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin(['foobar'], 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', None), None)
self.assertEqual(urljoin('http://foo.de/', ''), None)
self.assertEqual(urljoin('http://foo.de/', ['foobar']), None)
def test_parse_age_limit(self): def test_parse_age_limit(self):
self.assertEqual(parse_age_limit(None), None) self.assertEqual(parse_age_limit(None), None)
self.assertEqual(parse_age_limit(False), None) self.assertEqual(parse_age_limit(False), None)

View File

@ -1700,6 +1700,16 @@ def base_url(url):
return re.match(r'https?://[^?#&]+/', url).group() return re.match(r'https?://[^?#&]+/', url).group()
def urljoin(base, path):
if not isinstance(path, compat_str) or not path:
return None
if re.match(r'https?://', path):
return path
if not isinstance(base, compat_str) or not re.match(r'https?://', base):
return None
return compat_urlparse.urljoin(base, path)
class HEADRequest(compat_urllib_request.Request): class HEADRequest(compat_urllib_request.Request):
def get_method(self): def get_method(self):
return 'HEAD' return 'HEAD'