from socket import socket, AF_INET, SOCK_STREAM
import re
import logging

from parallels.hosting_check import ServiceIssue
from parallels.hosting_check import Severity
from parallels.hosting_check import ServiceIssueType
from parallels.hosting_check.messages import MSG
from parallels.hosting_check import NonZeroExitCodeException

logger = logging.getLogger(__name__)


class ServiceChecker(object):
	def __init__(self):
		self._command_cache = {}
		self._netstat_tool_cache = {}

	def check(self, services_to_check):
		issues = []
		for service_to_check in services_to_check:
			self._check_single_service(service_to_check, issues)
		return issues

	def _check_single_service(self, service_to_check, issues):
		try:
			try:
				service_to_check.runner.connect()
				for process in service_to_check.service.processes:
					started_process = self._check_processes(service_to_check, process, issues)
					if started_process:
						self._check_ports(service_to_check, process, issues)
			finally:
				service_to_check.runner.disconnect()
		except KeyboardInterrupt:
			# for compatibility with python 2.4
			raise
		except Exception, e:
			issues.append(
				ServiceIssue(
					severity=Severity.WARNING,
					category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
					problem=MSG(
						ServiceIssueType.SERVICE_INTERNAL_ERROR,
						description = service_to_check.description,
						type = service_to_check.service.type,
						error_message = str(e)
					)
				)
			)
			return

	def _check_processes(self, service_to_check, process, issues):
		for process_name in process.names:
			try:
				if service_to_check.service.is_windows:
					cmd = 'cmd /c "tasklist"'
				else:
					cmd = "ps ax" 
				result = self._cached_sh(service_to_check.runner, cmd)

				if result.find(process_name) == -1:
					continue
				return True
			except KeyboardInterrupt:
				# for compatibility with python 2.4
				raise
			except NonZeroExitCodeException, e:
				logger.debug(u"Exception:", exc_info=e)
				issues.append(
					ServiceIssue(
						severity=Severity.WARNING,
						category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
						problem=MSG(
							ServiceIssueType.SERVICE_INTERNAL_ERROR,
							description = service_to_check.description,
							type = service_to_check.service.type,
							error_message = str(e)
						)
					)
				)
		issues.append(
			ServiceIssue(
				severity=Severity.ERROR,
				category=ServiceIssueType.SERVICE_NOT_STARTED,
				problem=MSG(
					ServiceIssueType.SERVICE_NOT_STARTED,
					description=service_to_check.description,
					type=service_to_check.service.type
				)
			)
		)
		return False

	def _check_ports(self, service_to_check, process, issues):
		for port in process.ports:
			try:
				if service_to_check.service.is_windows:
					cmd = 'cmd /c "netstat -an"'
				else:
					netstat_tool = self._get_netstat_tool(service_to_check.runner)
					if netstat_tool is not None:
						cmd = "%s -tpln" % netstat_tool
					else:
						# If no netstat-like tool is installed, silently skip the checks
						cmd = None

				if cmd is not None:
					result = self._cached_sh(service_to_check.runner, cmd)
					if re.search(":%s\s" % port, result) is None:
						issues.append(
							ServiceIssue(
								severity=Severity.ERROR,
								category=ServiceIssueType.SERVICE_PORT_IS_CLOSED,
								problem=MSG(
									ServiceIssueType.SERVICE_PORT_IS_CLOSED,
									description=service_to_check.description,
									port=port,
									service=service_to_check.service.type
								)
							)
						)
						continue

				if service_to_check.service.check_connection:
					result = self._check_connection(service_to_check.host, port)

					if not result:
						issues.append(
							ServiceIssue(
								severity=Severity.ERROR,
								category=ServiceIssueType.SERVICE_CONNECTION_ERROR,
								problem=MSG(
									ServiceIssueType.SERVICE_CONNECTION_ERROR,
									description=service_to_check.description,
									port=port,
									service=service_to_check.service.type
								)
							)
						)
			except KeyboardInterrupt:
				# for compatibility with python 2.4
				raise
			except NonZeroExitCodeException, e:
				logger.debug(u"Exception:", exc_info=e)
				issues.append(
					ServiceIssue(
						severity=Severity.WARNING,
						category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
						problem=MSG(
							ServiceIssueType.SERVICE_INTERNAL_ERROR,
							description = service_to_check.description,
							type = service_to_check.service.type,
							error_message = str(e)
						)
					)
				)

	def _get_netstat_tool(self, runner):
		if runner not in self._netstat_tool_cache:
			self._netstat_tool_cache[runner] = None

			for tool in ['netstat', 'ss']:
				if self._tool_exists(runner, tool):
					self._netstat_tool_cache[runner] = tool

		return self._netstat_tool_cache[runner]

	@staticmethod
	def _tool_exists(runner, tool):
		exit_code, _, _ = runner.sh_unchecked('which {tool}', dict(tool=tool))
		return exit_code == 0

	def _cached_sh(self, runner, command):
		if (runner, command) not in self._command_cache:
			self._command_cache[(runner, command)] = runner.sh(command)
		return self._command_cache[(runner, command)]

	def _check_connection(self, host, port):
		s = socket(AF_INET, SOCK_STREAM)
		result = s.connect_ex((host, port))
		s.close()
		return result == 0
