#!/usr/bin/python3

# Tool creating a maven repository out of rpms built by OBS
# Copyright (C) 2019  SUSE Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License

import argparse
from datetime import datetime
import errno
import logging
import os
import os.path
import re
import rpm
import shutil
import sys
import subprocess
import tempfile
import xml.etree.ElementTree as ET
import yaml

import osc.core
import osc.conf

def make_str(s):
    if sys.version_info[0] < 3 and isinstance(s, unicode):
        s = s.encode('utf-8')
    elif sys.version_info[0] >= 3 and isinstance(s, bytes):
        s = s.decode('utf-8')
    return s

class Artifact:
    def __init__(self, config, repositories, group):
        self.artifact = config['artifact']
        self.group = config.get('group', group)
        self.package = config.get('package', self.artifact)
        self.arch = config.get('arch', 'noarch')
        repository = repositories.get(config['repository'], {})
        self.project = repository['project']
        self.repository = repository['repository']
        self.rpm = config.get('rpm')
        self.jar = config.get('jar')

    def get_binary(self, api):
        binaries = osc.core.get_binarylist(api, self.project, self.repository, 'x86_64', self.package, True)
        excluded = ['javadoc', 'examples', 'manual', 'test', 'demo']
        filtered = [file for file in binaries
                    if file.name.endswith('.%s.rpm' % self.arch) and
                    not bool([pattern for pattern in excluded if pattern in file.name]) and
                    re.match(self.rpm if self.rpm is not None else '', file.name)]

        if len(filtered) > 1:
            logging.warning('Found more than one file for %s:\n  %s' %
                    (self.artifact, "\n  ".join([file.name for file in filtered])))

        if len(filtered) == 0:
            logging.error('Found no file for %s among:\n  %s' % (self.artifact, '\n  '.join([f.name for f in binaries])))
            return None
        return filtered[0]

    def fetch_binary(self, api, file, tmp):
        target_file = os.path.join(tmp, file.name)
        logging.info('Downloading %s' % target_file)
        osc.core.get_binary_file(api, self.project, self.repository, 'x86_64', file.name, self.package,
                                 target_file, file.mtime)
        if os.path.isfile(target_file):
            return target_file
        return None

    def extract(self, rpm_file, tmp):
        ts = rpm.TransactionSet()
        ts.setVSFlags(rpm._RPMVSF_NOSIGNATURES)
        rpm_fd = open(rpm_file, 'rb')
        headers = ts.hdrFromFdno(rpm_fd)
        rpm_fd.close()

        version = make_str(headers[rpm.RPMTAG_VERSION])
        not_linked = [make_str(headers[rpm.RPMTAG_FILENAMES][idx])
                      for idx in range(len(headers[rpm.RPMTAG_FILENAMES]))
                      if not headers[rpm.RPMTAG_FILELINKTOS][idx]]

        logging.debug('not linked:\n  %s' % '\n  '.join(not_linked))

        pattern = self.jar if self.jar is not None else ''
        end_pattern = '[^/]*\.jar' if self.jar is None or not self.jar.endswith('.jar') else ''
        full_pattern = '^/usr/share/.*/%s%s$' % (pattern, end_pattern)

        logging.debug('full pattern: %s' % full_pattern)

        to_extract = [f for f in not_linked if re.match(full_pattern, f)]

        if len(to_extract) == 0:
            logging.error('Found no jar to extract in %s' % rpm_file)
        elif len(to_extract) > 1:
            logging.warning('Found several jars to extract in %s:\n  %s' % (rpm_file, '\n  '.join(to_extract)))

        jar_entry = to_extract[0]

        # try harder to guess the version number from the jar file since it may be different from the rpm
        matcher = re.search('%s-([0-9.]+).jar' % self.artifact, os.path.basename(jar_entry))
        if matcher:
            version = matcher.group(1)

        dst_path = os.path.join(tmp, os.path.basename(jar_entry))
        logging.info('extracting %s to %s' % (jar_entry, dst_path))

        rpm2cpio = subprocess.Popen(('rpm2cpio', rpm_file), stdout=subprocess.PIPE)
        dst = open(dst_path, 'wb')
        cpio = subprocess.Popen(('cpio', '-i', '--to-stdout', '.%s' % jar_entry),
                                stdin=rpm2cpio.stdout, stdout=dst, stderr=subprocess.PIPE)
        cpio.wait()
        dst.close()

        return (dst_path, version)


    def deploy(self, jar, version, repo, mtime):
        artifact_folder = os.path.join(repo, self.group, self.artifact)
        try:
            os.makedirs(os.path.join(artifact_folder, version))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise
        jar_path = os.path.join(artifact_folder, version, '%s-%s.jar' % (self.artifact, version))
        logging.info('deploying %s to %s' % (jar, jar_path))
        shutil.copyfile(jar, jar_path)
        logging.debug('Setting mtime %d on %s' %(mtime, jar_path))
        os.utime(jar_path, (mtime, mtime))

        pom = """<?xml version="1.0" encoding="UTF-8"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
         http://maven.apache.org/xsd/maven-4.0.0.xsd"
         xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
  <modelVersion>4.0.0</modelVersion>
  <groupId>%s</groupId>
  <artifactId>%s</artifactId>
  <version>%s</version>
  <description>POM was created from obs-to-maven</description>
</project>""" % (self.group, self.artifact, version)
        with open(os.path.join(artifact_folder, version, '%s-%s.pom' % (self.artifact, version)), 'w') as fd:
            fd.write(pom)

        # Maintain metadata file repo/group/artifact/maven-metadata-local.xml
        metadata_path = os.path.join(artifact_folder, "maven-metadata-local.xml")
        update_time = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
        write_anew = True
        if os.path.isfile(metadata_path):
            doc = ET.parse(metadata_path)
            versions_node = doc.find('versions')
            lastupdated_node = doc.find('lastUpdated')
            if versions_node is None or lastupdated_node is None:
                logging.warning('Invalid XML file: creating a new one: %s' % metadata_path)
            else:
                version_node = ET.SubElement(versions_node, 'version')
                version_node.text = version
                lastupdated_node.text = update_time
                doc.write(metadata_path)
                write_anew = False

        # Something wrong happened or we don't have the file: create a fresh one
        if write_anew:
            xml = """<?xml version="1.0" encoding="UTF-8"?>
<metadata>
  <groupId>%s</groupId>
  <artifactId>%s</artifactId>
  <versioning>
    <release>%s</release>
    <versions>
      <version>%s</version>
    </versions>
    <lastUpdated>%s</lastUpdated>
  </versioning>
</metadata>
""" % (self.group, self.artifact, version, version, update_time)
            with open(metadata_path, 'w') as fd:
                fd.write(xml)


    def process(self, api, repo, tmp):
        logging.info('Processing artifact %s' % self.artifact)
        file = self.get_binary(api)

        # Check if one of the artifact's jar has the same mtime. If so no need to update
        mtimes = [y for x in [
                    [os.stat(os.path.join(root,f)).st_mtime for f in files if f.endswith('.jar')]
                        for root, dirs, files in os.walk(repo)
                        if '%s%s%s' % (os.path.sep, self.artifact, os.path.sep) in root]
                    for y in x]

        if not [mtime for mtime in mtimes if file.mtime == int(mtime)]:
            logging.debug('package mtime: %d, [%s]' % (file.mtime, ', '.join(['%f' % t for t in mtimes])))
            rpm_file = self.fetch_binary(api, file, tmp)
            if rpm_file is None:
                logging.error('Failed to download %s' % file)
                return

            # Find out the version
            m = re.match('.*-([^-]+)-[^-]+.%s.rpm' % self.arch, rpm_file)
            if m is None:
                logging.error('Failed to get version of %s' % rpm_file)
            version = m.group(1)

            # Extract the jar and pom
            (jar, version) = self.extract(rpm_file, tmp)

            # Install in the repository
            self.deploy(jar, version, repo, file.mtime)
        else:
            logging.info('Skipping artifact %s' % self.artifact)


