#!/usr/bin/env python3
"""
oauth2_imap 1.0 - OAuth2 token acquisition and IMAP access verification.
Copyright (C) 2026. Steven Vertigan

This program comes with ABSOLUTELY NO WARRANTY.
This is free software, and you are welcome to redistribute it
under certain conditions (GPLv3).

https://opensource.org/license/gpl-3.0

Originally based on oauth2_imap (v1.40) by Gilles Lamiral.
Original: https://imapsync.lamiral.info/

Library usage
-------------
    import oauth2_imap

    result = oauth2_imap.run(
        user="foo@example.com",
        application="thunderbird",   # or "imapsync"
        provider="gmail",            # or "office365"; auto-detected if omitted
        client_id="...",
        client_secret="...",
        # any other keyword from the CLI options list below
    )
    # result is 0 on success, 1 on failure, None if setup failed.

CLI usage
---------
    python oauth2_imap.py [options] user@example.com

Options:
    --tests              Run all tests
    --testsone           Run a single targeted test
    --debug              Enable debug output
    --startover          Ignore existing token file and start fresh
    --provider           OAuth2 provider (gmail or office365)
    --authorize_uri      Override the authorization URI
    --token_uri          Override the token URI
    --redirect_uri       Override the redirect URI
    --scope_string       Override the OAuth2 scope string
    --application        Application name (thunderbird or imapsync); default: thunderbird
    --client_id          OAuth2 client ID
    --client_secret      OAuth2 client secret
    --token_file         Path to token file
    --local              Force plain HTTP localhost redirect
    --localssl           Force HTTPS localhost redirect
    --imap_server        IMAP server hostname
    --remotebrowser      Collect code from stdin instead of localhost
"""

import argparse
import base64
import hashlib
import json
import os
import random
import re
import socket
import ssl
import string
import sys
import threading
import time
#import unittest
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from urllib.parse import parse_qs, urlencode, urlparse, quote

import dns.resolver
import imaplib
import requests
import webbrowser


# ---------------------------------------------------------------------------
# OAuth2 provider configuration factories
# ---------------------------------------------------------------------------

def oauth2_office365():
    return {
        "authorize_uri": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
        "token_uri":     "https://login.microsoftonline.com/common/oauth2/v2.0/token",
        "redirect_uri":  "https://localhost",
        "scope_string":  "offline_access https://outlook.office.com/IMAP.AccessAsUser.All",
        "imap_server":   "outlook.office365.com",
    }


def oauth2_office365_imapsync():
    return {
        "authorize_uri": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
        "token_uri":     "https://login.microsoftonline.com/common/oauth2/v2.0/token",
        "redirect_uri":  "https://imapsync.lamiral.info/cgi-bin/auth",
        "scope_string":  "offline_access https://outlook.office.com/IMAP.AccessAsUser.All",
        "client_id":     "c46947ca-867f-48b7-9231-64213fdd765e",
        "client_secret": "cdX8Q~jy-ynhroJTduZJNM4WulTngWeYcCIIgdkq",
        "imap_server":   "outlook.office365.com",
    }


def oauth2_office365_thunderbird():
    return {
        "authorize_uri": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
        "token_uri":     "https://login.microsoftonline.com/common/oauth2/v2.0/token",
        "scope_string":  "offline_access https://outlook.office.com/IMAP.AccessAsUser.All",
        "redirect_uri":  "https://localhost",
        "client_id":     "9e5f94bc-e8a4-4e73-b8be-63364c29d753",
        "client_secret": "",
        "imap_server":   "outlook.office365.com",
    }


def oauth2_gmail():
    return {
        "authorize_uri": "https://accounts.google.com/o/oauth2/auth",
        "token_uri":     "https://accounts.google.com/o/oauth2/token",
        "redirect_uri":  "http://localhost",
        "imap_server":   "imap.gmail.com",
    }


