#!/usr/bin/env python

from psycopg2 import connect as pg_connect
from pyathena import connect as athena_connect
import os
import sys

athena_schema = os.getenv('ATHENA_SCHEMA')
work_group = os.getenv('WORK_GROUP')
aws_access_key_id = os.getenv('ACCESS_KEY')
aws_secret_access_key = os.getenv('SECRET_KEY')
database_name = os.getenv('DATABASE_NAME')

pg_schema = 'public'


if not athena_schema:
    print ("Schema name not found. Please refer to your connection sheet.")
    sys.exit()

if not work_group:
    print ("Work Group not found.  Please refer to your connection sheet.")
    sys.exit()

if not aws_access_key_id:
    print ("Access key not found.  Please refer to your connection sheet.")
    sys.exit()

if not aws_secret_access_key:
    print ("Secret key not found.  Please refer to your connection sheet.")
    sys.exit()


def get_new_cursor():
    return athena_connect(
        region_name='us-west-2',
        schema_name=athena_schema,
        work_group=work_group,
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key
    ).cursor()


schema_cursor = get_new_cursor()
schema_cursor.execute(
    f"SELECT table_name FROM information_schema.tables WHERE table_schema = %(schema)s",
    {'schema': athena_schema}
)

table_cursor = get_new_cursor()

conn = pg_connect(
    host='localhost',
    user='postgres', # This is a default value. It may be your username.
    database=database_name,
    # password='<password>', # By default, the password is empty. You can choose to enforce a password.
    # port='5432', # This is a default value and can be omitted
)

with conn:
    pg_cursor = conn.cursor()

    for schema_row in schema_cursor:
        table_name = schema_row[0]

        table_cursor.execute(f"SELECT * FROM {table_name}")
        columns_ddl = ',\n'.join([f'  {col[0]} {col[1]}' for col in table_cursor.description])

        drop_table = f"DROP TABLE IF EXISTS {pg_schema}.{table_name};"
        print(f'Dropping {pg_schema}.{table_name}...')
        pg_cursor.execute(drop_table)

        create_table = f"CREATE TABLE {pg_schema}.{table_name} (\n" \
                       f"{columns_ddl}\n" \
                       f");"
        print(f'Creating {pg_schema}.{table_name}...')
        pg_cursor.execute(create_table)

        column_names = ','.join([col[0] for col in table_cursor.description])
        place_holders = ','.join(['%s' for col in table_cursor.description])
        insert_rows = f"insert into {pg_schema}.{table_name} ({column_names}) \n" \
                      f"values ({place_holders})"
        print(f'Dumping data into {pg_schema}.{table_name}...')
        pg_cursor.executemany(insert_rows, table_cursor)

        print()

print("Done importing data!")
