#!/usr/bin/env python
#
# Transfer VM - VPX for exposing VDIs on XenServer
# Copyright (C) Citrix Systems, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
#
# XenAPI plugin for copying VDIs and VMs between pools.
#

import os
import os.path
import socket
import sys
from xml.dom import minidom

import XenAPI
import XenAPIPlugin

from pluginlib import *
configure_logging('copy')
from pluginlib import log


METADATA_VDI_SIZE = 2 * 1024 * 1024  # There's a fixed 2 MB buffer inside
                                     # puller, so raising this would involve
                                     # raising that too.


class NoSuchHost(PluginError):
    """The given remote host was invalid."""
    def __init__(self, *args):
        PluginError.__init__(self, *args)


class AttachFailed(PluginError):
    """We managed to expose the VDI, but failed to attach a new SR to it."""
    def __init__(self, *args):
        PluginError.__init__(self, *args)


class UnknownVDI(PluginError):
    """The bitmap we got described an unknown VDI."""
    def __init__(self, *args):
        PluginError.__init__(self, *args)


class ExposeDisappeared(PluginError):
    def __init__(self, *args):
        PluginError.__init__(self, *args)


class PullerFailed(PluginError):
    def __init__(self, *args):
        PluginError.__init__(self, *args)


class InvalidInstructions(PluginError):
    def __init__(self, *args):
        PluginError.__init__(self, *args)


@log_exceptions
def get_vm_forest(session, args):
    """
    Get an entire VM snapshot forest into a local SR, using HTTP VHD download
    for the disk contents, with Transfer VMs on both sides.
    """
    protocol, remote_host, remote_port, remote_username, remote_password = \
        get_remote_connection(args)

    remote_vm_uuids = validate_exists(args, 'remote_vm_uuids').split(',')
    local_sr_uuid = validate_exists(args, 'local_sr_uuid')
    extra_args = dict(args)
    del extra_args['remote_vm_uuids']
    del extra_args['local_sr_uuid']

    remote_session = login(protocol, remote_host, remote_port,
                           remote_username, remote_password)
    try:
        return get_vm_forest_(protocol, remote_host, remote_port, session,
                              remote_session, remote_vm_uuids, local_sr_uuid,
                              extra_args)
    finally:
        logout(remote_session)


def get_vm_forest_(protocol, remote_host, remote_port, local_session,
                   remote_session, remote_vm_uuids, local_sr_uuid,
                   extra_args):

    assert len(remote_vm_uuids) == 1 # Multiple VMs unimplemented

    metadata_vdi_ref, metadata_vdi_uuid = \
        get_metadata_(protocol, remote_host, remote_port,
                      local_session, remote_session,
                      remote_vm_uuids[0], local_sr_uuid)
    try:
        return get_vm_forest__(local_session, remote_session, remote_vm_uuids,
                               local_sr_uuid, extra_args, metadata_vdi_uuid)
    except:
        destroy_vdi(local_session, metadata_vdi_ref)
        raise


def get_vm_forest__(local_session, remote_session, remote_vm_uuids,
                    local_sr_uuid, extra_args, metadata_vdi_uuid):
    expose_records = {}
    vdi_map = {}

    instructions = get_import_instructions(local_session, metadata_vdi_uuid)

    all_remote_vm_uuids = \
        remote_vm_uuids + all_snapshots(remote_session, remote_vm_uuids)
    expose_forest(remote_session, extra_args, all_remote_vm_uuids,
                  expose_records)

    try:
        try:
            execute_instructions(local_session, local_sr_uuid, instructions,
                                 expose_records, vdi_map)

            new_vms = []
            log.debug('Remapping %s...', metadata_vdi_uuid)
            new_vms.append(
                remap_vm(local_session, local_sr_uuid, metadata_vdi_uuid,
                         convert_map_to_locations(local_session,
                                                  remote_session, vdi_map)))
            log.debug('Remapping %s done', metadata_vdi_uuid)
            return ','.join(new_vms)
        except:
            destroy_all_vdis(local_session, vdi_map)
            raise
    finally:
        for record_handle in expose_records.iterkeys():
            unexpose(remote_session, record_handle)