def oauth2_gmail_imapsync():
    return {
        "authorize_uri": "https://accounts.google.com/o/oauth2/auth",
        "token_uri":     "https://accounts.google.com/o/oauth2/token",
        "redirect_uri":  "http://localhost",
        "scope_string":  "https://mail.google.com/",
        "client_id":     "108687549524-s5ijqmadnmi4qfgfgicuquftv8f8a3da.apps.googleusercontent.com",
        "client_secret": "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",
        "imap_server":   "imap.gmail.com",
    }


def oauth2_gmail_thunderbird():
    return {
        "authorize_uri": "https://accounts.google.com/o/oauth2/auth",
        "token_uri":     "https://accounts.google.com/o/oauth2/token",
        "scope_string":  "https://mail.google.com/",
        "redirect_uri":  "https://localhost",
        "client_id":     "406964657835-aq8lmia8j95dhl1a2bvharmfk3t1hgqj.apps.googleusercontent.com",
        "client_secret": "kSmqreRr0qwBWJgbf5Y-PjSU",
        "imap_server":   "imap.gmail.com",
    }


# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------

def random_string(length=64):
    """Return a random lowercase alphabetic string of the given length."""
    return "".join(random.choices(string.ascii_lowercase, k=length))


def oauth2_code_challenge(code_verifier):
    """BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) — PKCE S256."""
    digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
    return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")


def merge_hashes(dest, *sources):
    """
    Merge one or more source dicts into dest (in place).
    Keys with None values in a source are skipped (not merged).
    Returns dest.
    """
    if dest is None:
        return None
    for src in sources:
        if src is None:
            continue
        for key, value in src.items():
            if value is not None:
                dest[key] = value
    return dest


def string_to_file(text, filepath):
    """Write text to filepath, creating directories as needed. Returns text on success."""
    path = Path(filepath)
    try:
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(text, encoding="utf-8")
        os.chmod(filepath, 0o600)
        return text
    except OSError as e:
        print(f"string_to_file: failure writing to {filepath}: {e}")
        return None


def file_to_lines(filepath):
    """Return list of lines from file, or empty list on failure."""
    path = Path(filepath)
    if not path.exists() or not path.is_file():
        return []
    try:
        return path.read_text(encoding="utf-8").splitlines(keepends=True)
    except OSError as e:
        print(f"Error reading file {filepath}: {e}")
        return []


def nthline(filepath, n):
    """Return the nth line (1-based) of a file, stripped of newline. Empty string on failure."""
    lines = file_to_lines(filepath)
    if n < 1 or n > len(lines):
        return ""
    return lines[n - 1].rstrip("\n\r")


# ---------------------------------------------------------------------------
# Provider / domain detection
# ---------------------------------------------------------------------------

def find_domain_from_email(email):
    """Extract the domain part of an email address. Returns None if not valid."""
    if not email:
        return None
    match = re.match(r"[^@]+@([^@]+)", email)
    return match.group(1) if match else None


def find_provider_from_domain(domain):
    """
    Look up MX records for domain and detect gmail / office365.
    Returns 'gmail', 'office365', or None.
    """
    if not domain:
        return None
    try:
        answers = dns.resolver.resolve(domain, "MX")
        if not answers:
            print(f"Can not find MX records for {domain}")
            return None
        mx_exchange = str(sorted(answers, key=lambda r: r.preference)[0].exchange).lower()
        if "outlook.com" in mx_exchange:
            return "office365"
        if "google.com" in mx_exchange:
            return "gmail"
    except Exception as e:
        print(f"Can not find MX records for {domain}: {e}")
    return None


def find_provider_from_email(email):
    """Detect OAuth2 provider from an email address via MX lookup."""
    return find_provider_from_domain(find_domain_from_email(email))


# ---------------------------------------------------------------------------
# OAuth2 parameter initialisation
# ---------------------------------------------------------------------------

