import geopandas as gpd
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed

import geopandas as gpd
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
import os


# def prepare_tasks(under, over, task_que):
#     task_list = []
#     # We want to ensure minimal data is being moved around, so we prepare data slices here
#     # that will be passed directly to each worker process
#     for under_index, over_indices in task_que.items():
#         # Extract only the necessary rows from 'under' and 'over' and create lightweight copies
#         under_slice = under.iloc[[under_index]].copy()  # Copy to ensure continuity outside the loop
#         over_slice = over.iloc[over_indices].copy()  # Copy to ensure continuity outside the loop
#         task_list.append({
#             'under_data': under_slice,
#             'over_data': over_slice
#         })
#     return task_list
#
#
# def process_intersection(task, under_data, over_data, how="intersection"):
#     # Extract the relevant portions of the dataframes
#     ref_poly = under_data.iloc[[task['under_index']]]
#     comp_poly = over_data.iloc[task['over_indices']]
#     # Perform the overlay operation
#     intersection = ref_poly.overlay(comp_poly, how=how, keep_geom_type=False)
#     return intersection
#
#
# def geom_overlay(under, over, how="intersection"):
#     if under.crs != over.crs:
#         print(f"Error: CRS mismatch \nUnder CRS: {under.crs}, \nOver CRS: {over.crs}")
#         return None
#     print(f"under : {len(under)}, over : {len(over)}")
#     if len(under) > len(over):
#         under, over = over, under
#
#     print("creating spatial index queries...")
#     sindex_query_result = under.sindex.query(over["geometry"], predicate="intersects")
#     task_que = {}
#     for key, value in zip(sindex_query_result[1], sindex_query_result[0]):
#         if key in task_que:
#             task_que[key].append(value)
#         else:
#             task_que[key] = [value]
#
#     print("preparing the task...")
#     print(len(task_que))
#     task_list = prepare_tasks(under, over, task_que)
#
#     print("executing...")
#     with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
#         future_to_task = {executor.submit(process_intersection, task, how): task for task in task_list}
#         total_tasks = len(future_to_task)
#         completed_tasks = 0
#
#         for future in as_completed(future_to_task):
#             completed_tasks += 1
#             print(f"Completed {completed_tasks} of {total_tasks} tasks")
#             result = future.result()  # Collect the result from the future
#
#     resulting_intersection_geom = pd.concat([future.result() for future in future_to_task])
#     return resulting_intersection_geom


def geom_overlay(under, over, how="intersection"):
    under_crs = over.crs
    over_crs = under.crs

    # switch values for optimization
    if len(under) > len (over):
        under, over = over, under

    if over_crs != under_crs:
        print(f"error : crs_mismatch \n under : {under_crs},\n over: {over_crs}")
        exit()

    print("creating_spatial index...")
    sindex_query_result = under.sindex.query(over["geometry"], predicate="intersects")

    # key is the reference(under) polygon index, and value is comparative(over) polygon index
    # I recommend to use reference layer with fewer polygons then comparative one. (saves search time)
    task_que = {}
    for key, value in zip(sindex_query_result[1], sindex_query_result[0]):
        if key in task_que:
            task_que[key].append(value)
        else:
            task_que[key] = [value]

    intersections = []
    print("executing...")
    for i, key in enumerate(task_que):
        # double brackets for... not passing row as a column.
        # how intuitive (clap. clap. clap.)
        ref_poly = under.iloc[[key]]
        ref_poly = gpd.GeoDataFrame(ref_poly, crs=over_crs)
        comp_poly = over.iloc[task_que[key]]
        intersection = ref_poly.overlay(comp_poly, how=how, keep_geom_type=False)
        intersections.append(intersection)
        print(f"{i}/{len(task_que)}")

    if len(intersections) > 0:
        resulting_intersection_geom = pd.concat(intersections)
    else :
        resulting_intersection_geom = []
    return resulting_intersection_geom