#
# This file is part of OpenClone.
#
# Copyright (C) 2009  David Gnedt
#
# OpenClone is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenClone is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with OpenClone.  If not, see <http://www.gnu.org/licenses/>.
#

import base64
import logging
import os
import socket
from ZSI.ServiceContainer import GetSOAPContext

from OpenCloneEngineService_services import *
from OpenCloneEngineService_services_server import OpenCloneEngineService
from OpenCloneEngineService_services_types import ns0
from ..config import config
from ..database.database import Database
from ...common.transfer import udpcast
from ...common import nettools

register_status = registerResponse().new_status
logon_status = logonResponse().new_status
logoff_status = logoffResponse().new_status
nextOperation_status = nextOperationResponse().new_status
nextOperation_idleoperation = ns0.IdleOperation_Def('operation').pyclass
nextOperation_imageoperation = ns0.ImageOperation_Def('operation').pyclass
nextOperation_shutdownoperation = ns0.ShutdownOperation_Def('operation').pyclass
nextOperation_partitionoperation = ns0.PartitionOperation_Def('operation').pyclass
statusUpdate_status = statusUpdateResponse().new_status
host_cpu_device = ns0.CPUDevice_Def('devices').pyclass
host_ram_device = ns0.RAMDevice_Def('devices').pyclass
host_harddisk_device = ns0.HardDiskDevice_Def('devices').pyclass
partitionoperation_mbr_partitiontable = ns0.MBRPartitiontable_Def('partitiontable').pyclass
partitiontable_mbr_partition = ns0.MBRPartition_Def('parts').pyclass

logger = logging.getLogger('engineservice')