def oauth2_parameters(oauth2):
    """Populate oauth2 dict with provider/application defaults."""
    app      = oauth2.get("application")
    provider = oauth2.get("provider")

    if not app or not provider:
        return

    mapping = {
        ("thunderbird", "office365"): oauth2_office365_thunderbird,
        ("thunderbird", "gmail"):     oauth2_gmail_thunderbird,
        ("imapsync",    "office365"): oauth2_office365_imapsync,
        ("imapsync",    "gmail"):     oauth2_gmail_imapsync,
    }

    factory = mapping.get((app, provider))
    if factory:
        merge_hashes(oauth2, factory())
    elif provider == "gmail":
        merge_hashes(oauth2, oauth2_gmail())
    elif provider == "office365":
        merge_hashes(oauth2, oauth2_office365())


def set_provider(oauth2):
    """
    Ensure oauth2['provider'] is set, either from the existing value or
    by guessing from the email MX record.
    Returns the provider string, or None on failure.
    """
    if oauth2 is None:
        return None
    if oauth2.get("user") is None:
        return None

    if oauth2.get("provider"):
        print(f"The provider used comes from --provider {oauth2['provider']}")
    else:
        provider = find_provider_from_email(oauth2["user"])
        if not provider:
            print(
                f"Could not guess the oauth2 provider from the address {oauth2['user']}\n"
                "Specify it with --provider office365 or --provider gmail\n"
                "or bring all the details with the full set of options."
            )
            return None
        print(f"Found provider {provider} from address {oauth2['user']}")
        oauth2["provider"] = provider

    return oauth2["provider"]


def oauth2_checklist(oauth2):
    """Verify all required OAuth2 parameters are present. Returns True if OK."""
    if oauth2 is None:
        return False
    required = ["authorize_uri", "token_uri", "scope_string",
                "client_id", "client_secret", "redirect_uri"]
    for param in required:
        if oauth2.get(param) is None:
            print(f"Abort. I need an oauth2 parameter called {param}")
            return False
    return True


def set_local_from_redirect_uri(oauth2):
    if oauth2 is None:
        return
    if "local" in oauth2:
        return
    oauth2["local"] = (oauth2.get("redirect_uri") == "http://localhost")


def set_localssl_from_redirect_uri(oauth2):
    if oauth2 is None:
        return
    if "localssl" in oauth2:
        return
    oauth2["localssl"] = (oauth2.get("redirect_uri") == "https://localhost")


def oauth2_init(oauth2):
    """Initialise PKCE parameters, state, and token file path."""
    user = oauth2.get("user")
    if not user:
        return None

    token_file        = oauth2.get("token_file") or f"tokens/oauth2_tokens_{user}.txt"
    code_verifier     = random_string(128)
    code_challenge    = oauth2_code_challenge(code_verifier)
    state             = random_string(32)

    set_local_from_redirect_uri(oauth2)
    set_localssl_from_redirect_uri(oauth2)

    merge_hashes(oauth2, {
        "user":                   user,
        "token_file":             token_file,
        "code_verifier":          code_verifier,
        "code_challenge":         code_challenge,
        "code_challenge_method":  "S256",
        "state":                  state,
    })
    return oauth2


# ---------------------------------------------------------------------------
# Authorization URI construction
# ---------------------------------------------------------------------------

def oauth2_authorization_code_uri(oauth2):
    """Build and return the full authorization URL."""
    params = {
        "client_id":             oauth2["client_id"],
        "tenant":                "common",
        "scope":                 oauth2["scope_string"],
        "login_hint":            oauth2["user"],
        "response_type":         "code",
        "redirect_uri":          oauth2["redirect_uri"],
        "code_challenge":        oauth2["code_challenge"],
        "code_challenge_method": oauth2["code_challenge_method"],
        "state":                 oauth2["state"],
    }
    uri = oauth2["authorize_uri"] + "?" + urlencode(params)
    oauth2["authorization_code_uri"] = uri
    return uri


# ---------------------------------------------------------------------------
# Browser launch
# ---------------------------------------------------------------------------

def launch_browser_on_url(url):
    """Attempt to open url in the default browser."""
    print(f"Opening URL in default browser: {url}")
    try:
        webbrowser.open(url)
    except Exception as e:
        print(f"Could not open browser: {e}")


