#!/usr/bin/env python # pkgsync: Help synchronize packages between systems from __future__ import print_function, with_statement import glob import json import logging import os import re import sqlite3 import sys from contextlib import closing RH_FAMILY = ('centos', 'ol', 'rhel', 'rockylinux') DEB_FAMILY = ('debian', 'ubuntu') DIST_FAMILIES = RH_FAMILY + DEB_FAMILY logger = logging.getLogger() class HistoryFileNotFoundError(Exception): pass class DistroUnsupported(Exception): pass def detect_distro(): """Get the Linux distribution""" distro = None if os.path.exists('/etc/oracle-release'): distro = 'ol' elif os.path.exists('/etc/centos-release'): distro = 'centos' elif os.path.exists('/etc/redhat-release'): distro = 'rhel' elif os.path.exists('/etc/os-release'): with open('/etc/os-release') as fi: lines = fi.readlines() for line in lines: key, value = line.split('=', 1) if key == 'ID': distro = value.strip() break return distro def compare_debian_version(a, b): """Compare Debian package versions""" re_digits_non_digits = re.compile(r'\d+|\D+') re_digits = re.compile(r'\d+') re_digit = re.compile(r'\d') def order(c): if c == '~': return -1 if re_digit.match(c): return int(c) + 1 if re_alpha.match(c): return ord(c) return ord(c) + 256 def compare_parts(p1, p2): if p1 is None: p1 = '' if p2 is None: p2 = '' lhs = re_digits_non_digits.findall(p1) rhs = re_digits_non_digits.findall(p2) while lhs or rhs: left = '0' right = '0' if lhs: left = lhs.pop(0) if rhs: right = rhs.pop(0) if re_digits.match(left) and re_digits.match(right): val_left = int(left) val_right = int(right) if val_left < val_right: return 1 if val_left > val_right: return -1 else: rv = compare_str(left, right) if rv != 0: return rv return 0 def compare_str(s1, s2): lhs = [order(c) for c in s1] rhs = [order(c) for c in s2] while lhs or rhs: left = '0' right = '0' if lhs: left = lhs.pop(0) if rhs: right = rhs.pop(0) if left < right: return 1 if left > right: return -1 return 0 e1 = int(a.get('epoch', 0)) e2 = int(b.get('epoch', 0)) if e1 < e2: return 1 if e1 > e2: return -1 re_alpha = re.compile(r'[A-Za-z]') v1 = a.get('version') v2 = b.get('version') rc = compare_parts(v1, v2) if rc != 0: return rc r1 = a.get('release') r2 = b.get('release') rc = compare_parts(r1, r2) if rc != 0: return rc return 0 def compare_redhat_version(a, b): """Compare RedHat package versions""" def compare_parts(part1, part2): if part1 == part2: return 0 p1 = list(part1) p2 = list(part2) while p1 or p2: for c in list(p1): if not c.isalnum() and not c == '~': p1.pop(0) else: break for c in list(p2): if not c.isalnum() and not c == '~': p2.pop(0) else: break try: c1 = p1[0] except IndexError: c1 = '' try: c2 = p2[0] except IndexError: c2 = '' if c1 == '~' and c2 == '~': p1.pop(0) p2.pop(0) elif c1 == '~': return 1 elif c2 == '~': return -1 if not c1 and not c2: break lhs_is_digit = c1.isdigit() if c1.isdigit(): r1 = [] for cx in list(p1): if cx.isdigit(): r1.append(p1.pop(0)) else: break r2 = [] for c in list(p2): if c.isdigit(): r2.append(p2.pop(0)) else: break if r1: while r1 and r1[0] == '0': r1.pop(0) if r2: while r2 and r2[0] == '0': r2.pop(0) if r1 == r2: continue if not r2: if lhs_is_digit: return -1 else: return 1 if len(r1) != len(r2): if len(r1) > len(r2): return -1 else: return 1 elif r1 == r2: return 0 elif r1 > r2: return -1 else: return 1 else: r1 = [] for cx in list(p1): if cx.isalpha(): r1.append(p1.pop(0)) else: break r2 = [] for cy in list(p2): if cy.isalpha(): r2.append(p2.pop(0)) else: break if not r2: if lhs_is_digit: return -1 else: return 1 if len(r1) != len(r2): if len(r1) > len(r2): return -1 else: return 1 elif r1 == r2: continue elif r1 > r2: return -1 else: return 1 return 0 e1 = int(a.epoch) e2 = int(b.epoch) if e1 < e2: return 1 if e1 > e2: return -1 v1 = a.version v2 = b.version rc = compare_parts(v1, v2) if rc != 0: return rc r1 = a.release r2 = b.release rc = compare_parts(r1, r2) if rc != 0: return rc return 0 def compare_version(v1, v2, distro): if distro in RH_FAMILY: return compare_redhat_version(v1, v2) elif distro in DEB_FAMILY: return compare_debian_version(v1, v2) raise DistroUnsupported(distro) def dict_factory(cursor, row): d = dict() for idx, col in enumerate(cursor.description): d[col[0]] = row[idx] return d def get_yum_history_db(): """Get the latest yum history database""" databases = glob.glob('/var/lib/yum/history/history-*-*-*.sqlite') reversed(sorted(databases)) logger.debug('SQLite databases found: {0}'.format(', '.join(databases))) for db in databases: filename = os.path.basename(db) date = filename[filename.find('-') + 1:filename.rfind('.')] # validate filename parts = date.split('-', 4) if len(parts) != 3: continue try: [int(p) for p in parts] except ValueError: continue logger.info('Using database: {0}'.format(db)) return db logger.info('No valid database found') return None def list_packages_yum(): """List packages installed by yum""" history_db = get_yum_history_db() if history_db is None: return None if sys.version_info.major == 2: db_url = history_db else: db_url = 'file:///{0}?mode=ro'.format(history_db) logger.debug('Connecting to {0}'.format(db_url)) with closing(sqlite3.connect(db_url)) as conn: conn.row_factory = dict_factory with closing(conn.cursor()) as cur: cur.execute("""SELECT name, arch, epoch, version, release, state FROM trans_data_pkgs JOIN pkgtups ON trans_data_pkgs.pkgtupid = pkgtups.pkgtupid JOIN trans_beg ON trans_beg.tid = trans_data_pkgs.tid ORDER BY timestamp ASC""") rows = cur.fetchall() logger.info('Found {0} rows'.format(len(rows))) results = dict() for row in rows: name = row['name'] arch = row['arch'] package = '{0}:{1}'.format(name, arch) state = row['state'] if state in ('Install', 'True-Install', 'Dep-Install', 'Upgrade', 'Update', 'Obsoleting'): logger.debug('+++ {0}: {1}'.format(state, row)) results[package] = dict( name=name, arch=arch, epoch=row['epoch'], release=row['release'], version=row['version'], ) elif state in ('Erase', 'Obsoleted'): logger.debug('--- {0}: {1}'.format(state, row)) if package in results: del results[package] else: logger.debug('... {0}: {1}'.format(state, row)) logger.info('Total packages: {0}'.format(len(results))) return results def list_packages_dnf(): """List packages installed by dnf""" # https://github.com/rpm-software-management/libdnf/blob/9a0e17562b19586b3ffa70fa93eb961b558794c7/libdnf/transaction/Types.hpp # INSTALL = 1, // a new package that was installed on the system # DOWNGRADE = 2, // an older package version that replaced previously installed version # DOWNGRADED = 3, // an original package version that was replaced # OBSOLETE = 4, // # OBSOLETED = 5, // # UPGRADE = 6, // # UPGRADED = 7, // # REMOVE = 8, // a package that was removed from the system # REINSTALL = 9, // a package that was reinstalled with the identical version # REINSTALLED = 10, // a package that was reinstalled with the identical version (old repo, for example) # REASON_CHANGE = 11 // a package was kept on the system but it's reason has changed history_db = '/var/lib/dnf/history.sqlite' if not os.path.isfile(history_db): return None if sys.version_info.major == 2: db_url = history_db else: db_url = 'file:///{0}?mode=ro'.format(history_db) logger.info('Connecting to {0}'.format(db_url)) with closing(sqlite3.connect(db_url)) as conn: conn.row_factory = dict_factory with closing(conn.cursor()) as cur: cur.execute("""SELECT name, arch, epoch, version, release, action FROM trans_item JOIN trans ON trans_item.trans_id = trans.id JOIN rpm ON rpm.item_id = trans_item.item_id ORDER BY dt_end ASC""") rows = cur.fetchall() logger.info('Found {0} rows'.format(len(rows))) results = dict() for row in rows: name = row['name'] arch = row['arch'] action = row['action'] package = '{0}:{1}'.format(name, arch) if action in (1, 6): # INSTALL, UPGRADE results[package] = dict( name=name, arch=arch, epoch=row['epoch'], release=row['release'], version=row['version'] ) elif row['action'] in (5, 8): # OBSOLETED, REMOVE if package in results: del results[package] return results def parse_dpkg_status(): def split_field(l): if ':' in l: field, value = l.split(':', 1) return field, value.strip() return l, None with open('/var/lib/dpkg/status', 'r') as fi: lines = fi.readlines() packages = list() package = dict() last_field = None for line in lines: line = line.rstrip('\n') # starting a new package entry if line == '' and package: packages.append(package) package = dict() continue # continuation of the previous field if line.startswith(' '): package[last_field] += line else: field, value = split_field(line) package[field] = value last_field = field if package: # just in case the file doesn't end with a new line packages.append(package) return packages def list_packages_dpkg(): """Get list of packages installed on a Debian-based system""" def parse_version(ver): e = '0' v = ver if ':' in ver: e = ver[:ver.find(':')] v = ver[ver.find(':') + 1:] if '-' in ver: v = ver[:ver.find('-')] r = ver[ver.find('-') + 1:] else: r = None return e, v, r packages = parse_dpkg_status() results = dict() for package in packages: name = package.get('Package') status = package.get('Status') if status != 'install ok installed': logger.debug('Package not installed, skipping: {0}'.format(name)) continue arch = package.get('Architecture') version_str = package.get('Version') epoch, version, release = parse_version(version_str) pkgarch = '{0}:{1}'.format(name, arch) results[pkgarch] = dict( name=name, arch=arch, epoch=epoch, release=release, version=version ) return results def list_packages(distro): """Get list of packages installed on the system""" packages = list() if distro in RH_FAMILY: packages = list_packages_yum() if packages is None: packages = list_packages_dnf() if packages is None: raise HistoryFileNotFoundError() elif distro in DEB_FAMILY: packages = list_packages_dpkg() else: raise DistroUnsupported(distro) return packages def package_full(d, distro): """Get a string of a package name, arch, version to pass to the package manager""" if distro in RH_FAMILY: return '{name}-{epoch}:{version}-{release}.{arch}'.format(**d) elif distro in DEB_FAMILY: package = ['{name}:{arch}'.format(**d), '='] # using explicit 0 epochs don't seem to work, so ignore it epoch = d.get('epoch') if epoch and epoch != '0': package.extend([epoch, ':']) version = d.get('version') package.append(version) # release is optional release = d.get('release') if release: package.extend(['-', release]) return ''.join(package) raise DistroUnsupported(distro) def package_name(d, distro): """Get the string of the package name and arch to pass to a package manager""" if distro in RH_FAMILY: return '{name}.{arch}'.format(**d) elif distro in DEB_FAMILY: return '{name}:{arch}'.format(**d) raise DistroUnsupported(distro) def is_valid_input(d): if not isinstance(d, dict): return False for k, v in d.items(): if not isinstance(k, str) or not isinstance(v, dict): return False if 'name' not in v and 'version' not in v: return False return True def main(options): logger.debug(options) if options.distro: distro = options.distro else: distro = detect_distro() logger.info('Detected distribution: {0}'.format(distro)) try: current_packages = list_packages(distro) except HistoryFileNotFoundError: logger.error('Could not find package history') sys.exit(1) if options.command == 'export': if options.file == '-': json.dump(current_packages, sys.stdout, indent=2) else: logger.info('Writing to {0}'.format(options.file)) with open(options.file, 'w') as fo: json.dump(current_packages, fo) elif options.command == 'import': if options.name_only: pkg_name_fn = package_name else: pkg_name_fn = package_full if options.file == '-': try: imported_packages = json.load(sys.stdin) except KeyboardInterrupt: sys.exit(1) else: with open(options.file, 'r') as fi: imported_packages = json.load(fi) if not is_valid_input(imported_packages): logger.error('Input file is in an invalid format') sys.exit(1) packages = list() if options.upgraded or options.outdated: upgraded_packages = sorted(set(imported_packages).intersection(current_packages)) for pkg in upgraded_packages: current = current_packages.get(pkg) imported = imported_packages.get(pkg) rc = compare_version(current, imported, distro) if rc == 1 and options.upgraded: logger.debug('Upgrade: {0} (current) - {1} (imported)'.format(current, imported)) packages.append(imported) elif rc == -1 and options.outdated: logger.debug('Outdated: {0} (current) - {1} (imported)'.format(current, imported)) packages.append(current) elif options.new: new_packages = sorted(set(imported_packages).difference(current_packages)) packages = [imported_packages.get(p) for p in new_packages] elif options.removed: removed_packages = sorted(set(current_packages).difference(imported_packages)) packages = [current_packages.get(p) for p in removed_packages] logger.debug('Packages: {0}', packages) try: for pkg in packages: name = pkg_name_fn(pkg, distro) print(name) except IOError: logger.debug('Broken pipe') sys.exit(0) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true') parser.add_argument('--verbose', action='store_true') parser.add_argument('--distro', choices=DIST_FAMILIES, help='Force distribution') subparsers = parser.add_subparsers(dest='command') subparsers.required = True # https://bugs.python.org/issue9253 parser_export = subparsers.add_parser('export') parser_export.add_argument('file', help='export packages to this file') parser_import = subparsers.add_parser('import') parser_import.add_argument('file', help='JSON file containing packages') parser_import.add_argument('--name-only', action='store_true', help='Output package names and architecture (no version)') pkg_group = parser_import.add_mutually_exclusive_group() pkg_group.add_argument('--upgraded', action='store_true', help='Only show upgraded packages (default)') pkg_group.add_argument('--new', dest='new', action='store_true', help='Only show new packages') pkg_group.add_argument('--removed', action='store_true', help='Only show removed packages') pkg_group.add_argument('--outdated', action='store_true', help='Only show outdated packages') pkg_group.add_argument('--same', action='store_true', help='Only show matching packages') args = parser.parse_args() stderr_handler = logging.StreamHandler(sys.stderr) logger.addHandler(stderr_handler) if args.debug: logger.setLevel(logging.DEBUG) elif args.verbose: logger.setLevel(logging.INFO) else: logger.setLevel(logging.ERROR) if args.command == 'import': if not any((args.upgraded, args.new, args.removed, args.outdated, args.same)): args.upgraded = True main(args)