Skip to content

Commit

Permalink
Mypy Update (#731)
Browse files Browse the repository at this point in the history
* set up

* codice update

* mag update

* swapi update

* fix for failed build the docs

* build the docs try

* attempt with fully expanded

* trying with numpy.array

* using numpy.typing.ArrayLike

* sphinx update

* sphinx update

* adding sutodoc

* it didnt work

* change to numpy.ndarray

* ultra update

* another try at build the docs

* putting docs class type in

* giving a return object name

* glows update

* idex update

* lo update

* swe update

* utils update

* cdf update

* tools update

* sew fix

* swe fixes

* pre-commit fix

* pre-commit will pass but docs will not

* hopefully the test shoyld work

* fixing test failures

* undoing all swe changes

* undoing all swe changes

* undoing all swe changes

* undoing all swe changes

* undoing all swe changes

* docs change

* trying to get build the docs to work

* trying to get build the docs to work

* test

* test

* test

* test

* test

* mypy changes, hopefully wont break docs

* mypy passes

* test for docs to work

* changes for greg

* changes from greg

* reformat change

* changes from Matthew
  • Loading branch information
daralynnrhode committed Aug 9, 2024
1 parent cb26199 commit e6908f8
Show file tree
Hide file tree
Showing 20 changed files with 77 additions and 56 deletions.
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,5 @@ repos:
rev: 'v1.10.0'
hooks:
- id: mypy
pass_filenames: false
args: [ ., --strict, --explicit-package-bases,
--disable-error-code, import-untyped,
--disable-error-code, import-not-found,
--disable-error-code, no-untyped-call,
--disable-error-code, type-arg ]
exclude: .*(tests|docs).*
additional_dependencies: [ numpy==1.26.4 ]
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
(r"py:.*", r".*InitVar*"),
(r"py:.*", r".*.glows.utils.constants.TimeTuple.*"),
(r"py:.*", r".*glows.utils.constants.DirectEvent.*"),
(r"py:.*", r".*numpy.int.*"),
(r"py:.*", r".*np.ndarray.*"),
(r"py:.*", r".*numpy._typing._array_like._ScalarType_co.*"),
]

# Ignore the inherited members from the <instrument>APID IntEnum class
Expand Down
13 changes: 6 additions & 7 deletions imap_processing/cdf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import re
from pathlib import Path
from typing import Optional

import imap_data_access
import numpy as np
Expand All @@ -25,8 +24,8 @@

def met_to_j2000ns(
met: np.typing.ArrayLike,
reference_epoch: Optional[np.datetime64] = IMAP_EPOCH,
) -> np.typing.ArrayLike:
reference_epoch: np.datetime64 = IMAP_EPOCH,
) -> np.typing.NDArray[np.int64]:
"""
Convert mission elapsed time (MET) to nanoseconds from J2000.
Expand Down Expand Up @@ -56,10 +55,10 @@ def met_to_j2000ns(
# to 32bit and overflow due to the nanosecond multiplication
time_array = (np.asarray(met, dtype=float) * 1e9).astype(np.int64)
# Calculate the time difference between our reference system and J2000
j2000_offset = (
(reference_epoch - J2000_EPOCH).astype("timedelta64[ns]").astype(np.int64)
)
return j2000_offset + time_array
j2000_offset: np.typing.NDArray[np.datetime64] = (
reference_epoch - J2000_EPOCH
).astype("datetime64[ns]")
return j2000_offset.astype(np.int64) + time_array


def load_cdf(
Expand Down
4 changes: 3 additions & 1 deletion imap_processing/codice/codice_l1a.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ def process_codice_l1a(file_path: Path, data_version: str) -> xr.Dataset:
apid = CODICEAPID.COD_HI_SECT_SPECIES_COUNTS
table_id, plan_id, plan_step, view_id = (1, 0, 0, 6)

met0 = (np.datetime64("2024-04-29T00:00") - IMAP_EPOCH).astype("timedelta64[s]")
met0: np.timedelta64 = (np.datetime64("2024-04-29T00:00") - IMAP_EPOCH).astype(
"timedelta64[s]"
)
met0 = met0.astype(np.int64)
met = [met0, met0 + 1] # Using this to match the other data products
science_values = "" # Currently don't have simulated data for this
Expand Down
7 changes: 4 additions & 3 deletions imap_processing/glows/l1b/glows_l1b_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -250,11 +251,11 @@ class DirectEventL1B:
# l1a_file_name: str # TODO: Add once L1A questions are answered
# ancillary_data_files: np.ndarray # TODO: Add once L1A questions are answered
# The following variables are created from the InitVar data
de_flags: np.ndarray = field(init=False, default=None)
de_flags: Optional[np.ndarray] = field(init=False, default=None)
# TODO: First two values of DE are sec/subsec
direct_event_glows_times: np.ndarray = field(init=False, default=None)
direct_event_glows_times: Optional[np.ndarray] = field(init=False, default=None)
# 3rd value is pulse length
direct_event_pulse_lengths: np.ndarray = field(init=False, default=None)
direct_event_pulse_lengths: Optional[np.ndarray] = field(init=False, default=None)
# TODO: where does the multi-event flag go?

def __post_init__(
Expand Down
5 changes: 3 additions & 2 deletions imap_processing/idex/l1/idex_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import IntEnum

import numpy as np
import numpy.typing as npt
import space_packet_parser
import xarray as xr

Expand Down Expand Up @@ -421,7 +422,7 @@ def _parse_low_sample_waveform(self, waveform_raw: str) -> list[int]:
]
return ints

def _calc_low_sample_resolution(self, num_samples: int) -> np.ndarray:
def _calc_low_sample_resolution(self, num_samples: int) -> npt.NDArray:
"""
Calculate the resolution of the low samples.
Expand All @@ -447,7 +448,7 @@ def _calc_low_sample_resolution(self, num_samples: int) -> np.ndarray:
)
return time_low_sr_data