def oauth2_invite_with_browser(oauth2):
    uri = oauth2_authorization_code_uri(oauth2)
    print(f"Go to the following link with your web browser:\n\n{uri}\n")
    if not oauth2.get("remotebrowser"):
        launch_browser_on_url(uri)


# ---------------------------------------------------------------------------
# Local HTTP server to capture the authorization code
# ---------------------------------------------------------------------------

class _CodeCapture:
    """Shared state between the HTTP handler and the caller."""
    code  = None
    state = None


def oauth2_lunch_httpd_localhost(oauth2):
    """
    Start a plain HTTP server on a random localhost port.
    Returns the HTTPServer instance, or None on failure.
    """
    server = HTTPServer(("localhost", 0), _make_handler(oauth2, _CodeCapture))
    server.timeout = 120
    port = server.server_address[1]
    url  = f"http://localhost:{port}"
    print(f"Now waiting for the code, 120 sec max, at {url}")
    oauth2["httpd"]        = server
    oauth2["redirect_uri"] = url
    return server


def oauth2_lunch_httpd_localhost_ssl(oauth2):
    """
    Start an HTTPS server on a random localhost port using localhost.crt / localhost.key.
    Returns the HTTPServer instance, or None on failure.
    """
    if oauth2 is None:
        return None
    try:
        server = HTTPServer(("localhost", 0), _make_handler(oauth2, _CodeCapture))
        server.timeout = 120
        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        ctx.load_cert_chain(certfile="localhost.crt", keyfile="localhost.key")
        server.socket = ctx.wrap_socket(server.socket, server_side=True)
        port = server.server_address[1]
        url  = f"https://localhost:{port}"
        print(f"Now waiting for the code, 120 sec max, at {url}")
        oauth2["httpd"]        = server
        oauth2["redirect_uri"] = url
        return server
    except Exception as e:
        print(f"Failed to start SSL httpd: {e}")
        return None


def _make_handler(oauth2, capture):
    """Factory that returns a BaseHTTPRequestHandler subclass with shared state."""

    class _Handler(BaseHTTPRequestHandler):
        def do_GET(self):
            parsed = urlparse(self.path)
            params = parse_qs(parsed.query)
            code  = params.get("code",  [None])[0]
            state = params.get("state", [None])[0]
            if code:
                capture.code  = code
                capture.state = state
                oauth2["code"]       = code
                oauth2["state_back"] = state
                message = oauth2_collect_code_answer(code).encode("utf-8")
                self.send_response(200)
                self.send_header("Content-Type", "text/plain")
                self.end_headers()
                self.wfile.write(message)
            else:
                self.send_response(403)
                self.end_headers()
                self.wfile.write(b"No code received.")

        def log_message(self, fmt, *args):
            pass  # suppress default access log noise

    return _Handler


def oauth2_collect_code_answer(code):
    return (
        "The authentication is ok, now you can go back where you started\n\n"
        "If it does not work, here is the code to copy and paste:\n\n"
        f"{code}\n\n"
        "Have a nice day!\n\n"
    )


def oauth2_collect_code_localhost(oauth2):
    """
    Wait up to 120 seconds for a single request to the local httpd and
    extract the code query parameter.  Returns the code string or ''.
    """
    server = oauth2.get("httpd")
    if not server:
        return ""
    capture = _CodeCapture
    capture.code = None

    # Replace the handler so it uses a fresh capture instance
    server.RequestHandlerClass = _make_handler(oauth2, capture)
    server.handle_request()  # blocks up to server.timeout seconds

    return capture.code or ""


def oauth2_collect_code_localhost_insist(oauth2):
    """Try up to 4 times to collect the code from localhost."""
    for attempt in ("first", "second", "third", "fourth"):
        code = oauth2_collect_code_localhost(oauth2)
        if code:
            print(f"\n\nSuccess collecting the code the {attempt} time")
            oauth2["code"] = code
            return code
    print("\n\nFailed to collect the code on localhost.")
    return ""