def get_import_instructions(session, metadata_vdi_uuid):
    host_ref = get_this_host(session)
    args = {}
    args['vm_metadata_vdi_uuid'] = metadata_vdi_uuid
    return \
        parse_instructions(
            unwrap_plugin_exceptions(
                session.xenapi.host.call_plugin,
                host_ref, 'transfer', 'get_import_instructions', args))


def parse_instructions(s):
    return s.split('\n')


def expose_forest(session, extra_args, vm_uuids, expose_records):
    host_ref = get_this_host(session)

    log.debug("Exposing forest for %s", str(vm_uuids))
    expose_args = dict(extra_args)
    expose_args['vm_uuids'] = ','.join(vm_uuids)
    expose_args['network_uuid'] = 'management'
    expose_args['read_only'] = 'true'

    record_handles_str = \
        unwrap_plugin_exceptions(
            session.xenapi.host.call_plugin,
            host_ref, 'transfer', 'expose_forest', expose_args)
    record_handles = record_handles_str.split(',')

    for record_handle in record_handles:
        expose_records[record_handle] = \
            get_expose_record(session, host_ref, record_handle)


def all_snapshots(session, vm_uuids):
    result = []
    for v in vm_uuids:
        r = session.xenapi.VM.get_by_uuid(v)
        result.extend([session.xenapi.VM.get_uuid(x)
                       for x in session.xenapi.VM.get_snapshots(r)])
    return result


def convert_map_to_locations(local_session, remote_session, vdi_map):
    """
    Convert a UUID -> UUID VDI map into the form that remap_vm wants.
    """
    result = {}
    for k, v in vdi_map.iteritems():
        old_ref = remote_session.xenapi.VDI.get_by_uuid(k)
        old_rec = remote_session.xenapi.VDI.get_record(old_ref)
        new_ref = local_session.xenapi.VDI.get_by_uuid(v)
        new_rec = local_session.xenapi.VDI.get_record(new_ref)
        result[old_rec['location']] = \
            {'ref': new_ref,
             'location': new_rec['location']
            }
    return result


def execute_instructions(local_session, local_sr_uuid, instructions,
                         expose_records, vdi_map):
    for instruction in instructions:
        log.debug('Instruction is %s', instruction)
        if instruction.startswith('create '):
            execute_create(local_session, local_sr_uuid, expose_records,
                           instruction, vdi_map)
        elif instruction.startswith('clone '):
            execute_clone(local_session, expose_records, instruction, vdi_map)
        elif instruction.startswith('reuse '):
            execute_reuse(local_session, expose_records, instruction, vdi_map)
        elif instruction.startswith('snap '):
            execute_snap(local_session, instruction, vdi_map)
        elif instruction.startswith('leaf '):
            execute_leaf(local_session, instruction, vdi_map)
        elif instruction == 'pass':
            pass
        else:
            raise InvalidInstructions(
                "Invalid instruction '%s'" % instruction)


def execute_create(local_session, local_sr_uuid, expose_records, instruction,
                   vdi_map):
    _, vdi_uuid, virtual_size = instruction.split(' ')
    local_sr_ref = local_session.xenapi.SR.get_by_uuid(local_sr_uuid)
    log.debug('Creating dup of %s in %s...', vdi_uuid, local_sr_ref)
    dest_ref = create_vdi(local_session, local_sr_ref,
                          'Copy of %s' % vdi_uuid,
                          long(virtual_size), False)
    dest_uuid = local_session.xenapi.VDI.get_uuid(dest_ref)
    get_vdi_vhd_finding_record(local_session, expose_records, vdi_uuid,
                               dest_uuid)
    log.debug('Creating dup of %s in %s done.', vdi_uuid, local_sr_ref)
    vdi_map[vdi_uuid] = dest_uuid


