diff --git a/.gitignore b/.gitignore index 3e029c27..330ab01c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.env *.pyc vendor/* Assignee_Lawyer_Disambiguation/lib/alchemy/config.ini @@ -8,16 +9,14 @@ venv .vscode *.yml *.yaml -mydumper/ +mydumper/* +!mydumper/mydumper.cnf.template Development/config.ini -airflow/airflow-webserver.pid -airflow/airflow.db -airflow/lawyer.pickle -airflow/logs/scheduler/2019-01-15/update-db.py.log -airflow/logs/scheduler/2019-01-16/update-db.py.log -airflow/logs/scheduler/latest +airflow/* +!airflow/dags/* +!airflow/airflow.cfg.template airflow-metadata-db-disk -airflow/unittests.cfg +airflow_pipeline_env.sh Development/dev_config.ini Assignee_Lawyer_Disambiguation/lib/alchemy/alchemy_config.ini airflow-metadata-db-disk/* @@ -28,6 +27,8 @@ node_modules/ app-db-exploration QA_* output/ +go_live/*.json +go_live/*.csv upload_*/* *_qa_loc* **/*.log @@ -44,7 +45,9 @@ scratch **/*.err pgpubs_* resources/sql.conf +resources/us-patent-application-*.dtd config.ini airflow-metadata-* patent_db_disk/ TableToggle.json +Z_Frame_job-*.csv diff --git a/QA/DatabaseTester.py b/QA/DatabaseTester.py index 67edc247..a368f212 100644 --- a/QA/DatabaseTester.py +++ b/QA/DatabaseTester.py @@ -21,6 +21,7 @@ from lib import utilities from datetime import datetime, timedelta + class DatabaseTester(ABC): def __init__(self, config, database_section, start_date, end_date, where_vi=True): # super().__init__(config, database_section, start_date, end_date) @@ -32,38 +33,52 @@ def __init__(self, config, database_section, start_date, end_date, where_vi=True self.where_vi = where_vi # Indicator for Upload/Patents database - self.qa_connection_string = get_connection_string(config, database='QA_DATABASE', connection='APP_DATABASE_SETUP') - self.connection = pymysql.connect(host=config['DATABASE_SETUP']['HOST'], - user=config['DATABASE_SETUP']['USERNAME'], - password=config['DATABASE_SETUP']['PASSWORD'], - db=database_section, - charset='utf8mb4', - cursorclass=pymysql.cursors.SSCursor, defer_connect=True) + self.qa_connection_string = get_connection_string( + config, database="QA_DATABASE", connection="APP_DATABASE_SETUP" + ) + self.connection = pymysql.connect( + host=config["DATABASE_SETUP"]["HOST"], + user=config["DATABASE_SETUP"]["USERNAME"], + password=config["DATABASE_SETUP"]["PASSWORD"], + db=database_section, + charset="utf8mb4", + cursorclass=pymysql.cursors.SSCursor, + defer_connect=True, + ) # self.database_connection_string = get_connection_string(config, database_section) self.config = config self.database_section = database_section self.class_called = class_called - if self.class_called == 'TextQuarterlyMergeTest' and database_section == 'pgpubs_text': - database_type = 'pregrant' + if ( + self.class_called == "TextQuarterlyMergeTest" + and database_section == "pgpubs_text" + ): + database_type = "pregrant" else: try: database_type = self.database_section.split("_")[0] except IndexError: database_type = self.database_section - if self.class_called == 'TextQuarterlyMergeTest' and database_section == 'patent_text': + if ( + self.class_called == "TextQuarterlyMergeTest" + and database_section == "patent_text" + ): self.version = self.find_previous_tuesday() - elif self.class_called == 'TextQuarterlyMergeTest' and database_section == 'pgpubs_text': + elif ( + self.class_called == "TextQuarterlyMergeTest" + and database_section == "pgpubs_text" + ): self.version = self.find_previous_thursday() else: self.version = self.end_date.strftime("%Y-%m-%d") # Add Quarter Variable - df = pd.DataFrame(columns=['date']) + df = pd.DataFrame(columns=["date"]) df.loc[0] = [self.version] - df['quarter'] = pd.to_datetime(df.date).dt.to_period('Q') - quarter = str(df['quarter'][0]) + df["quarter"] = pd.to_datetime(df.date).dt.to_period("Q") + quarter = str(df["quarter"][0]) self.quarter = quarter[:4] + "-" + quarter[5] ##### @@ -72,43 +87,53 @@ def __init__(self, config, database_section, start_date, end_date, where_vi=True def find_previous_thursday(self): # Subtract days until you reach the previous Thursday - while self.end_date.weekday() != 3: # 3 corresponds to Thursday (0 is Monday, 1 is Tuesday, and so on) + while ( + self.end_date.weekday() != 3 + ): # 3 corresponds to Thursday (0 is Monday, 1 is Tuesday, and so on) self.end_date -= timedelta(days=1) return self.end_date def find_previous_tuesday(self): # Subtract days until you reach the previous Tuesday - while self.end_date.weekday() != 1: # 1 corresponds to Tuesday (0 is Monday, 2 is Wednesday, and so on) + while ( + self.end_date.weekday() != 1 + ): # 1 corresponds to Tuesday (0 is Monday, 2 is Wednesday, and so on) self.end_date -= timedelta(days=1) return self.end_date.strftime("%Y-%m-%d") def init_qa_dict(self): # Place Holder for saving QA counts - keys map to table names in patent_QA self.qa_data = { - 'DataMonitor_count': [], - 'DataMonitor_nullcount': [], - 'DataMonitor_patentyearlycount': [], - 'DataMonitor_categorycount': [], - 'DataMonitor_floatingentitycount': [], - 'DataMonitor_maxtextlength': [], - 'DataMonitor_prefixedentitycount': [], - 'DataMonitor_locationcount': [], - 'DataMonitor_indexcount': [] + "DataMonitor_count": [], + "DataMonitor_nullcount": [], + "DataMonitor_patentyearlycount": [], + "DataMonitor_categorycount": [], + "DataMonitor_floatingentitycount": [], + "DataMonitor_maxtextlength": [], + "DataMonitor_prefixedentitycount": [], + "DataMonitor_locationcount": [], + "DataMonitor_indexcount": [], } - def query_runner(self, query, single_value_return=True, where_vi=False, vi_comparison = '='): - query = sqlparse.format(query, reindent = True, keyword_case ='lower') + def query_runner( + self, query, single_value_return=True, where_vi=False, vi_comparison="=" + ): + query = sqlparse.format(query, reindent=True, keyword_case="lower") vi_comparison = vi_comparison.strip() - assert vi_comparison in ['=', '<', '>', '<=', '>=', '<>', '!='] + assert vi_comparison in ["=", "<", ">", "<=", ">=", "<>", "!="] if where_vi: - vi_date = self.end_date.strftime('%Y-%m-%d') - if 'where' and 'main_table' in query: - where_statement = f" and main_table.version_indicator {vi_comparison} '{vi_date}'" - elif 'where' in query: + vi_date = self.end_date.strftime("%Y-%m-%d") + if "where" and "main_table" in query: + where_statement = ( + f" and main_table.version_indicator {vi_comparison} '{vi_date}'" + ) + elif "where" in query: where_statement = f" and version_indicator {vi_comparison} '{vi_date}'" else: - where_statement = f" where version_indicator {vi_comparison} '{vi_date}'" - q = query+where_statement + where_statement = ( + f" where version_indicator {vi_comparison} '{vi_date}'" + ) + q = query + where_statement else: q = query logger.info(q) @@ -120,7 +145,9 @@ def query_runner(self, query, single_value_return=True, where_vi=False, vi_compa query_start_time = time() generic_cursor.execute(q) query_end_time = time() - logger.info(f"\t\tThis query took {query_end_time - query_start_time:.3f} seconds") + logger.info( + f"\t\tThis query took {query_end_time - query_start_time:.3f} seconds" + ) if single_value_return: count_value = generic_cursor.fetchall()[0][0] else: @@ -131,50 +158,64 @@ def query_runner(self, query, single_value_return=True, where_vi=False, vi_compa self.connection.close() return count_value - - def load_table_row_count(self, table_name, where_vi): query = f""" SELECT count(*) as table_count from {table_name}""" - count_value = self.query_runner(query, single_value_return=True, where_vi=where_vi) - if count_value < 1 and table_name not in ['rawuspc', 'uspc', 'government_organization']: + count_value = self.query_runner( + query, single_value_return=True, where_vi=where_vi + ) + if count_value < 1 and table_name not in [ + "rawuspc", + "uspc", + "government_organization", + ]: raise Exception("Empty table found:{table}".format(table=table_name)) - self.qa_data['DataMonitor_count'].append( + self.qa_data["DataMonitor_count"].append( { "database_type": self.database_type, - 'table_name': table_name, - 'update_version': self.version, - 'table_row_count': count_value, - 'quarter': self.quarter - }) - + "table_name": table_name, + "update_version": self.version, + "table_row_count": count_value, + "quarter": self.quarter, + } + ) def test_blank_count(self, table, table_config, where_vi): for field in table_config["fields"]: - if table_config["fields"][field]["data_type"] in ['varchar', 'mediumtext', 'text']: + if table_config["fields"][field]["data_type"] in [ + "varchar", + "mediumtext", + "text", + ]: count_query = f""" SELECT count(*) as blank_count from `{table}` where `{field}` = ''""" - count_value = self.query_runner(count_query, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + count_query, single_value_return=True, where_vi=where_vi + ) if count_value != 0: exception_message = """ Blanks encountered in table found:{database}.{table} column {col}. Count: {count} - """.format(database=self.database_section, - table=table, col=field, - count=count_value) + """.format( + database=self.database_section, + table=table, + col=field, + count=count_value, + ) raise Exception(exception_message) - def test_null_byte(self, table, field, where_vi): nul_byte_query = f""" SELECT count(*) as count from `{table}` where INSTR(`{field}`, CHAR(0x00)) > 0""" - count_value = self.query_runner(nul_byte_query, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + nul_byte_query, single_value_return=True, where_vi=where_vi + ) if count_value > 0: - # attempt automatic correction + # attempt automatic correction bad_char_fix_query = f""" UPDATE `{table}` SET `{field}` = REPLACE(REPLACE(REPLACE(`{field}`, CHAR(0x00), ''), CHAR(0x08), ' b'), CHAR(0x1A), 'Z') @@ -191,45 +232,55 @@ def test_null_byte(self, table, field, where_vi): finally: if self.connection.open: self.connection.close() - logger.info(f"attempted to correct newlines in {table}.{field}. re-performing newline detection query:") + logger.info( + f"attempted to correct newlines in {table}.{field}. re-performing newline detection query:" + ) logger.info(nul_byte_query) - count_value = self.query_runner(nul_byte_query, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + nul_byte_query, single_value_return=True, where_vi=where_vi + ) if count_value > 0: exception_message = f"{count_value} rows with NUL Byte found in `{field}` of `{self.database_section}`.`{table}` after attempted correction." raise Exception(exception_message) - - def test_newlines(self, table, field, where_vi): skip = False - allowables = { # set of tables and fields where newlines are allowable in the field content - 'brf_sum_text' : ['summary_text'], - 'detail_desc_text' : ['description_text'], - 'claims' : ['claim_text'], - 'rel_app_text' : ['text'] + allowables = { # set of tables and fields where newlines are allowable in the field content + "brf_sum_text": ["summary_text"], + "detail_desc_text": ["description_text"], + "claims": ["claim_text"], + "rel_app_text": ["text"], } # autofixes = { # 'draw_desc_text' : ['draw_desc_text'], # 'rawassignee': ['orgnaization'] # } - if table in allowables: #non-text tables + if table in allowables: # non-text tables if field in allowables[table]: skip = True - elif re.match(".*_[0-9]{4}", table) and table[:-5] in allowables: #text-tables + elif re.match(".*_[0-9]{4}", table) and table[:-5] in allowables: # text-tables if field in allowables[table[:-5]]: skip = True if skip: - logger.info('newlines marked as permitted for this field. skipping newline test') - else: + logger.info( + "newlines marked as permitted for this field. skipping newline test" + ) + else: newline_query = f""" SELECT count(*) as count from `{table}` where INSTR(`{field}`, '\n') > 0""" - count_value = self.query_runner(newline_query, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + newline_query, single_value_return=True, where_vi=where_vi + ) if count_value > 0: - logger.info(f"{count_value} rows with unwanted newlines found in {field} of {table} for {self.database_section}. Correcting records ...") - makelogquery = f"CREATE TABLE IF NOT EXISTS `{table}_newline_log` LIKE {table}" + logger.info( + f"{count_value} rows with unwanted newlines found in {field} of {table} for {self.database_section}. Correcting records ..." + ) + makelogquery = ( + f"CREATE TABLE IF NOT EXISTS `{table}_newline_log` LIKE {table}" + ) filllogquery = f"INSERT INTO `{table}_newline_log` SELECT * FROM `{table}` WHERE `{field}` LIKE '%\n%'" fixquery = f""" UPDATE `{table}` @@ -248,35 +299,40 @@ def test_newlines(self, table, field, where_vi): finally: if self.connection.open: self.connection.close() - logger.info(f"attempted to correct newlines in {table}.{field}. re-performing newline detection query:") - count_value = self.query_runner(newline_query, single_value_return=True, where_vi=where_vi) + logger.info( + f"attempted to correct newlines in {table}.{field}. re-performing newline detection query:" + ) + count_value = self.query_runner( + newline_query, single_value_return=True, where_vi=where_vi + ) if count_value > 0: exception_message = f"{count_value} rows with unwanted and unfixed newlines found in {field} of {table} for {self.database_section}" raise Exception(exception_message) - def load_category_counts(self, table, field): category_count_query = f""" SELECT `{field}` as value , count(*) as count from `{table}` group by 1""" - count_value = self.query_runner(category_count_query, single_value_return=False, where_vi=False) + count_value = self.query_runner( + category_count_query, single_value_return=False, where_vi=False + ) for count_row in count_value: value = count_row[0] if value is None: - value = 'NULL' - self.qa_data['DataMonitor_categorycount'].append( + value = "NULL" + self.qa_data["DataMonitor_categorycount"].append( { "database_type": self.database_type, - 'table_name': table, + "table_name": table, "column_name": field, - 'update_version': self.version, - 'value': value, - 'count': count_row[1], - 'quarter': self.quarter - }) - + "update_version": self.version, + "value": value, + "count": count_row[1], + "quarter": self.quarter, + } + ) def load_nulls(self, table, table_config, where_vi): for field in table_config["fields"]: @@ -284,25 +340,30 @@ def load_nulls(self, table, table_config, where_vi): SELECT count(*) as null_count from `{table}` where `{field}` is null """ - count_value = self.query_runner(count_query, single_value_return=True, where_vi=where_vi) - if not table_config["fields"][field]['null_allowed']: + count_value = self.query_runner( + count_query, single_value_return=True, where_vi=where_vi + ) + if not table_config["fields"][field]["null_allowed"]: if count_value != 0: raise Exception( "NULLs encountered in table found:{database}.{table} column {col}. Count: {" "count}".format( - database=self.database_section, table=table, + database=self.database_section, + table=table, col=field, - count=count_value)) - self.qa_data['DataMonitor_nullcount'].append( + count=count_value, + ) + ) + self.qa_data["DataMonitor_nullcount"].append( { "database_type": self.database_type, - 'table_name': table, + "table_name": table, "column_name": field, - 'update_version': self.version, - 'null_count': count_value, - 'quarter': self.quarter - }) - + "update_version": self.version, + "null_count": count_value, + "quarter": self.quarter, + } + ) def test_zero_dates(self, table, field, where_vi): zero_query = f""" @@ -312,32 +373,40 @@ def test_zero_dates(self, table, field, where_vi): OR `{field}` LIKE '____-00-__' OR `{field}` LIKE '____-__-00' """ - count_value = self.query_runner(zero_query, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + zero_query, single_value_return=True, where_vi=where_vi + ) if count_value != 0: raise Exception( "zero date encountered in table found:{database}.{table} column {col}. Count: {" "count}".format( - database=self.database_section, table=table, col=field, - count=count_value)) - + database=self.database_section, + table=table, + col=field, + count=count_value, + ) + ) def test_null_version_indicator(self, table): - null_vi_query = \ -f"SELECT count(*) null_count " \ -f"from {table} " \ -f"where version_indicator is null" + null_vi_query = ( + f"SELECT count(*) null_count " + f"from {table} " + f"where version_indicator is null" + ) count_value = self.query_runner(null_vi_query, single_value_return=True) if count_value != 0: raise Exception( "Table {database}.{table} Has {count} Nulls in Version Indicator".format( - database=self.database_section, table=table, count=count_value)) - + database=self.database_section, table=table, count=count_value + ) + ) def test_white_space(self, table, field): - white_space_query = \ -f"SELECT count(*) " \ -f"from {table} W" \ -f"HERE CHAR_LENGTH(`{field}`) != CHAR_LENGTH(TRIM(`{field}`))" + white_space_query = ( + f"SELECT count(*) " + f"from {table} W" + f"HERE CHAR_LENGTH(`{field}`) != CHAR_LENGTH(TRIM(`{field}`))" + ) count_value = self.query_runner(white_space_query, single_value_return=True) if count_value != 0: logger.info("THE FOLLOWING QUERY NEEDS ADDRESSING") @@ -357,39 +426,51 @@ def check_for_indexes(self, table): print(f"Index count: {count_value} ({type(count_value)})") if int(count_value) == 0: current_db_query = "SELECT DATABASE();" - current_db = self.query_runner(current_db_query, single_value_return=True) + current_db = self.query_runner( + current_db_query, single_value_return=True + ) print(f"Connected to database: {current_db}") host_check_query = "SELECT @@hostname;" hostname = self.query_runner(host_check_query, single_value_return=True) print(f"Airflow is connected to MySQL host: {hostname}") logger.info(index_query) raise Exception(f"{self.database_section}.{table} has no indexes") - self.qa_data['DataMonitor_indexcount'].append( + self.qa_data["DataMonitor_indexcount"].append( { "database_type": self.database_type, - 'table_name': table, - 'update_version': self.version, - 'index_count': count_value, - 'quarter': self.quarter - }) + "table_name": table, + "update_version": self.version, + "index_count": count_value, + "quarter": self.quarter, + } + ) def test_rawassignee_org(self, table, where_vi=False): rawassignee_q = """ SELECT count(*) FROM rawassignee where name_first is not null and name_last is null""" - count_value = self.query_runner(rawassignee_q, single_value_return=True, where_vi=where_vi) + count_value = self.query_runner( + rawassignee_q, single_value_return=True, where_vi=where_vi + ) if count_value != 0: logger.info("THE FOLLOWING QUERY NEEDS ADDRESSING") logger.info(rawassignee_q) - raise Exception(f"{self.database_section}.{table} Has Wrong Organization values") - - - def test_related_floating_entities(self, table_name, table_config, where_vi=False, vi_comparison = '='): - if table_name not in self.exclusion_list and 'related_entities' in table_config: - for related_entity_config in table_config['related_entities']: - exists_query = f"""SHOW TABLES LIKE '{related_entity_config["related_table"]}'; """ - exists_table_count = self.query_runner(exists_query, single_value_return=False, where_vi=False) + raise Exception( + f"{self.database_section}.{table} Has Wrong Organization values" + ) + + def test_related_floating_entities( + self, table_name, table_config, where_vi=False, vi_comparison="=" + ): + if table_name not in self.exclusion_list and "related_entities" in table_config: + for related_entity_config in table_config["related_entities"]: + exists_query = ( + f"""SHOW TABLES LIKE '{related_entity_config["related_table"]}'; """ + ) + exists_table_count = self.query_runner( + exists_query, single_value_return=False, where_vi=False + ) if not exists_table_count: continue else: @@ -400,41 +481,57 @@ def test_related_floating_entities(self, table_name, table_config, where_vi=Fals where main_table.{main_table_id} is null and related_table.{related_table_id} is not null """.format( main_table=table_name, - related_table=related_entity_config['related_table'], - main_table_id=related_entity_config['main_table_id'], - related_table_id=related_entity_config['related_table_id']) - related_count = self.query_runner(related_query, single_value_return=True, where_vi=where_vi, vi_comparison=vi_comparison) + related_table=related_entity_config["related_table"], + main_table_id=related_entity_config["main_table_id"], + related_table_id=related_entity_config["related_table_id"], + ) + related_count = self.query_runner( + related_query, + single_value_return=True, + where_vi=where_vi, + vi_comparison=vi_comparison, + ) if related_count > 0: raise Exception( "There are rows for the id: {related_table_id} in {related_table} that do not have corresponding rows for the id: {" "main_table_id} in {main_table} for {db}".format( main_table=table_name, - related_table=related_entity_config['related_table'], - main_table_id=related_entity_config['main_table_id'], - related_table_id=related_entity_config['related_table_id'], - db=self.database_section) + related_table=related_entity_config["related_table"], + main_table_id=related_entity_config["main_table_id"], + related_table_id=related_entity_config[ + "related_table_id" + ], + db=self.database_section, + ) ) def load_main_floating_entity_count(self, table_name, table_config): - if table_name not in self.exclusion_list and 'related_entities' in table_config: - for related_entity_config in table_config['related_entities']: + if table_name not in self.exclusion_list and "related_entities" in table_config: + for related_entity_config in table_config["related_entities"]: ###### CHECKING IF THE RELATED TABLE HAS DATA - exists_query = f"""SHOW TABLES LIKE '{related_entity_config["related_table"]}'; """ - exists_table_count = self.query_runner(exists_query, single_value_return=False, where_vi=False) + exists_query = ( + f"""SHOW TABLES LIKE '{related_entity_config["related_table"]}'; """ + ) + exists_table_count = self.query_runner( + exists_query, single_value_return=False, where_vi=False + ) if not exists_table_count: continue else: ###### DYNAMICALLY PICKING THE LASTEST COLUMN FOR CHECKING FLOATING ENTITY COUNT year_columns = [] - if (table_name == 'persistent_assignee_disambig' and related_entity_config[ - 'related_table'] == 'assignee') or ( - table_name == 'persistent_inventor_disambig' and related_entity_config[ - 'related_table'] == 'inventor'): - columns = table_config['fields'].keys() + if ( + table_name == "persistent_assignee_disambig" + and related_entity_config["related_table"] == "assignee" + ) or ( + table_name == "persistent_inventor_disambig" + and related_entity_config["related_table"] == "inventor" + ): + columns = table_config["fields"].keys() for i in columns: words = i.split("_") for w in words: - if w[0] == '2': + if w[0] == "2": year_columns.append(w) last_year = max(year_columns) for k in columns: @@ -442,9 +539,13 @@ def load_main_floating_entity_count(self, table_name, table_config): winner_column = k related_entity_config["main_table_id"] = winner_column additional_where = "" - if 'custom_float_condition' in table_config and table_config[ - 'custom_float_condition'] is not None: - additional_where = "and " + table_config['custom_float_condition'] + if ( + "custom_float_condition" in table_config + and table_config["custom_float_condition"] is not None + ): + additional_where = ( + "and " + table_config["custom_float_condition"] + ) float_count_query = """ SELECT count(1) as count from {main_table} main @@ -455,19 +556,24 @@ def load_main_floating_entity_count(self, table_name, table_config): related_table=related_entity_config["related_table"], additional_where=additional_where, related_table_id=related_entity_config["related_table_id"], - main_table_id=related_entity_config["main_table_id"]) - related_table_count = self.query_runner(float_count_query, single_value_return=True, where_vi=False) - self.qa_data['DataMonitor_floatingentitycount'].append({ - "database_type": self.database_type, - 'update_version': self.version, - 'main_table': table_name, - 'related_table': related_entity_config["related_table"], - 'floating_count': related_table_count, - 'quarter': self.quarter - }) + main_table_id=related_entity_config["main_table_id"], + ) + related_table_count = self.query_runner( + float_count_query, single_value_return=True, where_vi=False + ) + self.qa_data["DataMonitor_floatingentitycount"].append( + { + "database_type": self.database_type, + "update_version": self.version, + "main_table": table_name, + "related_table": related_entity_config["related_table"], + "floating_count": related_table_count, + "quarter": self.quarter, + } + ) def load_entity_category_counts(self, table_name): - if table_name not in self.exclusion_list and self.category != '': + if table_name not in self.exclusion_list and self.category != "": if table_name == self.central_entity: count_query = f""" select {self.category}, count(1) @@ -480,109 +586,127 @@ def load_entity_category_counts(self, table_name): on related.{self.f_key} = main.{self.p_key} group by 1 """ - count_value = self.query_runner(count_query, single_value_return=False, where_vi=False) + count_value = self.query_runner( + count_query, single_value_return=False, where_vi=False + ) for count_row in count_value: - self.qa_data['DataMonitor_prefixedentitycount'].append( + self.qa_data["DataMonitor_prefixedentitycount"].append( { "database_type": self.database_type, - 'update_version': self.version, - 'patent_type': count_row[0], - 'table_name': table_name, - 'patent_count': count_row[1], - 'quarter': self.quarter - }) + "update_version": self.version, + "patent_type": count_row[0], + "table_name": table_name, + "patent_count": count_row[1], + "quarter": self.quarter, + } + ) def load_counts_by_location(self, table, field): row_query = "select count(1) from {tbl}".format(tbl=table) - if table == 'patent': - location_query = \ -f""" + if table == "patent": + location_query = f""" SELECT t.`{field}`, count(*) from {table} t join patent.country_codes cc on t.country = cc.`alpha-2` group by t.`{field}`""" else: - location_query = \ -f""" + location_query = f""" SELECT t.`{field}`, count(*) from {table} t join patent.country_codes cc on t.country = cc.`alpha-2` group by t.`{field}`""" - row_count = self.query_runner(row_query, single_value_return=True, where_vi=False) - count_value = self.query_runner(location_query, single_value_return=False, where_vi=False) + row_count = self.query_runner( + row_query, single_value_return=True, where_vi=False + ) + count_value = self.query_runner( + location_query, single_value_return=False, where_vi=False + ) for count_row in count_value: - self.qa_data['DataMonitor_locationcount'].append( + self.qa_data["DataMonitor_locationcount"].append( { "database_type": self.database_type, - 'update_version': self.version, - 'table_name': table, - 'table_row_count': row_count, - 'patent_id_count': count_row[1], - 'location': count_row[0], - 'quarter': self.quarter - }) + "update_version": self.version, + "table_name": table, + "table_row_count": row_count, + "patent_id_count": count_row[1], + "location": count_row[0], + "quarter": self.quarter, + } + ) def save_qa_data(self): qa_engine = create_engine(self.qa_connection_string) for qa_table in self.qa_data: qa_table_data = self.qa_data[qa_table] - if len(qa_table_data) == 0: + if len(qa_table_data) == 0: continue table_frame = pd.DataFrame(qa_table_data) - if qa_table == 'DataMonitor_topnentities': + if qa_table == "DataMonitor_topnentities": entity_set = f"""('{"', '".join(table_frame.entity_name.unique())}')""" - table_col = 'entity_name' + table_col = "entity_name" addl_condition = f"AND `{table_col}` IN {entity_set}" print_condition = f"for {entity_set} " - elif qa_table in ['DataMonitor_govtinterestsampler']: # table-specific QA tables that just identify records by update_version and db_type + elif qa_table in [ + "DataMonitor_govtinterestsampler" + ]: # table-specific QA tables that just identify records by update_version and db_type addl_condition = "" print_condition = "" - elif 'table_name' in table_frame.columns: + elif "table_name" in table_frame.columns: table_set = f"""('{"', '".join(table_frame.table_name.unique())}')""" table_col = "table_name" addl_condition = f"AND `{table_col}` IN {table_set}" print_condition = f"for {table_set} " - elif 'main_table' in table_frame.columns: # for floating entity table + elif "main_table" in table_frame.columns: # for floating entity table table_set = f"""('{"', '".join(table_frame.main_table.unique())}')""" table_col = "main_table" addl_condition = f"AND `{table_col}` IN {table_set}" print_condition = f"for {table_set} " else: - raise NotImplementedError(f"specification of existing rows to remove not implemented for {qa_table}.\ncolumns available: `{'`,`'.join(table_frame.columns)}`") + raise NotImplementedError( + f"specification of existing rows to remove not implemented for {qa_table}.\ncolumns available: `{'`,`'.join(table_frame.columns)}`" + ) try: - logger.info(f'removing prior {qa_table} {self.database_type} records {print_condition}on {self.version}') + logger.info( + f"removing prior {qa_table} {self.database_type} records {print_condition}on {self.version}" + ) clean_prior = f"DELETE FROM {qa_table} WHERE `update_version` = '{self.version}' AND `database_type` = '{self.database_type}' {addl_condition}" logger.info(clean_prior) qa_engine.execute(clean_prior) - logger.info(f'inserting new {qa_table} records for {self.version} and {self.database_type}') - table_frame.to_sql(name=qa_table, if_exists='append', con=qa_engine, index=False) + logger.info( + f"inserting new {qa_table} records for {self.version} and {self.database_type}" + ) + table_frame.to_sql( + name=qa_table, if_exists="append", con=qa_engine, index=False + ) except SQLAlchemyError as e: table_frame.to_csv("errored_qa_data" + qa_table, index=False) raise e def load_text_length(self, table_name, field_name): - text_length_query = \ -f"SELECT max(char_length(`{field_name}`)) " \ -f"from `{table_name}`;" + text_length_query = ( + f"SELECT max(char_length(`{field_name}`)) " f"from `{table_name}`;" + ) text_length = self.query_runner(text_length_query, single_value_return=True) - self.qa_data['DataMonitor_maxtextlength'].append({ - "database_type": self.database_type, - 'update_version': self.version, - 'table_name': table_name, - 'column_name': field_name, - 'max_text_length': text_length, - 'quarter': self.quarter - }) + self.qa_data["DataMonitor_maxtextlength"].append( + { + "database_type": self.database_type, + "update_version": self.version, + "table_name": table_name, + "column_name": field_name, + "max_text_length": text_length, + "quarter": self.quarter, + } + ) def test_patent_abstract_null(self, table, where_vi=False): - if self.central_entity == 'patent': + if self.central_entity == "patent": count_query = f""" SELECT count(*) as null_abstract_count from {self.central_entity} where abstract is null and type!='design' and type!='reissue' and id not in ('4820515', '4885173', '6095757', '6363330', '6571026', '6601394', '6602488', '6602501', '6602630', '6602899', '6603179', '6615064', '6744569', 'H002199', 'H002200', 'H002203', 'H002204', 'H002217', 'H002235') """ - elif self.central_entity == 'publication': + elif self.central_entity == "publication": count_query = f""" SELECT count(*) as null_abstract_count from {self.central_entity} p @@ -591,7 +715,8 @@ def test_patent_abstract_null(self, table, where_vi=False): count_value = self.query_runner(count_query, single_value_return=True) if count_value != 0: raise Exception( - f"NULLs (Non-design patents) encountered in table found:{self.database_section}.{table} column abstract. Count: {count_value}") + f"NULLs (Non-design patents) encountered in table found:{self.database_section}.{table} column abstract. Count: {count_value}" + ) def runStandardTests(self): self.init_qa_dict() @@ -601,60 +726,99 @@ def runStandardTests(self): logger.info(" -------------------------------------------------- ") logger.info(f"BEGINNING TESTS FOR {self.database_section}.{table}") logger.info(" -------------------------------------------------- ") - self.test_blank_count(table, self.table_config[table], where_vi=self.where_vi) - self.test_related_floating_entities(table, table_config=self.table_config[table], where_vi=self.where_vi) + self.test_blank_count( + table, self.table_config[table], where_vi=self.where_vi + ) + self.test_related_floating_entities( + table, table_config=self.table_config[table], where_vi=self.where_vi + ) self.load_nulls(table, self.table_config[table], where_vi=self.where_vi) self.test_null_version_indicator(table) self.load_table_row_count(table, where_vi=self.where_vi) self.load_main_floating_entity_count(table, self.table_config[table]) self.check_for_indexes(table) - #if table[:2] <= 'ot': # Use this to skip certain tables. Comment out when not in use. - logger.info(f"==============================================================================") + # if table[:2] <= 'ot': # Use this to skip certain tables. Comment out when not in use. + logger.info( + f"==============================================================================" + ) logger.info(f"BEGINNING TESTS FOR TABLE: {self.database_section}.{table} %") - logger.info(f"==============================================================================") - if self.class_called != "ReportingDBTester" and "PostProcessingQC" not in self.class_called: + logger.info( + f"==============================================================================" + ) + if ( + self.class_called != "ReportingDBTester" + and "PostProcessingQC" not in self.class_called + ): self.test_null_version_indicator(table) - if table == 'rawassignee': + if table == "rawassignee": self.test_rawassignee_org(table, where_vi=False) self.test_blank_count(table, self.table_config[table], where_vi=False) self.load_nulls(table, self.table_config[table], where_vi=False) - vi_cutoff_classes = ['DisambiguationTester', 'LawyerPostProcessingQC'] + vi_cutoff_classes = ["DisambiguationTester", "LawyerPostProcessingQC"] if "PostProcessingQC" not in self.class_called: - self.test_related_floating_entities(table_name=table, table_config=self.table_config[table], - where_vi=(True if self.class_called in vi_cutoff_classes else False), - vi_comparison=('<=' if self.class_called in vi_cutoff_classes else '=')) + self.test_related_floating_entities( + table_name=table, + table_config=self.table_config[table], + where_vi=( + True if self.class_called in vi_cutoff_classes else False + ), + vi_comparison=( + "<=" if self.class_called in vi_cutoff_classes else "=" + ), + ) self.load_main_floating_entity_count(table, self.table_config[table]) self.load_entity_category_counts(table) - if table == 'rawassignee': + if table == "rawassignee": self.test_rawassignee_org(table, where_vi=self.where_vi) if table == self.central_entity: self.test_patent_abstract_null(table) for field in self.table_config[table]["fields"]: - logger.info("==============================================================================") + logger.info( + "==============================================================================" + ) logger.info(f"\tBEGINNING TESTS FOR COLUMN: {table}.{field}") - logger.info("==============================================================================") - if self.table_config[table]["fields"][field]["data_type"] == 'date': + logger.info( + "==============================================================================" + ) + if self.table_config[table]["fields"][field]["data_type"] == "date": self.test_zero_dates(table, field, where_vi=self.where_vi) if self.table_config[table]["fields"][field]["category"]: self.load_category_counts(table, field) - if self.table_config[table]["fields"][field]['data_type'] in ['mediumtext', 'longtext', 'text']: + if self.table_config[table]["fields"][field]["data_type"] in [ + "mediumtext", + "longtext", + "text", + ]: self.load_text_length(table, field) - if self.table_config[table]["fields"][field]['data_type'] in ['mediumtext', 'longtext', 'text', 'varchar']: + if self.table_config[table]["fields"][field]["data_type"] in [ + "mediumtext", + "longtext", + "text", + "varchar", + ]: self.test_newlines(table, field, where_vi=self.where_vi) if self.table_config[table]["fields"][field]["location_field"]: self.load_counts_by_location(table, field) - if self.table_config[table]["fields"][field]['data_type'] == 'varchar' and 'id' not in field: + if ( + self.table_config[table]["fields"][field]["data_type"] == "varchar" + and "id" not in field + ): self.test_white_space(table, field) self.test_null_byte(table, field, where_vi=self.where_vi) logger.info(f"FINISHED WITH TABLE: {table}") counter += 1 - logger.info("==============================================================================") - logger.info(f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}") - logger.info("==============================================================================") + logger.info( + "==============================================================================" + ) + logger.info( + f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}" + ) + logger.info( + "==============================================================================" + ) self.save_qa_data() self.init_qa_dict() - def runDisambiguationTests(self): counter = 0 total_tables = len(self.table_config.keys()) @@ -669,9 +833,15 @@ def runDisambiguationTests(self): self.init_qa_dict() logger.info(f"FINISHED WITH TABLE: {table}") counter += 1 - logger.info("==============================================================================") - logger.info(f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}") - logger.info("==============================================================================") + logger.info( + "==============================================================================" + ) + logger.info( + f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}" + ) + logger.info( + "==============================================================================" + ) def runReportingTests(self): counter = 0 @@ -685,17 +855,28 @@ def runReportingTests(self): self.init_qa_dict() logger.info(f"FINISHED WITH TABLE: {table}") counter += 1 - logger.info("==============================================================================") - logger.info(f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}") - logger.info("==============================================================================") - -if __name__ == '__main__': + logger.info( + "==============================================================================" + ) + logger.info( + f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}" + ) + logger.info( + "==============================================================================" + ) + + +if __name__ == "__main__": # config = get_config() - config = get_current_config('granted_patent', **{ - "execution_date": datetime.date(2023, 10, 31) - }) + config = get_current_config( + "granted_patent", **{"execution_date": datetime.date(2023, 10, 31)} + ) # fill with correct run_id run_id = "backfill__2020-12-29T00:00:00+00:00" - pt = DatabaseTester(config, 'PatentsView_20231231', datetime.date(2023, 9, 30), datetime.date(2023, 12, 31)) + pt = DatabaseTester( + config, + "PatentsView_20231231", + datetime.date(2023, 9, 30), + datetime.date(2023, 12, 31), + ) pt.runTests() - diff --git a/QA/production/ProdDBTester.py b/QA/production/ProdDBTester.py index 63295b4a..7b177ab8 100644 --- a/QA/production/ProdDBTester.py +++ b/QA/production/ProdDBTester.py @@ -1,4 +1,7 @@ -from lib.download_check_delete_databases import query_for_all_tables_in_db, get_count_for_all_tables +from lib.download_check_delete_databases import ( + query_for_all_tables_in_db, + get_count_for_all_tables, +) from lib.configuration import get_current_config, get_unique_connection_string import datetime from QA.DatabaseTester import DatabaseTester @@ -11,40 +14,57 @@ class ProdDBTester(DatabaseTester): def __init__(self, config): - end_date = datetime.datetime.strptime(config['DATES']['END_DATE'], '%Y%m%d') - database_name = config['PATENTSVIEW_DATABASES']["REPORTING_DATABASE"] - super().__init__(config, database_name, datetime.date(year=1976, month=1, day=1),end_date) - self.connection = pymysql.connect(host=config['PROD_DATABASE_SETUP']['HOST'], - user=config['PROD_DATABASE_SETUP']['USERNAME'], - password=config['PROD_DATABASE_SETUP']['PASSWORD'], - db=database_name, - charset='utf8mb4', - cursorclass=pymysql.cursors.SSCursor, defer_connect=True) + end_date = datetime.datetime.strptime(config["DATES"]["END_DATE"], "%Y%m%d") + database_name = config["PATENTSVIEW_DATABASES"]["REPORTING_DATABASE"] + super().__init__( + config, database_name, datetime.date(year=1976, month=1, day=1), end_date + ) + + self.connection = pymysql.connect( + host=config["PROD_DATABASE_SETUP"]["HOST"], + user=config["PROD_DATABASE_SETUP"]["USERNAME"], + password=config["PROD_DATABASE_SETUP"]["PASSWORD"], + db=database_name, + charset="utf8mb4", + cursorclass=pymysql.cursors.SSCursor, + defer_connect=True, + ) self.database_type = "PROD_" + "PatentsView" def run_prod_db_tests(self): counter = 0 total_tables = len(self.table_config.keys()) self.init_qa_dict() - for table in self.table_config: + for table in ["cpc_current_group_application_year"]: # self.table_config: self.load_table_row_count(table, where_vi=False) self.check_for_indexes(table) self.save_qa_data() self.init_qa_dict() logger.info(f"FINISHED WITH TABLE: {table}") counter += 1 - logger.info(f"==============================================================================") - logger.info(f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}") - logger.info(f"==============================================================================") + logger.info( + f"==============================================================================" + ) + logger.info( + f"Currently Done With {counter} of {total_tables} | {counter/total_tables:.2%}" + ) + logger.info( + f"==============================================================================" + ) + def run_prod_db_qa(**kwargs): - config = get_current_config(type='granted_patent', schedule="quarterly", **kwargs) + config = get_current_config(type="granted_patent", schedule="quarterly", **kwargs) qc = ProdDBTester(config) qc.run_prod_db_tests() -if __name__ == '__main__': +if __name__ == "__main__": # check_reporting_db_row_count() - config = get_current_config('granted_patent', schedule='quarterly', **{"execution_date": datetime.date(2023, 4, 1)}) + config = get_current_config( + "granted_patent", + schedule="quarterly", + **{"execution_date": datetime.date(2025, 9, 1)}, + ) qc = ProdDBTester(config) qc.run_prod_db_tests() diff --git a/commands b/commands new file mode 100644 index 00000000..6567e587 --- /dev/null +++ b/commands @@ -0,0 +1,6 @@ +cp 20211230/patent/download/g_inventor_disambiguated.tsv /pv_export_volume/output/20211230_g_inventor_disambiguated.tsv + + +awk -F'\t' 'NR==1 || $6=="" || $6=="\"\""' 20220929_g_inventor_disambiguated.tsv > 20220929_output.tsv + +wc -l 20220929_output.tsv \ No newline at end of file diff --git a/gender_it b/gender_it index e8af867f..cc0f76d0 160000 --- a/gender_it +++ b/gender_it @@ -1 +1 @@ -Subproject commit e8af867f17354c28ae23de213db29213e49a8878 +Subproject commit cc0f76d008636a94d4f36601f3162df64f805015 diff --git a/lib/utilities.py b/lib/utilities.py index 2d08a6c4..28ad1ca0 100644 --- a/lib/utilities.py +++ b/lib/utilities.py @@ -32,93 +32,128 @@ def with_keys(d, keys): def class_db_specific_config(self, table_config, class_called): keep_tables = [] for i in table_config.keys(): - if class_called == 'DatabaseTester': - if "UploadTest" in table_config[i]['TestScripts']: + if class_called == "DatabaseTester": + if "UploadTest" in table_config[i]["TestScripts"]: keep_tables.append(i) - elif class_called == 'ElasticDBTester': + elif class_called == "ElasticDBTester": keep_tables.append(i) else: - if class_called in table_config[i]['TestScripts']: + if class_called in table_config[i]["TestScripts"]: keep_tables.append(i) self.table_config = with_keys(table_config, keep_tables) - if class_called[:4] == 'Text': + if class_called[:4] == "Text": pass else: if "PostProcessing" in str(self): tables_list = list(self.table_config.keys()) quarter_date = self.end_date.strftime("%Y%m%d") for table in tables_list: - if table in ['assignee', "assignee_disambiguation_mapping", 'location', "location_disambiguation_mapping", 'inventor', "inventor_disambiguation_mapping", "inventor_gender", "rawinventor_gender", "rawinventor_gender_agg"]: - self.table_config[f'{table}_{quarter_date}'] = self.table_config.pop(f'{table}') + if table in [ + "assignee", + "assignee_disambiguation_mapping", + "location", + "location_disambiguation_mapping", + "inventor", + "inventor_disambiguation_mapping", + "inventor_gender", + "rawinventor_gender", + "rawinventor_gender_agg", + ]: + self.table_config[f"{table}_{quarter_date}"] = ( + self.table_config.pop(f"{table}") + ) print(f"The following list of tables are run for {class_called}:") print(self.table_config.keys()) -def load_table_config(config, db='patent'): +def load_table_config(config, db="patent"): print(db) print(config["PATENTSVIEW_DATABASES"]["REPORTING_DATABASE"]) root = config["FOLDERS"]["project_root"] resources = config["FOLDERS"]["resources_folder"] - if db == 'patent': + if db == "patent": config_file = f"{root}/{resources}/{config['FILES']['table_config_granted']}" - elif db == 'pgpubs': + elif db == "pgpubs": config_file = f"{root}/{resources}/{config['FILES']['table_config_pgpubs']}" - elif db == 'patent_text' or db[:6] == 'upload': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_text_granted"]}' - elif db == 'pgpubs_text' or db[:6] == 'pgpubs': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_text_pgpubs"]}' + elif db == "patent_text" or db[:6] == "upload": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_text_granted"]}' + ) + elif db == "pgpubs_text" or db[:6] == "pgpubs": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_text_pgpubs"]}' + ) elif db == config["PATENTSVIEW_DATABASES"]["REPORTING_DATABASE"]: - config_file = f'{root}/{resources}/{config["FILES"]["table_config_reporting_db"]}' + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_reporting_db"]}' + ) elif db == "gender_attribution": - config_file = f'{root}/{resources}/{config["FILES"]["table_config_inventor_gender"]}' - elif db == 'bulk_exp_granted': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_bulk_exp_granted"]}' - elif db == 'bulk_exp_pgpubs': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_bulk_exp_pgpubs"]}' - elif db == 'elasticsearch_patent': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_elasticsearch_patent"]}' - elif db == 'elasticsearch_pgpub': - config_file = f'{root}/{resources}/{config["FILES"]["table_config_elasticsearch_pgpub"]}' + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_inventor_gender"]}' + ) + elif db == "bulk_exp_granted": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_bulk_exp_granted"]}' + ) + elif db == "bulk_exp_pgpubs": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_bulk_exp_pgpubs"]}' + ) + elif db == "elasticsearch_patent": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_elasticsearch_patent"]}' + ) + elif db == "elasticsearch_pgpub": + config_file = ( + f'{root}/{resources}/{config["FILES"]["table_config_elasticsearch_pgpub"]}' + ) print(f"reading table config from {config_file}") + config_file = config_file.replace("/project/", "") with open(config_file) as file: table_config = json.load(file) return table_config def get_relevant_attributes(self, class_called, database_section, config): - print(f"assigning class variables based on class {class_called} and database section {database_section}.") - if (class_called == "AssigneePostProcessingQC") or (class_called == "AssigneePostProcessingQCPhase2") : + print( + f"assigning class variables based on class {class_called} and database section {database_section}." + ) + if (class_called == "AssigneePostProcessingQC") or ( + class_called == "AssigneePostProcessingQCPhase2" + ): self.database_section = database_section - if self.database_section == 'patent': - self.table_config = load_table_config(config, db='patent') + if self.database_section == "patent": + self.table_config = load_table_config(config, db="patent") else: - self.table_config = load_table_config(config, db='pgpubs') - self.entity_table = 'rawassignee' - self.entity_id = 'uuid' - self.disambiguated_id = 'assignee_id' - self.disambiguated_table = 'assignee_'+ config['DATES']["END_DATE"] - self.disambiguated_data_fields = ['name_last', 'name_first', 'organization'] - self.aggregator = 'main.organization' + self.table_config = load_table_config(config, db="pgpubs") + self.entity_table = "rawassignee" + self.entity_id = "uuid" + self.disambiguated_id = "assignee_id" + self.disambiguated_table = "assignee_" + config["DATES"]["END_DATE"] + self.disambiguated_data_fields = ["name_last", "name_first", "organization"] + self.aggregator = "main.organization" self.category = "" self.central_entity = "" self.p_key = "" self.f_key = "" self.exclusion_list = [] - elif (class_called == "InventorGenderPostProcessingQC"): - self.table_config = load_table_config(config, db='gender_attribution') + elif class_called == "InventorGenderPostProcessingQC": + self.table_config = load_table_config(config, db="gender_attribution") - elif (class_called == "InventorPostProcessingQC") or (class_called == "InventorPostProcessingQCPhase2") : + elif (class_called == "InventorPostProcessingQC") or ( + class_called == "InventorPostProcessingQCPhase2" + ): self.database_section = database_section - if self.database_section == 'patent': - self.table_config = load_table_config(config, db='patent') + if self.database_section == "patent": + self.table_config = load_table_config(config, db="patent") else: - self.table_config = load_table_config(config, db='pgpubs') - self.entity_table = 'rawinventor' - self.entity_id = 'uuid' - self.disambiguated_id = 'inventor_id' - self.disambiguated_table = 'inventor' - self.disambiguated_data_fields = ['name_last', 'name_first', 'organization'] + self.table_config = load_table_config(config, db="pgpubs") + self.entity_table = "rawinventor" + self.entity_id = "uuid" + self.disambiguated_id = "inventor_id" + self.disambiguated_table = "inventor" + self.disambiguated_data_fields = ["name_last", "name_first", "organization"] # self.patent_exclusion_list.extend(['assignee', 'persistent_assignee_disambig']) # self.add_persistent_table_to_config(database_section) self.category = "" @@ -131,14 +166,24 @@ def get_relevant_attributes(self, class_called, database_section, config): elif class_called == "LawyerPostProcessingQC": self.database_section = database_section - self.table_config = load_table_config(config, db='patent') - self.entity_table = 'rawlawyer' - self.entity_id = 'uuid' - self.disambiguated_id = 'lawyer_id' - self.disambiguated_table = 'lawyer' - self.disambiguated_data_fields = ['name_last', 'name_first', "organization", "country"] + self.table_config = load_table_config(config, db="patent") + self.entity_table = "rawlawyer" + self.entity_id = "uuid" + self.disambiguated_id = "lawyer_id" + self.disambiguated_table = "lawyer" + self.disambiguated_data_fields = [ + "name_last", + "name_first", + "organization", + "country", + ] self.aggregator = 'case when main.organization is null then concat(main.name_last,", ",main.name_first) else main.organization end' - self.disambiguated_data_fields = ['name_last', 'name_first', "organization", "country"] + self.disambiguated_data_fields = [ + "name_last", + "name_first", + "organization", + "country", + ] self.category = "" self.central_entity = "" self.p_key = "" @@ -146,13 +191,13 @@ def get_relevant_attributes(self, class_called, database_section, config): self.exclusion_list = [] elif class_called == "LocationPostProcessingQC": - self.table_config = load_table_config(config, db='patent') - self.disambiguated_data_fields = ['city', 'state', 'country'] + self.table_config = load_table_config(config, db="patent") + self.disambiguated_data_fields = ["city", "state", "country"] self.aggregator = "concat(main.city, ', ', main.state, ',', main.country)" - self.entity_table = 'rawlocation' - self.entity_id = 'id' - self.disambiguated_id = 'location_id' - self.disambiguated_table = 'location' + self.entity_table = "rawlocation" + self.entity_id = "id" + self.disambiguated_id = "location_id" + self.disambiguated_table = "location" self.category = "" self.central_entity = "" self.p_key = "" @@ -160,7 +205,7 @@ def get_relevant_attributes(self, class_called, database_section, config): self.exclusion_list = [] elif class_called == "CPCTest": - self.table_config = load_table_config(config, db='patent') + self.table_config = load_table_config(config, db="patent") self.category = "" self.central_entity = "" self.p_key = "" @@ -168,7 +213,9 @@ def get_relevant_attributes(self, class_called, database_section, config): self.exclusion_list = [] elif class_called == "ReportingDBTester" or class_called == "ProdDBTester": - self.table_config = load_table_config(config, db = config["PATENTSVIEW_DATABASES"]["REPORTING_DATABASE"]) #db should be parameterized later, not hard-coded + self.table_config = load_table_config( + config, db=config["PATENTSVIEW_DATABASES"]["REPORTING_DATABASE"] + ) # db should be parameterized later, not hard-coded self.category = "" self.central_entity = "" self.p_key = "" @@ -176,7 +223,9 @@ def get_relevant_attributes(self, class_called, database_section, config): self.exclusion_list = [] elif class_called == "ElasticDBTester": - self.table_config = load_table_config(config, db = config['PATENTSVIEW_DATABASES']["ELASTICSEARCH_DB_TYPE"]) #db should be parameterized later, not hard-coded + self.table_config = load_table_config( + config, db=config["PATENTSVIEW_DATABASES"]["ELASTICSEARCH_DB_TYPE"] + ) # db should be parameterized later, not hard-coded self.category = "" self.central_entity = "" self.p_key = "" @@ -184,94 +233,106 @@ def get_relevant_attributes(self, class_called, database_section, config): self.exclusion_list = [] elif database_section == "patent" or ( - database_section[:6] == 'upload' and class_called[:6] in ('Upload','GovtIn', 'MergeT')): - self.exclusion_list = ['assignee', - 'cpc_group', - 'cpc_subgroup', - 'cpc_subsection', - 'government_organization', - 'inventor', - 'lawyer', - 'location', - 'location_assignee', - 'location_inventor', - 'location_nber_subcategory', - 'mainclass', - 'nber_category', - 'nber_subcategory', - 'rawlocation', - 'subclass', - 'usapplicationcitation', - 'uspatentcitation', - 'wipo_field'] - self.central_entity = 'patent' - self.category = 'type' - self.table_config = load_table_config(config, db='patent') + database_section[:6] == "upload" + and class_called[:6] in ("Upload", "GovtIn", "MergeT") + ): + self.exclusion_list = [ + "assignee", + "cpc_group", + "cpc_subgroup", + "cpc_subsection", + "government_organization", + "inventor", + "lawyer", + "location", + "location_assignee", + "location_inventor", + "location_nber_subcategory", + "mainclass", + "nber_category", + "nber_subcategory", + "rawlocation", + "subclass", + "usapplicationcitation", + "uspatentcitation", + "wipo_field", + ] + self.central_entity = "patent" + self.category = "type" + self.table_config = load_table_config(config, db="patent") self.p_key = "id" self.f_key = "patent_id" elif (database_section == "pregrant_publications") or ( - database_section[:6] == 'pgpubs' and class_called[:6] in ('Upload','GovtIn', 'MergeT')): + database_section[:6] == "pgpubs" + and class_called[:6] in ("Upload", "GovtIn", "MergeT") + ): # TABLES WITHOUT DOCUMENT_NUMBER ARE EXCLUDED FROM THE TABLE CONFIG self.central_entity = "publication" - self.category = 'kind' - self.exclusion_list = ['assignee', - 'clean_rawlocation', - 'government_organization', - 'inventor', - 'location_assignee', - 'location_inventor', - 'rawlocation', - 'rawlocation_geos_missed', - 'rawlocation_lat_lon'] - self.table_config = load_table_config(config, db='pgpubs') + self.category = "kind" + self.exclusion_list = [ + "assignee", + "clean_rawlocation", + "government_organization", + "inventor", + "location_assignee", + "location_inventor", + "rawlocation", + "rawlocation_geos_missed", + "rawlocation_lat_lon", + ] + self.table_config = load_table_config(config, db="pgpubs") self.p_key = "document_number" self.f_key = "document_number" - elif class_called[:4] == 'Text': + elif class_called[:4] == "Text": self.category = "" self.central_entity = "" self.p_key = "" self.f_key = "" self.exclusion_list = [] - if database_section[:6] == 'upload' or database_section == 'patent_text': + if database_section[:6] == "upload" or database_section == "patent_text": self.table_config = load_table_config(config, db=database_section) - elif database_section[:6] == 'pgpubs' or database_section == 'pgpubs_text': + elif database_section[:6] == "pgpubs" or database_section == "pgpubs_text": self.table_config = load_table_config(config, db=database_section) else: raise NotImplementedError - elif class_called[:19] == 'BulkDownloadsTester': - if 'granted' in database_section: - self.table_config = load_table_config(config, db='bulk_exp_granted') + elif class_called[:19] == "BulkDownloadsTester": + if "granted" in database_section: + self.table_config = load_table_config(config, db="bulk_exp_granted") self.central_entity = "patent" self.p_key = "patent_id" self.f_key = "patent_id" else: - self.table_config = load_table_config(config, db='bulk_exp_pgpubs') + self.table_config = load_table_config(config, db="bulk_exp_pgpubs") self.central_entity = "publication" self.p_key = "pgpub_id" self.f_key = "pgpub_id" - + self.category = "" self.exclusion_list = [] else: raise NotImplementedError + def update_to_granular_version_indicator(table, db): from lib.configuration import get_current_config, get_connection_string - config = get_current_config(type=db, **{"execution_date": datetime.date(2000, 1, 1)}) - cstr = get_connection_string(config, 'PROD_DB') + + config = get_current_config( + type=db, **{"execution_date": datetime.date(2000, 1, 1)} + ) + cstr = get_connection_string(config, "PROD_DB") engine = create_engine(cstr) - if db == 'granted_patent': - id = 'id' - fk = 'patent_id' - fact_table = 'patent' + if db == "granted_patent": + id = "id" + fk = "patent_id" + fact_table = "patent" else: - id = 'document_number' - fk = 'document_number' - fact_table = 'publications' + id = "document_number" + fk = "document_number" + fact_table = "publications" query = f""" update {table} update_table inner join {fact_table} p on update_table.{fk}=p.{id} @@ -283,6 +344,7 @@ def update_to_granular_version_indicator(table, db): query_end_time = time() print("This query took:", query_end_time - query_start_time, "seconds") + # Moved from AssigneePostProcessing - unused for now def add_persistent_table_to_config(self, database_section): columns_query = f""" @@ -300,11 +362,11 @@ def add_persistent_table_to_config(self, database_section): with self.connection.cursor() as crsr: crsr.execute(columns_query) column_data = pd.DataFrame.from_records( - crsr.fetchall(), - columns=['column', 'data_type', 'null_allowed', 'category']) + crsr.fetchall(), columns=["column", "data_type", "null_allowed", "category"] + ) table_config = { - 'persistent_assignee_disambig': { - 'fields': column_data.set_index('column').to_dict(orient='index') + "persistent_assignee_disambig": { + "fields": column_data.set_index("column").to_dict(orient="index") } } self.table_config.update(table_config) @@ -312,18 +374,19 @@ def add_persistent_table_to_config(self, database_section): def trim_whitespace(config): from lib.configuration import get_connection_string - cstr = get_connection_string(config, 'TEMP_UPLOAD_DB') - db_type = config['PATENTSVIEW_DATABASES']["TEMP_UPLOAD_DB"][:6] + + cstr = get_connection_string(config, "TEMP_UPLOAD_DB") + db_type = config["PATENTSVIEW_DATABASES"]["TEMP_UPLOAD_DB"][:6] engine = create_engine(cstr) print("REMOVING WHITESPACE WHERE IT EXISTS") - project_home = os.environ['PACKAGE_HOME'] - resources_file = "{root}/{resources}/columns_for_whitespace_trim.json".format(root=project_home, - resources=config["FOLDERS"][ - "resources_folder"]) + project_home = os.environ["PACKAGE_HOME"] + resources_file = "{root}/{resources}/columns_for_whitespace_trim.json".format( + root=project_home, resources=config["FOLDERS"]["resources_folder"] + ) cols_tables_whitespace = json.load(open(resources_file)) for table in cols_tables_whitespace.keys(): if db_type in cols_tables_whitespace[table]["TestScripts"]: - for column in cols_tables_whitespace[table]['fields']: + for column in cols_tables_whitespace[table]["fields"]: trim_whitespace_query = f""" update {table} set `{column}`= TRIM(`{column}`) @@ -335,7 +398,7 @@ def trim_whitespace(config): def xstr(s): if s is None: - return '' + return "" return str(s) @@ -348,23 +411,26 @@ def weekday_count(start_date, end_date): def id_generator(size=25, chars=string.ascii_lowercase + string.digits): - return ''.join(random.choice(chars) for _ in range(size)) + return "".join(random.choice(chars) for _ in range(size)) def download(url, filepath, api_key=None): - """ Download data from a URL with a handy progress bar """ + """Download data from a URL with a handy progress bar""" print("Downloading: {}".format(url)) headers = {"X-API-KEY": api_key} r = requests.get(url, headers=headers, stream=True) - content_length = r.headers.get('content-length') + content_length = r.headers.get("content-length") if not content_length: print("\tNo Content Length Attached. Attempting download without progress bar.") chunker = r.iter_content(chunk_size=1024) else: - chunker = progress.bar(r.iter_content(chunk_size=1024), expected_size=(int(content_length) / 1024) + 1) - with open(filepath, 'wb') as f: + chunker = progress.bar( + r.iter_content(chunk_size=1024), + expected_size=(int(content_length) / 1024) + 1, + ) + with open(filepath, "wb") as f: for chunk in chunker: if chunk: f.write(chunk) @@ -372,26 +438,29 @@ def download(url, filepath, api_key=None): def chunks(l, n): - '''Yield successive n-sized chunks from l. Useful for multi-processing''' + """Yield successive n-sized chunks from l. Useful for multi-processing""" chunk_list = [] for i in range(0, len(l), n): - chunk_list.append(l[i:i + n]) + chunk_list.append(l[i : i + n]) return chunk_list def better_title(text): title = " ".join( - [item if item not in ["Of", "The", "For", "And", "On"] else item.lower() for item in - str(text).title().split()]) - return re.sub('[' + string.punctuation + ']', '', title) + [ + item if item not in ["Of", "The", "For", "And", "On"] else item.lower() + for item in str(text).title().split() + ] + ) + return re.sub("[" + string.punctuation + "]", "", title) def write_csv(rows, outputdir, filename): - """ Write a list of lists to a csv file """ + """Write a list of lists to a csv file""" print(outputdir) print(os.path.join(outputdir, filename)) - writer = csv.writer(open(os.path.join(outputdir, filename), 'w', encoding='utf-8')) + writer = csv.writer(open(os.path.join(outputdir, filename), "w", encoding="utf-8")) writer.writerows(rows) @@ -433,40 +502,46 @@ def write_csv(rows, outputdir, filename): def mp_csv_writer(write_queue, target_file, header): - with open(target_file, 'w', newline='') as writefile: - filtered_writer = csv.writer(writefile, - delimiter=',', - quotechar='"', - quoting=csv.QUOTE_NONNUMERIC) + with open(target_file, "w", newline="") as writefile: + filtered_writer = csv.writer( + writefile, delimiter=",", quotechar='"', quoting=csv.QUOTE_NONNUMERIC + ) filtered_writer.writerow(header) while 1: message_data = write_queue.get() if len(message_data) != len(header): # "kill" is the special message to stop listening for messages - if message_data[0] == 'kill': + if message_data[0] == "kill": break else: print(message_data) - raise Exception("Header and data length don't match :{header}/{data_ln}".format(header=len(header), - data_ln=len( - message_data))) + raise Exception( + "Header and data length don't match :{header}/{data_ln}".format( + header=len(header), data_ln=len(message_data) + ) + ) filtered_writer.writerow(message_data) def log_writer(log_queue, log_prefix="uspto_parser"): - '''listens for messages on the q, writes to file. ''' - home_folder = os.environ['PACKAGE_HOME'] + """listens for messages on the q, writes to file.""" + home_folder = os.environ["PACKAGE_HOME"] logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) EXPANED_LOGFILE = datetime.datetime.now().strftime( - '{home_folder}/logs/{prefix}_expanded_log_%Y%m%d_%H%M%S.log'.format(home_folder=home_folder, - prefix=log_prefix)) + "{home_folder}/logs/{prefix}_expanded_log_%Y%m%d_%H%M%S.log".format( + home_folder=home_folder, prefix=log_prefix + ) + ) expanded_filehandler = logging.FileHandler(EXPANED_LOGFILE) expanded_filehandler.setLevel(logging.DEBUG) BASIC_LOGFILE = datetime.datetime.now().strftime( - '{home_folder}/logs/{prefix}_log_%Y%m%d_%H%M%S.log'.format(home_folder=home_folder, prefix=log_prefix)) + "{home_folder}/logs/{prefix}_log_%Y%m%d_%H%M%S.log".format( + home_folder=home_folder, prefix=log_prefix + ) + ) filehandler = logging.FileHandler(BASIC_LOGFILE) filehandler.setLevel(logging.INFO) @@ -478,7 +553,7 @@ def log_writer(log_queue, log_prefix="uspto_parser"): logger.addHandler(ch) while 1: message_data = log_queue.get() - if message_data["message"] == 'kill': + if message_data["message"] == "kill": logger.info("Kill Signal received. Exiting") break logger.log(message_data["level"], message_data["message"]) @@ -492,13 +567,13 @@ def save_zip_file(url, name, path, counter=0, log_queue=None, api_key=None): with requests.get(url, headers=headers, stream=True) as downloader: downloader.raise_for_status() zip_path = os.path.join(path, name) - with open(zip_path, 'wb') as f: + with open(zip_path, "wb") as f: for chunk in downloader.iter_content(chunk_size=8192): if chunk: f.write(chunk) # Extract and rename if revised - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: for zip_info in zip_ref.infolist(): z_nm, z_ext = os.path.splitext(name) f_nm, f_ext = os.path.splitext(zip_info.filename) @@ -506,7 +581,10 @@ def save_zip_file(url, name, path, counter=0, log_queue=None, api_key=None): tmp_dir = os.path.join(path, "tmp") os.makedirs(tmp_dir, exist_ok=True) zip_ref.extract(zip_info.filename, tmp_dir) - os.rename(os.path.join(tmp_dir, zip_info.filename), os.path.join(path, f"{z_nm}{f_ext}")) + os.rename( + os.path.join(tmp_dir, zip_info.filename), + os.path.join(path, f"{z_nm}{f_ext}"), + ) os.rmdir(tmp_dir) else: zip_ref.extract(zip_info.filename, path) @@ -517,7 +595,14 @@ def save_zip_file(url, name, path, counter=0, log_queue=None, api_key=None): print(f"{path} contains {os.listdir(path)}") -def get_files_to_download(product_id, api_key, execution_date_str = None, download_folder = None, log_queue = None, files_only=False): +def get_files_to_download( + product_id, + api_key, + execution_date_str=None, + download_folder=None, + log_queue=None, + files_only=False, +): headers = {"X-API-KEY": api_key, "accept": "application/json"} if execution_date_str is None: @@ -528,7 +613,7 @@ def get_files_to_download(product_id, api_key, execution_date_str = None, downlo "fileDataToDate": execution_date_str, } url = f"https://api.uspto.gov/api/v1/datasets/products/{product_id}" - + files_to_download = [] try: @@ -549,9 +634,7 @@ def get_files_to_download(product_id, api_key, execution_date_str = None, downlo file_url = file_info["fileDownloadURI"] if files_only: # If only URLs are needed, return them directly - files_to_download.append( - (file_url, filename) - ) + files_to_download.append((file_url, filename)) else: files_to_download.append( (file_url, filename, download_folder, idx, log_queue, api_key) @@ -568,8 +651,7 @@ def get_files_to_download(product_id, api_key, execution_date_str = None, downlo return files_to_download - -def download_xml_files(config, xml_template_setting_prefix='granted_patent'): +def download_xml_files(config, xml_template_setting_prefix="granted_patent"): from datetime import datetime product_id = config["USPTO_LINKS"]["product_identifier"] @@ -579,14 +661,14 @@ def download_xml_files(config, xml_template_setting_prefix='granted_patent'): download_folder = ( config["FOLDERS"]["granted_patent_bulk_xml_location"] if xml_template_setting_prefix == "granted_patent" - else config['FOLDERS']["pgpubs_bulk_xml_location"] + else config["FOLDERS"]["pgpubs_bulk_xml_location"] ) print(f"[DEBUG] Download folder: {download_folder}") execution_dt = config["DATES"]["END_DATE_DASH"] - print(f'this is the execution date: {execution_dt}') - print(f'this is the type of the execution date: {type(execution_dt)}') + print(f"this is the execution date: {execution_dt}") + print(f"this is the type of the execution date: {type(execution_dt)}") if isinstance(execution_dt, str): execution_dt = datetime.fromisoformat(execution_dt) @@ -606,7 +688,9 @@ def download_xml_files(config, xml_template_setting_prefix='granted_patent'): if parallelism > 1: pool = multiprocessing.Pool(parallelism) watcher = pool.apply_async(log_writer, (log_queue,)) - p_list = [pool.apply_async(save_zip_file, args=job) for job in files_to_download] + p_list = [ + pool.apply_async(save_zip_file, args=job) for job in files_to_download + ] for p in p_list: p.get() log_queue.put({"level": None, "message": "kill"}) @@ -618,11 +702,6 @@ def download_xml_files(config, xml_template_setting_prefix='granted_patent'): save_zip_file(*job) - - - - - # def download_xml_files(config, xml_template_setting_prefix='pgpubs'): # xml_template_setting = "{prefix}_bulk_xml_template".format(prefix=xml_template_setting_prefix) # xml_download_setting = "{prefix}_bulk_xml_location".format(prefix=xml_template_setting_prefix) @@ -710,16 +789,19 @@ def download_xml_files(config, xml_template_setting_prefix='granted_patent'): # pool.join() -def manage_ec2_instance(config, button='ON', identifier='xml_collector'): - instance_id = config['AWS_WORKER'][identifier] - ec2 = boto3.client('ec2', aws_access_key_id=config['AWS']['ACCESS_KEY_ID'], - aws_secret_access_key=config['AWS']['SECRET_KEY'], - region_name='us-east-1') - if button == 'ON': +def manage_ec2_instance(config, button="ON", identifier="xml_collector"): + instance_id = config["AWS_WORKER"][identifier] + ec2 = boto3.client( + "ec2", + aws_access_key_id=config["AWS"]["ACCESS_KEY_ID"], + aws_secret_access_key=config["AWS"]["SECRET_KEY"], + region_name="us-east-1", + ) + if button == "ON": response = ec2.start_instances(InstanceIds=[instance_id]) else: response = ec2.stop_instances(InstanceIds=[instance_id]) - return response['ResponseMetadata']['HTTPStatusCode'] == 200 + return response["ResponseMetadata"]["HTTPStatusCode"] == 200 def create_aws_boto3_session(): @@ -736,6 +818,7 @@ def create_aws_boto3_session(): print(f"Error creating AWS session: {e}") return None + def rds_free_space(identifier): """ Retrieve the FreeStorageSpace metric for an RDS instance. @@ -750,45 +833,46 @@ def rds_free_space(identifier): return None # Create CloudWatch client using the returned session - cloudwatch = session.client('cloudwatch', region_name='us-east-1') + cloudwatch = session.client("cloudwatch", region_name="us-east-1") from datetime import datetime, timedelta + response = cloudwatch.get_metric_data( MetricDataQueries=[ { - 'Id': 'fetching_FreeStorageSpace', - 'MetricStat': { - 'Metric': { - 'Namespace': 'AWS/RDS', - 'MetricName': 'FreeStorageSpace', - 'Dimensions': [ - { - "Name": "DBInstanceIdentifier", - "Value": identifier - } - ] + "Id": "fetching_FreeStorageSpace", + "MetricStat": { + "Metric": { + "Namespace": "AWS/RDS", + "MetricName": "FreeStorageSpace", + "Dimensions": [ + {"Name": "DBInstanceIdentifier", "Value": identifier} + ], }, - 'Period': 300, - 'Stat': 'Minimum' - } + "Period": 300, + "Stat": "Minimum", + }, } ], StartTime=(datetime.now() - timedelta(seconds=300 * 3)).timestamp(), EndTime=datetime.now().timestamp(), - ScanBy='TimestampDescending' + ScanBy="TimestampDescending", ) - return mean(response['MetricDataResults'][0]['Values']) + return mean(response["MetricDataResults"][0]["Values"]) def get_host_name(local=True): import requests from requests.exceptions import ConnectionError from airflow.utils import net + try: - host_key = 'local-hostname' + host_key = "local-hostname" if not local: - host_key = 'public-hostname' - r = requests.get("http://169.254.169.254/latest/meta-data/{hkey}".format(hkey=host_key)) + host_key = "public-hostname" + r = requests.get( + "http://169.254.169.254/latest/meta-data/{hkey}".format(hkey=host_key) + ) return r.text except ConnectionError: return net.get_host_ip_address() @@ -813,50 +897,64 @@ def archive_folder(source_folder, targets: list): print(file_name) shutil.copy(os.path.join(source_folder, file_name), targets[-1]) + def add_index_new_disambiguation_table(connection, table_name): from mysql.connector.errors import ProgrammingError + g_cursor = connection.cursor() - index_query = 'alter table {table_name} add primary key (uuid)'.format( - table_name=table_name) + index_query = "alter table {table_name} add primary key (uuid)".format( + table_name=table_name + ) print(index_query) try: g_cursor.execute(index_query) except ProgrammingError as e: from mysql.connector import errorcode + if not e.errno == errorcode.ER_MULTIPLE_PRI_KEY: raise + def link_view_to_new_disambiguation_table(connection, table_name, disambiguation_type): from mysql.connector.errors import ProgrammingError + g_cursor = connection.cursor() - index_query = 'alter table {table_name} add primary key (uuid)'.format( - table_name=table_name) + index_query = "alter table {table_name} add primary key (uuid)".format( + table_name=table_name + ) print(index_query) replace_view_query = """ CREATE OR REPLACE SQL SECURITY INVOKER VIEW {dtype}_disambiguation_mapping as SELECT uuid,{dtype}_id from {table_name} - """.format(table_name=table_name, dtype=disambiguation_type) + """.format( + table_name=table_name, dtype=disambiguation_type + ) try: g_cursor.execute(index_query) except ProgrammingError as e: from mysql.connector import errorcode + if not e.errno == errorcode.ER_MULTIPLE_PRI_KEY: raise print(replace_view_query) g_cursor.execute(replace_view_query) + def update_to_granular_version_indicator(table, db): from lib.configuration import get_current_config, get_connection_string - config = get_current_config(type=db, **{"execution_date": datetime.date(2000, 1, 1)}) - cstr = get_connection_string(config, 'PROD_DB') + + config = get_current_config( + type=db, **{"execution_date": datetime.date(2000, 1, 1)} + ) + cstr = get_connection_string(config, "PROD_DB") engine = create_engine(cstr) - if db == 'granted_patent': - id = 'id' - fk = 'patent_id' - fact_table = 'patent' + if db == "granted_patent": + id = "id" + fk = "patent_id" + fact_table = "patent" else: - id = 'document_number' - fk = 'document_number' - fact_table = 'publication' + id = "document_number" + fk = "document_number" + fact_table = "publication" query = f""" update {table} update_table inner join {fact_table} p on update_table.{fk}=p.{id} @@ -868,11 +966,13 @@ def update_to_granular_version_indicator(table, db): query_end_time = time() print("This query took:", query_end_time - query_start_time, "seconds") + def update_version_indicator(table, db, **kwargs): from lib.configuration import get_current_config, get_connection_string + config = get_current_config(type=db, schedule="quarterly", **kwargs) - ed = process_date(config['DATES']["end_date"], as_string=True) - cstr = get_connection_string(config, 'PROD_DB') + ed = process_date(config["DATES"]["end_date"], as_string=True) + cstr = get_connection_string(config, "PROD_DB") engine = create_engine(cstr) query = f""" update {table} update_table @@ -888,5 +988,8 @@ def update_version_indicator(table, db, **kwargs): if __name__ == "__main__": # update_to_granular_version_indicator('uspc_current', 'granted_patent') print("HI") - config = get_current_config("granted_patent", schedule='quarterly', **{"execution_date": datetime.date(2022, 6, 30)}) - + config = get_current_config( + "granted_patent", + schedule="quarterly", + **{"execution_date": datetime.date(2022, 6, 30)}, + ) diff --git a/lib/utils b/lib/utils index 23e24ec5..b3aa98b6 160000 --- a/lib/utils +++ b/lib/utils @@ -1 +1 @@ -Subproject commit 23e24ec5fb1310f653345cce995f39eb4288430d +Subproject commit b3aa98b60fe026f9b078ceb88b75c7747d8f732c diff --git a/mydumper/Dockerfile b/mydumper/Dockerfile index 2cbd87d4..8d9473b7 100644 --- a/mydumper/Dockerfile +++ b/mydumper/Dockerfile @@ -1,18 +1,23 @@ -# Use an official Debian image as the base -FROM debian:bullseye +FROM patentsview/airflow:0.5 +USER root -# Install mydumper and libmariadb3 using package manager -RUN apt-get update && \ - apt-get install -y libmariadb-dev && apt-get install -y libmariadb3 && apt-get install -y libmariadb-dev-compat +# Add MySQL repo key into a keyring and (re)declare the source with signed-by +RUN set -eux; \ + install -d -m 0755 /usr/share/keyrings; \ + wget -qO- https://repo.mysql.com/RPM-GPG-KEY-mysql-2023 | gpg --dearmor -o /usr/share/keyrings/mysql-archive-keyring.gpg; \ + echo "deb [arch=amd64 signed-by=/usr/share/keyrings/mysql-archive-keyring.gpg] http://repo.mysql.com/apt/debian bullseye mysql-apt-config mysql-8.0 mysql-tools" > /etc/apt/sources.list.d/mysql.list; \ + apt-get update; \ + apt-get install -y --no-install-recommends libglib2.0-0 wget; \ + rm -rf /var/lib/apt/lists/* -# Use the mysqlboy/docker-mydumper image as a base -FROM mydumper/mydumper:latest +# mydumper install +RUN set -eux; \ + cd /tmp; \ + wget -q https://github.com/maxbube/mydumper/releases/download/v0.9.3/mydumper_0.9.3-41.stretch_amd64.deb; \ + dpkg -i mydumper_0.9.3-41.stretch_amd64.deb || apt-get -f install -y; \ + rm -f mydumper_0.9.3-41.stretch_amd64.deb -# Copy mydumper configuration -COPY mydumper.cnf /etc/ +COPY mydumper.cnf /etc/mydumper.cnf +USER airflow +CMD ["mydumper", "--version"] -# Set the command to run when the container starts -#CMD ["mydumper", "-c", "/etc/mydumper.cnf"] -#CMD ["mydumper", "--version"] -#CMD ["mydumper"] -CMD ["echo", "Container started without running mydumper"] \ No newline at end of file diff --git a/mydumper/docker-compose-mydumper.yaml b/mydumper/docker-compose-mydumper.yaml index bdb88acd..438991ef 100644 --- a/mydumper/docker-compose-mydumper.yaml +++ b/mydumper/docker-compose-mydumper.yaml @@ -1,16 +1,14 @@ -version: '3' services: mydumper_service: build: context: . dockerfile: Dockerfile environment: - - S3_ACCESS_KEY_ID - - S3_SECRET_ACCESS_KEY - - S3_ENDPOINT_URL - - S3_BUCKET + - S3_ACCESS_KEY_ID + - S3_SECRET_ACCESS_KEY + - S3_ENDPOINT_URL + - S3_BUCKET volumes: - - {LOCAL_PATH}:/DatabaseBackups - - {LOCAL_PATH}:/project - command: - - mydumper \ No newline at end of file + - /PatentDataVolume/DatabaseBackups:/DatabaseBackups + - /airflow/PatentsView-DB:/project + command: mydumper --host=$MYSQL_HOST --user=$MYSQL_USER --password=$MYSQL_PASSWORD diff --git a/mydumper/mydumper.cnf b/mydumper/mydumper.cnf deleted file mode 100755 index 9d48ae64..00000000 --- a/mydumper/mydumper.cnf +++ /dev/null @@ -1,12 +0,0 @@ -[mydumper] -host = -user = -password = -port = 3306 - -[myloader] -host = -user = -password = -port = -protocol= diff --git a/updater/disambiguation/hierarchical_clustering_disambiguation b/updater/disambiguation/hierarchical_clustering_disambiguation index ee259a77..aca8eb09 160000 --- a/updater/disambiguation/hierarchical_clustering_disambiguation +++ b/updater/disambiguation/hierarchical_clustering_disambiguation @@ -1 +1 @@ -Subproject commit ee259a7760d5b7cf6fcd4c95e58aa18465712fbd +Subproject commit aca8eb091527a8e991a59b4413cbd14dcb2b5d98