def oauth2_collect_code_from_stdin(oauth2):
    """Ask the user to paste the code from the browser URL."""
    print("\n\nYou can find the code in the URL at the end of the authentication process.")
    code = input("\n\nPaste the code here and press ENTER: ").strip()
    oauth2["code"] = code
    return code


def oauth2_collect_code_remote(oauth2):
    """Fetch the code from the imapsync remote callback endpoint."""
    code_url = f"https://imapsync.lamiral.info/imapsync_auth/{oauth2['state']}"
    try:
        resp = requests.get(code_url, timeout=10, verify=False)
        if resp.ok:
            code = resp.text.strip()
            print(f"code from {code_url} = {code}")
            return code
        else:
            print(f"No code from {code_url}: {resp.status_code} {resp.reason}")
    except Exception as e:
        print(f"No code from {code_url}: {e}")
    return None


# ---------------------------------------------------------------------------
# Ask for authorization code (top-level dispatcher)
# ---------------------------------------------------------------------------

def oauth2_ask_authorization_code_localhost(oauth2):
    oauth2_invite_with_browser(oauth2)

    if oauth2.get("remotebrowser"):
        print("\n\nLet's collect the code from the terminal")
        return oauth2_collect_code_from_stdin(oauth2)
    else:
        print("\n\nNow I try to collect the code. Will give up in 120 seconds (maybe 4x).")
        code = oauth2_collect_code_localhost_insist(oauth2)
        if code:
            print("\nGOOD!")
            return code
        else:
            print("\n\nFailed to collect the code on localhost.")
            return oauth2_collect_code_from_stdin(oauth2)


def oauth2_ask_authorization_code_remote(oauth2):
    oauth2_invite_with_browser(oauth2)
    input("\n\nThen, after the authentication is finished, press ENTER: ")
    print("\n\nThanks. Now I try to collect the code.")
    code = oauth2_collect_code_remote(oauth2)
    if code:
        print("\n\nSuccess collecting the code")
        oauth2["code"] = code
        return code
    else:
        print("\n\nFailed to collect the code remotely")
        return oauth2_collect_code_from_stdin(oauth2)


def oauth2_ask_authorization_code(oauth2):
    """
    Launch the appropriate local httpd (or remote flow) to capture the
    authorization code.

    Priority:
      1. --local  → plain HTTP localhost httpd
      2. --localssl → HTTPS localhost httpd, with automatic fallback to
                      plain HTTP if the SSL cert files are missing
      3. otherwise → remote imapsync callback, then stdin

    The fallback from SSL to plain HTTP is safe: Microsoft's identity
    platform accepts any localhost redirect URI regardless of scheme or
    port (RFC 8252 §8.3), so a plain-HTTP listener on a dynamic port
    works fine even when the registered redirect_uri is https://localhost.
    """
    if oauth2.get("local"):
        server = oauth2_lunch_httpd_localhost(oauth2)
        if server:
            print("Launched plain HTTP httpd")
        else:
            print("Failed to launch localhost httpd")
            return
        return oauth2_ask_authorization_code_localhost(oauth2)

    elif oauth2.get("localssl"):
        server = oauth2_lunch_httpd_localhost_ssl(oauth2)
        if server:
            print("Launched SSL httpd")
            return oauth2_ask_authorization_code_localhost(oauth2)
        else:
            # SSL cert files are missing — fall back to plain HTTP.
            # oauth2_lunch_httpd_localhost will update oauth2["redirect_uri"]
            # to the actual http://localhost:PORT it binds to.
            print("Falling back to plain HTTP localhost httpd")
            server = oauth2_lunch_httpd_localhost(oauth2)
            if server:
                print("Launched plain HTTP httpd (fallback)")
                return oauth2_ask_authorization_code_localhost(oauth2)
            else:
                print("Failed to launch localhost httpd")
                return

    else:
        return oauth2_ask_authorization_code_remote(oauth2)


# ---------------------------------------------------------------------------
# Token request / refresh
# ---------------------------------------------------------------------------

