19
19
20
20
logger = logging .getLogger (__name__ )
21
21
22
+ STRING = "STRING"
23
+ BOOL = "BOOLEAN"
24
+ FLOAT = "FLOAT"
25
+ INT = "INTEGER"
26
+ DATE = "DATE"
27
+ TS = "TIMESTAMP"
28
+
22
29
# Map LinkML types to PySpark types
23
30
TYPE_MAP = {
24
- "boolean" : BooleanType () ,
25
- "xsd:boolean" : BooleanType () ,
31
+ "boolean" : BOOL ,
32
+ "xsd:boolean" : BOOL ,
26
33
# numerical
27
- "decimal" : FloatType () ,
28
- "double" : FloatType () ,
29
- "float" : FloatType () ,
30
- "integer" : IntegerType () ,
31
- "long" : FloatType () ,
34
+ "decimal" : FLOAT ,
35
+ "double" : FLOAT ,
36
+ "float" : FLOAT ,
37
+ "integer" : INT ,
38
+ "long" : FLOAT ,
32
39
# dates and times
33
- "date" : DateType () ,
34
- "dateTime" : DateType () ,
35
- "time" : TimestampType () ,
36
- "xsd:date" : DateType () ,
37
- "xsd:dateTime" : DateType () ,
38
- "xsd:time" : TimestampType () ,
39
- "linkml:DateOrDatetime" : DateType () ,
40
+ "date" : DATE ,
41
+ "dateTime" : DATE ,
42
+ "time" : TS ,
43
+ "xsd:date" : DATE ,
44
+ "xsd:dateTime" : DATE ,
45
+ "xsd:time" : TS ,
46
+ "linkml:DateOrDatetime" : DATE ,
40
47
# string-like
41
- "anyURI" : StringType (),
42
- "language" : StringType (),
43
- "string" : StringType (),
44
- "shex:nonLiteral" : StringType (),
45
- "shex:iri" : StringType (),
48
+ "anyURI" : STRING ,
49
+ "language" : STRING ,
50
+ "str" : STRING ,
51
+ "string" : STRING ,
52
+ "shex:nonLiteral" : STRING ,
53
+ "shex:iri" : STRING ,
54
+ }
55
+
56
+ remap = {
57
+ STRING : StringType (),
58
+ BOOL : BooleanType (),
59
+ FLOAT : FloatType (),
60
+ INT : IntegerType (),
61
+ DATE : DateType (),
62
+ TS : TimestampType (),
46
63
}
47
64
48
65
49
- def resolve_slot_range_class_relational (sv : SchemaView , class_name : str ) -> set [DataType ]:
66
+ class SchemaViewWithProcessed (SchemaView ):
67
+ def __init__ (self , * args , ** kwargs ) -> None :
68
+ self .PROCESSED = {}
69
+ self .RANGE_TO_TYPE = {}
70
+ super ().__init__ (* args , ** kwargs )
71
+
72
+
73
+ def resolve_slot_range_class_relational (sv : SchemaViewWithProcessed , class_name : str ) -> set [str ]:
50
74
"""Generate the appropriate slot range for a given class.
51
75
52
76
:param sv: the schema, via SchemaView
@@ -66,19 +90,19 @@ def resolve_slot_range_class_relational(sv: SchemaView, class_name: str) -> set[
66
90
if not class_id_slot .range :
67
91
msg = f"Class { class_name } identifier { class_id_slot .name } has no range: defaulting to string"
68
92
logger .warning (msg )
69
- sv .PROCESSED [class_name ] = StringType ()
70
- return {StringType () }
93
+ sv .RANGE_TO_TYPE [class_name ] = STRING
94
+ return {STRING }
71
95
72
96
if class_id_slot .range in sv .all_classes ():
73
97
msg = f"Class { class_id_slot .range } used as range for identifier slot of class { class_name } "
74
98
logger .warning (msg )
75
- sv .PROCESSED [class_name ] = StringType ()
76
- return {StringType () }
99
+ sv .RANGE_TO_TYPE [class_name ] = STRING
100
+ return {STRING }
77
101
78
102
return resolve_slot_range (sv , class_name = class_name , slot_name = class_id_slot .name , slot_range = class_id_slot .range )
79
103
80
104
81
- def resolve_slot_range_type (sv : SchemaView , type_name : str ) -> set [DataType ]:
105
+ def resolve_slot_range_type (sv : SchemaViewWithProcessed , type_name : str ) -> set [str ]:
82
106
"""Generate the appropriate slot range for a given type.
83
107
84
108
:param sv: the schema, via SchemaView
@@ -94,14 +118,14 @@ def resolve_slot_range_type(sv: SchemaView, type_name: str) -> set[DataType]:
94
118
msg = f"type { type_name } lacks base and uri fields"
95
119
logger .warning (msg )
96
120
# add it to the mapping
97
- sv .PROCESSED [type_name ] = StringType ()
98
- return {StringType () }
121
+ sv .RANGE_TO_TYPE [type_name ] = STRING
122
+ return {STRING }
99
123
100
124
type_uri = type_uri .removeprefix ("xsd:" )
101
- return {TYPE_MAP .get (type_uri , StringType () )}
125
+ return {TYPE_MAP .get (type_uri , STRING )}
102
126
103
127
104
- def resolve_slot_range (sv : SchemaView , class_name : str , slot_name : str , slot_range : str ) -> set [DataType ]:
128
+ def resolve_slot_range (sv : SchemaViewWithProcessed , class_name : str , slot_name : str , slot_range : str ) -> set [str ]:
105
129
"""Generate the appropriate spark datatype for a given slot_range.
106
130
107
131
:param sv: the schema, via SchemaView
@@ -115,8 +139,8 @@ def resolve_slot_range(sv: SchemaView, class_name: str, slot_name: str, slot_ran
115
139
:return: set of spark datatype(s) to use
116
140
:rtype: set[DataType]
117
141
"""
118
- if slot_range in sv .PROCESSED :
119
- return {sv .PROCESSED [slot_range ]}
142
+ if slot_range in sv .RANGE_TO_TYPE :
143
+ return {sv .RANGE_TO_TYPE [slot_range ]}
120
144
121
145
if slot_range in sv .all_classes ():
122
146
return resolve_slot_range_class_relational (sv , slot_range )
@@ -126,20 +150,22 @@ def resolve_slot_range(sv: SchemaView, class_name: str, slot_name: str, slot_ran
126
150
127
151
# resolve enums as strings for now
128
152
if slot_range in sv .all_enums ():
129
- sv .PROCESSED [slot_range ] = StringType ()
130
- return {StringType () }
153
+ sv .RANGE_TO_TYPE [slot_range ] = STRING
154
+ return {STRING }
131
155
132
156
if slot_range not in TYPE_MAP :
133
157
msg = f"{ class_name } .{ slot_name } range { slot_range } : no type mapping found; using StringType()"
134
158
logger .warning (msg )
135
159
# add it to the mapping
136
- sv .PROCESSED [slot_range ] = StringType ()
137
- return {StringType () }
160
+ sv .RANGE_TO_TYPE [slot_range ] = STRING
161
+ return {STRING }
138
162
139
163
return {TYPE_MAP [slot_range ]}
140
164
141
165
142
- def build_struct_for_class (sv : SchemaView , class_name : str ) -> dict [str , tuple [str , DataType , bool ]] | None :
166
+ def build_struct_for_class (
167
+ sv : SchemaViewWithProcessed , class_name : str
168
+ ) -> dict [str , tuple [str , DataType , bool ]] | None :
143
169
"""Generate the appropriate Spark schema for a class in a LinkML schema.
144
170
145
171
:param sv: the schema, via SchemaView
@@ -184,7 +210,7 @@ def build_struct_for_class(sv: SchemaView, class_name: str) -> dict[str, tuple[s
184
210
if len (slot_range_resolved ) > 1 :
185
211
msg = f"WARNING: { class_name } .{ slot .name } : more than one possible slot range: { ', ' .join (slot_range_resolved )} "
186
212
logger .warning (msg )
187
- slot_range_resolved = {StringType () }
213
+ slot_range_resolved = {STRING }
188
214
189
215
if len (slot_range_resolved ) == 0 :
190
216
msg = f"ERROR: { class_name } .{ slot .name } slot_range_set length is 0"
@@ -204,7 +230,7 @@ def build_struct_for_class(sv: SchemaView, class_name: str) -> dict[str, tuple[s
204
230
205
231
206
232
def generate_pyspark_from_sv (
207
- sv : SchemaView , classes : list [str ] | None = None
233
+ sv : SchemaViewWithProcessed , classes : list [str ] | None = None
208
234
) -> dict [str , dict [str , tuple [DataType , bool ]]]:
209
235
"""Generate pyspark tables from a LinkML schema.
210
236
@@ -215,7 +241,7 @@ def generate_pyspark_from_sv(
215
241
:param classes: list of class names to parse; defaults to None
216
242
:type classes: list[str] | None
217
243
:return: dictionary containing annotations for each field in each class of the schema, excluding abstract classes and mixins
218
- :rtype: ddict [str, dict[str, tuple[DataType, bool]]]
244
+ :rtype: dict [str, dict[str, tuple[DataType, bool]]]
219
245
"""
220
246
spark_schemas = {}
221
247
@@ -227,10 +253,12 @@ def generate_pyspark_from_sv(
227
253
return spark_schemas
228
254
229
255
230
- def write_output (sv : SchemaView , output_path : Path , spark_schemas : dict [str , dict [str , tuple [DataType , bool ]]]) -> None :
256
+ def write_output (
257
+ sv : SchemaViewWithProcessed , output_path : Path , spark_schemas : dict [str , dict [str , tuple [DataType , bool ]]]
258
+ ) -> None :
231
259
indent = " " * 4
232
260
# extract all the types from the StructFields
233
- all_types = {dt for table_fields in spark_schemas .values () for dt , _ in table_fields .values ()}
261
+ all_types = {remap [ dt ] for table_fields in spark_schemas .values () for dt , _ in table_fields .values ()}
234
262
header_material = [
235
263
f'"""Automated conversion of { sv .schema .name } to PySpark."""' ,
236
264
"" ,
@@ -247,7 +275,7 @@ def write_output(sv: SchemaView, output_path: Path, spark_schemas: dict[str, dic
247
275
[
248
276
f'{ indent } "{ table_name } ": StructType([' ,
249
277
* [
250
- f'{ indent } { indent } StructField("{ name } ", { dtype } , nullable={ nullable } ),'
278
+ f'{ indent } { indent } StructField("{ name } ", { remap [ dtype ] } , nullable={ nullable } ),'
251
279
for name , (dtype , nullable ) in table .items ()
252
280
],
253
281
f"{ indent } " + "]),\n " ,
@@ -259,10 +287,6 @@ def write_output(sv: SchemaView, output_path: Path, spark_schemas: dict[str, dic
259
287
print (f"PySpark schema written to { output_path } " )
260
288
261
289
262
- class SchemaViewWithProcessed (SchemaView ):
263
- PROCESSED = {}
264
-
265
-
266
290
if __name__ == "__main__" :
267
291
sv = SchemaViewWithProcessed ("./src/linkml/cdm_schema.yaml" )
268
292
0 commit comments