def execute_clone(local_session, expose_records, instruction, vdi_map):
    _, child_uuid, parent_uuid = instruction.split(' ')

    dest_uuid = vdi_map[parent_uuid]
    dest_ref = local_session.xenapi.VDI.get_by_uuid(dest_uuid)

    log.debug('Cloning %s for %s...', parent_uuid, child_uuid)
    new_dest_ref = local_session.xenapi.VDI.clone(dest_ref)
    new_dest_uuid = local_session.xenapi.VDI.get_uuid(new_dest_ref)
    log.debug('Cloning %s(%s) for %s created %s.', dest_uuid, parent_uuid,
              child_uuid, new_dest_uuid)
    get_vdi_vhd_finding_record(local_session, expose_records, child_uuid,
                               new_dest_uuid)
    log.debug('%s loaded into %s.', child_uuid, new_dest_uuid)
    vdi_map[child_uuid] = new_dest_uuid


def execute_reuse(local_session, expose_records, instruction, vdi_map):
    _, child_uuid, parent_uuid = instruction.split(' ')

    dest_uuid = vdi_map[parent_uuid]
    log.debug('%s now will use %s(%s).', child_uuid, dest_uuid, parent_uuid)
    get_vdi_vhd_finding_record(local_session, expose_records, child_uuid,
                               dest_uuid)
    log.debug('%s reused %s(%s).', child_uuid, dest_uuid, parent_uuid)
    del vdi_map[parent_uuid]
    vdi_map[child_uuid] = dest_uuid


def execute_snap(session, instruction, vdi_map):
    _, vdi_uuid = instruction.split(' ')
    dest_uuid = vdi_map[vdi_uuid]
    dest_ref = session.xenapi.VDI.get_by_uuid(dest_uuid)
    sr_ref = session.xenapi.VDI.get_SR(dest_ref)
    new_dest_ref, new_dest_uuid = \
        snapshot_leaf(session, dest_ref, dest_uuid)
    log.debug('Scanning %s to trigger GC...', sr_ref)
    session.xenapi.SR.scan(sr_ref)
    log.debug('Scanning %s done.', sr_ref)

    session.xenapi.VDI.set_name_label(new_dest_ref, 'Copy of %s' % vdi_uuid)
    vdi_map[vdi_uuid] = new_dest_uuid


def snapshot_leaf(session, dest_ref, dest_uuid):
    log.debug('Snapshotting %s to make leaf', dest_uuid)
    new_dest_ref = session.xenapi.VDI.snapshot(dest_ref)
    session.xenapi.VDI.destroy(dest_ref)
    new_dest_uuid = session.xenapi.VDI.get_uuid(new_dest_ref)
    log.debug('Snapshotting %s created %s.', dest_uuid, new_dest_uuid)
    return new_dest_ref, new_dest_uuid


def execute_leaf(session, instruction, vdi_map):
    _, vdi_uuid = instruction.split(' ')
    dest_uuid = vdi_map[vdi_uuid]
    log.debug('Leaving leaf %s writable.', dest_uuid)
    dest_ref = session.xenapi.VDI.get_by_uuid(dest_uuid)

    session.xenapi.VDI.set_name_label(dest_ref, 'Copy of %s' % vdi_uuid)
    vdi_map[vdi_uuid] = dest_uuid


def get_vdi_vhd_finding_record(local_session, expose_records, src_uuid,
                               dest_uuid):
    get_vdi_vhd(local_session, find_expose_record(expose_records, src_uuid),
                src_uuid, dest_uuid)


def find_expose_record(expose_records, vdi_uuid):
    k = 'url_path_%s' % vdi_uuid
    for r in expose_records.itervalues():
        if k in r:
            return r
    raise InvalidInstructions("Can't find expose record for %s" % vdi_uuid)


def get_vdi_vhd(local_session, record, vdi_uuid, dest_uuid):
    log.debug('Pulling %s as VHD into %s...', vdi_uuid, dest_uuid)
    url = '%s.vhd' % (record['url_full_%s' % vdi_uuid])
    ssl_cert = record.get('ssl_cert')
    complete_pull(local_session, dest_uuid, url, ssl_cert)
    log.debug('Pulling %s as VHD into %s done.', vdi_uuid, dest_uuid)


def complete_pull(local_session, dest_uuid, src_url, ssl_cert):
    record_handle = \
        expose_puller(local_session, dest_uuid, src_url, ssl_cert)
    try:
        wait_for_ack(local_session, record_handle)
    finally:
        unexpose(local_session, record_handle)


def expose_puller(session, vdi_uuid, src_url, src_cert):
    host_ref = get_this_host(session)

    expose_args = {}
    expose_args['transfer_mode'] = 'http_pull'
    expose_args['vdi_uuid'] = vdi_uuid
    expose_args['network_uuid'] = 'management'
    expose_args['src_urls'] = src_url
    if src_cert:
        expose_args['src_certs'] = src_cert

    log.debug("Exposing pull from %s for VDI %s...", src_url, vdi_uuid)
    result = \
        unwrap_plugin_exceptions(
            session.xenapi.host.call_plugin,
            host_ref, 'transfer', 'expose', expose_args)
    log.debug("Exposing pull from %s for VDI %s done.", src_url, vdi_uuid)
    return result


def wait_for_ack(session, record_handle):
    log.debug("Waiting for %s to acknowledge...", record_handle)

    host_ref = get_this_host(session)

    while True:
        record = get_expose_record(session, host_ref, record_handle)
        if all_acked(record):
            return
        time.sleep(5)


def all_acked(record):
    devices = record['all_devices'].split(',')
    for device in devices:
        s = 'status_%s' % device
        if s in record:
            v = record[s]
            if v != 'OK':
                x = None
                try:
                    x = parse_xmlrpc_value(v)
                except:
                    raise PullerFailed(v)
                raise XenAPI.Failure(x)
        else:
            return False
    return True


@log_exceptions
def get_vm(session, args):
    """
    Copy a remote (source) VM over to a local (destination) SR, using Transfer
    VMs on the source, temporary iSCSI SRs locally, and VDI.copy (i.e. a full
    copy, pumped by xapi).
    """
    protocol, remote_host, remote_port, remote_username, remote_password = \
        get_remote_connection(args)

    remote_vm_uuid = validate_exists(args, 'remote_vm_uuid')
    local_sr_uuid = validate_exists(args, 'local_sr_uuid')
    extra_args = dict(args)
    del extra_args['remote_vm_uuid']
    del extra_args['local_sr_uuid']

    remote_session = login(protocol, remote_host, remote_port,
                           remote_username, remote_password)
    try:
        return get_vm_(protocol, remote_host, remote_port,
                       session, remote_session,
                       remote_vm_uuid, local_sr_uuid, extra_args)
    finally:
        logout(remote_session)


def get_vm_(protocol, remote_host, remote_port, local_session, remote_session,
            remote_vm_uuid, local_sr_uuid, extra_args):

    vdi_map = \
        get_all_vdis(local_session, remote_session, remote_vm_uuid,
                     local_sr_uuid, extra_args)
    try:
        return import_vm_and_remap_vdis(protocol, remote_host, remote_port,
                                        local_session, remote_session,
                                        remote_vm_uuid, local_sr_uuid,
                                        vdi_map)
    except:
        destroy_all_vdis(local_session, vdi_map)
        raise


def import_vm_and_remap_vdis(protocol, remote_host, remote_port,
                             local_session, remote_session, remote_vm_uuid,
                             local_sr_uuid, vdi_map):

    r, u = get_metadata_(protocol, remote_host, remote_port,
                         local_session, remote_session,
                         remote_vm_uuid, local_sr_uuid)
    try:
        return remap_vm(local_session, local_sr_uuid, u, vdi_map)
    except:
        destroy_vdi(local_session, r)
        raise


def remap_vm(local_session, local_sr_uuid, metadata_vdi_uuid, vdi_map):
    host_ref = get_this_host(local_session)
    remap_args = {}
    remap_args['vm_metadata_vdi_uuid'] = metadata_vdi_uuid
    remap_args['vdi_map'] = vdi_map_to_string(vdi_map)
    remap_args['sr_uuid'] = local_sr_uuid
    return \
        unwrap_plugin_exceptions(
            local_session.xenapi.host.call_plugin,
            host_ref, 'transfer', 'remap_vm', remap_args)