def _calc_high_sample_resolution(self, num_samples: int) -> np.ndarray:
def _calc_high_sample_resolution(self, num_samples: int) -> npt.NDArray:
"""
Calculate the resolution of high samples.
Expand Down
5 changes: 3 additions & 2 deletions imap_processing/lo/l0/data_classes/science_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass

import numpy as np
import numpy.typing as npt
import space_packet_parser

from imap_processing.ccsds.ccsds_data import CcsdsData
Expand Down Expand Up @@ -277,7 +278,7 @@ def _parse_section(
binary_string: BinaryString,
decompression: Decompress,
data_shape: tuple[int, int],
) -> np.array:
) -> npt.NDArray:
"""
Parse a single section of data in the science counts data binary.
Expand Down Expand Up @@ -322,7 +323,7 @@ def _extract_binary(
section_length: int,
bit_length: int,
decompression: Decompress,
) -> np.ndarray:
) -> npt.NDArray:
"""
Extract and decompress science count binary data section.
Expand Down
4 changes: 3 additions & 1 deletion imap_processing/lo/l0/data_classes/science_direct_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def _decompress_data(self) -> None:

# Case decoder indicates which parts of the data
# are transmitted for each case.
case_decoder = CASE_DECODER[(case_number, self.MODE[de_idx])]
case_decoder = CASE_DECODER[(case_number, self.MODE[de_idx])] # type: ignore[index]
# Todo Mypy Error: Invalid index type "tuple[int, ndarray[Any, Any]]" for
# "dict[tuple[int, int], TOFFields]"; expected type "tuple[int, int]"

# Check the case decoder to see if the TOF field was
# transmitted for this case. Then grab the bits from
Expand Down
2 changes: 1 addition & 1 deletion imap_processing/lo/l0/data_classes/star_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class StarSensor(LoBase):
SHCOARSE: int
COUNT: int
DATA_COMPRESSED: str
DATA: np.array
DATA: np.ndarray

# TODO: Because test data does not currently exist, the init function contents
# must be commented out for the unit tests to run properly
Expand Down
4 changes: 3 additions & 1 deletion imap_processing/mag/l0/mag_l0_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __post_init__(self) -> None:
# Convert string output from space_packet_parser to numpy array of
# big-endian bytes
self.VECTORS = np.frombuffer(
int(self.VECTORS, 2).to_bytes(len(self.VECTORS) // 8, "big"),
int(self.VECTORS, 2).to_bytes(len(self.VECTORS) // 8, "big"), # type: ignore[arg-type]
# TODO Check MYPY Error: Argument 1 to "int" has incompatible type
# "Union[ndarray[Any, Any], str]"; expected "Union[str, bytes, bytearray]"
dtype=np.dtype(">b"),
)

Expand Down
9 changes: 5 additions & 4 deletions imap_processing/mag/l1a/mag_l1a_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import floor

import numpy as np
import numpy.typing as npt

from imap_processing.cdf.utils import J2000_EPOCH, met_to_j2000ns

Expand Down Expand Up @@ -190,7 +191,7 @@ class MagL1a:
is_mago: bool
is_active: int
shcoarse: int
vectors: np.array
vectors: np.ndarray
starting_packet: InitVar[MagL1aPacketProperties]
packet_definitions: dict[np.datetime64, MagL1aPacketProperties] = field(init=False)
most_recent_sequence: int = field(init=False)
Expand All @@ -216,14 +217,14 @@ def __post_init__(self, starting_packet: MagL1aPacketProperties) -> None:
self.most_recent_sequence = starting_packet.src_seq_ctr

def append_vectors(
self, additional_vectors: np.array, packet_properties: MagL1aPacketProperties
self, additional_vectors: np.ndarray, packet_properties: MagL1aPacketProperties
) -> None:
"""
Append additional vectors to the current vectors array.
Parameters
----------
additional_vectors : numpy.array
additional_vectors : numpy.ndarray
New vectors to append.
packet_properties : MagL1aPacketProperties
Additional vector definition to add to the l0_packets dictionary.
Expand All @@ -244,7 +245,7 @@ def append_vectors(
@staticmethod
def calculate_vector_time(
vectors: np.ndarray, vectors_per_sec: int, start_time: TimeTuple
) -> np.array:
) -> npt.NDArray:
"""
Add timestamps to the vector list, turning the shape from (n, 4) to (n, 5).
Expand Down
17 changes: 9 additions & 8 deletions imap_processing/swapi/l1/swapi_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy

import numpy as np
import numpy.typing as npt
import xarray as xr

from imap_processing import imap_module_directory
Expand All @@ -11,7 +12,7 @@
from imap_processing.utils import packet_file_to_datasets


def filter_good_data(full_sweep_sci: xr.Dataset) -> np.ndarray:
def filter_good_data(full_sweep_sci: xr.Dataset) -> npt.NDArray:
"""
Filter out bad data sweep indices.
Expand All @@ -29,7 +30,7 @@ def filter_good_data(full_sweep_sci: xr.Dataset) -> np.ndarray:
Returns
-------
numpy.ndarray
good_data_indices : numpy.ndarray
Good data sweep indices.
"""
# PLAN_ID for current sweep should all be one value and
Expand Down Expand Up @@ -70,7 +71,7 @@ def filter_good_data(full_sweep_sci: xr.Dataset) -> np.ndarray:

def decompress_count(
count_data: np.ndarray, compression_flag: np.ndarray
) -> np.ndarray:
) -> npt.NDArray:
"""
Will decompress counts based on compression indicators.
Expand Down Expand Up @@ -99,7 +100,7 @@ def decompress_count(
Returns
-------
numpy.ndarray
new_count : numpy.ndarray
Array with decompressed counts.
"""
# Decompress counts based on compression indicators
Expand All @@ -120,7 +121,7 @@ def decompress_count(
return new_count


def find_sweep_starts(packets: xr.Dataset) -> np.ndarray:
def find_sweep_starts(packets: xr.Dataset) -> npt.NDArray:
"""
Find index of where new cycle started.
Expand All @@ -138,7 +139,7 @@ def find_sweep_starts(packets: xr.Dataset) -> np.ndarray:
Returns
-------
numpy.ndarray
indices_start : numpy.ndarray
Array of indices of start cycle.
"""
if packets["epoch"].size < 12:
Expand Down Expand Up @@ -175,7 +176,7 @@ def find_sweep_starts(packets: xr.Dataset) -> np.ndarray:
return np.where(valid)[0]


def get_indices_of_full_sweep(packets: xr.Dataset) -> np.ndarray:
def get_indices_of_full_sweep(packets: xr.Dataset) -> npt.NDArray:
"""
Get indices of full cycles.
Expand All @@ -195,7 +196,7 @@ def get_indices_of_full_sweep(packets: xr.Dataset) -> np.ndarray:
Returns
-------
numpy.ndarray
full_cycle_indices : numpy.ndarray
1D array with indices of full cycle data.
"""
indices_of_start = find_sweep_starts(packets)
Expand Down
4 changes: 2 additions & 2 deletions imap_processing/swe/l1a/swe_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def swe_science(decom_data: list, data_version: str) -> xr.Dataset:
science_array = []
raw_science_array = []

metadata_arrays: np.array = collections.defaultdict(list)
metadata_arrays: dict[list] = collections.defaultdict(list)

# We know we can only have 8 bit numbers input, so iterate over all
# possibilities once up front
Expand All @@ -133,7 +133,7 @@ def swe_science(decom_data: list, data_version: str) -> xr.Dataset:
# where 1260 = 180 x 7 CEMs
# Take the "raw_counts" indices/counts mapping from
# decompression_table and then reshape the return
uncompress_data = np.take(decompression_table, raw_counts).reshape(180, 7)
uncompress_data = np.take(decompression_table, raw_counts).reshape(180, 7) # type: ignore[attr-defined]
# Save raw counts data as well
raw_counts = raw_counts.reshape(180, 7)

Expand Down
11 changes: 6 additions & 5 deletions imap_processing/swe/l1b/swe_l1b_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr

Expand Down Expand Up @@ -70,7 +71,7 @@ def read_lookup_table(table_index_value: int) -> Any:
raise ValueError("Error: Invalid table index value")


def deadtime_correction(counts: np.ndarray, acq_duration: int) -> np.ndarray:
def deadtime_correction(counts: np.ndarray, acq_duration: int) -> npt.NDArray:
"""
Calculate deadtime correction.
Expand Down Expand Up @@ -118,7 +119,7 @@ def deadtime_correction(counts: np.ndarray, acq_duration: int) -> np.ndarray:
return corrected_count


def convert_counts_to_rate(data: np.ndarray, acq_duration: int) -> np.ndarray:
def convert_counts_to_rate(data: np.ndarray, acq_duration: int) -> npt.NDArray:
"""
Convert counts to rate using sampling time.
Expand Down Expand Up @@ -206,7 +207,7 @@ def apply_in_flight_calibration(data: np.ndarray) -> None:

def populate_full_cycle_data(
l1a_data: xr.Dataset, packet_index: int, esa_table_num: int
) -> np.ndarray:
) -> npt.NDArray:
"""
Populate full cycle data array using esa lookup table and l1a_data.
Expand Down Expand Up @@ -277,7 +278,7 @@ def populate_full_cycle_data(
return full_cycle_data


def find_cycle_starts(cycles: np.ndarray) -> np.ndarray:
def find_cycle_starts(cycles: np.ndarray) -> npt.NDArray:
"""
Find index of where new cycle started.
Expand Down Expand Up @@ -312,7 +313,7 @@ def find_cycle_starts(cycles: np.ndarray) -> np.ndarray:
return np.where(valid)[0]


def get_indices_of_full_cycles(quarter_cycle: np.ndarray) -> np.ndarray:
def get_indices_of_full_cycles(quarter_cycle: np.ndarray) -> npt.NDArray:
"""
Get indices of full cycles.
Expand Down
3 changes: 2 additions & 1 deletion imap_processing/ultra/l0/decom_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Ultra Decompression Tools."""

import numpy as np
import numpy.typing as npt
import space_packet_parser

from imap_processing.ultra.l0.ultra_utils import (
Expand Down Expand Up @@ -154,7 +155,7 @@ def decompress_image(
binary_data: str,
width_bit: int,
mantissa_bit_length: int,
) -> np.ndarray:
) -> npt.NDArray:
"""
Will decompress a binary string representing an image into a matrix of pixel values.
Expand Down
2 changes: 1 addition & 1 deletion imap_processing/ultra/l0/decom_ultra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def append_tof_params(
decom_data: dict,
packet: Packet,
decompressed_data: list,
decompressed_data: np.ndarray,
data_dict: dict,
stacked_dict: dict,
) -> None:
Expand Down
Loading

0 comments on commit e6908f8

Please sign in to comment.