Should the CTA Shed Some Wait? ArcPy-Based Code


This page contains all the Python modules used in my ArcPy-based network dataset creation code. Module List: arcpy_config, code_playground, create_network_dataset_oop, general_tools, network_types, tools_for_place. You can click on the name of a module to expand it and look at the code. The readme included contains most the relevant information about how to understand and play around with my code.

Network Datasets with ArcPy

by Aaron Rumph | Read Me Last Updated: 11/17/25

About

Easily Create Network Datasets with Open Data!

This repository contains a few modules that allow for creating network datasets automatically with ArcPy using OpenStreetMap data (courtesy of osmnx). At the moment, you can create network datasets for any place (was designed with cities in mind, but can also do counties or any place that Nominatim will recognize, so maybe neighborhoods or zip codes) in the world if doing place_only or specified for geographic scope (see How to Use). If you would like, you can also create a network dataset for the entire county, MSA, or CSA a place is in, provided it is in the U.S. (see How to Use). At the moment, the only network types that are supported are “walk” and “transit” (see How to Use), with both using pedestrian network data from OSM. Currently, if transit is selected, the transit network dataset will automatically query the TransitLand API for transit data for the place and use the GTFS data found for that place to create the network dataset. I am working on a “bring your own” feature, so that you can use your own transit data to create a transit network dataset, however for the moment, be warned that 1) the transit data will automatically be found and used and 2) at the moment, only the original input place is used to find the transit data (see How to Use and Current Known Issues), so if you give it a place like “San Francisco” it will work, but “Kensington, CA, USA” will likely turn up no results for transit data. As such, I recommend using the largest place possible (i.e., while it will be able to generate a network dataset for the CSA that Kensington (Bay Area) is in, it will not be able to find the transit data for Kensington, so I recommend using San Francisco as your input place instead. It’s the same CSA either way, but this way you’ll get transit data for every agency that serves SF).

Also, fair warning, while I have tried to optimize my code as much as possible, it may still take a while to run. There’s only so much I can do. Though I’d like to think I’m not horrible at programming, I’m sure I’ve done some stupid things. ArcPy tends to be pretty slow (so whenever possible, I’ve tried to avoid using it). Also, the sheer amount of data used in creating a network dataset (particularly if you’re using elevation data or if you’re doing CSAs or MSAs) means it’s just going to take a while no matter what.

The goal of this readme is to make my code accessible to others, primarily planning students/planners. As such, not much technical knowledge (at least when it comes to Python, ArcGIS on the other hand…) is presumed. If you have lots of experience working with Python (or any other programming language), you may want to skip the IDE Set Up section.

To give a sense here are some (pretty accurate) estimated runtimes for various scenarios based on my own experiments:

  • Berkeley, CA city limits, elevation enabled walking network dataset (10,000 nodes, 30,000 edges): 2.5 minutes
  • San Francisco, CA city limits, elevation enabled walking network dataset: (80,000 nodes, 250,000 edges): 20 minutes
  • San Francisco, CA city limits, elevation enabled transit network dataset: (80,000 nodes, 250,000 edges): 45 minutes
  • San Francisco, CA CSA, no elevation transit network dataset: (1,000,000 nodes, 4,000,000 edges): 1 hour
  • Chicago, IL city limits, no elevation walking network dataset: (300,000 nodes, 1,000,000 edges): 5.5 minutes

Generate Isochrones from Addresses or Coordinates!

Additionally, this repository allows you to create isochrones from addresses or coordinates with network datasets created as described above (see How to Use). I have made some attempts to make this as user-friendly as possible, so I would say go with the default parameters for the most part, but there is a lot to customize if you would like to! Also, I mostly added this feature as a way of showing off my code to others without having to open ArcGIS (I have, only half-jokingly, said that my goal is to never have to open ArcGIS and do anything manually again) so feel free to use it as you see fit! I have made some editorial decisions around what the default parameters are, like choosing what I think is the best color-scheme for isochrones, what the default cut-offs should be etc., but go wild!

Set Up

Most Important Bit

Before you can run any of my code, you need three things:

1) A working Python IDE you can actually run the code in.

Any IDE will do just fine, but I am partial to Pycharm. You could run it straight in the command line, but I would not recommend it because harder to find documentation for my code. Also, theoretically, you could run it straight in the Python window in ArcGIS Pro, but I am not sure how that would work with it split up into modules.

2) ArcGIS Pro and the associated Python environment that Esri provides (Most Important of the Three!).

Because ArcGIS is closed source, ArcPy is also closed source, so when you try to use it, Esri will check that you have a valid license and access to the necessary extensions (in this case, just Network Analyst). To do this, they make you use the environment they provide (or a clone of it) to do anything with ArcPy. Now, because running this code involves using packages that are not included in the standard arcgispro-py3 environment, you will need to clone the environment that ArcGIS Pro provides. To do this, you can either go to options in ArcGIS Pro proper, and clone it there, or you can use the Python Command Prompt that comes with ArcGIS (can just search “Python Command Prompt” in start menu) and enter:

    conda create --name arcgispro-py3-nd-clone --clone arcgispro-py3
If you are planning to use the command line to run the code, you will need to activate the environment before running the code. To do this, enter:
    conda activate arcgispro-py3-nd-clone

Also, you will need to install the required packages for this project that don’t come already installed in the clone (see Requirements.txt). Make a note of the path to your environment as you’ll need it to set up the interpreter in the IDE.

If you are using an IDE to run the code:

You will need to set up the environment in the IDE to use the cloned environment. This will vary depending on the IDE you are using, but pretty straightforward in most cases (if not, use Google/Bing/Your IDE’s Documentation/AskJeeves), or see below.

If you do not know how to do this, and are using an IDE for the first time, see IDE Set Up section below

Now, you should be set up to use ArcPY!

3) API Keys

You will need API keys for exactly two things. First, if you would like to use the transit network dataset feature, you will need an API key for the TransitLand API (can be gotten for free here). Second, no matter what, you will need an API key for the Census Bureau API from here (I guess not if you only do place_only for geographic scope, but trust me you want one, and it’s free).

Once you have those two API keys, make a .env file (I would recommend calling it api_keys.env, and putting it in the same directory as the rest of the code because that is what the code is already set up for, but if you would like, you can configure it however you like to configure your .env files) and add the following lines:

TRANSITLAND_API_KEY=your_transitland_api_key_here
CENSUS_BUREAU_API_KEY=your_census_bureau_api_key_here

Note, if you are new to .env files, make sure that there are no spaces around the equals sign (yes, this did trip me up, how could you tell).

Finally, for geocoding with Nominatim’s API, you will need to use a user agent (see Nominatim’s policy here You don’t need to sign up for anything, just pick a name that you feel best describes your use case and use that. Then, you must add it to your api_keys.env (or equivalent) file with

NOMINATIM_USER_AGENT=your_user_agent_here

Also note that if you use spaces, you should use quotes around your user agent. For example:

NOMINATIM_USER_AGENT="Your User Agent Here"

IDE Set Up and Downloading Code (skip if familiar)

The easiest way to play around with my code is just to clone it from this GitHub repository and then import it into your IDE. If you already have an IDE, I’m going to assume you know how to do this. If not, I would recommend PyCharm as your IDE because it’s free and is just what I use (you can download it here).

Once you’ve downloaded and installed PyCharm, you should see a screen that looks like this:

image

Now, you can just hit the Clone Repository button in the top right, and enter the following URL:

https://github.com/aaronrumph/transit_network_datsets_with_ArcPy

Then, you can just hit the Clone button, and you will now have a copy of the files necessary to play around with my code! Now, you need to configure the project to use the right interpreter. If you have not already done the steps explained in item 2) in the Set Up section, you will need to do that now. Now, hit ctrl+alt+s to open the settings, and then navigate to Python and then under that, Interpreter. Now, you should see a window that looks like this

Click Add Interpreter, and then Local Interpreter. Then, click Select existing, and navigate to the path to your cloned argispro-py3 environment that you created in item 2) of the Set Up section (arcgispro-py3-nd-clone). By default, your path should be something like:

...Users<your_name>AppDataLocalESRIcondaenvsarcgispro-py3-nd-clonepython.exe"

You can then set that as your interpreter and after it loads, you should be good to go!

How to Use

Honestly, the best thing to see how it works (as frustrating as it sounds), is just to play around with it, and read the documentation for the Classes, methods, and functions. My suggestion, is to run the code_playground.py file, and play around with the starting defaults I’ve provided to see how it works.

If, however, you don’t have time for that, simply create a ArcProject object:

your_arcgis_project = ArcProject(name="demonstration", project_location="dir_where_the_project_folder_will_go")

and a Place object:

your_place = Place(arcgis_project=your_arcgis_project, place_name="San Francisco", geographic_scope="place_only")

and then call the create_network_dataset_from_place method:

your_place.create_network_dataset_from_place(network_type="walk", use_elevation=False)

or the generate_isochrone method:

your_place.generate_isochrones_from_place(isochrone_name="demonstration_isochrone, addresses=[your_addr_1, your_addr_2], points=[your_point_1, your_point_2], 
                                          network_type="walk", use_elevation=False)

Features

1) Create Network Datasets using free, open data!

  • Create walking or transit network datasets for various geographies
  • Create network datasets for a bounding box
  • Create network datasets that take elevation into account

2) Generate Isochrones using the Network Datasets

  • Generate isochrones from addresses
  • Generate isochrones from points

Development Process

Nothing to see here yet, sorry!

Current Known Issues

  • If, when trying to create transit network dataset, you give a place that doesn’t return any agencies from TransitLand’s API, it will crash and burn when it doesn’t find any.
  • At the moment, only looks for transit agencies that serve the main place given, not all in the CSA, MSA, or county
  • Color ramp for isochrones is opposite of what I would like (darker means less time rather than lighter means less time)

Eventual Goals

Somethings I would like to add/change about this code:

Open Source!

At the moment, my code uses ArcPy, which is great in that it’s easy and Esri gives you a lot of very powerful tools fresh out of the box. That said, there are two problems: first, you have to have an ArcGIS Pro license (expensive) as well as access to the Network Analyst extension; second, while it is nice not having to write all the code for creating a network dataset (which is just a fancy graph) myself, Esri’s code is not exactly the fastest or best at times.

As such, I hope to eventually do all of that stuff myself. Fortunately, there are many, many other people who are committed to/working on open source transportation network stuff, and I hope to expand on their work! Some of this stuff already exists and it’s just a matter of adapting it for this use case (e.g., there are a million implementations of Dijkstra’s to use for shortest path), whereas some of it will require lots of work to create a solution from scratch (e.g., I have not come across any open source algorithms that can create graphs that represent a transit network based on GTFS data).

Other Network Analysis Tools

For now, the only two features I have are creating network datasets and generating isochrones, and while these are all well and good, I would like to (and indeed, need to, for a project I’m working on for class) add support for Origin-Destination Cost Matrices, Shortest Path Routing, and more.

Other Network Types

Unfortunately, my code in its current form only allows for the creation/use of transit and walking network datasets. When I have some more time, I would love to add support for driving and biking (and biking + transit).

Bring Your Own GTFS Data/Improved Transit Networks

As mentioned previously, when building a transit network dataset at the moment, you can only use the automatically retrieved GTFS data for the specified place. I would like to, and should soon, add the ability to ‘bring your own’ GTFS data to use in building the transit network dataset. My goal is to make the user interface for this as friendly as possible, so that it is usable by other people.

Additionally, I would like to fix the current behavior where it only gets the transit agencies that serve the ‘core’ place rather than the entire selected geographic scope.

A CLI/GUI

Long term, I would like to create a GUI or CLI that I can distribute to people so that they can use this code without having to install an IDE, or clone their arcgispro-py3 env themself. This may be dependent on the Open Source goal, as otherwise it’d be a mess of permissions (especially with ArcGIS/Esri), but I believe it’s the ultimate long term goal.

Python
"""
This module sets up the environment for using ArcPy by modifying the system PATH variable
to look in the right place and configures ArcPy for parallel processing. The paths included here are the defaults, but if
you installed ArcGIS Pro in a weird way, or have messed with the Program Files too much, you may have to change them.
"""
import os


def set_up_arcpy_env():
    arcgis_bin = r"C:Program FilesArcGISProbin"
    arcgis_extensions = r"C:Program FilesArcGISProbinExtensions"
    os.environ["PATH"] = arcgis_bin + os.pathsep + arcgis_extensions + os.pathsep + os.environ.get("PATH", "")
    import arcpy
    arcpy.env.parallelProcessingFactor = "100%"

Python
"""
This module contains helpful helper functions that can be used in any given module with no circular import
problems and does not use arcpy at all to avoid any problems with license checks
"""
import base64
import logging
import random
import re
import pluscodes
import requests

# getting API key from .env file
import os
from dotenv import load_dotenv
load_dotenv("api_keys.env")
census_bureau_api_key = os.getenv("CENSUS_BUREAU_API_KEY")
from geopy.geocoders import Nominatim

# nominatim user agent
nominatim_user_agent = os.getenv("NOMINATIM_USER_AGENT")



# regex matching patterns for use in checking other stuff
regex_matching_patterns = {
    # regex pattern to match valid longitudes built using regex generator because regex hard:
    "longitude":
        r"^[-+]?(180(.0+)?|1[0-7]d(.d+)?|d{1,2}(.d+)?)$",

    # regex pattern to match valid latitudes
    "latitude":
        r"^[-+]?([1-8]?d(.d+)?|90(.0+)?)$"}

def time_function(func):
    """
    Decorator to time functions
    :param func:
    :return:
    """
    import time
    def wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        logging.info(f"Function {func.__name__} took {turn_seconds_into_minutes(elapsed_time)} to complete.")
        return result
    return wrapper


class ReferencePlace:
    def __init__(self, place_name:str=None, bound_box:tuple[str|float, str|float, str|float, str|float]=None):
        """
        ReferencePlace class acts like a sort of dictionary and contains the place name and or bounding box for a place
        Attributes:
            place_name | str: the place name in the form "place, (division), country"
            bound_box | tuple[str,str,str,str]: (longitude_min, latitude_min, longitude_max, latitude_max)
            plus_codes_for_bbox_corners | tuple[str]: plus codes for each corner of the bounding box

        """
        self.place_name = place_name
        self.bound_box = bound_box

        if self.bound_box:
            self.pretty_name = self.bound_box
        else:
            self.pretty_name = self.place_name


        """ 
        Using plus codes as a way to represent the bounding box provided. This is mostly done to shorten file names. 
        The downside is that 1. Plus codes are not as legible as coordinates (I know that SF, for instance, is around 37
        something degrees north), and 2. Plus codes can be longer than the bounding box coordinates (e.g. SF is around
        12 characters long, while the bounding box coordinates are around 11 characters long). But better to have a 
        standardized way of doing it than to have a mix of formats and plus codes are shorter in cases where precise 
        coordinates are used.
        """
        self.plus_codes_for_bbox_corners:list = []
        self.out_name:str = ""

        # check that either a bound box or a place name has been provided:
        if self.place_name is None and self.bound_box is None:
            raise ValueError("Reference place arguments (place_name and bound_box) cannot both be undefined")

        # this part is just exception handling because everything else follows from this, so need correct ReferencePlace
        # first checking that bound_box was provided
        if self.bound_box is not None:
            # det out name to reflect bounding box
            self.out_name = f"Bounding box: {self.bound_box}"
            # check if ALL the provided values for the bbox were floats
            if all(isinstance(coord, float) for coord in self.bound_box):
            # in case the bounding box is not in the correct format (aka, if floats passed instead), rewrite it
                _rewrite_bound_box:list = [str(float_coord) for float_coord in self.bound_box]
                self.bound_box = tuple(_rewrite_bound_box)

            # next checking that each element in bounding box is 1. a str, 2. only float-ables, and 3. between -180 and 180
            for coord in self.bound_box:
                # check first that coord is string
                if not isinstance(coord, str):
                    # since already converted bbox from floats to strs (if was floats), just raise error for other types
                        raise ValueError("Bounding box coordinates must be given as strings")

                # easy check first to see that coordinate in bounding box does not contain bad characters
                if not re.match(r"-?d+(?:.d+)?", coord):
                    raise ValueError("Bounding box coordinates must be strings containing "
                                     "only digits and decimal points")

            # now need to check that the provided bounding box is in fact (left, bottom, right, top)
            if not re.match(regex_matching_patterns["longitude"], self.bound_box[0]): #
                raise ValueError("First value passed in bound_box must be a valid longitude string")
            if not re.match(regex_matching_patterns["longitude"], self.bound_box[2]):
                raise ValueError("Third value passed in bound_box must be a valid longitude string")
            if not re.match(regex_matching_patterns["latitude"], self.bound_box[1]):
                raise ValueError("Second value passed in bound_box must be a valid latitude string")
            if not re.match(regex_matching_patterns["latitude"], self.bound_box[3]):
                raise ValueError("Fourth value passed in bound_box must be a valid latitude string")

            # if have made it this far without an error then bound box is safe to use and can now create plus codes
            self.create_plus_codes_for_bbox()

        # place name was provided but no bound box
        else:
            self.out_name = f"Place:{self.place_name}"

    def create_plus_codes_for_bbox(self):
        """
        Creates plus codes for the bounding box of the reference place using the bottom left corner and top right corner

        :return self.plus_codes_for_bbox_corners: tuple[str,str] | where each string in tuple is a pluscode
        """

        # always first check that bounding box was actually provided
        if self.bound_box is None:
            raise ValueError("Cannot generate plus codes for the bounding box because none was provided!")

        else:
            # need to convert bbox coordinates to floats to encode as plus codes
            lat_long_bottom_left_corner = (float(self.bound_box[1]), float(self.bound_box[0]))
            lat_long_top_right_corner = (float(self.bound_box[3]), float(self.bound_box[2]))

            # using pluscodes module to encode (lat, lon) to plus codes
            plus_code_bottom_left = pluscodes.encode(lat_long_bottom_left_corner[1], lat_long_bottom_left_corner[0])
            plus_code_top_right = pluscodes.encode(lat_long_top_right_corner[1], lat_long_top_right_corner[0])

            # now update attribute
            self.plus_codes_for_bbox_corners.append(plus_code_bottom_left)
            self.plus_codes_for_bbox_corners.append(plus_code_top_right)

            # now want to make tuple so can't accidentally modify
            self.plus_codes_for_bbox_corners = tuple(self.plus_codes_for_bbox_corners)

            # method output
            return self.plus_codes_for_bbox_corners

    # automatically create plus codes

