# -*- coding: utf-8 -*-
# h-client, a client for an h-source server (such as http://www.h-node.org/)
# Copyright (C) 2011  Antonio Gallo
# Copyright (C) 2011  Michał Masłowski  <mtjm@mtjm.eu>
#
#
# h-client is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# h-client is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with h-client.  If not, see <http://www.gnu.org/licenses/>.


"""
Hardware and operating system detection for h-client.
"""




import os
import re
import subprocess
import glob


#: Mapping from codenames used by distros to h-source distro codes.
DISTRO_CODENAMES = {
	"deltah": "gnewsense_2_3",
	"metad": "gnewsense_3_0",
	"awen":	"trisquel_3_5",
	"parabola": "parabola",
	"taranis": "trisquel_4_0",
	"slaine": "trisquel_4_5",
	"spartakus": "blag_140000",
	"dagda": "trisquel_5_0",
	"brigantia": "trisquel_5_5",
	"toutatis": "trisquel_6_0",
	}


class FileOpener(object):

	"""
	Open and list files.  Different implementations of this class
	would be used to test the distro detection code.
	"""

	@staticmethod
	def open(file_name):
		"""Open a named text file for reading.

		Raises `IOError` if the file is missing or cannot be read.
		"""
		return open(file_name)

	@staticmethod
	def glob(pattern):
		"""List names of files matching the *pattern*.

		Like `glob.glob`.  Only a small number of fixed patterns will
		be used as arguments.
		"""
		return glob.glob(pattern)


def parse_os_release(file_object):
	"""Parse os-release into a dictionary."""
	data = {}
	# See os-release(5) for specification of the format parsed.
	for line in file_object:
		if line[0] == "#" or not line.strip():
			continue
		name, value = line.split("=", 1)
		value = value.rstrip("\n")
		if value[0] in ('"', "'"):
			value = value[1:-1]
		data[name] = value
	return data


def distro_from_os_release(data):
	"""Return distro code as string for the specified dictionary of
	os-release data."""
	# TODO detect other distros if they have /etc/os-release.  The
	# reason for not using the data directly is different distro
	# naming in h-source.
	data_id = data.get('ID', '')
	if data_id == 'parabola':
		return 'parabola'
	elif data.get('ID', '') == 'debian':
		version_id = data.get('VERSION_ID')
		if version_id:
			# E.g. 'debian_8'.
			return 'debian_' + version_id
		else:
			# Testing or Sid.
			return 'debian'
	# Most other distros would also use VERSION_ID.
	return ''


def codename_to_distrocode(codename_string):
	"""Return the h-source distro code from the codename.

	If the codename does not match any distro code, an empty string is
	returned.
	"""
	codenames = list(DISTRO_CODENAMES.keys())
	for codename in codenames:
		if codename_string.find(codename) != -1:
			return DISTRO_CODENAMES[codename]
	return ""


def user_distribution(opener=FileOpener):
	"""
	Return the h-source distro code of the user distro.

	The optional argument specifies an object having ``open`` and
	``glob`` attributes used to find files containing the distro
	versions.

	An empty string is returned if the distro is not known.
	"""
	try:
		with opener.open("/etc/os-release") as os_release:
			release = distro_from_os_release(parse_os_release(os_release))
	except IOError:
		pass  # no /etc/os-release, try /usr/lib/os-release
	else:
		if release:
			return release
	try:
		with opener.open("/usr/lib/os-release") as os_release:
			release = distro_from_os_release(parse_os_release(os_release))
	except IOError:
		pass  # os-release missing, try non-standard /etc/*-release files
	else:
		if release:
			return release
	for release in opener.glob("/etc/*-release"):
		if release == "/etc/os-release":
			continue
		with opener.open(release) as release_file:
			for line in release_file:
				if line.find("CODENAME") != -1 \
						or release == "/etc/system-release":
					code = codename_to_distrocode(line.rstrip("\n").lower())
					if code:
						return code
	return ""


#: Version string of the currently running kernel.
KERNEL_VERSION = os.uname()[2]


#: Regular expression object to parse ``lspci`` output value with
#: string and numeric names.
_CLASS = re.compile(r"^(.+)\s+\[([0-9a-f]+)\]$")


def parse_lspci(lspci_output):
	"""Iterate parsed output of ``lspci -vmmnnk``.

	For each device a dictionary is yielded, with the keys being the
	tags of ``lspci`` output.  The values are strings, except for the
	following keys:

	  ``Class``, ``Vendor``, ``Device``, ``SVendor``, ``SDevice``
	    a tuple of string and numeric names
	  ``Module``
	    a list of strings naming all modules

	See the ``lspci(8)`` man page for description of the tags found.
	"""
	if type(lspci_output) == bytes:
		lspci_output = lspci_output.decode('utf-8')
	assert type(lspci_output) == str

	device = {}
	for line in lspci_output.split("\n"):
		if not line:
			if device:
				yield device
				device = {}
			continue
		tag, value = line.split("\t", 1)
		assert tag[-1] == ":"

		tag = tag[:-1]
		if tag != "Module" and tag in device:
			raise AssertionError("duplicated device tag %r" % tag)
		if tag in ("Class", "Vendor", "Device", "SVendor", "SDevice"):
			match = _CLASS.match(value)
			device[tag] = (match.group(1), int(match.group(2), 16))
		elif tag == "Module":
			device.setdefault(tag, []).append(value)
		else:
			device[tag] = value
	if device:
		yield device