def get_metadata_(protocol, remote_host, remote_port,
                  local_session, remote_session, remote_vm_uuid,
                  local_sr_uuid):

    log.debug('Creating metadata VDI...')

    local_sr_ref = local_session.xenapi.SR.get_by_uuid(local_sr_uuid)
    metadata_vdi = create_vdi(local_session, local_sr_ref,
                              'Remote metadata for %s' % remote_vm_uuid,
                              METADATA_VDI_SIZE, False)
    metadata_vdi_uuid = local_session.xenapi.VDI.get_uuid(metadata_vdi)

    log.debug('Creating metadata VDI done (%s).', metadata_vdi_uuid)

    pull_metadata(protocol, remote_host, remote_port,
                  local_session, remote_session, remote_vm_uuid,
                  metadata_vdi_uuid)

    return metadata_vdi, metadata_vdi_uuid


def pull_metadata(protocol, remote_host, remote_port,
                  local_session, remote_session, remote_vm_uuid,
                  dest_vdi_uuid):
    vm_ref = remote_session.xenapi.VM.get_by_uuid(remote_vm_uuid)
    url = \
        ('%s://x:x@%s:%s/export_metadata?session_id=%s&include_vhd_parents=true&ref=%s' %
         (protocol, remote_host, remote_port, remote_session.handle, vm_ref))

    log.debug('Pulling metadata from %s...', url)

    if protocol == 'https':
        remote_host_ref = get_this_host(remote_session)
        ssl_cert = \
            remote_session.xenapi.host.get_server_certificate(remote_host_ref)
    else:
        ssl_cert = ''

    complete_pull(local_session, dest_vdi_uuid, url, ssl_cert)

    log.debug('Pulling metadata done.')


def vdi_map_to_string(vdi_map):
    result = []
    for k, v in vdi_map.iteritems():
        result.append('%s=%s' % (k, v['location']))
    return ','.join(result)


def get_all_vdis(local_session, remote_session, remote_vm_uuid,
                 local_sr_uuid, extra_args):
    vm_ref = remote_session.xenapi.VM.get_by_uuid(remote_vm_uuid)
    vdi_map = {}
    try:
        for vbd in remote_session.xenapi.VM.get_VBDs(vm_ref):
            if remote_session.xenapi.VBD.get_type(vbd) == 'CD':
                continue
            vdi = ignore_failure(remote_session.xenapi.VBD.get_VDI, vbd)
            get_vdi_and_map(vdi_map, local_session, remote_session,
                            local_sr_uuid, extra_args, vdi)
        suspend_vdi = ignore_failure(remote_session.xenapi.VM.get_suspend_VDI,
                                     vm_ref)
        get_vdi_and_map(vdi_map, local_session, remote_session, local_sr_uuid,
                        extra_args, suspend_vdi)
        return vdi_map
    except:
        destroy_all_vdis(local_session, vdi_map)
        raise


def get_vdi_and_map(vdi_map, local_session, remote_session, local_sr_uuid,
                    extra_args, vdi_ref):
    if not vdi_ref:
        return
    vdi_rec = ignore_failure(remote_session.xenapi.VDI.get_record, vdi_ref)
    if not vdi_rec:
        return
    new_vdi_ref = get_vdi_(local_session, remote_session,
                           vdi_rec['uuid'], local_sr_uuid, extra_args)
    new_vdi_location = \
        local_session.xenapi.VDI.get_location(new_vdi_ref)
    vdi_map[vdi_rec['location']] = {'ref': new_vdi_ref,
                                    'location': new_vdi_location}
    log.debug('VDI %s(%s) copied to %s', vdi_rec['uuid'],
              vdi_rec['location'], new_vdi_location)


def destroy_all_vdis(session, vdi_map):
    """Nothrow guarantee."""
    for vdi_bit in vdi_map.values():
        destroy_vdi(session, vdi_bit['ref'])


def destroy_vdi(session, vdi_ref):
    """Nothrow guarantee."""
    uuid = ignore_failure(session.xenapi.VDI.get_uuid, vdi_ref)
    if uuid is not None:
        log.debug('Destroying VDI %s(%s) ... ', uuid, vdi_ref)
        ignore_failure(session.xenapi.VDI.destroy, vdi_ref)
        log.debug('Destroying VDI %s(%s) done.', uuid, vdi_ref)