def create_snake_name(name:str | ReferencePlace):
    """
    Creates snake name for place. For use in a bunch of other stuff. Can either take a string (simple place name like
    'San Francisco, California, USA' or a reference place dictionary (see reference place class) with a place name and
    a bounding box
    :param name:
    :return:
    """
    # first, if passed simple place name, just replace commas and spaces.
    if isinstance(name, str):
        return name.replace(" ", "_").replace(",", "").replace(r"/", "").lower()

    # else if passed ReferencePlace
    elif isinstance(name, ReferencePlace):
        # in case where bounding box is not provided
        if name.bound_box is None:
            concatenated_place_name = name.place_name.replace(" ", "_").replace(
                                            ",", "").replace(r"/", "").lower()
            return concatenated_place_name

        # in case where bounding box is given but place name is not
        elif name.bound_box is None:
            concatenated_plus_codes = "_".join(name.plus_codes_for_bbox_corners)
            return concatenated_plus_codes

        # case where both bounding box and place name are provided, by default will use bounding box
        else:
            concatenated_plus_codes = "_".join(name.plus_codes_for_bbox_corners)
            return concatenated_plus_codes

    # in case where tried to pass something other than a string or a ReferencePlace
    else:
        raise ValueError("Name must be either a string or a ReferencePlace instance")

# dictionary for converting state abbreviation to fips and name
state_fips_and_abbreviations = {
    "AL": {"fips": "01", "name": "Alabama"},
    "AK": {"fips": "02", "name": "Alaska"},
    "AZ": {"fips": "04", "name": "Arizona"},
    "AR": {"fips": "05", "name": "Arkansas"},
    "CA": {"fips": "06", "name": "California"},
    "CO": {"fips": "08", "name": "Colorado"},
    "CT": {"fips": "09", "name": "Connecticut"},
    "DE": {"fips": "10", "name": "Delaware"},
    "DC": {"fips": "11", "name": "District of Columbia"},
    "FL": {"fips": "12", "name": "Florida"},
    "GA": {"fips": "13", "name": "Georgia"},
    "HI": {"fips": "15", "name": "Hawaii"},
    "ID": {"fips": "16", "name": "Idaho"},
    "IL": {"fips": "17", "name": "Illinois"},
    "IN": {"fips": "18", "name": "Indiana"},
    "IA": {"fips": "19", "name": "Iowa"},
    "KS": {"fips": "20", "name": "Kansas"},
    "KY": {"fips": "21", "name": "Kentucky"},
    "LA": {"fips": "22", "name": "Louisiana"},
    "ME": {"fips": "23", "name": "Maine"},
    "MD": {"fips": "24", "name": "Maryland"},
    "MA": {"fips": "25", "name": "Massachusetts"},
    "MI": {"fips": "26", "name": "Michigan"},
    "MN": {"fips": "27", "name": "Minnesota"},
    "MS": {"fips": "28", "name": "Mississippi"},
    "MO": {"fips": "29", "name": "Missouri"},
    "MT": {"fips": "30", "name": "Montana"},
    "NE": {"fips": "31", "name": "Nebraska"},
    "NV": {"fips": "32", "name": "Nevada"},
    "NH": {"fips": "33", "name": "New Hampshire"},
    "NJ": {"fips": "34", "name": "New Jersey"},
    "NM": {"fips": "35", "name": "New Mexico"},
    "NY": {"fips": "36", "name": "New York"},
    "NC": {"fips": "37", "name": "North Carolina"},
    "ND": {"fips": "38", "name": "North Dakota"},
    "OH": {"fips": "39", "name": "Ohio"},
    "OK": {"fips": "40", "name": "Oklahoma"},
    "OR": {"fips": "41", "name": "Oregon"},
    "PA": {"fips": "42", "name": "Pennsylvania"},
    "RI": {"fips": "44", "name": "Rhode Island"},
    "SC": {"fips": "45", "name": "South Carolina"},
    "SD": {"fips": "46", "name": "South Dakota"},
    "TN": {"fips": "47", "name": "Tennessee"},
    "TX": {"fips": "48", "name": "Texas"},
    "UT": {"fips": "49", "name": "Utah"},
    "VT": {"fips": "50", "name": "Vermont"},
    "VA": {"fips": "51", "name": "Virginia"},
    "WA": {"fips": "53", "name": "Washington"},
    "WV": {"fips": "54", "name": "West Virginia"},
    "WI": {"fips": "55", "name": "Wisconsin"},
    "WY": {"fips": "56", "name": "Wyoming"},
    "AS": {"fips": "60", "name": "American Samoa"},
    "GU": {"fips": "66", "name": "Guam"},
    "MP": {"fips": "69", "name": "Northern Mariana Islands"},
    "PR": {"fips": "72", "name": "Puerto Rico"},
    "UM": {"fips": "74", "name": "U.S. Minor Outlying Islands"},
    "VI": {"fips": "78", "name": "U.S. Virgin Islands"}
}