class Configuration:
    def __init__(self, config_path, repo):
        data = {}
        if os.path.isfile(config_path):
            f = open(config_path, 'r')
            data = yaml.safe_load(f)
            f.close()
        self.api = data.get('api', osc.conf.DEFAULTS['apiurl'])
        self.repo = repo
        repositories = data.get('repositories', {})
        self.artifacts = [Artifact(artifact, repositories, data.get('group', 'suse')) for artifact in data.get('artifacts', [])]

        # Load OSC configuration, setup HTTP authentication
        osc.conf.get_config(override_apiurl=self.api)

def main():
    parser = argparse.ArgumentParser(
        description="OBS to Maven repository synchronization tool",
        conflict_handler='resolve',
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('config', help="Path to the YAML configuration file")
    parser.add_argument('out', help="Path to the output maven repository")

    parser.add_argument(
        "-d", "--debug",
        help="Show debug messages",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
        default=logging.INFO)

    args = parser.parse_args()

    logging.basicConfig(level=args.loglevel)
    logging.debug('Reading configuration')
    config = Configuration(args.config, args.out)
    tmp = tempfile.mkdtemp(prefix="obsmvn-")
    for artifact in config.artifacts:
        artifact.process(config.api, config.repo, tmp)
    shutil.rmtree(tmp)

if __name__ == '__main__':
    main()