@log_exceptions
def get_metadata(session, args):
    """
    Get the metadata for the specified remote VM, and put it in a new VDI
    in the given local SR.
    """
    protocol, remote_host, remote_port, remote_username, remote_password = \
        get_remote_connection(args)

    remote_vm_uuid = validate_exists(args, 'remote_vm_uuid')
    local_sr_uuid = validate_exists(args, 'local_sr_uuid')

    remote_session = login(protocol, remote_host, remote_port,
                           remote_username, remote_password)

    _, u = get_metadata_(protocol, remote_host, remote_port,
                         session, remote_session,
                         remote_vm_uuid, local_sr_uuid)
    return u


@log_exceptions
def get_vdi(session, args):
    """
    Copy a remote (source) VDI over to a local (destination) SR, using a
    Transfer VM on the source, a temporary iSCSI SR locally, and VDI.copy
    (i.e. a full copy, pumped by xapi).
    """
    protocol, remote_host, remote_port, remote_username, remote_password = \
        get_remote_connection(args)

    remote_vdi_uuid = validate_exists(args, 'remote_vdi_uuid')
    local_sr_uuid = validate_exists(args, 'local_sr_uuid')
    extra_args = dict(args)
    del extra_args['remote_vdi_uuid']
    del extra_args['local_sr_uuid']

    remote_session = login(protocol, remote_host, remote_port,
                           remote_username, remote_password)
    try:
        dest_vdi_ref = \
            get_vdi_(session, remote_session, remote_vdi_uuid, local_sr_uuid,
                     extra_args)
        return session.xenapi.VDI.get_uuid(dest_vdi_ref)
    finally:
        logout(remote_session)


def get_vdi_(local_session, remote_session, remote_vdi_uuid, local_sr_uuid,
             extra_args):
    # For debugging, we can create small blank VDIs here instead of copying
#     local_sr_ref = local_session.xenapi.SR.get_by_uuid(local_sr_uuid)
#     return local_session.xenapi.VDI.create(
#         { 'name_label': 'Fake',
#           'name_description': '',
#           'SR': local_sr_ref,
#           'virtual_size': '100',
#           'type': 'User',
#           'sharable': False,
#           'read_only': False,
#           'xenstore_data': {},
#           'other_config': {},
#           'sm_config': {},
#           'tags': [] })

    result = with_exposed_vdi(local_session, remote_session, remote_vdi_uuid,
                              True, extra_args,
                              lambda src_vdi: copy_vdi(local_session, src_vdi,
                                                       local_sr_uuid))
    remote_vdi_ref = remote_session.xenapi.VDI.get_by_uuid(remote_vdi_uuid)
    remote_vdi_rec = remote_session.xenapi.VDI.get_record(remote_vdi_ref)
    def set_field(f):
        getattr(local_session.xenapi.VDI, 'set_%s' % f)(result,
                                                        remote_vdi_rec[f])
    set_field('name_label')
    set_field('name_description')
    set_field('xenstore_data')
    set_field('other_config')
    set_field('tags')
    return result


@log_exceptions
def put_vdi(session, args):
    """
    Upload the contents of the given local (source) VDI into the given remote
    (destination) VDI, using a Transfer VM exposing the destination VDI, a
    temporary iSCSI SR locally, and dd to copy the blocks.
    """
    protocol, remote_host, remote_port, remote_username, remote_password = \
        get_remote_connection(args)

    remote_vdi_uuid = validate_exists(args, 'remote_vdi_uuid')
    local_vdi_uuid = validate_exists(args, 'local_vdi_uuid')
    extra_args = dict(args)
    del extra_args['remote_vdi_uuid']
    del extra_args['local_vdi_uuid']

    remote_session = login(protocol, remote_host, remote_port,
                           remote_username, remote_password)
    try:
        src_vdi = session.xenapi.VDI.get_by_uuid(local_vdi_uuid)
        with_exposed_vdi(session, remote_session, remote_vdi_uuid, False,
                         extra_args,
                         lambda dest_vdi: dd_vdi(session, src_vdi, dest_vdi))
        return 'OK'
    finally:
        logout(remote_session)


