#!/usr/bin/env python3

import time
import send_lib
import sys
import re
import requests
import json

# constants
SERVICE_NAME = "o365_groups_mu"
DEBUG = 0
TIMEOUT = 60  # 1 min
TENANT_REGEX_FORMAT = re.compile(r'[\w-]+')
RETRY_AFTER_DEFAULT = 5  # 5 sec
MAX_ELEMENTS_IN_QUERY = 15  # how many elements are permitted by the API for filtering (combines number of elements
ADD_MEMBERS_CHUNK_SIZE = 20  # max size of chunk of members to be added
# in 'in' clause and also 'startsWith' clause).
BASE_URL = 'https://graph.microsoft.com/v1.0'
PERUN_MANAGED = '[synced from Perun] '
auth_endpoint = "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token"

client_id = None
client_secret = None
users_mapping = {}  # dict {userPrincipalName: objectId}
access_token = None

#  statistics
groups_not_found = set()
users_not_found = set()
groups_updated = 0
members_added = 0
members_removed = 0


def print_debug(message):
	if DEBUG:
		print(message)


def print_stats():
	message = "Propagation finished.\n"
	message += f"Updated memberships in {groups_updated} group(s).\n"
	message += f"Added {members_added} members in total.\n"
	message += f"Removed {members_removed} members in total.\n"
	message += f"Groups not found: {groups_not_found}.\n" if len(groups_not_found) > 0 else ""
	message += f"Users not found: {users_not_found}." if len(users_not_found) > 0 else ""
	print(message)


def request_access_token():
	post_data = {
		'client_id': client_id,
		'client_secret': client_secret,
		'scope': 'https://graph.microsoft.com/.default',
		'grant_type': 'client_credentials'
	}
	response = requests.post(auth_endpoint, post_data)
	print_debug(f'Response from auth endpoint ended with status code {response.status_code}')
	if response.ok:
		response_data = response.json()
		global access_token
		access_token = response_data['access_token']
	else:
		send_lib.die_with_error("Communication with authorization endpoint failed.")


def call_api(url: str, params=None, type='GET', retry_no=0):
	"""
	Calls the API. Handles slowing requests down if limits are hit and renews access token if needed.
	@return: data if present,
			None if response was OK and returned no data, or if response suggests object is already in the desired state,
			False if request for adding members failed and request needs to be resent with fixed dataset.
	"""
	if retry_no > 3:
		send_lib.die_with_error("Exceeded max number of retries.")
	if type == 'GET':
		response = requests.get(url=url, params=params, headers={"Authorization": "Bearer " + access_token},
								timeout=TIMEOUT)
	elif type == 'DELETE':
		response = requests.delete(url=url, json=params, headers={"Authorization": "Bearer " + access_token},
								   timeout=TIMEOUT)
	elif type == 'PATCH':
		response = requests.patch(url=url, json=params, headers={"Authorization": "Bearer " + access_token},
								  timeout=TIMEOUT)
	elif type == 'POST':
		response = requests.post(url=url, json=params, headers={"Authorization": "Bearer " + access_token},
								 timeout=TIMEOUT)
	else:
		send_lib.die_with_error(f"Invalid request type {type}")
	if response.ok:
		try:
			response = response.json()
		except json.decoder.JSONDecodeError:
			return None  # some requests don't return anything
		data = response["value"] if "value" in response else response
		if '@odata.nextLink' in response:
			data.extend(call_api(response['@odata.nextLink'], None))
		return data
	else:
		if response.status_code == 401:  # access token expired
			print_debug("Renewing access token.")
			request_access_token()
			call_api(url, params, type, retry_no + 1)
		elif response.status_code == 429:  # too many requests
			retry_after = response.headers.get('Retry-After', RETRY_AFTER_DEFAULT)
			print_debug(f'"Too many requests" in response, waiting {retry_after} sec.')
			time.sleep(retry_after)
			call_api(url, params, type, retry_no + 1)
		elif response.status_code == 400:  # might be non-existing object
			response = response.json()
			if "error" in response and "message" in response["error"] and response["error"]["message"].startswith(
				"Invalid object"):
				print_debug(f'Object not found - {params}.')
				return None
			elif "error" in response and "message" in response["error"] and "already exist" in response["error"]["message"]:
				print_debug(f'Some member of the chunk is already member of the group.')
				return False
			elif "error" in response and "message" in response["error"] and \
				"Invalid URL format specified in @odata.bind for members" in response["error"]["message"]:
				print_debug(f'Invalid object id in the request (might have been removed).')
				return False
			else:
				send_lib.die_with_error(f"Request ended with error 400, response: {response.content}")
		else:
			send_lib.die_with_error(f"Request ended with error, response: {response.content}")


def prepare_user_ids(our_group, their_members):
	"""
	Saves fetched member's names and ids to users mapping. Fetches missing user's ids.
	If user does not exist, it will be silently skipped and mapping won't exist for him.
	"""
	print_debug(f'Preparing user ids for group {our_group["groupId"]}')
	# update mapping
	for their_member in their_members:
		users_mapping[their_member["userPrincipalName"]] = their_member["id"]
	# retrieve missing members
	missing_members = []
	for member in our_group["members"]:
		if member not in users_mapping:
			missing_members.append(member)
	# prepare chunks to stay within API request limits and fetch missing users
	url = f'{BASE_URL}/users'
	for i in range(0, len(missing_members), MAX_ELEMENTS_IN_QUERY):
		print_debug(f'Fetching user IDs for logins bulk #{i // MAX_ELEMENTS_IN_QUERY + 1}')
		chunk = missing_members[i:i + MAX_ELEMENTS_IN_QUERY]
		chunk = list(map(lambda c: f"'{c}'", chunk))
		names = ','.join(chunk)
		params = {
			"$filter": f"userPrincipalName in ({names})",
			"$select": "userPrincipalName,id"
		}
		users = call_api(url, params)
		for user in users:
			users_mapping[user["userPrincipalName"]] = user["id"]