class OpenCloneEngineServiceImpl(OpenCloneEngineService):
    def __init__(self):
        OpenCloneEngineService.__init__(self, impl=self)
    
    def __update_devices(self, db, db_host, host):
        old_devs = []
        old_devs.extend(db_host.devices)
        del db_host.devices[:]
        for dev in host.Devices:
            if isinstance(dev.typecode, ns0.CPUDevice_Def):
                cpu = None
                for d in old_devs:
                    if isinstance(d, db.newCPUDevice):
                        old_devs.remove(d)
                        cpu = d
                        break
                
                if cpu is None:
                    cpu = db.newCPUDevice()
                
                cpu.processor_no = dev.get_attribute_processor_no()
                cpu.vendor_id = dev.get_attribute_vendor_id()
                cpu.model_name = dev.get_attribute_model_name()
                cpu.mhz = dev.get_attribute_mhz()
                cpu.bogomips = dev.get_attribute_bogomips()
                cpu.cache = dev.get_attribute_cache()
                cpu.core_id = dev.get_attribute_core_id()
                cpu.cpu_cores = dev.get_attribute_cpu_cores()
                db_host.devices.append(cpu)
            
            elif isinstance(dev.typecode, ns0.RAMDevice_Def):
                ram = None
                for d in old_devs:
                    if isinstance(d, db.newRAMDevice):
                        old_devs.remove(d)
                        ram = d
                        break
                
                if ram is None:
                    ram = db.newRAMDevice()
                
                ram.size = dev.get_attribute_size()
                db_host.devices.append(ram)
            
            elif isinstance(dev.typecode, ns0.HardDiskDevice_Def):
                hd = None
                for d in old_devs:
                    if isinstance(d, db.newHardDiskDevice) and d.address == dev.get_attribute_address():
                        old_devs.remove(d)
                        hd = d
                        break
                
                if hd is None:
                    hd = db.newHardDiskDevice()
                
                hd.address = dev.get_attribute_address()
                hd.size = dev.get_attribute_size()
                hd.sector_size = dev.get_attribute_sector_size()
                hd.model_no = dev.get_attribute_model_no()
                hd.serial_no = dev.get_attribute_serial_no()
                hd.firmware_rev = dev.get_attribute_firmware_rev()
                hd.wwn = dev.get_attribute_wwn()
                hd.cylinder = dev.get_attribute_cylinder()
                hd.heads = dev.get_attribute_heads()
                hd.sectors = dev.get_attribute_sectors()
                
                if '_partitiontable' not in dir(dev):
                    if hd.partitiontable is not None:
                        [db.delete(part) for part in hd.partitiontable.parts]
                        db.delete(hd.partitiontable)
                
                else:
                    if isinstance(dev.Partitiontable.typecode, ns0.MBRPartitiontable_Def):
                        if hd.partitiontable is None:
                            hd.partitiontable = db.newMBRPartitiontable()
                        
                        hd.partitiontable.disk_signature = base64.b64decode(dev.Partitiontable.get_attribute_disk_signature())
                        hd.partitiontable.bootloader = base64.b64decode(dev.Partitiontable.get_attribute_bootloader())
                        hd.partitiontable.partitions = base64.b64decode(dev.Partitiontable.get_attribute_partitions())
                        if dev.Partitiontable.get_attribute_extended() is not None:
                            hd.partitiontable.extended = base64.b64decode(dev.Partitiontable.get_attribute_extended())
                        
                        else:
                            hd.partitiontable.extended = None
                        
                        if dev.Partitiontable.get_attribute_unused() is not None:
                            hd.partitiontable.unused = base64.b64decode(dev.Partitiontable.get_attribute_unused())
                        
                        else:
                            hd.partitiontable.unused = None
                        
                        old_parts = []
                        old_parts.extend(hd.partitiontable.parts)
                        del hd.partitiontable.parts[:]
                        for part in dev.Partitiontable.Parts:
                            if isinstance(part.typecode, ns0.MBRPartition_Def):
                                p = None
                                for d in old_parts:
                                    if isinstance(d, db.newMBRPartition) and d.no == part.get_attribute_no():
                                        old_parts.remove(d)
                                        p = d
                                        break
                                
                                if p is None:
                                    p = db.newMBRPartition()
                                
                                p.fs = part.get_attribute_fs()
                                p.os = part.get_attribute_os()
                                p.no = part.get_attribute_no()
                                p.record_type = part.get_attribute_record_type()
                                p.bootable = part.get_attribute_bootable()
                                p.partition_type = part.get_attribute_partition_type()
                                p.start_lba = part.get_attribute_start_lba()
                                p.sectors = part.get_attribute_sectors()
                                p.start_cylinder = part.get_attribute_start_cylinder()
                                p.start_head = part.get_attribute_start_head()
                                p.start_sector = part.get_attribute_start_sector()
                                p.end_cylinder = part.get_attribute_end_cylinder()
                                p.end_head = part.get_attribute_end_head()
                                p.end_sector = part.get_attribute_end_sector()
                                hd.partitiontable.parts.append(p)
                            
                            else:
                                raise Exception('Unknown partition type')
                        
                        [db.delete(part) for part in old_parts]
                    
                    else:
                        raise Exception('Unknown partition table type')
                
                db_host.devices.append(hd)
            
            else:
                raise Exception('Unknown device type')
        
        # TODO: Remove partitiontable and partitions if hard disk is removed
        
        [db.delete(dev) for dev in old_devs]
    
    def __create_operation(self, db, operation, context_id):
        op = None
        if isinstance(operation, db.newImageOperation):
            op = nextOperation_imageoperation()
            op.set_attribute_id(operation.id)
            op.set_attribute_context_id(context_id)
            op.set_attribute_mode(operation.mode)
            op.set_attribute_address(operation.address)
            url = 'udpcast://'
            if operation.transfer_mode == 'unicast':
                # TODO: Get server ip address
                #url += '127.0.0.1'
                url += operation.transfer_mode
            
            elif operation.transfer_mode == 'multicast':
                if operation.mode == 'backup':
                    raise Exception('Multicast transfer mode not supported by backup mode')
                
                # TODO: Get multicast ip address
                #url += '224.0.0.1'
                url += operation.transfer_mode
            
            else:
                raise Exception('Unknown transfer mode "%s"' % operation.transfer_mode)
            
            data_path = config.get('storage', 'data_path')
            if not os.path.exists(data_path):
                raise Exception('Data path "%s" does not exist' % data_path)
            
            imgdir = '%s%simage%d' % (data_path, os.sep, operation.image_partition.disk.image.id)
            imgfile = '%s%sdisk%d_partition%d' % (imgdir, os.sep, operation.image_partition.disk.id, operation.image_partition.id)
            
            # Get client IP from webservice
            client_ip = socket.getnameinfo(GetSOAPContext().connection.getpeername(), socket.AI_PASSIVE)[0]
            # Get interface for client IP
            rt = nettools.RoutingTable()
            entry = rt.lookup(client_ip)
            # Get server IP for destination
            dev = nettools.getDeviceInfo(entry['Iface'], dest=entry['Destination'], mask=entry['Mask'])
            
            #interface = config.get('network', 'multicast_interface')
            interface = dev['ip']
            if operation.mode == 'backup':
                if os.path.exists(imgfile):
                    raise Exception('Image file "%s" already exists' % imgfile)
                
                if not os.path.exists(imgdir):
                    old_mask = os.umask(002)
                    os.mkdir(imgdir)
                    os.umask(old_mask)
                
                uc = udpcast.manager.getReceiver(operation.id, imgfile, interface)
            
            elif operation.mode == 'restore':
                if not os.path.exists(imgfile):
                    raise Exception('Image file "%s" not found' % imgfile)
                
                # TODO: Check, which clients are on the same interface/subnet
                
                client_count = 0
                
                cur_index = operation.task.operations.index(operation)
                for taskhost in operation.task.taskhosts:
                    if taskhost.status == 'running' and (taskhost.cur_operation is None or operation.task.operations.index(taskhost.cur_operation) <= cur_index):
                        client_count = client_count + 1
                
                logger.info('Starting UDPcast in send mode on interface %s for %d clients' % (interface, client_count))
                logger.debug('udpcast client_count: %d  taskhosts: %d' % (client_count, len(operation.task.taskhosts)))
                
                uc = udpcast.manager.getSender(operation.id, imgfile, interface, client_count)
            
            else:
                raise Exception('Unknown image method "%s"' % operation.mode)
            
            url += ':%d/' % uc.port
            #url = 'udpcast://%s:%d/' % (uc.ip, uc.port)
            #url = uc.getURL()
            op.set_attribute_url(url)
            op.set_attribute_program(operation.image_partition.format)
            op.set_attribute_compression(operation.image_partition.compression)
        
        elif isinstance(operation, db.newPartitionOperation):
            op = nextOperation_partitionoperation()
            op.set_attribute_id(operation.id)
            op.set_attribute_context_id(context_id)
            op.set_attribute_program(operation.program)
            op.set_attribute_address(operation.address)
            op.set_attribute_restore_bootloader(operation.restore_bootloader)
            op.set_attribute_restore_unused(operation.restore_unused)
            if isinstance(operation.partitiontable, db.newMBRPartitiontable):
                op.Partitiontable = partitionoperation_mbr_partitiontable()
                op.Partitiontable.set_attribute_disk_signature(base64.b64encode(operation.partitiontable.disk_signature))
                op.Partitiontable.set_attribute_bootloader(base64.b64encode(operation.partitiontable.bootloader))
                op.Partitiontable.set_attribute_partitions(base64.b64encode(operation.partitiontable.partitions))
                if operation.partitiontable.extended is not None:
                    op.Partitiontable.set_attribute_extended(base64.b64encode(operation.partitiontable.extended))
                
                if operation.partitiontable.unused is not None:
                    op.Partitiontable.set_attribute_unused(base64.b64encode(operation.partitiontable.unused))
            
            else:
                raise Exception('Unknown partitiontable type "%s"', type(operation.partitiontable))
            
            # TODO: Set partitions
            op.Partitiontable.Parts = []
        
        elif isinstance(operation, db.newShutdownOperation):
            op = nextOperation_shutdownoperation()
            op.set_attribute_id(operation.id)
            op.set_attribute_context_id(context_id)
            op.set_attribute_mode(operation.mode)
        
        return op
    
    def authorize(self, auth_info, post, action):
        #print 'authorize(%s, %s, %s)' % (auth_info, post, action)
        #db = Database()
        #db.close()
        return 1
    
    def register(self, host):
        logger.debug('register(%s)' % host)
        status = register_status()
        status.set_attribute_value(0)
        db = Database()
        
        logger.info('Register request from host %s' % host.Hostid.get_attribute_mac())
        
        db_host = db.queryHostByMac(host.Hostid.get_attribute_mac())
        if db_host is not None:
            logger.error('Host %s tried to re-register' % host.Hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        else:
            if host.get_attribute_ip() is None:
                ip = GetSOAPContext().connection.getpeername()[0]
            
            else:
                ip = host.get_attribute_ip()
            
            db_host = db.newHost(mac=host.Hostid.get_attribute_mac(), ip=ip, hostname=host.get_attribute_hostname(), serial_no=host.get_attribute_serial_no(), running=True, last_boot_time=db.now())
            self.__update_devices(db, db_host, host)
            db.save(db_host)
            db.commit()
            logger.info('Host %s successfully registered' % host.Hostid.get_attribute_mac())
        
        db.close()
        return status
    
    def logon(self, host):
        logger.debug('logon(%s)' % host)
        status = logon_status()
        status.set_attribute_value(0)
        db = Database()
        
        logger.info('Logon request from host %s' % host.Hostid.get_attribute_mac())
        
        db_host = db.queryHostByMac(host.Hostid.get_attribute_mac())
        if db_host is None:
            logger.error('Logon failed (Host %s doesn\'t exist)' % host.Hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        else:
            # TODO: Check if host is already running
            if db_host.running is True:
                logger.warn('Host %s already logged on' % host.Hostid.get_attribute_mac())
            
            db_host.running = True
            db_host.last_boot_time = db.now()
            
            if db_host.ip is None:
                if host.get_attribute_ip() is None:
                    db_host.ip = GetSOAPContext().connection.getpeername()[0]
                
                else:
                    db_host.ip = host.get_attribute_ip()
            
            if db_host.hostname is None and host.get_attribute_hostname() is not None:
                db_host.hostname = host.get_attribute_hostname()
            
            if db_host.serial_no is None and host.get_attribute_serial_no() is not None:
                db_host.serial_no = host.get_attribute_serial_no()
            
            self.__update_devices(db, db_host, host)
            db.commit()
            logger.info('Host %s successfully logged on' % host.Hostid.get_attribute_mac())
        
        db.close()
        return status
    
    def logoff(self, hostid):
        logger.debug('logoff(%s)' % hostid)
        status = logoff_status()
        status.set_attribute_value(0)
        db = Database()
        
        logger.info('Logoff request from host %s' % hostid.get_attribute_mac())
        
        host = db.queryHostByMac(hostid.get_attribute_mac())
        if host is None:
            logger.error('Logoff failed (Host %s doesn\'t exist)' % hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        elif host.running is not True:
            logger.error('Logoff failed (Host %s isn\'t running)' % hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        else:
            host.running = False
            db.commit()
            logger.info('Host %s successfully logged off' % hostid.get_attribute_mac())
        
        db.close()
        return status
    
    def nextOperation(self, hostid):
        logger.debug('nextOperation(%s)' % hostid)
        status = nextOperation_status()
        status.set_attribute_value(0)
        operation = None
        db = Database()
        
        logger.info('Next operation request from host %s' % hostid.get_attribute_mac())
        
        host = db.queryHostByMac(hostid.get_attribute_mac())
        if host is None:
            logger.error('Next operation failed (Host %s doesn\'t exist)' % hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        elif host.running is not True:
            logger.error('Next operation failed (Host %s isn\'t running)' % hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        else:
            for hosttask in host.hosttasks:
                logger.debug('Current operation: %s (%s) Next operation: %s (%s)' % (hosttask.cur_operation_id, hosttask.cur_operation, hosttask.next_operation_id, hosttask.next_operation))
                
                # Skip failed and unknown status hosttasks
                if hosttask.status != 'running':
                    continue
                
                if hosttask.cur_operation is None and len(hosttask.task.operations) > 0:
                    # Start first operation
                    hosttask.cur_operation = hosttask.next_operation = hosttask.task.operations[0]
                    operation = self.__create_operation(db, hosttask.cur_operation, hosttask.task_id)
                
                else:
                    if hosttask.next_operation is not None:
                        if hosttask.cur_operation == hosttask.next_operation:
                            # Restart operation
                            logger.warn('Restarting operation %d from task %s (%d) on host %s' % (hosttask.cur_operation_id, hosttask.task.description, hosttask.task_id, hostid.get_attribute_mac()))
                        
                        # Start next operation
                        hosttask.cur_operation = hosttask.next_operation
                        operation = self.__create_operation(db, hosttask.cur_operation, hosttask.task_id)
                    
                    else:
                        # Last operation reached, continue with next task
                        pass
                    
                    #op_iter = iter(hosttask.task.operations)
                    #try:
                    #    for op in op_iter:
                    #        if op is hosttask.cur_operation:
                    #            # Start next operation
                    #            hosttask.cur_operation = hosttask.next_operation = op_iter.next()
                    #            operation = self.__create_operation(db, hosttask.cur_operation, hosttask.task_id)
                    #            break
                    #    
                    #except StopIteration:
                    #    # Last operation reached, continue with next task
                    #    pass
                
                if operation is not None:
                    break
            
            if operation is not None:
                hosttask.percentage = 0.0
                hosttask.speed = None
                db.commit()
            
            else:
                operation = nextOperation_idleoperation()
                operation.set_attribute_sleeptime(15000)
        
        logger.info('Next operation on host %s is %s' % (hostid.get_attribute_mac(), str(type(operation.typecode))[8:-2].split('.')[-1]))
        
        db.close()
        return status, operation
    
    def statusUpdate(self, statusupdate):
        logger.debug('statusUpdate(%s)' % statusupdate)
        status = statusUpdate_status()
        status.set_attribute_value(0)
        nextupdate = 10000
        db = Database()
        
        logger.info('Status update request from host %s' % statusupdate.Hostid.get_attribute_mac())
        
        host = db.queryHostByMac(statusupdate.Hostid.get_attribute_mac())
        if host is None:
            logger.error('Status update failed (Host %s doesn\'t exist)' % statusupdate.Hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        elif host.running is not True:
            logger.error('Status update failed (Host %s isn\'t running)' % statusupdate.Hostid.get_attribute_mac())
            status.set_attribute_value(1)
        
        else:
            hosttask = db.queryHostTask(host.id, statusupdate.get_attribute_context_id())
            
            if hosttask is None:
                logger.error('Status update failed (Task %s doesn\'t exist)' % statusupdate.Hostid.get_attribute_mac())
                status.set_attribute_value(1)
            
            else:
                # Update status
                if statusupdate.get_attribute_percentage() is not None:
                    hosttask.percentage = statusupdate.get_attribute_percentage()
                
                if statusupdate.get_attribute_speed() is not None:
                    hosttask.speed = statusupdate.get_attribute_speed()
                
                if statusupdate.Status is not None and statusupdate.Status.get_attribute_value() is not None:
                    if statusupdate.Status.get_attribute_value() == 0:
                        # Set image size if operation was imaging in backup mode
                        op = hosttask.cur_operation
                        if isinstance(op, db.newImageOperation):
                            if op.mode == 'backup':
                                data_path = config.get('storage', 'data_path')
                                imgdir = '%s%simage%d' % (data_path, os.sep, op.image_partition.disk.image.id)
                                imgfile = '%s%sdisk%d_partition%d' % (imgdir, os.sep, op.image_partition.disk.id, op.image_partition.id) 
                                if os.path.exists(imgfile):
                                    op.image_partition.image_size = os.path.getsize(imgfile)
                        
                        # Find next operation
                        op_iter = iter(hosttask.task.operations)
                        try:
                            for op in op_iter:
                                if op is hosttask.cur_operation:
                                    # Set next operation
                                    hosttask.next_operation = op_iter.next()
                                    break
                        
                        except StopIteration:
                            # Last operation reached
                            logger.info('Host %s finished task %s (%d)' % (statusupdate.Hostid.get_attribute_mac(), hosttask.task.description, hosttask.task_id))
                            hosttask.status = 'finished'
                            hosttask.cur_operation_id = None
                            hosttask.next_operation_id = None
                            hosttask.percentage = None
                            # TODO: Set average speed
                    
                    else:
                        logger.error('Task %s (%d) failed on host %s' % (hosttask.task.description, hosttask.task_id, statusupdate.Hostid.get_attribute_mac()))
                        hosttask.status = 'failed'
                
                db.commit()
                
                logger.info('Host %s sucessfully updated the status' % statusupdate.Hostid.get_attribute_mac())
        
        db.close()
        return status, nextupdate