def dd_vdi(session, src_vdi, dest_vdi):
    with_vdi_in_dom0(session, src_vdi, True,
                     lambda src_dev: dd_vdi_(session, src_dev, dest_vdi))


def dd_vdi_(session, src_dev, dest_vdi):
    with_vdi_in_dom0(session, dest_vdi, False,
                     lambda dest_dev: dd_vdi__(src_dev, dest_dev))


def dd_vdi__(src_dev, dest_dev):
    os.system('dd if=/dev/%s of=/dev/%s bs=2M' % (src_dev, dest_dev))


def login(protocol, remote_host, remote_port, remote_username, remote_password):
    try:
        remote_session = \
            XenAPI.Session('%s://%s:%s/' % (protocol, remote_host, remote_port))
        remote_session.login_with_password(remote_username, remote_password)
        return remote_session
    except socket.gaierror, exn:
        log.debug('gaierror logging in: %s', exn)
        raise NoSuchHost(remote_host)


def logout(session):
    """
    Nothrow guarantee.
    """
    try:
        session.xenapi.session.logout()
    except Exception, exn:
        log.warn('Ignoring exception when logging out: %s', exn)


def expose(session, vdi_uuid, read_only, extra_args):
    """
    Expose the given VDI using iSCSI.
    """

    log.debug("Exposing VDI %s, read_only=%s ... ", vdi_uuid, read_only)

    host_ref = get_this_host(session)

    expose_args = extra_args
    expose_args['transfer_mode'] = 'iscsi'
    expose_args['vdi_uuid'] = vdi_uuid
    expose_args['network_uuid'] = 'management'
    expose_args['read_only'] = read_only and 'true' or 'false'
    record_handle = \
        unwrap_plugin_exceptions(
            session.xenapi.host.call_plugin,
            host_ref, 'transfer', 'expose', expose_args)

    try:
        result = get_expose_record(session, host_ref, record_handle)
        log.debug("Exposing VDI %s, read_only=%s done.", vdi_uuid, read_only)
        return result
    except:
        unexpose(session, record_handle)
        raise


def get_expose_record(session, host_ref, record_handle):
    try:
        return record_to_dict(
            unwrap_plugin_exceptions(
                session.xenapi.host.call_plugin,
                host_ref, 'transfer', 'get_record',
                {'record_handle': record_handle}))
    except XenAPI.Failure, exn:
        if (len(exn.details) > 3 and
                exn.details[0] == 'XENAPI_PLUGIN_EXCEPTION' and
                exn.details[2] == 'ArgumentError'):
            log.debug('Transfer VM %s has gone away', record_handle)
            raise ExposeDisappeared()
        else:
            log.error("Exception when getting transfer record: %s", exn)
            raise
    except Exception, exn:
        log.error("Exception when getting transfer record: %s", exn)
        raise


def unexpose(session, record_handle):
    """
    Nothrow guarantee.
    """

    try:
        log.debug("Unexposing %s ... ", record_handle)

        host_ref = get_this_host(session)

        session.xenapi.host.call_plugin(
            host_ref, 'transfer', 'unexpose',
            {'record_handle': record_handle})

        log.debug("Unexposing %s done.", record_handle)
    except Exception, exn:
        log.warn('Ignoring exception %s when unexposing %s', exn,
                 record_handle)


def record_to_dict(xml):
    result = {}
    doc = minidom.parseString(xml)
    try:
        el = doc.getElementsByTagName('transfer_record')[0]
        for k, v in el.attributes.items():
            result[k] = v
    finally:
        doc.unlink()
    return result


def with_exposed_vdi(local_session, remote_session, remote_vdi_uuid,
                     read_only, extra_args, f):
    transfer_record = expose(remote_session, remote_vdi_uuid, read_only,
                             extra_args)

    try:
        sr = make_sr_for_transfer(local_session,
                                  transfer_record['ip'],
                                  transfer_record['port'],
                                  transfer_record['iscsi_iqn'],
                                  transfer_record['username'],
                                  transfer_record['password'])
        try:
            remote_vdi = get_unique_vdi(local_session, sr)
            introduce_vdi(local_session, remote_vdi)
            return f(remote_vdi)
        finally:
            forget_sr(local_session, sr)
    finally:
        unexpose(remote_session, transfer_record['record_handle'])