def fetch_their_members(our_group):
	"""
	Fetches members of group in O365, retrieves their userPrincipalNames and internal ids.
	"""
	group_id = our_group["groupId"]
	url = f'{BASE_URL}/groups/{group_id}/members'
	params = {'$select': 'userPrincipalName,id'}
	return call_api(url, params)


def check_group_description(our_group):
	"""
	Checks, if group has description suggesting its members are managed by Perun.
	Adds prefix and URL to resource if this info is missing.
	"""
	group_id = our_group["groupId"]
	url = f'{BASE_URL}/groups/{group_id}'
	response = call_api(url)
	their_description = response["description"]
	if not their_description.startswith(PERUN_MANAGED):
		new_description = PERUN_MANAGED + their_description + " | Manage in Perun: " + our_group["perunResourceUrl"]
	elif their_description.startswith(PERUN_MANAGED) and not their_description.endswith(our_group["perunResourceUrl"]):
		# Resource ID changed in Perun, update description
		new_description = their_description.split(" | Manage in Perun: ")
		new_description = new_description[0] + " | Manage in Perun: " + our_group["perunResourceUrl"]
	else:
		return
	params = {"description": new_description}
	call_api(url, params, 'PATCH')
	print_debug(f"Description updated for group {group_id}")


def remove_their_members(our_group, their_members, scope):
	"""
	Removes members that are not present in Perun.
	Also, based on service's property skips or removes external accounts (not ending with '@scope').
	"""
	our_members = our_group["members"]
	our_missing_members = list(filter(lambda mmbr: mmbr not in our_members, their_members))
	keep_external_accounts = our_group["keepExternalAccounts"]
	if keep_external_accounts:
		our_missing_members = list(filter(lambda mmbr: mmbr.endswith(scope), our_missing_members))
	for missing_member in our_missing_members:
		print_debug(f'Removing user {missing_member} from group {our_group["groupId"]}')
		url = f'{BASE_URL}/groups/{our_group["groupId"]}/members/{users_mapping[missing_member]}/$ref'
		call_api(url, None, 'DELETE')
		global members_removed
		members_removed += 1  # also counts users removed by someone else during propagation!


def add_our_members(our_group, their_members):
	"""
	Adds missing members to group in O365,
	silently skips users whose mapping does not exist (were not found) and adds them to statistics
	"""
	their_missing_members = list(filter(lambda mmbr: mmbr not in their_members, our_group["members"]))
	names = []
	for missing_member in their_missing_members:
		uid = users_mapping.get(missing_member)
		if uid is not None:
			names.append(f'{BASE_URL}/directoryObjects/{uid}')
		else:
			users_not_found.add(missing_member)

	print_debug(f'{len(names)} members to be added to group {our_group["groupId"]}') if len(names) > 0 else None
	for i in range(0, len(names), ADD_MEMBERS_CHUNK_SIZE):
		print_debug(f'Adding members of chunk #{i // ADD_MEMBERS_CHUNK_SIZE + 1}')
		chunk = names[i:i + ADD_MEMBERS_CHUNK_SIZE]
		call_add_members(chunk, our_group["groupId"])


def call_add_members(chunk, group_id):
	url = f"{BASE_URL}/groups/{group_id}"
	if len(chunk) == 0:
		return

	params = {
		"members@odata.bind": chunk
	}
	response = call_api(url, params, 'PATCH')

	if response is False:
		if len(chunk) == 1:
			print_debug(f'User {chunk[0]} is already member of group {group_id}')
			return
		half = len(chunk) // 2
		call_add_members(chunk[:half], group_id)
		call_add_members(chunk[half:], group_id)
	else:
		global members_added
		members_added += len(chunk)


if __name__ == "__main__":
	# check input format
	send_lib.check_input_fields(sys.argv, True)
	facility = sys.argv[1]
	destination = sys.argv[2]
	destination_type = sys.argv[3]
	send_lib.check_destination_type_allowed(destination_type, send_lib.DESTINATION_TYPE_SERVICE_SPECIFIC)
	send_lib.check_destination_format(destination, destination_type, TENANT_REGEX_FORMAT)

	# read configuration, obtain access token
	auth_endpoint = auth_endpoint.replace("{tenant}", destination)
	config_properties = send_lib.get_custom_config_properties(SERVICE_NAME, destination, ["clientId", "clientSecret"])
	if not config_properties:
		send_lib.die_with_error("Configuration not found or properties missing!")
	client_id = config_properties[0]
	client_secret = config_properties[1]
	request_access_token()

	# process data for each group
	gen_folder = send_lib.get_gen_folder(facility, SERVICE_NAME)
	with open(f'{gen_folder}/{SERVICE_NAME}', 'rb') as f:
		our_complete_data = json.load(f)
		scope = our_complete_data["scope"]
		our_complete_groups = our_complete_data["groups"]
		for group_ours in our_complete_groups:
			their_group_members = fetch_their_members(group_ours)
			if their_group_members is None:
				groups_not_found.add(group_ours["groupId"])
				continue

			check_group_description(group_ours)
			prepare_user_ids(group_ours, their_group_members)

			their_group_members = list(map(lambda mmbr: mmbr['userPrincipalName'], their_group_members))
			removed_members_before = members_removed
			remove_their_members(group_ours, their_group_members, scope)
			added_members_before = members_added
			add_our_members(group_ours, their_group_members)
			if members_added > added_members_before or members_removed > removed_members_before:
				groups_updated += 1
	print_stats()
