#
# Copyright 2009 Canonical Ltd.
#
# Written by:
#     Gustavo Niemeyer <gustavo.niemeyer@canonical.com>
#     Sidnei da Silva <sidnei.da.silva@canonical.com>
#
# This file is part of the Image Store Proxy.
#
# This program is free software: you can redistribute it and/or modify it 
# under the terms of the GNU General Public License version 3, as published 
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but 
# WITHOUT ANY WARRANTY; without even the implied warranties of 
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR 
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along 
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from hashlib import sha256
import commands
import time
import os

from urllib import quote
from twisted.internet import reactor

from imagestore.lib.tests.mocker import ARGS, KWARGS
from imagestore.lib.fetch import fetch, PyCurlError, HTTPCodeError
from imagestore.lib.twistedutil import mergeDeferreds

from imagestore.tests.helpers import ServiceTestCase
from imagestore.downloadservice import (
    DownloadService, DownloadServiceError, DownloadFileTask, REMOVE_FILE_DELAY)

from imagestore.tests.helpers import ServiceTestCase


class DownloadServiceTest(ServiceTestCase):

    def setUp(self):
        self.basePath = self.makeDir()
        self.service = DownloadService(reactor, self.basePath)

    def testDownloadFile(self):
        image1 = self.createImage(1, withFiles=True)

        tasks = {}
        deferreds = {}
        for imageFile in image1["files"]:
            kind = imageFile["kind"]
            tasks[kind] = task = DownloadFileTask(
                imageFile["url"], imageFile["size-in-bytes"],
                imageFile["sha256"])
            deferreds[kind] = self.service.addTask(task)

        def callbackKernel(result):
            expectedPath = os.path.join(self.basePath, tasks["kernel"].sha256)
            self.assertEquals(expectedPath, result)
            self.assertEquals(os.path.getsize(expectedPath),
                              tasks["kernel"].size)

        def callbackRAMDisk(result):
            expectedPath = os.path.join(self.basePath, tasks["ramdisk"].sha256)
            self.assertEquals(expectedPath, result)
            self.assertEquals(os.path.getsize(expectedPath),
                              tasks["ramdisk"].size)

        def callbackImage(result):
            expectedPath = os.path.join(self.basePath, tasks["image"].sha256)
            self.assertEquals(expectedPath, result)
            self.assertEquals(os.path.getsize(expectedPath),
                              tasks["image"].size)

        deferreds["kernel"].addCallback(callbackKernel)
        deferreds["ramdisk"].addCallback(callbackRAMDisk)
        deferreds["image"].addCallback(callbackImage)
        return self.runServicesAndWaitForDeferred(
            [self.service], mergeDeferreds(deferreds.values()))

    def testDownloadFileFailsChecksum(self):
        image1 = self.createImage(1, withFiles=True)

        tasks = {}
        deferreds = {}
        byKind = {}
        for imageFile in image1["files"]:
            kind = imageFile["kind"]
            byKind[kind] = imageFile

        # In order to force a failed checksum verification, we will
        # force downloading a different url, which will have a
        # different checksum.
        kernelFile = byKind["kernel"]
        imageFile = byKind["image"]
        tasks["kernel"] = task = DownloadFileTask(
            imageFile["url"], imageFile["size-in-bytes"],
            kernelFile["sha256"])
        deferreds["kernel"] = self.service.addTask(task)

        def callbackKernel(result):
            self.fail("Should have raised an exception")

        def errbackKernel(failure):
            failure.trap(DownloadServiceError)
            self.assertStartsWith(failure.value.message,
                                  "Checksum mismatch on downloaded file "
                                  "(expected b90aa07301e80d19137e9f69c5a")

            # The bad file should be removed after checking.
            self.assertEquals(os.listdir(self.basePath), [])

        deferreds["kernel"].addCallback(callbackKernel)
        deferreds["kernel"].addErrback(errbackKernel)
        return self.runServicesAndWaitForDeferred(
            [self.service], mergeDeferreds(deferreds.values()))

    def testDownloadServicePreservesRecentFilesWhenStarting(self):
        path = self.makeFile("content", dirname=self.basePath)
        mtime = time.time() - REMOVE_FILE_DELAY + 10
        os.utime(path, (mtime, mtime))

        self.service.start()

        def stopped(result):
            self.assertTrue(os.path.isfile(path))

        deferred = self.service.stop()
        deferred.addCallback(stopped)
        return deferred

    def testDownloadServiceRemovesOldFilesWhenStarting(self):
        path = self.makeFile("content", dirname=self.basePath)
        mtime = time.time() - REMOVE_FILE_DELAY - 1
        os.utime(path, (mtime, mtime))

        self.service.start()

        def stopped(result):
            self.assertFalse(os.path.isfile(path))

        deferred = self.service.stop()
        deferred.addCallback(stopped)
        return deferred

    def testDownloadFileUncompressesGZFiles(self):
        path = self.makeFile("content", dirname=self.basePath)
        status, output = commands.getstatusoutput("gzip " + path)
        self.assertEquals(status, 0, "gzip failed:\n" + output)

        path += ".gz"

        file = open(path)
        try:
            hash = sha256(file.read()).hexdigest()
        finally: 
            file.close()

        task = DownloadFileTask("file://" + path, os.path.getsize(path), hash)
        deferred = self.service.addTask(task)

        def callback(result):
            expectedPath = os.path.join(self.basePath, hash)
            self.assertEquals(expectedPath, result)

            file = open(expectedPath)
            try:
                self.assertEquals(file.read(), "content")
            finally:
                file.close()

        deferred.addCallback(callback)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

    def testDownloadFileUncompressesGZFilesWithFailure(self):
        path = self.makeFile("content", dirname=self.basePath, suffix=".gz")

        file = open(path)
        try:
            hash = sha256(file.read()).hexdigest()
        finally: 
            file.close()

        task = DownloadFileTask("file://" + path, os.path.getsize(path), hash)
        deferred = self.service.addTask(task)

        def errback(failure):
            self.assertStartsWith(failure.getErrorMessage(),
                                  "Uncompression of file failed:")
            return failure

        deferred.addErrback(errback)
        self.assertFailure(deferred, DownloadServiceError)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

    def testDownloadFileUncompressesTarGZFiles(self):
        # Note that path is an absolute path, so the command below will
        # actually create a tarball which contains an entry under a
        # subdirectory.  The system should be able to handle these cases
        # too.
        path = self.makeFile("content", dirname=self.basePath)
        status, output = commands.getstatusoutput("tar czvf %s.tar.gz %s" %
                                                  (path, path))
        self.assertEquals(status, 0, "tar failed:\n" + output)

        path += ".tar.gz"

        file = open(path)
        try:
            hash = sha256(file.read()).hexdigest()
        finally: 
            file.close()

        # We'll try to screw up the algorithm by pre-creating the extract
        # directory with dummy data.  It should be removed before
        # uncompression takes place.
        extractPath = os.path.join(self.basePath, hash) + ".tar.gz.extract"
        os.mkdir(extractPath)
        open(os.path.join(extractPath, "REMOVE-ME"), "w").close()

        task = DownloadFileTask("file://" + path, os.path.getsize(path), hash)
        deferred = self.service.addTask(task)

        def callback(result):
            expectedPath = os.path.join(self.basePath, hash)
            self.assertEquals(expectedPath, result)

            file = open(expectedPath)
            try:
                self.assertEquals(file.read(), "content")
            finally:
                file.close()

            self.assertFalse(os.path.exists(expectedPath + ".tar.gz"))

        deferred.addCallback(callback)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

    def testDownloadFileUncompressesTarGZFilesWithFailure(self):
        path = self.makeFile("content", dirname=self.basePath, suffix=".tar.gz")

        file = open(path)
        try:
            hash = sha256(file.read()).hexdigest()
        finally: 
            file.close()

        task = DownloadFileTask("file://" + path, os.path.getsize(path), hash)
        deferred = self.service.addTask(task)

        def errback(failure):
            self.assertStartsWith(failure.getErrorMessage(),
                                  "Uncompression of file failed:")
            return failure

        deferred.addErrback(errback)
        self.assertFailure(deferred, DownloadServiceError)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

    def testDownloadFileUncompressesTarGZFilesWithMultipleFiles(self):
        path1 = self.makeFile("content1", dirname=self.basePath)
        path2 = self.makeFile("content2", dirname=self.basePath)

        tarPath = self.makeFile(suffix=".tar.gz")

        status, output = commands.getstatusoutput("tar czvf %s %s %s" %
                                                  (tarPath, path1, path2))
        self.assertEquals(status, 0, "tar failed:\n" + output)

        file = open(tarPath)
        try:
            hash = sha256(file.read()).hexdigest()
        finally: 
            file.close()

        task = DownloadFileTask("file://" + tarPath,
                                os.path.getsize(tarPath), hash)
        deferred = self.service.addTask(task)

        def errback(failure):
            self.assertStartsWith(failure.getErrorMessage(),
                                  "Uncompression of file failed: "
                                  "tar.gz contains more than one file")
            return failure

        deferred.addErrback(errback)
        self.assertFailure(deferred, DownloadServiceError)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

    def testDownloadFileVerifiesChecksumForValidity(self):
        path = self.makeFile("content", dirname=self.basePath, suffix=".gz")

        task = DownloadFileTask("file://" + path, os.path.getsize(path),
                                "bad hash")
        deferred = self.service.addTask(task)

        def errback(failure):
            self.assertStartsWith(failure.getErrorMessage(),
                                  "Invalid checksum: bad hash")
            return failure

        deferred.addErrback(errback)
        self.assertFailure(deferred, DownloadServiceError)

        return self.runServicesAndWaitForDeferred([self.service], deferred)