#: Regular expression matching a row of ``lsusb`` output.
_LSUSB_ROW = re.compile(r"^\s*((?:[a-zA-Z:]+ )*[a-zA-Z:]+)\s+((?:[^\s](?:.*[^\s])?)?)\s*$")
#: Regular expression matching value of ``idVendor`` and similar fields.
_NAMED_USBID = re.compile(r"^0x([a-zA-Z0-9]{4})\s*(.*)\s*$")
#: Regular expression matching value of ``bInterfaceClass`` and similar fields.
_USBCLASS = re.compile(r"^([0-9]+)\s*(.*)\s*$")


def parse_lsusb(lsusb_output):
	"""Iterate parsed output of ``lsusb -v``.

	For each device a dictionary is yielded.
	"""
	if type(lsusb_output) == bytes:
		lsusb_output = lsusb_output.decode('utf-8')
	assert type(lsusb_output) == str

	device = {}
	for line in lsusb_output.split("\n"):
		line = line.strip()
		if not line:
			continue
		if line.startswith("Bus "):
			if device:
				yield device
			device = {}
			continue
		match = _LSUSB_ROW.match(line)
		if not match:
			continue
		tag, value = match.group(1), match.group(2)
		if tag in ("idVendor", "idProduct"):
			key = tag[2:].lower()
			res = _NAMED_USBID.match(value)
			if res:
				device[key + "Id"] = res.group(1)
				device[key + "Name"] = res.group(2)
		else:
			try:
				key = {"bInterfaceClass": "class",
					   "bInterfaceSubClass": "subclass",
					   "bInterfaceProtocol": "protocol"}[tag]
			except KeyError:
				continue
			else:
				res = _USBCLASS.match(value)
				if res:
					device[key + "Id"] = "%02x" % int(res.group(1))
					device[key + "Name"] = res.group(2).strip()
	if device:
		yield device


def createDevices(pci_devices=None, usb_devices=None):
	"""Return a dictionary of device objects.

	The optional *pci_devices* and *usb_devices* are used if set
	instead of return values of `get_lspci_data` and `get_lsusb_data`.
	"""
	from hclient.devices import Device, get_device_type_for_class
	devices = {}
	if pci_devices is None:
		lspci = subprocess.Popen(("lspci", "-vmmnnk"), stdout=subprocess.PIPE,
								 stderr=subprocess.PIPE)
	if usb_devices is None:
		lsusb = subprocess.Popen(("lsusb", "-v"), stdout=subprocess.PIPE,
								 stderr=subprocess.PIPE)
	if pci_devices is None:
		pci_devices = parse_lspci(lspci.communicate()[0])
	for pci_device in pci_devices:
		dev_type = get_device_type_for_class(pci_device["Class"][1])
		if not dev_type:
			continue
		dev = Device(dev_type)
		dev.setBus("PCI")
		dev.setVendorId("%04x" % pci_device["Vendor"][1])
		dev.setProductId("%04x" % pci_device["Device"][1])
		if "SVendor" in pci_device:
			dev.setSubVendorId("%04x" % pci_device["SVendor"][1])
			dev.setSubsystemVendor(pci_device["SVendor"][0])
		if "SDevice" in pci_device:
			dev.setSubProductId("%04x" % pci_device["SDevice"][1])
			dev.setSubsystemName(pci_device["SDevice"][0])
		dev.setModel(pci_device["Device"][0])
		dev.kernel = KERNEL_VERSION
		if "Rev" in pci_device:
			dev.setModel("%s (rev %s)" % (dev.getModel(), pci_device["Rev"]))
		driver = ""
		if "Driver" in pci_device:
			driver = pci_device["Driver"]
		if "Module" in pci_device:
			modules = ", ".join(pci_device["Module"])
			if driver and driver != modules:
				if len(pci_device["Module"]) == 1:
					driver += " (%s module)" % modules
				else:
					driver += " (%s modules)" % modules
			else:
				driver = modules
		# TODO get X.Org driver for graphics cards
		dev.setDriver(driver)
		key = "p_%s:%s" % (dev.getVendorId(), dev.getProductId())
		# TODO it would be more elegant to use ints here
		device_class = "%04x" % pci_device["Class"][1]
		devices[key] = [dev, device_class, "insert", "0"]
	if usb_devices is None:
		usb_devices = parse_lsusb(lsusb.communicate()[0])
	for device in usb_devices:
		classcode = "".join((device["classId"],
							 device["subclassId"],
							 device["protocolId"]))
		dev_type = get_device_type_for_class(int(classcode, 16))
		if not dev_type:
			continue
		dev = Device(dev_type)
		dev.setBus("USB")
		dev.setVendorId(device["vendorId"])
		dev.interface = dev_type.interfaces.index("USB")
		dev.setProductId(device["productId"])
		dev.setModel(device["productName"])
		dev.kernel = KERNEL_VERSION
		# TODO get kernel modules, SANE driver for scanners or CUPS
		# driver for printers
		key = "u_%s:%s" % (dev.getVendorId(), dev.getProductId())
		devices[key] = [dev, classcode, "insert", "0"]
	return devices


# Local Variables:
# indent-tabs-mode: t
# python-guess-indent: nil
# python-indent: 4
# tab-width: 4
# End:
