Compare commits

...

2 Commits

Author SHA1 Message Date
Matthew Howle
bfdbe6d78f Raise DistroUnsupported for unsupported distros 2021-09-24 14:38:42 -04:00
Matthew Howle
c84e47111e Add RedHat version compare 2021-09-24 14:38:11 -04:00

166
pkgsync
View File

@@ -25,6 +25,10 @@ class HistoryFileNotFoundError(Exception):
pass pass
class DistroUnsupported(Exception):
pass
def detect_distro(): def detect_distro():
"""Get the Linux distribution""" """Get the Linux distribution"""
distro = None distro = None
@@ -47,8 +51,8 @@ def detect_distro():
return distro return distro
def compare_version(a, b): def compare_debian_version(a, b):
"""Compare package versions""" """Compare Debian package versions"""
re_digits_non_digits = re.compile(r'\d+|\D+') re_digits_non_digits = re.compile(r'\d+|\D+')
re_digits = re.compile(r'\d+') re_digits = re.compile(r'\d+')
re_digit = re.compile(r'\d') re_digit = re.compile(r'\d')
@@ -142,6 +146,155 @@ def compare_version(a, b):
return 0 return 0
def compare_redhat_version(a, b):
"""Compare RedHat package versions"""
def compare_parts(part1, part2):
if part1 == part2:
return 0
p1 = list(part1)
p2 = list(part2)
while p1 or p2:
for c in list(p1):
if not c.isalnum() and not c == '~':
p1.pop(0)
else:
break
for c in list(p2):
if not c.isalnum() and not c == '~':
p2.pop(0)
else:
break
try:
c1 = p1[0]
except IndexError:
c1 = ''
try:
c2 = p2[0]
except IndexError:
c2 = ''
if c1 == '~' and c2 == '~':
p1.pop(0)
p2.pop(0)
elif c1 == '~':
return 1
elif c2 == '~':
return -1
if not c1 and not c2:
break
lhs_is_digit = c1.isdigit()
if c1.isdigit():
r1 = []
for cx in list(p1):
if cx.isdigit():
r1.append(p1.pop(0))
else:
break
r2 = []
for c in list(p2):
if c.isdigit():
r2.append(p2.pop(0))
else:
break
if r1:
while r1 and r1[0] == '0':
r1.pop(0)
if r2:
while r2 and r2[0] == '0':
r2.pop(0)
if r1 == r2:
continue
if not r2:
if lhs_is_digit:
return -1
else:
return 1
if len(r1) != len(r2):
if len(r1) > len(r2):
return -1
else:
return 1
elif r1 == r2:
return 0
elif r1 > r2:
return -1
else:
return 1
else:
r1 = []
for cx in list(p1):
if cx.isalpha():
r1.append(p1.pop(0))
else:
break
r2 = []
for cy in list(p2):
if cy.isalpha():
r2.append(p2.pop(0))
else:
break
if not r2:
if lhs_is_digit:
return -1
else:
return 1
if len(r1) != len(r2):
if len(r1) > len(r2):
return -1
else:
return 1
elif r1 == r2:
continue
elif r1 > r2:
return -1
else:
return 1
return 0
e1 = int(a.epoch)
e2 = int(b.epoch)
if e1 < e2:
return 1
if e1 > e2:
return -1
v1 = a.version
v2 = b.version
rc = compare_parts(v1, v2)
if rc != 0:
return rc
r1 = a.release
r2 = b.release
rc = compare_parts(r1, r2)
if rc != 0:
return rc
return 0
def compare_version(v1, v2, distro):
if distro in RH_FAMILY:
return compare_redhat_version(v1, v2)
elif distro in DEB_FAMILY:
return compare_debian_version(v1, v2)
raise DistroUnsupported(distro)
def dict_factory(cursor, row): def dict_factory(cursor, row):
d = dict() d = dict()
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):
@@ -382,6 +535,8 @@ def list_packages(distro):
raise HistoryFileNotFoundError() raise HistoryFileNotFoundError()
elif distro in DEB_FAMILY: elif distro in DEB_FAMILY:
packages = list_packages_dpkg() packages = list_packages_dpkg()
else:
raise DistroUnsupported(distro)
return packages return packages
@@ -406,7 +561,8 @@ def package_full(d, distro):
if release: if release:
package.extend(['-', release]) package.extend(['-', release])
return ''.join(package) return ''.join(package)
return '{name}'.format(**d) # need a better default raise DistroUnsupported(distro)
def package_name(d, distro): def package_name(d, distro):
@@ -415,7 +571,7 @@ def package_name(d, distro):
return '{name}.{arch}'.format(**d) return '{name}.{arch}'.format(**d)
elif distro in DEB_FAMILY: elif distro in DEB_FAMILY:
return '{name}:{arch}'.format(**d) return '{name}:{arch}'.format(**d)
return '{name}'.format(**d) raise DistroUnsupported(distro)
def is_valid_input(d): def is_valid_input(d):
@@ -480,7 +636,7 @@ def main(options):
current = current_packages.get(pkg) current = current_packages.get(pkg)
imported = imported_packages.get(pkg) imported = imported_packages.get(pkg)
rc = compare_version(current, imported) rc = compare_version(current, imported, distro)
if rc == 1 and options.upgraded: if rc == 1 and options.upgraded:
logger.debug('Upgrade: {0} (current) - {1} (imported)'.format(current, imported)) logger.debug('Upgrade: {0} (current) - {1} (imported)'.format(current, imported))