def oauth2_request_tokens(oauth2):
    """Exchange the authorization code for access + refresh tokens."""
    print("Exchanging the code for an access token and a refresh token...")
    try:
        resp = requests.post(
            oauth2["token_uri"],
            data={
                "code":          oauth2["code"],
                "client_id":     oauth2["client_id"],
                "client_secret": oauth2["client_secret"],
                "redirect_uri":  oauth2["redirect_uri"],
                "grant_type":    "authorization_code",
                "code_verifier": oauth2["code_verifier"],
            },
            verify=False,
            timeout=11,
        )
        print(f"token_response: {resp.text}\n")
        data = resp.json()
    except Exception as e:
        print(f"Token request failed: {e}")
        return

    access_token  = data.get("access_token")
    refresh_token = data.get("refresh_token")

    if access_token:
        print(f"access token:\n{access_token}\n")
        oauth2["access_token"] = access_token
    if refresh_token:
        print(f"refresh token:\n{refresh_token}\n")
        oauth2["refresh_token"] = refresh_token
    else:
        print("No refresh token proposed\n")


def oauth2_refresh_tokens(oauth2):
    """Use the refresh token to obtain a new access (and possibly refresh) token."""
    print("Refreshing the access and refresh tokens")
    try:
        resp = requests.post(
            oauth2["token_uri"],
            data={
                "client_id":     oauth2["client_id"],
                "client_secret": oauth2["client_secret"],
                "refresh_token": oauth2["refresh_token"],
                "grant_type":    "refresh_token",
            },
            verify=False,
            timeout=11,
        )
        data = resp.json()
    except Exception as e:
        print(f"Token refresh failed: {e}")
        return None

    access_token  = data.get("access_token")
    refresh_token = data.get("refresh_token")

    if access_token:
        print(f"access token:\n{access_token}\n")
        oauth2["access_token"] = access_token
    if refresh_token:
        print(f"refresh token:\n{refresh_token}\n")
        oauth2["refresh_token"] = refresh_token
    else:
        print("No refresh token proposed\n")

    return access_token


# ---------------------------------------------------------------------------
# Token file persistence
# ---------------------------------------------------------------------------

def oauth2_load_tokens_from_file(oauth2):
    """Load access + refresh tokens from file. Returns access_token or None."""
    if oauth2.get("startover"):
        return None
    token_file = oauth2["token_file"]
    print(f"Reading tokens from file {token_file}, if any\n")
    access_token  = nthline(token_file, 1)
    refresh_token = nthline(token_file, 2)
    oauth2["access_token"]  = access_token  or None
    oauth2["refresh_token"] = refresh_token or None
    return access_token or None


def oauth2_save_tokens_to_file(oauth2):
    """Write access + refresh tokens to file. Returns True on success."""
    token_file = oauth2["token_file"]
    lines = "\n".join([
        oauth2.get("access_token",  ""),
        oauth2.get("refresh_token", ""),
        "# The first   line is the access  token",
        "# The second  line is the refresh token",
        f"# Account is {oauth2['user']}",
        f"# File generated on {time.ctime()} by oauth2_imap.py",
        "",
    ])
    print(f"Writing tokens to the file {token_file}")
    return bool(string_to_file(lines, token_file))


# ---------------------------------------------------------------------------
# IMAP access check
# ---------------------------------------------------------------------------

def oauth2_check_imap_access(oauth2):
    """
    Authenticate to the IMAP server using XOAUTH2 and list folders.
    Returns True on success, None on failure.
    """
    user         = oauth2.get("user")
    access_token = oauth2.get("access_token")
    imap_server  = oauth2.get("imap_server")

    if not user:
        print("No user given. It is useless to try an authentication, is not it?")
        return None
    if not access_token:
        print("No access token given. It is useless to try an authentication, is not it?")
        return None

    # imaplib.authenticate() base64-encodes whatever the callback returns
    # before putting it on the wire, so the callback must return the raw
    # (un-encoded) auth string as bytes.  Returning a pre-encoded string
    # would cause double-encoding and a BAD Command Argument Error.
    auth_bytes = f"user={user}\x01auth=Bearer {access_token}\x01\x01".encode("ascii")

    try:
        imap = imaplib.IMAP4_SSL(imap_server, 993)
    except Exception as e:
        print(f"Can't connect to imap server {imap_server}: {e}\n")
        return None

    try:
        imap.authenticate("XOAUTH2", lambda x: auth_bytes)
    except imaplib.IMAP4.error as e:
        print(f"Auth error: {e}\n")
        return None

    status, folder_data = imap.list()
    if status == "OK" and folder_data:
        folders = [f.decode() if isinstance(f, bytes) else f for f in folder_data]
        print(f"Found {len(folders)} folders: {', '.join(folders)}\n")
        print(f"Success IMAP login to account {user} with access token in {oauth2['token_file']}")
        imap.logout()
        return True
    else:
        print("Found no folders. Bad sign.")
        return None