def get_reference_places_for_scope(place_name:str, geographic_scope:str):
    """
    using nested functions because I am tired and want to
    """

    logging.info(f"Getting reference places for {place_name} {geographic_scope}")
    # first can just return city proper if try passing "place_only" as scope
    if geographic_scope == "place_only":
        return [ReferencePlace(place_name)]

    def get_lat_lon(place_name: str):
        """
        This function takes a place name (like "San Francisco, California, USA") and returns a list of ReferencePlaces
        corresponding to places in the same MSA as the place given.
        :param place_name:
        :return:
        """

        # first get the lat and lon of the place using the nominatim api
        nominatim_url = "https://nominatim.openstreetmap.org/search"
        # q is unspecified query
        nominatim_params = {
            "q": place_name,
            "limit": "10",
            "format": "json"}
        # need to give valid User-Agent header
        nominatim_headers = {"User-Agent": nominatim_user_agent}

        # now get response from API
        nominatim_response = requests.get(nominatim_url, params=nominatim_params, headers=nominatim_headers)
        nominatim_response.raise_for_status()
        nominatim_data = nominatim_response.json()

        # extract lat and lon from response
        place_lat = nominatim_data[0]["lat"]
        place_lon = nominatim_data[0]["lon"]
        place_lat_lon:dict = {"lat": place_lat, "lon": place_lon}

        return place_lat_lon

    # the url for the geocoding API
    census_bureau_geocoding_url = "https://geocoding.geo.census.gov/geocoder/geographies/coordinates"

    # adding in support for just the county level for a given place
    def get_county_from_lat_lon(place_lat_lon:dict):

        census_bureau_geocoding_params = {
            "x": place_lat_lon["lon"],
            "y": place_lat_lon["lat"],
            "format": "json",
            "benchmark": "Public_AR_Current",   # just the weird current version shorthand
            "vintage": "4",                     # same as above
            "layers": "82"}  # 80 is the layer code for county

        census_bureau_response = requests.get(census_bureau_geocoding_url, params=census_bureau_geocoding_params)
        census_bureau_response.raise_for_status()

        # turn into JSON
        census_bureau_data = census_bureau_response.json()
        returned_geographies = census_bureau_data["result"]["geographies"]
        name_for_county = returned_geographies["Counties"][0]["NAME"]

        # going to return a set of counties for consistency with msa and csa
        set_of_counties = {name_for_county}

        return set_of_counties

    # now that have lat and lon in place, can get MSA (and state) geoids using Census geocoder
    def get_msa_from_lat_lon(place_lat_lon:dict):
        """

        :param place_lat_lon:
        :return: geoids | {"msa_geoid": msa_geoid, "state_geoid": state_geoid}
        """
        # using the Census Bureau's geocoding API because free and gives geoids back getting msa geoid and state geoid
        census_bureau_geocoding_params = {
            "x": place_lat_lon["lon"],
            "y": place_lat_lon["lat"],
            "format": "json",
            "benchmark": "Public_AR_Current",   # just the weird current version shorthand
            "vintage": "4",                     # same as above
            "layers": "93,80"}  # 93 is the layer code for MSA and 80 is the layer code for state

        census_bureau_response = requests.get(census_bureau_geocoding_url, params=census_bureau_geocoding_params)
        census_bureau_response.raise_for_status()

        # take response json data and use to po
        census_bureau_data = census_bureau_response.json()
        returned_geographies = census_bureau_data["result"]["geographies"]

        # get state geoid out of returned geographies
        state_geoid = returned_geographies["States"][0]["GEOID"]
        # can safely use first result because msas are mutually exclusive
        msa_geoid = returned_geographies["Metropolitan Statistical Areas"][0]["GEOID"]
        # get state name out of returned geographies
        state_name = returned_geographies["States"][0]["BASENAME"]

        # returning a dict: {"msa_geoid": geoid (str), "state_geoid": geoid (str)}
        geoids = {"msa_geoid": msa_geoid, "state_geoid": state_geoid, "state_name": state_name}
        return geoids

    # basically same function as above, just for csa
    def get_csa_from_lat_lon(place_lat_lon:dict):
        census_bureau_geocoding_params = {
            "x": place_lat_lon["lon"],
            "y": place_lat_lon["lat"],
            "format": "json",
            "benchmark": "Public_AR_Current",   # just the weird current version shorthand
            "vintage": "4",                     # same as above
            "layers": "97,80"}  # 97 is csa code and 80 is state

        census_bureau_response = requests.get(census_bureau_geocoding_url, params=census_bureau_geocoding_params)
        census_bureau_response.raise_for_status()

        # take response json data and use to po
        census_bureau_data = census_bureau_response.json()
        returned_geographies = census_bureau_data["result"]["geographies"]
        csa_data = returned_geographies["Combined Statistical Areas"][0]


        # can safely use first result for csa
        csa_geoid = csa_data["GEOID"]
        # get name of csa (for finding states)
        csa_name = csa_data["NAME"]

        # now the hard part, turning the name into a list of states in the csa. e.g., Chicago-Naperville, IL-IN-WI CSA
        csa_comma_separated_name = csa_name.split(", ") # note the space, so first char of last item is letter not space

        # since every csa is named "..., STATE ABBREVIATION(s) CSA", can just take last element of comma separated
        csa_states_with_csa = csa_comma_separated_name[-1]
        # now slice string so that we have just the state abbreviation(s) (i.e., take all but last four characters)
        concatenated_state_abbreviations = csa_states_with_csa[:-4]
        # the list of state names to return
        state_abbreviations = []
        # now check if multiple states in csa
        if "-" in concatenated_state_abbreviations:
            for state_abbreviation in concatenated_state_abbreviations.split("-"):
                state_abbreviations.append(state_abbreviation)
        else:
            state_abbreviations.append(concatenated_state_abbreviations)

        # using the state_fips_and_abbreviations dict, can now get state geoids and names
        state_geoids = []
        state_names = []
        for state_abbreviation in state_abbreviations:
            state_geoid = state_fips_and_abbreviations[state_abbreviation]["fips"]
            state_geoids.append(state_geoid)
            state_name = state_fips_and_abbreviations[state_abbreviation]["name"]
            state_names.append(state_name)


        # retunr dictionary with info on the csa (geoid, the states in the csa's names and geoids)
        csa_info = {"csa_geoid": csa_geoid, "state_geoids": state_geoids, "state_names": state_names}
        return csa_info

    def get_msas_from_csa(csa_info):
        # get csd info out of provided dict
        csa_geoid = csa_info["csa_geoid"]
        state_geoids = csa_info["state_geoids"]
        state_names = csa_info["state_names"]
        # the list of msa_geoid_dict dicts to return
        msa_geoid_list = []

        # the geoinfo census bureau api url needed for query
        geoinfo_url = (f"https://api.census.gov/data/2023/geoinfo?get=NAME&for=metropolitan%20statistical%20area/"
                       f"micropolitan%20statistical%20area:*&in=combined%20statistical%20area:{csa_geoid}"
                       f"&key={census_bureau_api_key}")

        # query the API
        geoinfo_response = requests.get(geoinfo_url)
        geoinfo_response.raise_for_status()
        geoinfo_data = geoinfo_response.json()

        # go through each msa returned (the first is always just the format so can skip)
        for msa in geoinfo_data[1:]:
            # just in case test to same msa[0] not "NAME"
            if msa[0] != "NAME":
                # geoid is 3rd element of list returned by api
                msa_geoid = msa[2]
                # need to create a dictionary for each msa with state name and geoid
                msa_geoid_dict = {"msa_geoid": msa_geoid, "state_geoids": state_geoids, "state_names": state_names}
                # now add to list
                msa_geoid_list.append(msa_geoid_dict)
        return msa_geoid_list

    def get_counties_from_msa(msa_geoids:list[dict]):
        """ Returns a set of county names (e.g. "Alameda County, California") for the msa provided"""

        # the set of counties to return
        counties_in_msa = set()

        # iterating through msas provided
        for msa in msa_geoids:
            # get data out of the msa dict
            msa_geoid = msa["msa_geoid"]
            state_geoids = msa["state_geoids"]
            state_names = msa["state_names"]

            # because we do not know which state the msa is in so need to try each state with each msa
            for state_geoid, state_name in zip(state_geoids, state_names):
                logging.debug(f"Trying state geoid {state_geoid} for state {state_name} for MSA {msa_geoid}")
                # easier to just define new url for each msa rather than try to set parameters with base url
                geoinfo_counties_url = (f"https://api.census.gov/data/2023/geoinfo?"
                                        f"get=NAME&for=county:*&in=metropolitan%20"
                                        f"statistical%20area/micropolitan%20statistical%20area:{msa_geoid}%20"
                                        f"state%20(or%20part):{state_geoid}&key={census_bureau_api_key}")

                # query the API
                geoinfo_counties_response = requests.get(geoinfo_counties_url)

                # check if the response code is 200, meaning actually has data (otherwise no data and will fail)
                if geoinfo_counties_response.status_code == 200:
                    logging.debug(f"State geoid {state_geoid} for state {state_name} for MSA {msa_geoid} was
                    successful")
                    # since actually has data, can turn response in json
                    geoinfo_counties_data = geoinfo_counties_response.json()
                    # go through each county returned (the first is always just the format so can skip)
                    for county in geoinfo_counties_data[1:]:
                        # just in case test to same county[0] not "NAME"
                        if county[0] != "NAME":
                            county_no_state_name = county[0].split(";")[0]
                            counties_in_msa.add(f"{county_no_state_name}, {state_name}")

        return counties_in_msa

    # now main part of code
    if geographic_scope == "msa":
        this_lat_lon = get_lat_lon(place_name=place_name)
        this_msa = get_msa_from_lat_lon(this_lat_lon)
        these_county_names = get_counties_from_msa([this_msa]) # remember, am passing a list of msas

    elif geographic_scope == "csa":
        this_lat_lon = get_lat_lon(place_name=place_name)
        this_csa = get_csa_from_lat_lon(this_lat_lon)
        these_msas = get_msas_from_csa(this_csa)
        these_county_names = get_counties_from_msa(these_msas)

    elif geographic_scope == "county":
        this_lat_lon = get_lat_lon(place_name=place_name)
        these_county_names = get_county_from_lat_lon(this_lat_lon)

    else:
        raise Exception("Geographic scope must be either 'city', 'county', 'msa', or 'csa'")

    # list to return
    county_reference_places = [ReferencePlace(place_name=place_name)] # starting off with original place name RefPlace

    # go through county names and make ReferencePlaces
    for county_name in these_county_names:
        county_reference_places.append(ReferencePlace(place_name=county_name))

    logging.info(f"Found {len(county_reference_places) - 1} counties for {place_name} in {geographic_scope}")

    # return list of ReferencePlaces (counties)
    return county_reference_places

# function to check that a point is a valid coordinate
def check_if_valid_coordinate_point(point: tuple) -> None:
    """
    Used to check that a point is a valid coordinate (i.e. a tuple or list of two values that are
    both floats and match lat and lon requirements)

    :param point: tuple or list of two values (latitude, longitude)
    :return: None
    """

    # in case a point is not a list or tuple
    if not isinstance(point, tuple):
        raise ValueError("each point must be a tuple (or list) of (latitude, longitude)")

    # otherwise (point is either a list or tuple)
    else:
        # check whether the tuple or list has exactly two values
        if len(point) != 2:
            raise ValueError("each point must be a tuple (or list) of (latitude, longitude). You passed "
                             "a tuple or list with fewer or more than two values")
        # point has two values, so check that they are valid lat lon values
        else:
            # check that each coordinate is a float
            for coordinate in point:
                if not isinstance(coordinate, float):
                    raise ValueError("points must contain two float values")

                # check that lat and lon values are within valid ranges
            if point[0] < -90 or point[0] > 90:
                raise ValueError("first value for must be a valid latitude (between -90 and 90)")
            if point[1] < -180 or point[1] > 180:
                raise ValueError("second value for point must be a valid longitude (between -180 and 180)")

def get_coordinates_from_address(address: list[str] | str):
    """
    Takes a single address or list of addresses as input and then turns it into a tuple of (latitude, longitude)
    using nominatim API
    """

    # set up for query
    nominatim_url = "https://nominatim.openstreetmap.org/search"

    # first in case where only one address provided
    if isinstance(address, str):
        logging.info(f"Geocoding input address {address}")
        # again using q for open ended search which will return a bunch of characteristics of the address
        nominatim_params = {
            "q": address,
            "limit": "1",
            "format": "json"}

        # again, need to use user agent for nominatim API (but can be an string value????)
        nominatim_headers = {"User-Agent": nominatim_user_agent}

        # now can actually query the API
        nominatim_response = requests.get(nominatim_url, params=nominatim_params, headers=nominatim_headers)
        nominatim_response.raise_for_status()
        nominatim_data = nominatim_response.json()

        # if nominatim couldn't find anywhere, then address was improperly formed
        if len(nominatim_data) == 0:
            raise ValueError(f"Address {address} could not be geocoded, "
                             f"please double check that the address is correct")

        # get lat and lon out of response
        nominatim_content = nominatim_data[0]
        address_lat = float(nominatim_content["lat"])
        address_lon = float(nominatim_content["lon"])

        # now define address coordinate tuple to return
        address_coordinates = (address_lat, address_lon)
        logging.info(f"Successfully geocoded address: {address}, lat, lon = {address_coordinates}")
        return address_coordinates

    # in case where multiple addresses provided
    elif isinstance(address, list):
        # list of coordinate tuples to return
        all_address_coordinates = []
        # again using q for open ended search which will return a bunch of characteristics of the address
        for individual_address in address:
            logging.info(f"Geocoding input addresses {individual_address}")
            # again using q for open ended search which will return a bunch of characteristics of the address
            nominatim_params = {
                "q": individual_address,
                "limit": "1",
                "format": "json"}

            # again, need to use user agent for nominatim API (but can be any string value????)
            nominatim_headers = {"User-Agent": nominatim_user_agent}

            # now can actually query the API
            nominatim_response = requests.get(nominatim_url, params=nominatim_params, headers=nominatim_headers)
            nominatim_response.raise_for_status()
            nominatim_data = nominatim_response.json()

            # if nominatim couldn't find anywhere, then address was improperly formed
            if len(nominatim_data) == 0:
                raise ValueError(
                    f"Address {address} could not be geocoded, please double check that the address is correct")

            # get lat and lon out of response
            nominatim_content = nominatim_data[0]
            ind_address_lat = float(nominatim_content["lat"])
            ind_address_lon = float(nominatim_content["lon"])

            individual_address_coordinates:tuple = (ind_address_lat, ind_address_lon)

            # now define address coordinate tuple to return
            all_address_coordinates.append(individual_address_coordinates)

            # need to sleep because rate limits of nominatim API
            from time import sleep as sleepytime
            sleepytime(1.1)
        logging.info(f"Successfully geocoded {len(address)} addresses")
        return all_address_coordinates

    # in weird third case
    else:
        raise ValueError("Must provide either a string or a list of strings")

def generate_random_base64_value(input_number) -> str:
# generate random scenario id
        random_float = random.random()
        random_integer_value = str(int(input_number * random_float)) # using random number between 0 and 1 billion
        random_base64_value = base64.b64encode(random_integer_value.encode()).decode()
        random_base64_value.replace("/", "xx") # replacing dash just to be safe

        return random_base64_value

def turn_seconds_into_minutes(seconds: float) -> str:
    """ 
    Simple helper function that formats seconds into minutes and seconds as desired for logging
    """
    minutes = int(seconds / 60)
    remaining_seconds = seconds % 60
    if minutes == 0:
        return f"{remaining_seconds:.2f} seconds"
    else:
        return f"{minutes} minutes, {remaining_seconds:.2f} seconds"



if __name__ == "__main__":
    # demonstration of plus codes for bounding box
    san_francisco_ref_place = ReferencePlace(place_name="San Francisco, California, USA",
                                             bound_box=(-122.4194, 37.7749, -122.3731, 37.8091))
    print(san_francisco_ref_place.plus_codes_for_bbox_corners)

Python
"""
This module is used to define the network types and their attributes used for creation of network dataset
"""
import os
from pathlib import Path

network_dataset_template_dir = os.path.join(Path(__file__).parent,"nd_templates")

def get_template_path(template_name):
    """ Takes network name and gives path to corresponding template """
    return os.path.join(network_dataset_template_dir, f"{template_name}_nd_template.xml")


network_types_attributes = {"walk_no_z": {
                                "osmnx_network_type":"walk",
                                "network_dataset_template_name": get_template_path("walk_no_z"),
                                "isochrone_travel_mode":"Walking"},
                            "walk_z": {
                                "osmnx_network_type":"walk",
                                "network_dataset_template_name": get_template_path("walk_z"),
                                "isochrone_travel_mode": "Walking"},
                            "bike_no_z": {
                                "osmnx_network_type":"bike",
                                "network_dataset_template_name": get_template_path("bike_no_z")},
                            "bike_z": {
                                "osmnx_network_type":"bike",
                                "network_dataset_template_name": get_template_path("bike_z")},
                            "transit_no_z": {
                                "osmnx_network_type":"walk",
                                "network_dataset_template_name": get_template_path("transit_no_z"),
                                "isochrone_travel_mode": "Public transit time"},
                            "transit_z": {
                                "osmnx_network_type":"walk",
                                "network_dataset_template_name": get_template_path("transit_z"),
                                "isochrone_travel_mode": "Public transit time"},
                            "drive": {
                                "osmnx_network_type":"drive",
                                "network_dataset_template_name": get_template_path("drive")},
                            "transit_plus_biking_no_z" : {
                                "osmnx_network_type":"bike",
                                "network_dataset_template_name": get_template_path("transit_plus_biking_no_z")},
                            "transit_plus_biking_z" : {
                                "osmnx_network_type":"bike",
                                "network_dataset_template_name": get_template_path("transit_plus_biking_z")}
                            }

Python
"""
This module does most of the heavy lifting. Includes classes for ArcGIS Project, StreetNetwork, ElevationMapper, 
CacheFolder, Cache, TransitNetwork, StreetFeatureClasses, FeatureDataset, and GeoDatabase.
Originally, had written it as procedural (?) code, but decided to use OOP approach because 1. I like OOP better, and
2. Made more sense given that working with ArcGIS objects.
"""

import math
import os
import shutil
import sys
from pathlib import Path

from arcgis import features



# Setting up python environment/making sure env points to extensions correctly
import arcpy_config
arcpy_config.set_up_arcpy_env()

# if you want to add modules, MUST (!!!!!!!) come after this block (ensures extensions work in cloned environment)
# this is because arcpy is extremely fragile
import arcpy
import arcpy_init
#

import pickle
import logging
import osmnx as ox  # used for getting streets data
import geopandas as gpd
import networkx as nx
import time  # used for checking runtimes of functions/methods
import requests  # used for USGS API querying
import multiprocessing as mp  # used for querying in bulk
import asyncio # used for querying USGS api asynchronously
import aiohttp # same as requests (basically) but for asyncio
from itertools import repeat
import re
import platform
import isodate
from datetime import datetime
from zipfile import ZipFile
import random
from shapely.geometry import Point
import numpy as np


# local module(s)
import transit_data_for_arcgis
import network_types
from general_tools import *
from gtfs_tools import *
from gtfs_tools import transit_land_api_key

# making sure that using windows because otherwise cannot use arcpy and ArcGIS
if platform.system() != "Windows":
    raise OSError("Cannot run this module because not using Windows. ArcGIS and ArcPy require Windows")

# logging setup
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

# osmnx setup
ox.settings.timeout = 3000
ox.settings.max_query_area_size = 5000000000000
ox.settings.use_cache = True

# arcgis project template to allow for creating project
arcgis_project_template = arcpy.mp.ArcGISProject("project_template.aprx")

# simple thing to make sure network analyst is checked in/out
network_analyst_extension_checked_out = False


### need to fix arcproject setup/saving etc. ###

# ArcGIS project setup
class ArcProject:
    """ Class for ArcGIS Projects.
        Parameters:
            name: str,
            project_location: str | path to directory where project should exist, by default,
                the project will be within the same directory as this module.
            reset: bool (clears project if true).

        Methods:
            set_up_project(self) (see documentation below)
    """

    def __init__(self, name, project_location=Path(__file__).parent, reset=False):
        self.name = name
        self.project_location = project_location
        self.reset = reset
        self.project_dir_path = os.path.join(self.project_location, self.name)
        self.path = os.path.join(self.project_dir_path, rf"{self.name}.aprx")
        self.maps = []
        self.layouts = []
        self.project_file = self.path
        self.arcObject = None

        # set up on initialization
        self.set_up_project()

    # method to create project if doesn't already exist, and to activate project
    def set_up_project(self):
        """
        Checks if a project already exists with the same name at the selected path, if so,
        sets project as active workspace. Otherwise, creates a blank project with nothing in it with desired name in
        desire location.
        :return
            self.path | str
        """
        if arcpy.Exists(self.path):
            logging.info("Project already exists, opening it")
            self.arcObject = arcpy.mp.ArcGISProject(self.path)
        else:
            logging.info("Project does not exist, creating it")
            os.makedirs(self.project_dir_path)
            arcgis_project_template.saveACopy(self.path)
            self.arcObject = arcpy.mp.ArcGISProject(self.path)
        # set project as current workspace
        arcpy.env.workspace = self.project_dir_path

        return self.path
    # add more methods here. Add map, etc.

    def save_project(self):
        """Saves the arcproject"""
        try:
            # have to release the project file so can save????/
            del self.arcObject
            # reopen the project
            self.arcObject = arcpy.mp.ArcGISProject(self.path)
            # and now can save it
            self.arcObject.save()
            logging.info("Project saved and reopened successfully")

        except OSError as e:
            logging.warning(f"Could not save project (file may locked or project open: {e}")

# global functions needed at global level to do multiprocessing for API queries
def init_worker(counter, lock):
    # initializer function for multiprocess processes, need in order to have a global variable accessible to all workers
    global shared_counter, shared_lock
    shared_counter = counter
    shared_lock = lock

# functions for dealing with checking network analyst extension in and out
def check_out_network_analyst_extension():
    """
    Checks out network analyst extension
    :return: True - successful, exception if False
    """
    # need to check that network analyst extension is actually available to use, and then check it out
    if arcpy.CheckExtension("Network") == "Available":
        arcpy.CheckOutExtension("Network")
        logging.info("Network Analyst extension checked out")
        return True
    else:
        raise Exception("Network Analyst extension is not available")


def check_network_analyst_extension_back_in():
    """
    Checks network analyst extension back in
    :return: True - successful, exception if False
    """
    # need to check that network analyst extension is actually available to use, and then check it out
    if network_analyst_extension_checked_out:
        arcpy.CheckInExtension("Network")
        logging.info("Network Analyst extension checked back in")
        return True
    else:
        raise Exception("Network Analyst extension is not checked out")


def add_points_arcgis(feature_dataset_path: str, fc_name: str, point_coordinates: tuple | list[tuple]) -> str:
    """
    Adds points to a feature class in a feature dataset
    :param feature_dataset_path: str
    :param fc_name: str
    :param point_coordinates: tuple | list[tuple]
    :return: str | fc_path (the path to the feature class where the points were added)
    """
    # fc_path
    fc_path = os.path.join(feature_dataset_path, fc_name)

    # check to make sure valid (lat, lon) points provided
    if isinstance(point_coordinates, list):
        for point in point_coordinates:
            check_if_valid_coordinate_point(point)
    elif isinstance(point_coordinates, tuple):
        check_if_valid_coordinate_point(point_coordinates)
    else:
        raise Exception("Invalid point coordinates provided")

    # by default will overwrite existing feature class with same name in same loc (low risk of accidental collision)
    arcpy.env.overwriteOutput = True
    # create empty feature class
    arcpy.management.CreateFeatureclass(out_path=feature_dataset_path, out_name=fc_name, geometry_type="POINT",
                                        spatial_reference=arcpy.SpatialReference(4326))
    # fields for feature class
    fields = ["SHAPE@XY"]

    # in case where list of points provided, add each one
    if isinstance(point_coordinates, list):
        with arcpy.da.InsertCursor(fc_path, fields) as cursor:
            for point in point_coordinates:
                # flip to be (x, y) rather than lat, lon (y, x)
                lat, lon = point[0], point[1]
                point_xy = (lon, lat)

                # insert point_xy
                cursor.insertRow([point_xy])

    # in case where single point provided, add it
    elif isinstance(point_coordinates, tuple):
        with arcpy.da.InsertCursor(fc_path, fields) as cursor:
            # flip to be (x, y) rather than lat, lon (y, x)
            lat, lon = point[0], point[1]
            point_xy = (lon, lat)

            # insert point_xy
            cursor.insertRow([point_xy])

    return fc_path


class CacheFolder:
    """ Class for cache folder, takes param
        network_snake_name: str (MUST be in snake case) which will be the name of cache folder

        Attributes:
            network_snake_name | str: snake name of the network (passing snake name rather than StreetNetwork object
                because CacheFolder class must preceed StreetNetwork class in the code)
            env_dir_path | Path: path to the environment directory
            path | str: path to the cache folder

        Methods:
            check_if_cache_folder_exists(): Returns True if cache folder already exists for the city.
            set_up_cache_folder(): Return True if there is already a cache folder for city. If not, creates one.
            reset_cache_folder(): Completely reset the cache folder for the city (highly unadvisable because deletes
                osm data and elevation data
    """

    def __init__(self, snake_name_with_scope):
        self.snake_name_with_scope = snake_name_with_scope
        self.env_dir_path = Path(__file__).parent
        self.path = os.path.join(self.env_dir_path, "place_caches", f"{self.snake_name_with_scope}_cache")

    def check_if_cache_folder_exists(self):
        """ Returns True if cache folder already exists for the city."""
        if os.path.exists(self.path):
            return True
        else:
            return False

    def set_up_cache_folder(self):
        """Return True if there is already a cache folder for city. If not, creates one."""
        if os.path.exists(self.path):
            raise Exception(f"There is already a cache folder for {self.snake_name_with_scope}")
        else:
            os.makedirs(self.path)

    def reset_cache_folder(self):
        # completely reset the cache folder for the city
        if not os.path.exists(self.path):
            raise Exception(f"Cannot reset the cache folder for {self.snake_name_with_scope} "
                            f"because no such folder exists"
                            )
        else:
            os.makedirs(self.path, exist_ok=True)


# simple Cache class with obvious methods (read, write, check if exists)
class Cache:
    """
    Class for cache for use in saving street network data.

    Attributes:
        cache_folder | CacheFolder obj: cache folder for the street network
        cache_name | str: name of the cache
        cache_path | str: path to the cache

    Methods:
        check_if_cache_already_exists(): checks if cache already exists
        read_cache_data(): reads cached data from cache file
        write_cache_data(): writes desired data to cache file
    """

    def __init__(self, cache_folder: CacheFolder, cache_name):
        self.cache_folder = cache_folder
        self.cache_name = cache_name
        self.cache_path = os.path.join(self.cache_folder.path, self.cache_name)

    def check_if_cache_already_exists(self):
        if os.path.exists(self.cache_path):
            return True
        else:
            return False

    def read_cache_data(self):
        if self.check_if_cache_already_exists():
            with open(self.cache_path, "rb") as cache_file:
                cache_data = pickle.load(cache_file)
            return cache_data
        else:
            raise Exception("Cannot get cache data because there is no cache")

    def write_cache_data(self, data_to_cache):
        with open(self.cache_path, "wb") as cache_file:
            pickle.dump(data_to_cache, cache_file)


# StreetNetwork class very bare bones, just gets street network graph and makes associated GeoDataframse
class StreetNetwork:
    """
    Class representing street network for a location

    Attributes:
        geographic_scope (str): Geographic scope of the street network {"place_only", "msa", "csa"}
        reference_place_list: list of reference places to get network for (len will be one if using city limits, other-
        wise, will contain all the places in the MSA if MSA desired and in CSA if CSA desired)
        network_type (str): type of network being created {"walk_no_z", "walk_z", "bike_no_z", "bike_z",
        
        Local attributes:
        "transit_no_z", "transit_z", "drive", "transit_plus_biking_no_z", "transit_plus_biking_z"}
        bound_boxes (list): Bounding boxes passed (if bounding box used, len of list will be one)
        snake_name (str): Snake name of the city
        cache_folder (CacheFolder): Cache folder for the street network
        graph_cache (Cache): Cache for the street network graph
        nodes_cache (Cache): Cache for the street network nodes
        edges_cache (Cache): Cache for the street network edges
        network_graph (networkx.Graph): Street network graph
        network_nodes (geopandas.GeoDataFrame): Street network nodes
        network_edges (geopandas.GeoDataFrame): Street network edges
        elevation_enabled (bool): Whether elevation is enabled
        osmnx_type (str): Type of street network for osmnx ("walk", "bike", "drive", "drive_service", "all", "all_public")

        Methods:
            get_street_network_from_osm(timer_on=True, reset=False): Gets street network from OpenStreetMaps or
            from cache
    """

    def __init__(self, geographic_scope: str, reference_place_list: list[ReferencePlace], network_type="walk_no_z"):
        #### need to make __init__ method cleaner ####
        self.geographic_scope = geographic_scope
        self.reference_place_list = reference_place_list
        self.network_type = network_type

        # get place name and bbox out of reference place
        self.place_names = [reference_place.place_name for reference_place in reference_place_list]
        self.bound_boxes = [reference_place.bound_box for reference_place in reference_place_list]
        self.main_reference_place = reference_place_list[0]

        if self.main_reference_place.bound_box:
            self.geographic_scope = "bbox"

        # create snake name for StreetNetwork (the first item in the list is always the main place)
        self.snake_name = create_snake_name(self.main_reference_place)
        # create snake name with geographic scope encoded
        self.snake_name_with_scope = f"{self.snake_name}_{self.geographic_scope}"
        # link cache folder
        self.cache_folder = CacheFolder(self.snake_name_with_scope)
        # decode the network type into proper network type designation for osmnx query
        self.osmnx_type = network_types.network_types_attributes[self.network_type]["osmnx_network_type"]

        # check if there is a cache folder for desired street network
        if not self.cache_folder.check_if_cache_folder_exists():
            self.cache_folder.set_up_cache_folder()

        # setting up caches for this street network
        self.graph_cache = Cache(self.cache_folder, "graph_cache")
        self.nodes_cache = Cache(self.cache_folder, "nodes_cache")
        self.edges_cache = Cache(self.cache_folder, "edges_cache")
        self.edges_cache.cache_folder.check_if_cache_folder_exists()

        # placeholder for now, but will update in get_street_network_from_osm method
        # so can access gdfs and graph when passing instance as argument
        self.network_graph = None
        self.network_nodes = None
        self.network_edges = None
        self.elevation_enabled = False

    def get_street_network_graph_from_osmnx(self, timer_on=True, reset=False):
        """
        Main function for class that gets street network edges and nodes, either from cache, if cache exists, or from
        OpenStreetMaps (through osmnx)
        :param timer_on: logs run time for method if, on by default
        :param reset: if True, will reset cache and get street network from OSM
        :return: street network graph, and the nodes and egdes that make up the graph in geodataframes
        """
        process_start_time = time.perf_counter()
        logging.info("Getting street network")

        # first if not using cache or no cache data
        if reset or (not self.graph_cache.check_if_cache_already_exists()):
            logging.info("Getting street network from OSM")

            # if using city, getting OSM data if not using cache or if no cache exists
            if self.main_reference_place.bound_box is None:

                # if there is more than one reference place (not using city limits) need to combine street networks
                if len(self.reference_place_list) > 1:
                    network_graphs = [] # list of the networkx graph objects representing city street networks
                    for reference_place in self.reference_place_list:
                        logging.info(f"Getting street network for {reference_place.pretty_name}")

                        # getting each individual graph one at a time first
                        single_network_graph = ox.graph_from_place(reference_place.place_name,
                                                                   network_type=self.osmnx_type, retain_all=True,
                                                                   truncate_by_edge=True)
                                                                   # need to use truncate_by_edge when using
                                                                   # multiple reference places to avoid gaps at borders

                        # add each individual graph to the list so can combine
                        network_graphs.append(single_network_graph)

                    logging.info("Combining street networks")
                    # using compose all to combine the various street grids
                    network_graph = nx.compose_all(network_graphs)

                # when using city limits, only need street grid for main place
                elif len(self.reference_place_list) == 1:
                    logging.info(f"Getting street network for {self.main_reference_place.pretty_name} city proper")
                    network_graph = ox.graph_from_place(self.main_reference_place.place_name,
                                                        network_type=self.osmnx_type, retain_all=True)

                else:
                    raise Exception("Cannot get street network because no place was specified")
                
                # turn graph into gdfs
                network_nodes, network_edges = ox.graph_to_gdfs(network_graph, nodes=True, edges=True)


            # if using bound box, getting OSM Data
            else:
                logging.info(f"Getting street network for {self.main_reference_place.pretty_name}")
                network_graph = ox.graph_from_bbox(bbox=self.main_reference_place.bound_box,
                                                   network_type=self.osmnx_type, retain_all=True)
                # again graph to gdfs
                network_nodes, network_edges = ox.graph_to_gdfs(network_graph, nodes=True, edges=True)

        # otherwise using cached data
        else:
            logging.info("Using cached street network")
            network_graph = self.graph_cache.read_cache_data()
            network_nodes, network_edges = self.nodes_cache.read_cache_data(), self.edges_cache.read_cache_data()

        # just because I want to keep track of everything
        if timer_on:
            logging.info(f"Got street network from OSM in {time.perf_counter() - process_start_time} seconds")
            logging.info(f"Street network for {self.main_reference_place.place_name} {self.geographic_scope} had "
                         f"{len(network_nodes)} nodes and "
                         f"{len(network_edges)} edges")

        self.graph_cache.write_cache_data(network_graph)
        self.nodes_cache.write_cache_data(network_nodes)
        self.edges_cache.write_cache_data(network_edges)
        self.network_graph, self.network_nodes, self.network_edges = network_graph, network_nodes, network_edges
        return network_graph, network_nodes, network_edges

    # methods to count nodes and edges in street network. No real reason to exist other than interest
    def count_nodes(self):
        if self.network_nodes is None:
            raise Exception("Cannot count nodes because there are no nodes")
        else:
            return len(self.network_nodes)

    def count_edges(self):
        if self.network_edges is None:
            raise Exception("Cannot count edges because there are no edges")
        else:
            return len(self.network_edges)

    # adding a method that gets the PBF osm data because need for R5
    def get_osm_pbf_data(self, reset=False):
        """
        Downloads the necessary osm data in pbf form for use in R5 routing based on geographic scope
        :param reset:
        :return:
        """


# ElevationMapper class adds elevation to nodes and grades to edges (coming soon?) using USGS EPQS API
class ElevationMapper:
    """ WARNINGS!:
            1. Adding elevation takes forever, ~130ms per node with 14 threads (1.8s per node per thread!)
                because the API is so slow
            2. Uses multiprocessing (scary) so use the "if __name__ == __main__" guard!!!!!
        Attributes:
            street_network | StreetNetwork obj, threads_available | int, reset | bool

        Methods:
            add_elevation_data_to_nodes(): gets elevation data from USGS EPQS API and adds to nodes geodataframe
            add_elevation_data_to_edges(): calculates grades for edges based on elevation data (must be called after
                add_elevation_data_to_nodes)
    """

    def __init__(self, street_network: StreetNetwork, concurrent_requests_desired:int=100, reset=False):
        self.street_network = street_network
        self.concurrent_requests_desired = concurrent_requests_desired
        self.reset = reset
        (self.street_network_graph, self.street_network_nodes, self.street_network_edges) = (
            street_network.network_graph, street_network.network_nodes, street_network.network_edges
        )
        self.node_counter = 0
        self.elevation_cache = Cache(street_network.cache_folder, "elevation")
        self.nodes_without_elevation = 0

    # decided to switch to using asyncio rather than multiprocessing because is I/O bound
    async def get_elevation_data_single_query(self, async_session:aiohttp.ClientSession,
                                              idx, lon, lat, semaphore: asyncio.Semaphore, node_counter,
                                              total_nodes, start_time):
        usgs_url = "https://epqs.nationalmap.gov/v1/json"
        params = {"x": float(lon), "y": float(lat), "units": "Meters", "wkid": 4326}

        # using semaphore to limit number of concurrent requests
        async with semaphore:
            # using try in case of exception in getting elevation for node
            try:
                # essentially same as requests.get but for async function. Timeout set as 30s should be high enough but
                # 45 or 60s might be better (using 100 because not happy with me)
                async with async_session.get(usgs_url, params=params, timeout=100) as response:

                    if response.status != 200:
                        raise Exception(f"HTTP {response.status}: {await response.text()}")

                    # wait on response then get the "value" output received from the API
                    elevation_data_response = await response.json()
                    elevation = elevation_data_response.get("value")

                    # check whethere received actual data or default no data, which for usgs is -1000000
                    if elevation == -1000000:
                        elevation = None

                    # update counter for progress tracking
                    node_counter["nodes_completed"] += 1
                    current_node = node_counter["nodes_completed"]
                    time_elapsed = time.perf_counter() - start_time
                    querying_rate = current_node / time_elapsed if time_elapsed > 0 else 0
                    estimated_total_time = total_nodes / querying_rate if querying_rate > 0 else 0

                    # convert the elapsed time and estimated time into minutes and seconds for prettier printing
                    time_elapsed_in_minutes_and_seconds = {"minutes": time_elapsed // 60, "seconds": time_elapsed % 60}
                    estimated_total_time_in_minutes_and_seconds = {"minutes": estimated_total_time // 60,
                                                                   "seconds": estimated_total_time % 60}
                    # printing progress bar/tracker thingy
                    print(f"rElevation: {current_node}/{total_nodes} nodes processed. "
                          f"{time_elapsed_in_minutes_and_seconds['minutes']:.0f} minutes: "
                          f"{time_elapsed_in_minutes_and_seconds['seconds']:.0f} seconds/"
                          f"{estimated_total_time_in_minutes_and_seconds['minutes']:.0f} minutes: "
                          f"{estimated_total_time_in_minutes_and_seconds['seconds']:.0f} seconds"
                          f" (elapsed time/estimated total time). Rate = {querying_rate} nodes/second",
                          end="", flush=True)

                    # now that have elevation can return the index of the node (from itterrowing them) and its elevation
                    return idx, elevation

            # in case of exception, just returning None as the elevation for the node.
            except Exception as e:
                self.nodes_without_elevation += 1
                return idx, None

        logging.warning(f"Couldn't get elevation for {self.nodes_without_elevation}/{total_nodes} nodes")
        return idx, None

    # this function just tells the little async worker thingies how and when to query the API
    async def get_all_elevation_data(self):
        """
        Gets elevation data for all nodes in street network using the async function get_elevation_data_single_query.
        :return:
        """
        # nodes to go through
        nodes_list = [(idx, float(row["x"]), float(row["y"])) for idx, row in self.street_network_nodes.iterrows()]

        # the semaphore is the host of the restaurant in the Guido asyncio restaurant analogy
        semaphore = asyncio.Semaphore(self.concurrent_requests_desired)

        # have to manually define TCP connector to force close otherwise leave zombie connections
        connector = aiohttp.TCPConnector(limit=100, limit_per_host=100, ttl_dns_cache=300, enable_cleanup_closed=True,
                                        force_close=True)

        node_counter = {"nodes_completed": 0}
        start_time = time.perf_counter()
        total_nodes = len(nodes_list)


        # open new aiohttp session (basically does same as response = response.get...)
        async with aiohttp.ClientSession(connector=connector) as session:

            # only asyncio task is getting elevation data for each node
            tasks = [self.get_elevation_data_single_query(async_session=session, idx=idx, lon=lon, lat=lat,
                                                          semaphore=semaphore, node_counter=node_counter,
                                                          total_nodes=total_nodes, start_time=start_time)
                    for idx, lon, lat in nodes_list]

            # results from querying elevation API for each node
            results = await asyncio.gather(*tasks)
            return dict(results)

    # this method is the main method that actually sends out for the elevation data, and then adds it to street network
    def add_elevation_data_to_nodes(self):

        logging.info(f"Adding elevation to the street network for "
                     f"{self.street_network.main_reference_place.place_name} {self.street_network.geographic_scope}")

        if not self.reset and self.elevation_cache.check_if_cache_already_exists():
            elevation_data = self.elevation_cache.read_cache_data()
        else:
            # run the async function to get the elevation data for all the nodes
            elevation_data = asyncio.run(self.get_all_elevation_data())

        # make list of elevations
        z_values:list = [elevation_data.get(idx) for idx in elevation_data]
        # update the local street network gdf within the ElevationMapper
        self.street_network_nodes["z"] = z_values
        self.street_network.network_nodes["z"] = z_values

        # cache elevation_data dict {<node_idx>: <elevation>}
        self.elevation_cache.write_cache_data(elevation_data)

        # now add to the original graph
        for node_idx, elevation in elevation_data.items():
            if node_idx in self.street_network_graph.nodes:
                elevation_value = float(elevation) if elevation is not None else None
                self.street_network_graph.nodes[node_idx]["elevation"] = elevation_value
                self.street_network_graph.nodes[node_idx]["z"] = elevation_value

        # update the elevation enabled flag
        self.street_network.elevation_enabled = True

        # update the geometry of the nodes to include the elevation
        self.street_network_nodes["geometry"] = [
            Point(row.geometry.x, row.geometry.y, z if z is not None else 0)
            for (idx, row), z in zip(self.street_network_nodes.iterrows(), z_values)
        ]

        # change original network nodes geometry to include elevation
        self.street_network.network_nodes["geometry"] = [
            Point(row.geometry.x, row.geometry.y, z if z is not None else 0)
            for (idx, row), z in zip(self.street_network_nodes.iterrows(), z_values)
        ]

        # print extra blank line after progress bar thingy
        print()
        logging.info(f"Got elevation data for {len(self.street_network_nodes)} nodes")
        logging.info("Elevation data added to the street network")

    def add_grades_to_edges(self):
        """
        Calculates and adds grade (slope) to edges based on node elevations.
        Grade is calculated as: (elevation_change / edge_length) * 100
        Positive grade = uphill, Negative grade = downhill, Grade = 0 means flat ground
        """
        number_of_edges_without_elevation_data = 0

        process_start_time = time.perf_counter()
        logging.info(
            f"Adding grades to edges for {self.street_network.main_reference_place.pretty_name} street network")

        # doubled check street network is elevation enabled
        if not self.street_network.elevation_enabled:
            raise Exception("Cannot add grades to edges, elevation data not available for nodes")

        # iterating throuuh all edges in the streetnetwork graph
        for start_node, end_node, key, data in self.street_network_graph.edges(keys=True, data=True):
            # elevation for start and end nodes (start_node and end_node respectively)
            start_elevation = self.street_network_graph.nodes[start_node].get("elevation")
            end_elevation = self.street_network_graph.nodes[end_node].get("elevation")
            edge_length = data.get("length", 0)

            # only calculate grade if elevation for start and end node available
            if start_elevation is not None and end_elevation is not None and edge_length > 0:
                elevation_change = end_elevation - start_elevation
                grade = (elevation_change / edge_length) * 100  # doing grade as percentage (100% = 45 degree slope)

                # updating graph data based on calculated grades
                data["grade"] = round(grade, 2)  # rounding because otherwise crazy number

                data["grade_magnitude"] = round(abs(grade), 2)  #
                data["direction"] = 1 if grade > 0 else 0
            else:
                number_of_edges_without_elevation_data += 1
                # handling missing data
                data["grade"] = None
                data["grade_magnitude"] = None
                data["direction"] = None

        logging.warning(f"Missing elevation or length data for {number_of_edges_without_elevation_data}/"
                    f"{len(self.street_network_graph.edges())} edges total")

        # update the edges geodataframe
        self.street_network.network_nodes, self.street_network.network_edges = ox.graph_to_gdfs(
            self.street_network_graph, nodes=True, edges=True
        )

        process_run_time = time.perf_counter() - process_start_time
        logging.info(f"Successfully added grades to edges in {turn_seconds_into_minutes(process_run_time)}")


### need to review GeoDatabase and ArcProject classes code to make sure no errors
class GeoDatabase:
    def __init__(self, arcgis_project: ArcProject, street_network: StreetNetwork):
        self.project = arcgis_project
        self.street_network = street_network
        self.project_file_path = arcgis_project.path
        self.gdb_name = self.street_network.snake_name_with_scope
        self.gdb_path = os.path.join(self.project.project_dir_path, f"{self.gdb_name}.gdb")

    def save_gdb(self):
        # using try to avoid errors in case of file lock
        try:
            # save the project
            self.project.save_project()

            # now can update the list of geodatabases and set created one as the default

            # get the list of dictionaries representing the databases in the project
            gdb_dictionary_for_adding = {"databasePath": self.gdb_path, "isDefaultDatabase": True}
            # get current list of gdbs for project
            current_gdbs = self.project.arcObject.databases
            # go through and make sure that none of the current gdbs will be default
            for gdb_dict in current_gdbs:
                gdb_dict["isDefaultDatabase"] = False

            # in case the desired gdb is not showing up in the current gdbs list
            if self.gdb_path not in [gdb_dict["databasePath"] for gdb_dict in current_gdbs]:
                current_gdbs.append(gdb_dictionary_for_adding)

            # now can safely set the geodatabase for this call as default
            for gdb_dict in current_gdbs:
                if gdb_dict["databasePath"] == self.gdb_path:
                    gdb_dict["isDefaultDatabase"] = True

            # now update project
            self.project.arcObject.updateDatabases(current_gdbs)

            # housekeeping
            logging.info("Project saved successfully and default geodatabase set")

        # will get OSError if not allowed to save because of lock
        except OSError as e:
            logging.warning(f"Could not save project (file may locked or project open: {e})")

    def set_up_gdb(self, reset=False):
        """ Creates geodatabase if one does not exist, and resets it if reset desired."""
        # making sure not trying to set up gdb before project
        if not arcpy.Exists(self.project.path):
            raise FileNotFoundError(f"Cannot set up geodatabase in the the provided project {self.project.name}"
                                    f"because it doesn't exist")

        # in case where gdb does not exist and reset is not desired
        if not reset and not arcpy.Exists(self.gdb_path):
            logging.info("No existing geodatabase found, creating new geodatabase")
            arcpy.management.CreateFileGDB(self.project.project_dir_path, self.gdb_name)

        # in case where gdb does not exist but reset is desired (erroneously)
        elif not arcpy.Exists(self.gdb_path) and reset:
            logging.warning(f"No geodatabase with the name {self.gdb_name} exists, but you indicated you would like to"
                            f"reset. If this was a mistake, stop the script (in case other stuff gets reset).")
            logging.info("No existing geodatabase found, creating new geodatabase")
            arcpy.management.CreateFileGDB(self.project.project_dir_path, self.gdb_name)

        # in case where gdb exists and reset is desired
        elif arcpy.Exists(self.gdb_path):
            logging.info("Existing geodatabase found, deleting and creating new geodatabase")

            # need to delete using Arc...
            arcpy.env.overwriteOutput = True
            arcpy.Delete_management(self.gdb_path)

            # then create
            arcpy.management.CreateFileGDB(self.project.project_dir_path, self.gdb_name)

        # in case where gdb exists and reset is not desired
        else:
            raise Exception(f"A GeoDatabase {self.gdb_name} already exists but reset desire was not indicated")

        # modifying current arcpy env
        arcpy.env.workspace = self.gdb_path
        # just setting current gdb as default. need to debug to figure out why have to do this way
        self.project.arcObject.defaultGeodatabase = self.gdb_path
        # save project to make sure that default gdb actually gets set
        self.save_gdb()
        current_gdbs = self.project.arcObject.databases
        debuggggg_me = True

class FeatureDataset:  # add scenario_id so can do multiple scenarios of same network type
    """
        The feature dataset is a container for various feature classes, and is where the network dataset will go.
        
        Attributes:
            gdb (GeoDatabase): The geodatabase where the feature dataset will be created.
            street_network (StreetNetwork): The street network that will be used to create the feature dataset.
            scenario_id (str): The ID of the scenario that the feature dataset will be created for.
            network_type (str): The type of network that the feature dataset will be created for.
            reset (bool): Whether to reset the feature dataset if it already exists.
            name (str): The name of the feature dataset.
            path (str): The path to the feature dataset.
            
        Methods:
            create_feature_dataset(): Creates the feature dataset.
            reset_feature_dataset(): Resets the feature dataset.
    """
    
    def __init__(self, gdb: GeoDatabase, street_network: StreetNetwork, scenario_id:str, network_type: str = "walk_no_z", reset=False):
        self.gdb = gdb
        self.street_network = street_network
        self.scenario_id = scenario_id
        self.network_type = network_type
        self.reset = reset
        self.name = f"{self.network_type}_{scenario_id}_fd"
        self.path = os.path.join(self.gdb.gdb_path, f"{self.name}")

        # if reset true, then overwrite existing features
        arcpy.env.overwriteOutput = True

    def create_feature_dataset(self):
        # check if FD exists, and reset if needed, else create
        logging.info("Creating feature dataset")

        # making sure don't accidentally try to create when desired feature dataset already exists
        if arcpy.Exists(self.path) and not self.reset:
            raise Exception(f"The feature dataset {self.path} already exists, either reset it or use it")

        elif arcpy.Exists(self.path):
            arcpy.Delete_management(self.path)
        # creating feature dataset
        arcpy.management.CreateFeatureDataset(self.gdb.gdb_path, self.name,
                                              spatial_reference=arcpy.SpatialReference(4326))
        logging.info("Feature dataset successfully created")

    def reset_feature_dataset(self):
        # just a method to reset the desired feature dataset
        if not arcpy.Exists(self.path):
            raise Exception(f"Cannot delete feature dataset at {self.path} because it doesn't exist")
        else:
            arcpy.Delete_management(self.path)
            arcpy.management.CreateFeatureDataset(self.gdb.gdb_path, self.name,
                                                  spatial_reference=arcpy.SpatialReference(4326))


class StreetFeatureClasses:
    """
    StreetFeatureClasses is nodes and edges for the street network

    Attributes:
        feature_dataset (FeatureDataset): Feature dataset containing street nodes and edges
        street_network (StreetNetwork): Street network containing necessary nodes and edges geodataframe
        use_elevation (bool): Whether to use elevation in calculations
        reset (bool): Whether to reset the feature classes

    Methods:
        create_empty_feature_classes(): Creates empty feature classes for street nodes and edges
        add_street_network_data_to_feature_classes(): Adds street network data to empty feature classes created
    """

    def __init__(self, feature_dataset: FeatureDataset, street_network: StreetNetwork, use_elevation=False,
                 reset=False):
        self.feature_dataset = feature_dataset
        self.street_network = street_network
        self.use_elevation = use_elevation  # determines whether to add z values to nodes in feature classes
        self.reset = reset

        # paths for the two feature classes
        self.nodes_fc_path = os.path.join(self.feature_dataset.path, "nodes_fc")
        self.edges_fc_path = os.path.join(self.feature_dataset.path, "edges_fc")

        # useful shorthand to have rather than writing self.street_network.network_nodes etc all the time
        self.nodes = street_network.network_nodes
        self.edges = street_network.network_edges

        # the cache folder
        self.cache_folder = self.street_network.cache_folder

        # error handling for using elevation when creating feature classes for street network
        if "z" not in self.street_network.network_nodes and self.use_elevation:
            raise Exception("Cannot use elevation because input StreetNetwork object has no z values")

    def calculate_walk_times(self) -> None:
        """
        Adds necessary walk time columns to geodataframes for edges. In case elevation not enabled on network
        dataset, then will calculate walk time as (length in meters/ 85 meters per minute). In case elevation
        enabled, then will calculate walk time using Tobler's hiking function and calculate against and along walk
        times.
         """
        logging.info("Calculating walk times for edges")

        def calculate_flat_ground_walk_time(edges_gdf) -> None:
            logging.info("Not using elevation; will calculate walk time as (length in meters/ 85 meters per minute)")
            # love pandas/geopandas because all I have to do to calculate a new field is this!
            edges_gdf["walk_time"] = (edges_gdf["length"] / 85)

        def calculate_walk_time_with_elevation(edges_gdf) -> None:
            logging.info("Using elevation; will use Tobler's to calculate walk time")

            # in case no grade then don't just fail in the background
            if "grade" not in edges_gdf.columns:
                logging.warning("Input gdf has no grade column, will calculate flat ground walk times")
                calculate_flat_ground_walk_time(edges_gdf)
            else:
                # Am using Tobler's hiking function here (see wikipedia). FT represents "from-to" or along for a given
                # edge, and TF represents "to-from" or against for a given edge.

                # first calculate speed for along
                speed_FT_km_per_hour = 6 * (np.exp(-3.5 * np.abs((edges_gdf["grade"]/100) + 0.05)))
                speed_FT_m_per_min = (speed_FT_km_per_hour * 1000) / 60

                # now calculate speed for against
                speed_TF_km_per_hour = 6 * (np.exp(-3.5 * np.abs((-(edges_gdf["grade"]/100)) + 0.05)))
                speed_TF_m_per_min = (speed_TF_km_per_hour * 1000) / 60

                # now  graded walk times for edges are just equal to the length divided by the respective speeds above
                edges_gdf["walk_time_graded_FT"] = edges_gdf["length"] / speed_FT_m_per_min
                edges_gdf["walk_time_graded_TF"] = edges_gdf["length"] / speed_TF_m_per_min

        # now can calculate walk times
        if self.use_elevation:
            calculate_walk_time_with_elevation(self.edges)
        else:
            calculate_flat_ground_walk_time(self.edges)

    # now going to convert gdfs to geojson (faster than creating feature class using insert cursor)
    def convert_geodataframes_to_geojson(self) -> tuple[str, str]:
        logging.info("Converting geodataframes to GeoJSONs")
        process_start_time = time.perf_counter()

        # create paths for the shapefiles
        nodes_geojson_path = os.path.join(self.cache_folder.path, "nodes_fc.geojson")
        edges_geojson_path = os.path.join(self.cache_folder.path, "edges_fc.geojson")

        # make sure that the required directories exist
        os.makedirs(os.path.dirname(nodes_geojson_path), exist_ok=True)
        os.makedirs(os.path.dirname(edges_geojson_path), exist_ok=True)

        # write the gdfs to geojsons
        self.nodes.to_file(nodes_geojson_path, driver="GeoJSON")
        self.edges.to_file(edges_geojson_path, driver="GeoJSON")

        # housekeeping
        process_run_time = time.perf_counter() - process_start_time
        logging.info(f"Finished converting geodataframes to GeoJSONs in {turn_seconds_into_minutes(process_run_time)}")

        return nodes_geojson_path, edges_geojson_path

    # now can convert the geojsons for nodes and edges into feature classes
    def convert_geojsons_to_feature_class(self, outputted_geojsons: tuple[str, str]):
        """
        Takes the street network geojson created by convert_geodataframes_to_geojson and turns them
        into geodatabase feature classes
        """
        logging.info("Converting GeoJSONs to GeoDatabase feature classes")
        process_start_time = time.perf_counter()

        # get the paths to the geojson out from the provided tuple
        nodes_geojson_path = outputted_geojsons[0]
        edges_geojson_path = outputted_geojsons[1]

        # just a reminder that the object has the desired fc paths as attributes
        # (self.nodes_fc_path and self.edges_fc_path)

        # convert nodes from geojson to feature class using arcpy
        logging.info("Converting nodes GeoJSON to feature class")
        arcpy.conversion.JSONToFeatures(in_json_file=nodes_geojson_path, out_features=self.nodes_fc_path,
                                        geometry_type="POINT")
        # do the same for edges
        logging.info("Converting edges GeoJSON to feature class")
        arcpy.conversion.JSONToFeatures(in_json_file=edges_geojson_path, out_features=self.edges_fc_path,
                                        geometry_type="POLYLINE")

        # housekeeping
        process_run_time = time.perf_counter() - process_start_time
        logging.info(f"Finished converting GeoJSONs to geodatabase feature classes in {turn_seconds_into_minutes(process_run_time)}")

        ### MIGHT WANT TO ADD BIT HERE TO DELETE THE geojsons? ###

    def map_street_network_to_feature_classes(self):
        """
        The only method needed to call for this feature class, maps the street network edges and
        nodes to feature classes.
        """
        # first check if cached GeoJSONs

        nodes_geojson_path = os.path.join(self.cache_folder.path, "nodes_fc.geojson")
        edges_geojson_path = os.path.join(self.cache_folder.path, "edges_fc.geojson")

        # in case where cached data exists
        if os.path.exists(nodes_geojson_path) and os.path.exists(edges_geojson_path) and not self.reset:

            logging.info("Cached GeoJSONs found, converting to feature classes")
            process_start_time = time.perf_counter()

            # just need to convert the geojsons to feature classes
            input_geojsons = (nodes_geojson_path, edges_geojson_path)
            self.convert_geojsons_to_feature_class(input_geojsons)

            # housekeeping
            process_run_time = time.perf_counter() - process_start_time
            logging.info(f"Mapped the street network to feature classes in {turn_seconds_into_minutes(process_run_time)}")

        else:
            logging.info("Cached data not available, mapping street network to feature classes")
            process_start_time = time.perf_counter()
            # first, calculate walk times while still a dataframe
            self.calculate_walk_times()

            # next, convert gdfs to geojsons and then geojsons to feature classes
            outputted_geojsons = self.convert_geodataframes_to_geojson()
            self.convert_geojsons_to_feature_class(outputted_geojsons)

            # housekeeping
            process_run_time = time.perf_counter() - process_start_time
            logging.info(f"Mapped the street network to feature classes in {turn_seconds_into_minutes(process_run_time)}")


class TransitNetwork:
    def __init__(self, geographic_scope, feature_dataset: FeatureDataset, reference_place_list: list[ReferencePlace],
                 modes: list = None, agencies_to_include:list[TransitAgency]=None, own_gtfs_data_paths:list[str]=None):
        """
        Transit network class for place
        
        Attributes:
            geographic_scope: GeographicScope | geographic scope of the transit network {"place_only", "msa", "csa"}
            feature_dataset: FeatureDataset | feature dataset where the transit network will be created
            reference_place_list: list[ReferencePlace] | list of reference places for the transit network
            modes: list | modes to be included in the transit network {"all", "bus", "heavy_rail", "light_rail",
            "regional_rail", "ferry", "gondola", "funicular", "trolleybus", "monorail"}
            (for more, see gtfs_tools.route_types documentation)

        **Methods:**
            get_transit_agencies_for_place: creates a list of transit agencies that serve place (see method doc)


        """
        self.geographic_scope = geographic_scope
        self.feature_dataset = feature_dataset
        self.reference_place_list = reference_place_list
    
        self.place_names = [reference_place.place_name for reference_place in reference_place_list]
        self.bound_box = [reference_place.bound_box for reference_place in reference_place_list]
        self.main_reference_place = reference_place_list[0]

        if self.main_reference_place.bound_box:
            self.geographic_scope = "bbox"

        self.snake_name = create_snake_name(self.main_reference_place)
        self.snake_name_with_scope = f"{self.snake_name}_{self.geographic_scope}"
        self.modes = modes
        # eventually this is what will be used to pass gtfs data to create network dataset
        self.gtfs_folders = None

        # link to cache folder
        self.cache_folder = CacheFolder(self.snake_name_with_scope)

        # if no agencies are specified, will default to all agencies that serve the place
        self.agencies_to_include = agencies_to_include
        if self.agencies_to_include is None:
            logging.info("a list of transit agencies to include was not specified so all agencies that serve the place"
                         " will be used")
            self.agencies_to_include = self.get_agencies_for_place()

        # adding the ability to bring your own gtfs data
        self.own_gtfs_data_paths = own_gtfs_data_paths
        if self.own_gtfs_data_paths is None:
            logging.info(f"Own GTFS data not provided, will query TransitLand API to get transit agencies that serve"
                         f"{self.main_reference_place.pretty_name}")

        # once the first three methods have been run (get_agencies_for_place, get_gtfs_for_transit_agencies,
        # unzip_gtfs_data), this dictionary will contain the paths of unzipped gtfs data available for place
        # for desired Agencies
        self.gtfs_folders = None
        self.agency_feed_valid_dates = {}

    # using requests instead of aihttp for this because only single request
    def get_agencies_for_place(self):
        """
        Takes a place name of format 'city, state, country'
        and returns a list of transit agencies that serve the place

        :return: list[TransitAgency] | list of transit agencies (TransitAgencyObjects) that serve the place
        """
        logging.info(f"Getting agencies that serve {self.main_reference_place.pretty_name}")
        # list of TransitAgency objects to be returned
        agencies_for_place = []

        # in case where using place (with geographic scope) rather than bounding box
        if self.main_reference_place.bound_box is None:

            # iterate through reference places to get the agencies that serve them
            for reference_place in self.reference_place_list:
                # transit land's API only requires the city name (although this seems stupid)
                place_short_name = reference_place.place_name.split(",")[0]

                transit_land_response = requests.get(f"https://transit.land/api/v2/rest/agencies?api_key="
                                                     f"{transit_land_api_key}"
                                                     f"&city_name={place_short_name}")
                transit_land_response.raise_for_status()

                # json containing the agencies
                transit_land_data = transit_land_response.json()

                # going through the agency dicts provided by the api and using them as kwargs for TransitAgency object
                for agency_data in transit_land_data["agencies"]:
                    # fill out TransitAgency objects using the data from the API
                    temp_agency = TransitAgency(**agency_data)
                    agencies_for_place.append(temp_agency)
                logging.info(f"Found {len(agencies_for_place)} agencies that serve {reference_place.pretty_name}"
                             f" {self.geographic_scope}")

        # in case where using bounding box
        else:
            # the bbox is concatenated into a string to be used in the query to transitland's API
            bbox_query_string = ",".join(self.main_reference_place.bound_box)
            transit_land_response = requests.get(f"https://transit.land/api/v2/rest/agencies?api_key="
                                                 f"{transit_land_api_key}"
                                                 f"&bbox={bbox_query_string}")
            transit_land_response.raise_for_status()
            transit_land_data = transit_land_response.json()

            for agency_data in transit_land_data["agencies"]:
                # fill out TransitAgency objects using the data from the API
                temp_agency = TransitAgency(**agency_data)
                agencies_for_place.append(temp_agency)

            logging.info(f"Found {len(agencies_for_place)} agencies that serve {self.main_reference_place.pretty_name}")

        # now can set self.agencies_that_serve_place
        return agencies_for_place

    def get_gtfs_for_transit_agencies(self):
        """
            Gets the latest static GTFS data for agencies desired.

            :param agencies_to_include: list[TransitAgency] | list of transit agencies to get GTFS data for
                (by default will be all agencies that serve the reference place)
            :return: gtfs_zip_folders | dict with TransitAgencies as keys and
             the (zipped) file names where the gtfs feeds are written as values.
            :return: agency_feed_valid_dates: dict{TransitAgency: {"last_updated": MMDDYYYY, "valid_until": MMDDYYYY}}
                | dictionary of the valid dates for each agency's feed and when it was last updated (because will have
                to deal with feeds that are not current or currently valid.
        """
        # if a list of agencies to include was not provided then by default will use every transit agency serving place
        logging.info(f"Getting GTFS data for {len(self.agencies_to_include)} transit agencies")
        # outputs for the method
        gtfs_zip_folders = {}


        # next, iterating through the list of desired agencies and getting their data
        for agency in self.agencies_to_include:
            # onestop_id (used to query for feed) is in feed_version["feed"]["onestop_id"]
            feed_version = agency.feed_version
            onestop_id = feed_version["feed"]["onestop_id"]
            transit_land_api_url = (f"https://transit.land/api/v2/rest/feeds/{onestop_id}/download_latest_feed_version"
                                    f"?api_key={transit_land_api_key}")

            # standard API query
            response = requests.get(transit_land_api_url)
            response.raise_for_status()

            # the file path where the zipped gtfs data will be saved to (yes complicated, but best for organization
            file_path = os.path.join(self.cache_folder.path, "gtfs_caches", f"{onestop_id}", "zipped_gtfs",
                                     f"{onestop_id}.zip")

            # set up folders in case they don't already exist
            os.makedirs(os.path.dirname(file_path), exist_ok=True)

            # writing the content that was returned from transit land to a zip file
            with open(file_path, "wb") as gtfs_zipped_file:
                gtfs_zipped_file.write(response.content)

            # now need to see 1. when the feed was fetched (and that it's not too out of date) and 2. whether it's valid
            feed_version_query_response = requests.get(f"https://transit.land/api/v2/rest/feeds/{onestop_id}"
                                                       f"?api_key={transit_land_api_key}")
            feed_version_query_response.raise_for_status()
            feed_version_query_response_json = feed_version_query_response.json()

            # when the latest static feed was fetched by transit land
            latest_feed_fetch_date = (feed_version_query_response_json["feeds"][0]["feed_versions"]
                                      [0]["fetched_at"])
            # the latest date in the calendar that the data is valid for (can either extend or not use)
            latest_feed_valid_until = (feed_version_query_response_json["feeds"][0]["feed_versions"]
                                      [0]["latest_calendar_date"])

            # adding the zipped gtfs folder to the path dict
            gtfs_zip_folders[agency] = file_path
            # creating a dictionary with these dates needed for each agency
            self.agency_feed_valid_dates[agency] = {"last_updated": latest_feed_fetch_date,
                                               "valid_until": latest_feed_valid_until}

        logging.info("Successfully downloaded GTFS data for desired agencies")
        return gtfs_zip_folders

    def unzip_gtfs_data(self):
        """
        Unzips the GTFS data that was downloaded from Transit Land.
        :return: unzipped_gtfs_filepaths: dict{TransitAgency: path of unzipped gtfs folder}
        """
        agency_zip_folders = self.get_gtfs_for_transit_agencies()
        logging.info("Unzipping the downloaded GTFS data")
        unzipped_gtfs_filepaths = {}

        # go through each agency in the provided agency_zip_folders
        for agency in agency_zip_folders:

            # the path where each unzipped gtfs folder will be saved to
            onestop_id_directory = os.path.dirname(os.path.dirname(agency_zip_folders[agency]))
            unzipped_gtfs_filepath = os.path.join(onestop_id_directory,
                                                  "unzipped_gtfs")

            # set up folders in case they don't already exist
            os.makedirs(unzipped_gtfs_filepath, exist_ok=True)

            # using zipfile module to extract all .txt files provided
            with ZipFile(agency_zip_folders[agency], "r") as zipped_file:
                zipped_file.extractall(path=unzipped_gtfs_filepath)

            # adding the unzipped gtfs folder to the path dict
            unzipped_gtfs_filepaths[agency] = unzipped_gtfs_filepath

        # set the gtfs_folders attribute and return the unzipped folder paths
        self.gtfs_folders = unzipped_gtfs_filepaths
        logging.info("Successfully unzipped the downloaded GTFS data")
        return unzipped_gtfs_filepaths

    def check_whether_data_valid(self):
        """
        Checks whether the data is valid for the desired agencies.
        :return:
        """
        logging.info("Checking whether the downloaded data is still valid (and can therefore be used to create a "
                     "network dataset in ArcGIS)")

        # a dictionary that says for each agency if the downloaded
        still_valid_gtfs_data = {}

        for agency in self.agency_feed_valid_dates:
            # check if current date is past the "valid_until" date
            valid_until_date_iso = isodate.parse_date(self.agency_feed_valid_dates[agency]["valid_until"])
            current_date_iso = isodate.parse_date(datetime.now().strftime("%Y-%m-%d-%f"))

            if valid_until_date_iso <= current_date_iso:
                logging.warning(f"The data for {agency.agency_name} is no longer valid (valid until {valid_until_date_iso})")
            else:
                still_valid_gtfs_data[agency] = True

        logging.info(f"Data was valid for {len(still_valid_gtfs_data)}/{len(self.agency_feed_valid_dates)} agencies")
        return still_valid_gtfs_data

    def create_public_transit_data_model(self) -> None:
        """
        Creates a public transit data model in the feature dataset for this scenario using the valid GTFS data.
        :return:  None
        """


        # only using valid gtfs data to create network dataset because otherwise gets screwy
        valid_gtfs_data = self.check_whether_data_valid() # output of method is dict of transit agencies and validities
        logging.info(f"Creating public transit data model for {self.main_reference_place.pretty_name} using "
                     f"{len(valid_gtfs_data)} agencies")

        gtfs_folders_to_use = []
        for agency in valid_gtfs_data:
            if valid_gtfs_data[agency]:
                gtfs_folders_to_use.append(self.gtfs_folders[agency])

        # create a Public Transit Data Model using arcpy (using interpolate because some minor agencies have low quality
        # GTFS data, and while this is annoying, you  can just calculate arrival times using interpolate
        arcpy.transit.GTFSToPublicTransitDataModel(in_gtfs_folders=gtfs_folders_to_use,
                                                   target_feature_dataset=self.feature_dataset.path,
                                                   interpolate="INTERPOLATE", make_lve_shapes="MAKE_LVESHAPES")


    def connect_network_to_streets(self):
        # flesh out method
        edges_fc_path = os.path.join(self.feature_dataset.path, "edges_fc")
        arcpy.transit.ConnectPublicTransitDataModelToStreets(target_feature_dataset=self.feature_dataset.path,
                                                             in_streets_features=edges_fc_path)


# because switching to using r5py need to compose everything into one graph/network]
class R5Network:
    def __init__(self, transit_network: TransitNetwork, street_network: StreetNetwork):
        """
        R5Network class for creating R5 instance using street network and transit network GTFS data.
        """
        self.transit_network = transit_network
        self.street_network = street_network

    def check_input_acceptability(self):
        """
        Checks that the input street network and transit network are valid for creating R5 instance.
        :return: None
        """
        logging.info("Checking input street network and transit network to see if can create R5 network")
        # check that street network has necessary data
        if self.street_network.network_graph is None:
            raise Exception("Street network does not have a network graph, cannot create R5 instance")
        # check that transit network has necessary gtfs data
        if self.transit_network.gtfs_folders is None:
            raise Exception("Transit network does not have GTFS data, cannot create R5 instance")
        logging.info("Input street network and transit network are valid for creating R5 network")



class NetworkDataset:
    """
    Network dataset class for use in ArcGIS Pro.

    Attributes:
        feature_dataset (FeatureDataset): The feature dataset to create the network dataset in.
        network_type (str): The type of network dataset to create. Default is "walk_no_z".
        use_elevation (bool): Whether to use elevation in calculating bike or walk times. Default is False.
        reset (bool): Whether to reset the network dataset. Default is False.
        street_network (str): The path to the street network feature class.
        name (str): The name of the network dataset. By default will be the network type with "_nd" appended.
        path (str): The path to the network dataset (always C:\...\<feature_dataset_name>\<network_dataset_name>.nd)

    """

    def __init__(self, feature_dataset: FeatureDataset, network_type: str = "walk_no_z", use_elevation=False,
                 reset=False):
        self.feature_dataset = feature_dataset
        self.network_type = network_type
        self.use_elevation = use_elevation
        self.reset = reset
        self.street_network = self.feature_dataset.street_network
        self.name = f"{self.network_type}_nd"
        self.path = os.path.join(self.feature_dataset.path, self.name)
        self.nodes_fc_path = os.path.join(self.feature_dataset.path, "nodes_fc")
        self.edges_fc_path = os.path.join(self.feature_dataset.path, "edges_fc")
        self.has_been_created = False

        # network_types module contains a dictionary containing templates and names for diff network types
        self.template_name = network_types.network_types_attributes[self.network_type]["network_dataset_template_name"]
        self.template_path = os.path.join(Path(__file__).parent, self.template_name)

        # check that network type exists and that a template for that network type exists
        if self.network_type not in network_types.network_types_attributes:
            raise Exception(f"Network type {self.network_type} does not exist "
                            f"(see network_types.py documentation for valid network types)")

        elif not os.path.exists(self.template_path):
            raise Exception(f"Template for network type {self.network_type} does not exist yet, sorry!")

        # check not trying to use elevation without elevation enabled street network
        if use_elevation and not self.street_network.elevation_enabled:
            raise Exception("Cannot use elevation")

    def create_network_dataset(self):
        """
        Creates network dataset of type (self.network type)
        :return
            Path of network dataset
        """
        process_start_time = time.perf_counter()
        logging.info(
            f"Creating {self.network_type} network dataset for {self.street_network.main_reference_place.pretty_name}")  

        # making sure network analyst checked out
        if not network_analyst_extension_checked_out:
            check_out_network_analyst_extension()

        # making sure not trying to obliviously create a network dataset if one with same name already exists
        if arcpy.Exists(self.path) and not self.reset:
            raise Exception(f"Cannot create new network dataset {self.name} because one already exists"
                            f"with that name at the desired path {self.path}")

        # main code block for the actual creating of the network dataset   
        try:
            if self.reset:
                arcpy.Delete_management(self.path)
                logging.info(
                    f"Existing network dataset {self.name} for {self.street_network.main_reference_place.pretty_name}"
                    f" {self.street_network.geographic_scope} already exists, deleting and creating new")  

            # check whether to use elevation in network dataset, raise error if no elevation data
            if self.use_elevation:
                # first need to check that the street network is actually elevation enabled
                if not self.street_network.elevation_enabled:
                    raise Exception("The street network provided does not have elevation data")

            # check that both nodes and edges feature classes exist in dataset
            if not arcpy.Exists(self.nodes_fc_path):
                raise Exception("Nodes feature class for street network does not exist in feature dataset")
            if not arcpy.Exists(self.edges_fc_path):
                raise Exception("Edges feature class for street network does not exist in feature dataset")

            arcpy.na.CreateNetworkDatasetFromTemplate(network_dataset_template=self.template_path,
                                                      output_feature_dataset=self.feature_dataset.path)
            logging.info("Successfully created walking network dataset from template")

            # mark that has been created
            self.has_been_created = True
            logging.info(f"Successfully created network dataset in {time.perf_counter() - process_start_time} seconds")
            return self.path

        finally:
            # have to check extension back in when done running
            arcpy.CheckInExtension("Network")

    def build_network_dataset(self, rebuild=False):
        """
        Builds the network dataset that has been created
        :param rebuild: bool (True if rebuilding network_dataset desired)
        :return: Network dataset path: str
        """
        process_start_time = time.perf_counter()
        logging.info("Building network dataset")
        # checking out network analyst extension if not already
        if not network_analyst_extension_checked_out:
            check_out_network_analyst_extension()

        # making sure that the network dataset actually exists!
        if not self.has_been_created:
            raise Exception("The network dataset has not been created yet!")
        arcpy.na.BuildNetwork(self.path)

        # checking extensions out has weird behavior so always need these checks to see that it isn't oddly out/in
        if network_analyst_extension_checked_out:
            check_network_analyst_extension_back_in()
        process_run_time = time.perf_counter() - process_start_time
        logging.info(f"Network dataset successfully built in {turn_seconds_into_minutes(process_run_time)}")

        # save the gdb?
        self.feature_dataset.gdb.save_gdb()
        return self.path

# makig sure that network analyst extension is checked back in after done running
if network_analyst_extension_checked_out:
    check_network_analyst_extension_back_in()


Python
"""
This module exists to take all the various tools already created in the create_network_dataset_oop.py file and
organize them into a class that can be used to create network datasets from OSM street network data. Also used
to get transit data for a place and use that to create a transit network dataset.
"""
from Geoenrichment import travel_mode

# yes I know it's bad practice to use import * but in this case, I've made sure that it won't cause any problems
# (can safely map namespace of create_network_dataset_oop to this module because were developed in tandem)
from create_network_dataset_oop import *
from gtfs_tools import *
from general_tools import *


# standard library modules
import logging
import os



# logging setup
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# don't want to display debugging messages when running this script
logging.getLogger("requests").setLevel(logging.INFO)
logging.getLogger("urllib3").setLevel(logging.INFO)
logging.getLogger("general_tools").setLevel(logging.INFO)

# set up arcpy environment
arcpy_config.set_up_arcpy_env()

# if you want to add modules (non std lib), MUST (!!!!!!!) come after this block
# (ensures extensions work in cloned environment)
import arcpy
import arcpy_init
#


class Place:
    """
    Class that represents a place and contains methods to create network datasets for it

    Parameters:
        arcgis_project (ArcProject): ArcGIS project object
        place_name (str): Name of place
        bound_box (tuple): Bounding box of place | (longitude_min, latitude_min, longitude_max, latitude_max)
        geographic_scope (str): Geographic scope of place | {"place_only"", "county", "msa", "csa", "specified"}
            specified means a list of places to include in the network dataset (must be well-formed place names that
            OSM will recognize).
        scenario_id (str): Scenario ID | if you would like to keep track of the same GeoDatabase/Network Dataset
            across multiple runs, then you should use the same scenario_id for each run! By default, will be
            a random (partially modified) base64 value.
        specified_places_to_include (list[str]): List of places to include in the network dataset if using
            "specified" geographic scope

    Methods:
        use_scope_for_place(geographic_scope="place_only") -> None: Sets the geographic scope for the place
        create_network_dataset_from_place(network_type="walk", use_elevation=False, full_reset=False,
                                          elevation_reset=False) -> None: Creates network dataset for place of specified
                                          type
        get_agencies_for_place() -> None: Gets agencies that serve the place (for transit network datasets)
        generate_isochrones_for_place() -> None: Generates isochrones for the place

    """
    def __init__(self, arcgis_project:ArcProject, place_name:str | None=None,
                 bound_box:tuple[str | float, str | float, str | float, str | float] | None=None,
                 geographic_scope:str="place_only",
                 scenario_id:str=None, specified_places_to_include:list[str]=None):


        if place_name is None and bound_box is None:
            raise ValueError("Must provide either a place or bounding box")
        # parameters
        self.arcgis_project = arcgis_project
        self.place_name = place_name
        self.bound_box = bound_box
        self.geographic_scope = geographic_scope
        self.scenario_id = scenario_id
        self.specified_places_to_include = specified_places_to_include

        if self.scenario_id is None:
            self.scenario_id = generate_random_base64_value(1000000000)

        if self.bound_box:
            self.geographic_scope = "bbox"

        # important attributes not passed
        self.main_reference_place = ReferencePlace(place_name=self.place_name, bound_box=self.bound_box)
        self.snake_name = create_snake_name(self.main_reference_place)
        self.snake_name_with_scope = f"{self.snake_name}_{self.geographic_scope}"
        # cache folder for place
        self.cache_folder = CacheFolder(self.snake_name_with_scope)

        # list of reference places for creating networks
        self.reference_place_list = []

        # the corresponding gdb path:
        self.gdb_path = os.path.join(self.arcgis_project.project_dir_path, f"{self.snake_name_with_scope}.gdb")
        self.network_dataset_for_place = None

        # attributes to check whether certain things exist
        self.street_network_data_exists = False
        self.elevation_data_exists = False
        self.gdb_exists = False
        self.feature_dataset_exists = False
        self.streets_feature_classes_exists = False


        # set up cache folder when initializing instance if one doesn't already exist
        if not self.cache_folder.check_if_cache_folder_exists():
            self.cache_folder.set_up_cache_folder()

        # agencies that serve the place (will be set later by the get_agencies_for_place method)
        self.agencies_that_serve_place = None

        # always use scope for place
        self.use_scope_for_place(geographic_scope=geographic_scope)

    def use_scope_for_place(self, geographic_scope="place_only"):

        # can set the geographic scope to something new
        self.geographic_scope = geographic_scope

        # if using bounding box
        if self.bound_box:
            # the list of ReferencePlaces to use in creating the network (just the bounding box)
            reference_place_list = [self.main_reference_place]

        # if using city limits and reference place
        elif self.place_name and self.geographic_scope == "place_only":
            # the list of ReferencePlaces to use in creating the network
            reference_place_list = [self.main_reference_place]

        # if using specified places
        elif self.place_name and self.geographic_scope == "specified":
            logging.info(f"Using the provided list of other places to include {self.specified_places_to_include}")
            # check that if specified is selected that specified_places_to_include is not None
            if self.specified_places_to_include is None:
                raise ValueError("Geographic scope is 'specified', but no specified places to include were provided")
            # for each place name in specified_places_to_include, create a ReferencePlace object and add it to the list
            reference_place_list = [ReferencePlace(place_name=place_name) for
                                    place_name in self.specified_places_to_include]
            # also need to add the original place to the list (in first position so that the network is seen as being
            # 'centered' on the main place)
            reference_place_list.insert(0, self.main_reference_place)

        # if using county, msa, or csa
        elif (self.place_name and
              (self.geographic_scope == "county" or self.geographic_scope == "msa" or self.geographic_scope == "csa")):
            reference_place_list = get_reference_places_for_scope(self.place_name, self.geographic_scope)

        else:
            raise ValueError("Geographic scope not recognized, please use "
                            "'city', 'county', 'msa', 'csa', or 'specified'")

        # now set the list
        self.reference_place_list = reference_place_list
        return reference_place_list

    def create_network_dataset_from_place(self, network_type="walk", use_elevation=False, full_reset=False,
                                          elevation_reset=False) -> None:                                # still need to figure out what to do with bounding box rather than place
        """
        Creates network dataset from for specified place using OSM street network data.
        :param network_type: str | {"walk", "drive", "bike", "transit"}
        :param use_elevation: bool | whether to use elevation data for the streets in the network (in order to calculate
            grade adjusted walk-times/bike-times
        :param use_elevation: bool | Whether to use elevation when creating the network dataset. If True, will query
            USGS EPQS API to get elevations, but be warned this can be very slow (even though I made it as fast as I
            could with asyncio)
        :param full_reset: bool | WARNING!!: If True, will nuke the entire geodatabase (if it exists) for the place.
            Use with caution!
        :param elevation_reset: bool | If True, will reset the elevation data for the network dataset (if it exists).
            To be avoided because requires querying the EPQS API again, which can be very slow. However, if you are
            seeing that many edges in the network dataset have no elevation data (will see a logging message in the
            console), you may want to try this.
        :return: None | You will have to manually open ArcGIS Pro, and open the catalog to add your network dataset to
            your map
        """
        logging.info(f"Creating network dataset for {self.main_reference_place.pretty_name}")

        if not use_elevation:
            # add no_z to network_type if not using elevation to work with network_types_attributes dict (see module)
            network_type += "_no_z"
        else:
            # add _z to network_type if using elevation to work with network_types_attributes dict (see module)
            network_type += "_z"

        # create StreetNetwork object for this place
        street_network_for_place = StreetNetwork(self.geographic_scope,self.reference_place_list,
                                                 network_type=network_type)

        # prepare geodatabase and create feature dataset
        geodatabase_for_place = GeoDatabase(self.arcgis_project, street_network=street_network_for_place)
        geodatabase_for_place.set_up_gdb(reset=full_reset)
        feature_dataset_for_place = FeatureDataset(geodatabase_for_place, street_network_for_place,
                                                   scenario_id=self.scenario_id, network_type=network_type,
                                                   reset=full_reset)
        feature_dataset_for_place.create_feature_dataset()

        # elevation handling for street network
        if not use_elevation:
            street_network_for_place.get_street_network_graph_from_osmnx(reset=full_reset)
            # take street network map to feature classes
            street_feature_classes_for_place = StreetFeatureClasses(feature_dataset_for_place, street_network_for_place,
                                                                    use_elevation=False,
                                                                    reset=full_reset)
        else:
            # using elevation
            street_network_for_place.get_street_network_graph_from_osmnx(reset=full_reset)

            # create new ElevationMapper object for this street network and then getting elevations and nodes
            elevation_mapper_for_place = ElevationMapper(street_network=street_network_for_place,reset=elevation_reset)
            elevation_mapper_for_place.add_elevation_data_to_nodes()
            elevation_mapper_for_place.add_grades_to_edges()

            # worth using reset here just to be safe because still need to debug
            street_feature_classes_for_place = StreetFeatureClasses(feature_dataset_for_place, street_network_for_place,
                                                                    use_elevation=True,
                                                                    reset=full_reset)

        # testing switch over to gdf -> shp -> fc method
        street_feature_classes_for_place.map_street_network_to_feature_classes()

        # if using transit then need to create a Transit Network and get it setup
        if "transit" == network_type[:7]:
            transit_network_for_place = TransitNetwork(geographic_scope=self.geographic_scope,
                                                       feature_dataset=feature_dataset_for_place,
                                                       reference_place_list=[self.main_reference_place])
            # download gtfs data for place
            transit_network_for_place.unzip_gtfs_data()
            # create public transit model from downloaded GTFS data
            transit_network_for_place.create_public_transit_data_model()
            transit_network_for_place.connect_network_to_streets()

        # create and build network dataset from streets feature classes
        network_dataset_for_place = NetworkDataset(feature_dataset_for_place, network_type=network_type,
                                                   reset=full_reset)
        # set self.network_dataset_for_place
        self.network_dataset_for_place = network_dataset_for_place

        # create and build network dataset
        network_dataset_for_place.create_network_dataset()
        network_dataset_for_place.build_network_dataset()


    def get_agencies_for_place(self):

        """ Takes a place name of format 'city, state, country'
        and returns a list of transit agencies that serve the place

        :return: list[TransitAgency] | list of transit agencies (TransitAgencyObjects) that serve the place
        """
        # list of TransitAgency objects to be returned
        agencies_for_place = []

        # transit land's API only requires the city name (although this seems stupid)
        place_short_name = self.main_reference_place.place_name.split(",")[0]                                                    # fix so can use bounding box too
        transit_land_response = requests.get(f"https://transit.land/api/v2/rest/agencies?apikey={transit_land_api_key}"
                                             f"&city_name={place_short_name}")
        transit_land_response.raise_for_status()
        transit_land_data = transit_land_response.json()

        # going through the agency dicts provided by the api and using them as kwargs for TransitAgency object
        for agency_data in transit_land_data["agencies"]:
            temp_agency = TransitAgency(**agency_data)
            agencies_for_place.append(temp_agency)

        # now can set self.agencies_that_serve_place
        self.agencies_that_serve_place = agencies_for_place

        return self.agencies_that_serve_place

    def generate_isochrone(self, isochrone_name:str=None, addresses:list[str] | str=None,
                           points:list[tuple[float, float]] | tuple[float, float]=None, network_type:str="walk",
                           use_elevation:bool=False, cutoffs_minutes:list[float]=None,
                           travel_direction:str="TO_FACILITIES", day_of_week:str="Today", date:str=None,
                           analysis_time:str="9:00 AM", open_on_complete=True, reset_network_dataset:bool=False):
        """
        This method creates an isochrone for a given set of addresses or points using the specified network type and
        cutoffs. This will mostly just be used to demonstrate the functionality of the network dataset. Can pass either
        well-formed address(es) (see address parameter) or tuple(s) of (latitude, longitude) (see points parameter). Can
        specify what kind of network to use; by default, will use non-elevation enabled transit network dataset.
        If one already exists for the place, it will be used. If not, a new one matching the desired analysis
        will be generated. Can choose to specify cutoffs in minutes (see cutoffs_minutes parameter) or
        let the method use the default cutoffs (10, 20, 30, 40, 50, and 60 minute cutoffs).

        :param isochrone_name: the name for the isochrone service area layer
        :param addresses: the address(es) to use for the isochrone analysis
        :param points: the points to use for the isochrone
        :param network_type: the network type to be built or used for the analysis
        :param use_elevation: whether the network should use elevation or not
        :param cutoffs_minutes: the isochrone cutoffs to be used (in minutes)
        :param travel_direction: whether to do analysis_time to or from facilities {"FROM_FACILITIES", "TO_FACILITIES"}
        :param day_of_week: the day of the week to use for the analysis {"Today", "Monday", "Tuesday", "Wednesday",
                                                                        "Thursday", "Friday", "Saturday", "Sunday"}
        :param date: date to use for the analysis (e.g. "11/15/2025")
        :param analysis_time: time to use for the analysis (e.g. "9:00 AM")
        :param open_on_complete: whether to open the project automatically after the analysis is complete
        :param reset_network_dataset: WARNING! whether to reset the network dataset before running the analysis
            (doing this will wreck any other isochrones that used the old, conflicting network dataset if there was one

        :return:
        """
        logging.info("Generating isochrone")

        # if no name for service area layer provided, generate random base64 value
        if isochrone_name is None:
            isochrone_name = generate_random_base64_value(1000000000)

        # check that travel_direction is valid
        if travel_direction not in ["FROM_FACILITIES", "TO_FACILITIES"]:
            raise ValueError("travel_direction must be one of the following: FROM_FACILITIES, TO_FACILITIES")

        # set default cutoffs to use for service area analysis
        if cutoffs_minutes is None:
            # in case network desired is transit
            if network_type == "transit":
                cutoffs_minutes = [10, 20, 30, 40, 50, 60]
            # in case network desired is walk
            elif network_type == "walk":
                cutoffs_minutes = [5, 10, 15, 20, 25, 30]
            # in case a different type of network is desired
            else:
                cutoffs_minutes = [10, 20, 30, 40, 50, 60]

        # some exception handling to make sure input parameters are valid
        # check that either address(es) or point(s) provided
        if addresses is None and points is None:
            raise ValueError("must provide either address(es) or point(s)")

        # checking that correct type(s) for addresses (valid address checking handled by get_coordinates_from_address)
        if not isinstance(addresses, list) and not isinstance(addresses, str) and addresses is not None:
            raise ValueError("addresses must be a list of strings or a single address string")

        # checking that points are valid
        if not isinstance(points, list) and not isinstance(points, tuple) and points is not None:
            raise ValueError("points must be a list of tuples or a single tuple of (latitude, longitude)")
        # in case list of points provided
        elif isinstance(points, list):
            # check that each point is a tuple.
            for point in points:
                check_if_valid_coordinate_point(point)
        # in case single point provided
        elif isinstance(points, tuple):
            check_if_valid_coordinate_point(points)

        # check desired network type, if provided, is valid
        if not isinstance(network_type, str):
            raise ValueError("network_type must be a string")
        # check whether desired network type can be used
        if network_type not in ["transit", "walk"]:
            raise ValueError(f"network_type must be one of the following: "
                             f"{' '.join(network_types.network_types_attributes)}")

        # check that use_elevation is either True or False
        if not isinstance(use_elevation, bool):
            raise ValueError("use_elevation must be a boolean")

        # check that cutoffs_minutes, if provided, is a list of integers
        if not isinstance(cutoffs_minutes, list) and cutoffs_minutes is not None:
            raise ValueError("cutoffs_minutes must be a list of integers or None")

        # now check whether network dataset of corresponding type already exists
        if use_elevation:
            using_elevation_tag = "z"
        else:
            using_elevation_tag = "no_z"
        feature_dataset_would_be_named = f"{network_type}_{using_elevation_tag}_{self.scenario_id}_fd"
        network_dataset_path = os.path.join(self.gdb_path, feature_dataset_would_be_named,
                                            f"{network_type}_{using_elevation_tag}_nd")

        # the next part involves preparing the input points and addresses for the 'facilities' sublayer for the analysis
        # points and addresses (which have been geocoded to points) will go into the input feature class, which will
        # then be used as the input for the facilities sublayer for the analysis
        logging.info(f"Processing inputted points/addresses: {points} {addresses}")
        # the list of input points that will go into the input feature class (includes points AND addresses as points)
        input_point_coordinates = []

        # add points to list of point coordinates
        if points is not None:
            if isinstance(points, list):
                for point in points:
                    input_point_coordinates.append(point)
            elif isinstance(points, tuple):
                input_point_coordinates.append(points)
            else:
                raise ValueError("points must be a list of tuples or a single tuple of (latitude, longitude)")

        # turn address into points (can pass either list of addresses or single address so no need to check type)
        if isinstance(addresses, list):
            address_coordinates = get_coordinates_from_address(addresses)
            for address_coordinate in address_coordinates:
                # add to the list of input
                input_point_coordinates.append(address_coordinate)
        elif isinstance(addresses, str):
            address_coordinates = get_coordinates_from_address(addresses)
            input_point_coordinates.append(address_coordinates)

        # now actually dealing with network dataset for the analysis
        logging.info("Checking whether or not a network dataset required for the analysis already exists")

        # check that 1) network dataset exists and 2) reset_network_dataset is False (no ND reset desired)
        if arcpy.Exists(network_dataset_path) and not reset_network_dataset:
            logging.info(f"Using existing network dataset {network_type}_{using_elevation_tag}")
            pass
        # in case no matching network dataset could be found or reset desired
        else:
            logging.info(f"No network dataset found. Creating {network_type}_{using_elevation_tag}")
            # passing the reset_network_dataset parameter to the create_network_dataset_from_place method just in case
            self.create_network_dataset_from_place(network_type=network_type, use_elevation=use_elevation,
                                                   full_reset=reset_network_dataset,
                                                   elevation_reset=reset_network_dataset)

        # create points feature classes
        input_fc_path = add_points_arcgis(feature_dataset_path=
                                          os.path.join(self.gdb_path, feature_dataset_would_be_named),
                                          fc_name=f"{self.snake_name}_test_points", point_coordinates=input_point_coordinates)

        # set travel mode
        _travel_mode = network_types.network_types_attributes[f"{network_type}_{using_elevation_tag}"][
                                                                                                "isochrone_travel_mode"]

        # set analysis_time to analyze
        if date is None:
            # use ArcGIS magic dates for day of week
            day_map = {
                "Today": "12/30/1899",
                "Sunday": "12/31/1899",
                "Monday": "1/1/1900",
                "Tuesday": "1/2/1900",
                "Wednesday": "1/3/1900",
                "Thursday": "1/4/1900",
                "Friday": "1/5/1900",
                "Saturday": "1/6/1900"
            }
            time_to_analyze = f"{day_map[day_of_week]} {analysis_time}"
        else:
            time_to_analyze = f"{date} {analysis_time}"

        # the path where the isochrone layer will go
        isochrone_path = os.path.join(self.gdb_path, isochrone_name)
        # now can create service area layer
        check_out_network_analyst_extension()
        logging.info("Creating service area analysis layer")
        result_object = arcpy.na.MakeServiceAreaAnalysisLayer(network_data_source=network_dataset_path,
                                                              layer_name=isochrone_name, travel_mode=_travel_mode,
                                                              travel_direction=travel_direction, # gives warning because expects literal but fine
                                                              time_of_day=time_to_analyze, cutoffs=cutoffs_minutes,
                                                              geometry_at_overlaps="DISSOLVE")
        # get layer object out
        layer_object = result_object.getOutput(0)

        # get facilities and polygons names
        sublayer_names = arcpy.na.GetNAClassNames(layer_object)
        facilities_layer_name = sublayer_names["Facilities"]
        polygons_layer_name = sublayer_names["SAPolygons"]

        # add facilities to layer
        logging.info("Adding inputted points/addresses to service area analysis layer")
        arcpy.na.AddLocations(layer_object, facilities_layer_name, in_table=input_fc_path)

        # solve the layer
        logging.info("Solving service area analysis layer")
        arcpy.na.Solve(layer_object)

        # save the service area as layer file ### NEED TO FIX LAYER NAME BECAUSE RIGHT NOW IT IS FULL PATH??
        layers_dir = os.path.join(self.arcgis_project.project_dir_path, "layers")
        if not os.path.exists(layers_dir):
            os.makedirs(layers_dir, exist_ok=True)
        output_layer_file = os.path.join(layers_dir, f"{isochrone_name}.lyrx")
        arcpy.SaveToLayerFile_management(in_layer=layer_object, out_layer=output_layer_file, is_relative_path="ABSOLUTE")

        # setting up the maps for the project
        maps = self.arcgis_project.arcObject.listMaps("Map")
        if not maps:
            self.arcgis_project.arcObject.createMap("Map")
            self.arcgis_project.arcObject.save()
            maps = self.arcgis_project.arcObject.listMaps("Map")

        # add the layer to the map
        aprxMap = maps[0]
        # now make LayerFile Object and add to map
        isochrone_layer_file = arcpy.mp.LayerFile(output_layer_file)
        aprxMap.addDataFromPath(isochrone_layer_file)

        # color to make all pretty like
        sa_polygons = aprxMap.listLayers(polygons_layer_name)[0]
        polygon_symbology = sa_polygons.symbology
        polygon_symbology.updateRenderer('GraduatedColorsRenderer')
        polygon_symbology.renderer.classificationField = "ToBreak"
        polygon_symbology.renderer.breakCount = len(cutoffs_minutes)
        polygon_symbology.renderer.classificationMethod = "NaturalBreaks"
        polygon_symbology.renderer.colorRamp = self.arcgis_project.arcObject.listColorRamps("Inferno")[0]

        # have to manually reverse colors because stupid
        breaks = polygon_symbology.renderer.classBreaks
        colors = [brk.symbol.color for brk in breaks]
        for i, brk in enumerate(breaks):
            brk.symbol.color = colors[i]

        # now can actually set the polygon's symbology
        sa_polygons.symbology = polygon_symbology

        # saving hopefully makes changes persist
        self.arcgis_project.arcObject.save()

        # housekeeping
        logging.info("Isochrones generated successfully")

        if open_on_complete:
            logging.info("Now opening ArcGIS Pro")
            # need to clear the arcObject to ensure it's not locked so can open automatically because lazy
            del self.arcgis_project.arcObject
            # sleepytime! (I'm so incredibly sick of debugging this, and I'm getting a bit loopy)
            from time import sleep as sleepytime
            sleepytime(1)
            # now can open project automatically?
            os.startfile(self.arcgis_project.path)

    # will add support for route as well
    def generate_route(self, route_name:str=None, addresses:list[str] | str=None,
                           points:list[tuple[float]] | tuple[float]=None, network_type:str="transit",
                           use_elevation:bool=False, day_of_week:str="Today", date:str=None, analysis_time:str="9:00 AM", 
                       open_on_complete=True):
        pass

    # important!! method that generates origin-destination cost matrix
    @time_function
    def generate_od_cost_matrix(self, matrix_name:str=None, origins_fc_name:str=None,
                                destinations_fc_name:str=None, network_type:str="transit",
                                use_elevation:bool=False, day_of_week:str="Today", analysis_date:str=None,
                                analysis_time:str="9:00 AM", open_on_complete:bool=False,
                                reset_network_dataset:bool=False):
        """
        :param matrix_name: (str) the name for the od-cost matrix that will be generated
        :param origins_fc_name: (str) the name of the feature class containing the points to be used as origins in
            generating the matrix. Note that name should be relative path and feature class must be points!
        :param destinations_fc_name: (str) the name of the feature class containing the points to be used as destinations
            in generating the matrix. Note that name should be relative path and feature class must be points!
        :param network_type: (str) the type of network to be used for the matrix
        :param use_elevation: (bool) if True, will try to use an elevation enabled network dataset
        :param day_of_week: (str)
        :param analysis_date:
        :param analysis_time:
        :param open_on_complete:
        :param reset_network_dataset:
        :return:
        """
        # set the environment/workspace
        arcpy.env.workspace = self.gdb_path
        arcpy.env.overwriteOutput = True

        # first, lots of type checking of course
        if not isinstance(matrix_name, str) and matrix_name is not None:
            raise Exception("Matrix name must be a string or left alone")

        if not isinstance(origins_fc_name, str) and origins_fc_name is not None:
            raise Exception("The name for the feature class to be used for the origins must be a string")

        # also check that necessary parameters passed
        if origins_fc_name is None or destinations_fc_name is None:
            raise Exception("Must provide the names of both the origins and destinations feature classes")

        # set matrix name if none provided
        if matrix_name is None:
            matrix_name = generate_random_base64_value(100000000)

        # now check whether network dataset of corresponding type already exists
        if use_elevation:
            using_elevation_tag = "z"
        else:
            using_elevation_tag = "no_z"
        feature_dataset_would_be_named = f"{network_type}_{using_elevation_tag}_{self.scenario_id}_fd"
        network_dataset_path = os.path.join(self.gdb_path, feature_dataset_would_be_named,
                                            f"{network_type}_{using_elevation_tag}_nd")

        # check that 1) network dataset exists and 2) reset_network_dataset is False (no ND reset desired)
        if arcpy.Exists(network_dataset_path) and not reset_network_dataset:
            logging.info(f"Using existing network dataset {network_type}_{using_elevation_tag}")
            pass
        # in case no matching network dataset could be found or reset desired
        else:
            logging.info(f"No network dataset found. Creating {network_type}_{using_elevation_tag}")
            # passing the reset_network_dataset parameter to the create_network_dataset_from_place method just in case
            self.create_network_dataset_from_place(network_type=network_type, use_elevation=use_elevation,
                                                   full_reset=reset_network_dataset,
                                                   elevation_reset=reset_network_dataset)


        # set travel mode for the analysis
        _analysis_travel_mode = network_types.network_types_attributes[f"{network_type}_{using_elevation_tag}"][
            "isochrone_travel_mode"]

        # set analysis_time to analyze
        if analysis_date is None:
            # use ArcGIS magic dates for day of week
            day_map = {
                "Today": "12/30/1899",
                "Sunday": "12/31/1899",
                "Monday": "1/1/1900",
                "Tuesday": "1/2/1900",
                "Wednesday": "1/3/1900",
                "Thursday": "1/4/1900",
                "Friday": "1/5/1900",
                "Saturday": "1/6/1900"
            }
            time_to_analyze = f"{day_map[day_of_week]} {analysis_time}"
        else:
            time_to_analyze = f"{analysis_date} {analysis_time}"

        # generate the actual cost matrix layer
        check_out_network_analyst_extension()
        logging.info("Creating OD Cost Matrix Analysis Layer")
        result_object = arcpy.na.MakeODCostMatrixAnalysisLayer(network_data_source=network_dataset_path,
                                                               layer_name=matrix_name, travel_mode=_analysis_travel_mode,
                                                               cutoff=None,
                                                               time_of_day=time_to_analyze,
                                                               time_zone="LOCAL_TIME_AT_LOCATIONS",
                                                               line_shape="NO_LINES")

        # get layer object out
        layer_object = result_object.getOutput(0)

        # get the names of the sublayers
        sublayer_names = arcpy.na.GetNAClassNames(layer_object)
        origins_layer_name = sublayer_names["Origins"]
        destination_layer_name = sublayer_names["Destinations"]

        # now need to add locations for origins and destinations
        logging.info("Adding origins to the matrix layer")
        arcpy.na.AddLocations(in_network_analysis_layer=layer_object, sub_layer=origins_layer_name,
                              in_table=origins_fc_name, search_tolerance="10000 Meters")
        logging.info("Adding destinations to the matrix layer")
        arcpy.na.AddLocations(in_network_analysis_layer=layer_object, sub_layer=destination_layer_name,
                              in_table=destinations_fc_name, search_tolerance="10000 Meters")

        # now can solve the layer
        logging.info("Solving the OD Cost Matrix")
        arcpy.na.Solve(layer_object)

        # the path where will save a copy of the layer file
        output_layer_file = os.path.join(self.cache_folder.path, "od_cost_matrices", f"{network_type}",
                                         f"{matrix_name}.lyrx")
        # make directories in case don't exist yet
        os.makedirs(os.path.dirname(output_layer_file), exist_ok=True)
        # now can save
        logging.info(f"Saving a copy of the OD cost matrix to {output_layer_file}")
        layer_object.saveACopy(output_layer_file)

        logging.info(f"Successfully solved matrix {matrix_name} for {self.main_reference_place.pretty_name} "
                     f"{self.geographic_scope} {analysis_time} {day_of_week}")

        # setting up the maps for the project
        logging.info("Adding OD Cost Matrix layer to map")
        maps = self.arcgis_project.arcObject.listMaps("Map")
        if not maps:
            self.arcgis_project.arcObject.createMap("Map")
            self.arcgis_project.arcObject.save()
            maps = self.arcgis_project.arcObject.listMaps("Map")

        # add the layer to the map
        aprxMap = maps[0]
        # now make LayerFile Object and add to map
        od_matrix_layer_file = arcpy.mp.LayerFile(output_layer_file)
        aprxMap.addDataFromPath(od_matrix_layer_file)

    @time_function
    def generate_large_od_cost_matrix(self, matrix_name:str=None, origins_fc_path:str=None,
                                      destinations_fc_path:str=None, network_type:str= "transit",
                                      use_elevation:bool=False, day_of_week:str="Today", analysis_date:str=None,
                                      analysis_time:str="9:00 AM", open_on_complete:bool=False,
                                      reset_network_dataset:bool=False, number_of_threads:int=None,
                                      max_inputs_per_chunk:int=None):
        """
        This method is intended for generating OD cost matrices for extremely large networks/datasets. Compared with
        the standard generate_od_cost_matrix method it should be faster as it uses multiprocessing, breaking
        the problem into smaller chunks and then calculating each at the same time. Uses the SolveLargeODCostMatrix
        tool published by Esri. See more:
        https://github.com/Esri/large-network-analysis-tools?tab=readme-ov-file#Solve-Large-OD-Cost-Matrix-tool.

        :param matrix_name: (str) the name for the od-cost matrix that will be generated
        :param origins_fc_path: (str) the full path to the feature class containing the points to be used as origins in
            generating the matrix. Feature class must be points!
        :param destinations_fc_path: (str) the full path to the feature class containing the points to be used as
            destinations in generating the matrix. Feature class must be points!
        :param network_type: (str) the type of network to be used for the matrix {"transit", "walk", "drive", "bike"}
        :param use_elevation: (bool) if True, will try to use an elevation enabled network dataset
        :param day_of_week: (str) {"Today", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
        :param analysis_date: (str) date to use for the analysis (e.g. "11/15/2025")
        :param analysis_time: (str) time to use for the analysis (e.g. "9:00 AM")
        :param open_on_complete: (bool) whether to open the project automatically after the analysis is complete
        :param reset_network_dataset: (bool) WARNING! whether to reset the network dataset before running the analysis
        :param number_of_threads: (int) the number of parallel processes to use for the analysis,
            defaults to number of cores - 1 for safety
        :param max_inputs_per_chunk: (int) the maximum number of inputs to be used per chunk
        :return:
        """
        logging.info("Generating OD Cost Matrix")

        # setup environment
        arcpy.env.workspace = self.gdb_path
        arcpy.env.overwriteOutput = True

        # add type checking here

        # need to import the toolbox published by Esri
        logging.info("Importing LargeNetworkAnalysisTools toolbox to arcpy")
        path_to_toolbox = os.path.join(Path(__file__).parent, "large_network_analysis_tools",
                                       "LargeNetworkAnalysisTools.pyt")
        arcpy.ImportToolbox(path_to_toolbox)
        logging.info("Successfully imported toolbox")

        # now the rest of setup necessary before can use the tool
        logging.info("Preparing necessary data to generate the matrix")

        # set matrix name if none provided
        if matrix_name is None:
            matrix_name = (f"{analysis_date}_{day_of_week}_{analysis_time}_od_matrix")

        # now check whether network dataset of corresponding type already exists
        if use_elevation:
            using_elevation_tag = "z"
        else:
            using_elevation_tag = "no_z"
        feature_dataset_would_be_named = f"{network_type}_{using_elevation_tag}_{self.scenario_id}_fd"
        network_dataset_path = os.path.join(self.gdb_path, feature_dataset_would_be_named,
                                            f"{network_type}_{using_elevation_tag}_nd")

        # check that 1) network dataset exists and 2) reset_network_dataset is False (no ND reset desired)
        if arcpy.Exists(network_dataset_path) and not reset_network_dataset:
            logging.info(f"Using existing network dataset {network_type}_{using_elevation_tag}")
            pass
        # in case no matching network dataset could be found or reset desired
        else:
            # want to check to make sure before creating new network dataset just in case (because assuming nd is big)
            are_you_sure_this_is_right_q = input(f"No network dataset {network_type}_{using_elevation_tag} was found"
                                                 f"for {self.main_reference_place.pretty_name} {self.geographic_scope}."
                                                 f"If there is one it will be reset. "
                                                 f"Enter 'Yes' or 'Y' if you'd like to continue or 'No' or 'N' if not")
            if are_you_sure_this_is_right_q == "Yes" or are_you_sure_this_is_right_q == "Y":
                logging.info(f"Creating {network_type}_{using_elevation_tag}")
                # passing the reset_network_dataset parameter to the create_network_dataset_from_place method just in case
                self.create_network_dataset_from_place(network_type=network_type, use_elevation=use_elevation,
                                                       full_reset=reset_network_dataset,
                                                       elevation_reset=reset_network_dataset)

        # set travel mode for the analysis
        _analysis_travel_mode = network_types.network_types_attributes[f"{network_type}_{using_elevation_tag}"][
            "isochrone_travel_mode"]

        # set analysis_time to analyze
        if analysis_date is None:
            # use ArcGIS magic dates for day of week
            day_map = {
                "Today": "12/30/1899",
                "Sunday": "12/31/1899",
                "Monday": "1/1/1900",
                "Tuesday": "1/2/1900",
                "Wednesday": "1/3/1900",
                "Thursday": "1/4/1900",
                "Friday": "1/5/1900",
                "Saturday": "1/6/1900"
            }
            time_to_analyze = f"{day_map[day_of_week]} {analysis_time}"
        else:
            time_to_analyze = f"{analysis_date} {analysis_time}"

        # set the number of inputs per chunk and max processes based on number of cores and number of features
        num_cores = os.cpu_count() - 1  # includes hyperthreading, minus 1 to be safe

        # count the features for origins
        origins_count_result_object = arcpy.GetCount_management(origins_fc_path)
        origins_feature_count = int(origins_count_result_object.getOutput(0))

        # count the features for destinations
        destinations_count_result_object = arcpy.GetCount_management(destinations_fc_path)
        destinations_feature_count = int(destinations_count_result_object.getOutput(0))

        # setting max chuck size to 1000
        num_feature_per_chunk = 1000

        # IMPORTANT: outputting to csvs because faster, this is path where csvs will be made
        csv_output_folder_path = os.path.join(self.cache_folder.path, "large_oc_cost_matrix_csvs", network_type,
                                              matrix_name)
        # make sure that necessary dirs exist NOTE: if folder already exists will cause problems, so only parent dirs
        csv_output_folder_parent_dir = os.path.dirname(csv_output_folder_path)
        os.makedirs(csv_output_folder_parent_dir, exist_ok=True)

        # check out the network analyst extension (just in case)
        check_out_network_analyst_extension()

        # now have access to the LargeNetworkAnalysisTools module, next basically same as generate_od_cost_matrix
        arcpy.LargeNetworkAnalysisTools.SolveLargeODCostMatrix(Origins=origins_fc_path,
                                                               Destinations=destinations_fc_path,
                                                               Network_Data_Source=network_dataset_path,
                                                               Travel_Mode=_analysis_travel_mode,
                                                               Time_Units="Minutes",
                                                               Distance_Units="Meters",
                                                               Max_Inputs_Per_Chunk=num_feature_per_chunk,
                                                               Max_Processes=num_cores,
                                                               Output_Format="CSV files",
                                                               Output_Folder=csv_output_folder_path,
                                                               Time_Of_Day=analysis_time)

        logging.info(f"Successfully solved large OD cost matrix {matrix_name} for "
                     f"{self.main_reference_place.pretty_name} {self.geographic_scope} {analysis_time} {day_of_week}")





# # # # # # # # # # # # # # # # # # Testing Area :::: DO NOT REMOVE "if __name__ ..." # # # # # # # # # # # # # # # # #

if __name__ == "__main__":
    arc_package_project = ArcProject("upp_461_final")
    Berkeley = Place(arc_package_project, place_name="Chicago, Illinois, USA", geographic_scope="csa",
                       scenario_id="AM_Peak")
    taz_centroids_path = r"PATH_TO_CENTROIDS"
    Berkeley.generate_large_od_cost_matrix(matrix_name="PM_Peak", origins_fc_path=taz_centroids_path,
                                           destinations_fc_path=taz_centroids_path, network_type="transit",
                                           use_elevation=False, day_of_week="Monday",
                                           analysis_time="5:00 PM", open_on_complete=True)



    ### NOTE FOR DEBUGGING/FIXING CODE::::
    """
    create_network_dataset_from_place() fails if 
        1. The Place object (or an equivalent one with the same attributes) has already had a network dataset created
        and 
        2. The network dataset created prior's use_elevation attribute does not match the use_elevation attribute of the 
            current call. 
            
    This is because switched from caching the edges gdf to using geojson as cache and now precalculating walk times (as 
    opposed to calculating field IN Arc before). Easy solution is to calculate both "flat_walk_time" and graded walk times.
    Need to edit template xml tho
    
    also: at the moment, when using specified, MSA, or CSA as geographic scope, the graphs that osmnx returns only 
    include edges that are fully inside a given place. Thus, because in my workflow I am getting the graphs for each
    place in the list (whether that be each county for CSAs/MSAs or each specified place for specified), and then 
    composing them, at the moment the resulting network dataset has gaps at the borders of the places (in parentheses 
    above). I'm planning on fixing this by getting the polygon for each place from osmnx using the geocode_to_gdf 
    function and then combining the polygons using shapely, and then using graph_from_polygon instead of ...from_place,
    but that will be next. 
"""

Python
""" This is the main user-interface module for those who want to experiment with my code """

from tools_for_place import *

if __name__ == "__main__":
    your_arcgis_project = ArcProject(name="code_demonstration") # input the name of the ArcGIS Pro project you want to
                                                           # use (this will be created if it doesn't exist)

    your_place = Place(arcgis_project=your_arcgis_project, place_name = "Berkeley, California, USA",
                       geographic_scope="place_only")   # input the place you want to create a network dataset (see
                                                        # documentation for Place for more information)

    your_place.create_network_dataset_from_place()  # input any parameters you would like to change, or run the default
                                                    # and see what happens! (see documentation
                                                    # for create_network_dataset_from_place for more information)

    your_place.generate_isochrone(addresses=["2534 Durant Ave, Berkeley, CA, USA","1561 Solano Ave, Berkeley, CA, USA"])
                                                    # input any parameters you would like to change, or run the default
                                                    # and see what happens!
                                                    # (see documentation for generate_isochrone for more information)

Leave a Reply

Discover more from Aaron Rumph

Subscribe now to keep reading and get access to the full archive.

Continue reading