def make_sr_for_transfer(session, target, port, target_iqn, username,
                         password):
    name = 'Transfer SR for LUN %s:%s/%s' % (target, port, target_iqn)
    log.debug("Making %s ... ", name)
    host_ref = get_this_host(session)
    sr = session.xenapi.SR.create(
        host_ref,
        {'target': target,
         'targetIQN': target_iqn,
         'chapuser': username,
         'chappassword': password,
         'port': port},
        '0', name, '', 'iscsi', '', False, {})
    session.xenapi.SR.add_to_other_config(sr, 'HideFromXenCenter', 'true')
    log.debug("%s is %s.", name, sr)
    return sr


def forget_sr(session, sr):
    """
    Nothrow guarantee.
    """
    log.debug("Forgetting SR %s ... ", sr)

    pbds = []
    try:
        pbds = session.xenapi.SR.get_PBDs(sr)
    except Exception, exn:
        log.warn('Ignoring exception %s when getting PBDs for %s', exn, sr)
    for pbd in pbds:
        try:
            session.xenapi.PBD.unplug(pbd)
        except Exception, exn:
            log.warn('Ignoring exception %s when unplugging PBD %s', exn, pbd)
    try:
        session.xenapi.SR.forget(sr)
        log.debug("Forgetting SR %s done.", sr)
    except Exception, exn:
        log.warn('Ignoring exception %s when forgetting SR %s', exn, sr)


def get_unique_vdi(session, sr):
    vdi_refs = session.xenapi.SR.get_VDIs(sr)
    if len(vdi_refs) != 1:
        raise AttachFailed()
    return vdi_refs[0]


def introduce_vdi(session, vdi_ref):
    vdi_rec = session.xenapi.VDI.get_record(vdi_ref)
    session.xenapi.VDI.introduce(
        vdi_rec['uuid'],
        vdi_rec['name_label'],
        vdi_rec['name_description'],
        vdi_rec['SR'],
        vdi_rec['type'],
        vdi_rec['sharable'],
        vdi_rec['read_only'],
        vdi_rec['other_config'],
        vdi_rec['location'],
        vdi_rec['xenstore_data'],
        vdi_rec['sm_config'])


def copy_vdi(session, src_vdi_ref, dest_sr_uuid):
    """
    Copy the given source VDI to the given destination SR, using
    VDI.copy (i.e. a full copy, pumped by xapi).
    """
    log.debug("Copying VDI %s to %s ... ", src_vdi_ref, dest_sr_uuid)
    dest_sr_ref = session.xenapi.SR.get_by_uuid(dest_sr_uuid)
    dest_vdi_ref = session.xenapi.VDI.copy(src_vdi_ref, dest_sr_ref)
    log.debug("Copied VDI %s to %s make VDI %s.", src_vdi_ref, dest_sr_uuid,
              dest_vdi_ref)
    return dest_vdi_ref


def get_remote_connection(args):
    protocol = get_protocol(args)
    remote_host = validate_exists(args, 'remote_host')
    remote_port = get_port(args, protocol)
    remote_username = validate_exists(args, 'remote_username')
    remote_password = validate_exists(args, 'remote_password')
    return \
       protocol, remote_host, remote_port, remote_username, remote_password

def get_protocol(args):
    protocol = validate_exists(args, 'protocol', 'https')
    if protocol not in ['http', 'https']:
        raise ArgumentError('Invalid protocol %s' % protocol)
    return protocol


def get_port(args, protocol):
    return validate_exists(args, 'remote_port',
                           protocol == 'https' and 443 or 80)


if __name__ == '__main__':
    XenAPIPlugin.dispatch({'get_vdi': get_vdi,
                           'put_vdi': put_vdi,
                           'get_vm': get_vm,
                           'get_vm_forest': get_vm_forest,
                           'get_metadata': get_metadata,
                          })