# ---------------------------------------------------------------------------
# Main approval flow
# ---------------------------------------------------------------------------

def oauth2_approval(oauth2):
    if oauth2_load_tokens_from_file(oauth2) and oauth2_check_imap_access(oauth2):
        print("Access token is ok but let's get a new one anyway.")

    if oauth2.get("refresh_token"):
        print("Found a refresh token. Refreshing the access token with it, and maybe the refresh token.")
        if not oauth2_refresh_tokens(oauth2):
            oauth2_ask_authorization_code(oauth2)
            oauth2_request_tokens(oauth2)
    else:
        oauth2_ask_authorization_code(oauth2)
        oauth2_request_tokens(oauth2)

    if oauth2_check_imap_access(oauth2):
        oauth2_save_tokens_to_file(oauth2)
        return 0
    else:
        print("Failure")
        return 1


# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------

def parse_args(argv):
    parser = argparse.ArgumentParser(
        description="OAuth2 token acquisition and IMAP access verification tool."
    )
    parser.add_argument("user", nargs="?", help="Email address (e.g. foo@example.com)")
    parser.add_argument("--tests",          action="store_true")
    parser.add_argument("--testsone",       action="store_true")
    parser.add_argument("--debug",          action="store_true")
    parser.add_argument("--startover",      action="store_true")
    parser.add_argument("--provider",       default=None)
    parser.add_argument("--authorize_uri",  default=None)
    parser.add_argument("--token_uri",      default=None)
    parser.add_argument("--redirect_uri",   default=None)
    parser.add_argument("--scope_string",   default=None)
    parser.add_argument("--application",    default="thunderbird")
    parser.add_argument("--client_id",      default=None)
    parser.add_argument("--client_secret",  default=None)
    parser.add_argument("--token_file",     default=None)
    parser.add_argument("--local",          action="store_true", default=False)
    parser.add_argument("--localssl",       action="store_true", default=False)
    parser.add_argument("--imap_server",    default=None)
    parser.add_argument("--remotebrowser",  action="store_true")
    return parser.parse_args(argv)



# ---------------------------------------------------------------------------
# Public library entry point
# ---------------------------------------------------------------------------

def run(
    user,
    *,
    application="thunderbird",
    provider=None,
    authorize_uri=None,
    token_uri=None,
    redirect_uri=None,
    scope_string=None,
    client_id=None,
    client_secret=None,
    token_file=None,
    imap_server=None,
    local=None,
    localssl=None,
    remotebrowser=False,
    startover=False,
    debug=False,
    tests=False,
    testsone=False,
):
    """
    Acquire OAuth2 tokens for *user* and verify IMAP access.

    This is the library entry point.  Every parameter corresponds directly to
    the identically-named CLI flag.  All parameters except *user* are
    keyword-only so callers must name them explicitly.

    Parameters
    ----------
    user : str
        Email address of the account to authenticate, e.g. ``"foo@example.com"``.
    application : str, optional
        Selects the built-in client credentials set.  One of ``"thunderbird"``
        (default) or ``"imapsync"``.
    provider : str, optional
        Force a specific OAuth2 provider (``"gmail"`` or ``"office365"``).
        When omitted the provider is auto-detected via MX record lookup.
    authorize_uri : str, optional
        Override the provider's authorization endpoint URL.
    token_uri : str, optional
        Override the provider's token endpoint URL.
    redirect_uri : str, optional
        Override the OAuth2 redirect URI.
    scope_string : str, optional
        Override the OAuth2 scope string.
    client_id : str, optional
        OAuth2 application client ID.
    client_secret : str, optional
        OAuth2 application client secret.
    token_file : str, optional
        Path to the file used for persisting access and refresh tokens.
        Defaults to ``"tokens/oauth2_tokens_<user>.txt"``.
    imap_server : str, optional
        Hostname of the IMAP server to test access against.
    local : bool, optional
        Force plain-HTTP localhost redirect handling.  Inferred from
        *redirect_uri* when omitted.
    localssl : bool, optional
        Force HTTPS localhost redirect handling.  Inferred from
        *redirect_uri* when omitted.
    remotebrowser : bool, optional
        When ``True``, skip the automatic localhost code capture and instead
        prompt the user to paste the code from stdin.
    startover : bool, optional
        When ``True``, ignore any existing token file and re-authenticate
        from scratch.
    debug : bool, optional
        Print the fully-resolved ``oauth2`` configuration dict before running.

    Returns
    -------
    int or None
        ``0`` on successful IMAP authentication, ``1`` on failure, or ``None``
        if a setup step (provider detection, parameter validation, etc.) failed
        before authentication was attempted.  When *tests* or *testsone* is
        ``True`` the return value is always ``None``.
    """


    if not user:
        print(f"\nusage: oauth2_imap.run(user='foo@example.com', ...)\n")
        return None

    # --- build and populate the oauth2 state dict ---
    oauth2 = {"user": user, "provider": provider}

    if not set_provider(oauth2):
        return None

    oauth2["application"] = application
    oauth2_parameters(oauth2)

    # Explicit caller-supplied values override the provider defaults.
    # Only non-None and non-False boolean values are forwarded so that
    # unset flags don't clobber defaults populated above.  In particular,
    # local/localssl default to False from argparse but must not be merged
    # as False here — oauth2_init will derive them from redirect_uri instead.
    caller_overrides = {
        k: v for k, v in {
            "authorize_uri":  authorize_uri,
            "token_uri":      token_uri,
            "redirect_uri":   redirect_uri,
            "scope_string":   scope_string,
            "client_id":      client_id,
            "client_secret":  client_secret,
            "token_file":     token_file,
            "imap_server":    imap_server,
            "local":          local or None,
            "localssl":       localssl or None,
            "remotebrowser":  remotebrowser or None,
            "startover":      startover or None,
        }.items()
        if v is not None
    }
    merge_hashes(oauth2, caller_overrides)

    if not oauth2_checklist(oauth2):
        return None

    if not oauth2_init(oauth2):
        return None

    if debug:
        import pprint
        print("run(), oauth2:")
        pprint.pprint(oauth2)

    return oauth2_approval(oauth2)


# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------

def main(argv=None):
    """
    Parse *argv* (defaults to ``sys.argv[1:]``) and delegate to :func:`run`.

    This keeps the CLI as a thin shim over the library function so that both
    paths share identical logic.
    """
    if argv is None:
        argv = sys.argv[1:]

    args = parse_args(argv)

    if not args.tests and not args.testsone and not args.user:
        print(f"\nusage: {sys.argv[0]} foo@example.com\n")
        return

    run(
        user=args.user or "",
        application=args.application,
        provider=args.provider,
        authorize_uri=args.authorize_uri,
        token_uri=args.token_uri,
        redirect_uri=args.redirect_uri,
        scope_string=args.scope_string,
        client_id=args.client_id,
        client_secret=args.client_secret,
        token_file=args.token_file,
        imap_server=args.imap_server,
        local=args.local,
        localssl=args.localssl,
        remotebrowser=args.remotebrowser,
        startover=args.startover,
        debug=args.debug,
        tests=args.tests,
        testsone=args.testsone,
    )


if __name__ == "__main__":
    main()
