## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importosimportsysimportdecimalimporttimeimportmathimportdatetimeimportcalendarimportjsonimportreimportbase64fromarrayimportarrayimportctypesfromcollections.abcimportIterablefromfunctoolsimportreducefromtypingimport(cast,overload,Any,Callable,ClassVar,Dict,Iterator,List,Optional,Union,Tuple,Type,TypeVar,TYPE_CHECKING,)frompyspark.utilimportis_remote_only,JVM_INT_MAXfrompyspark.serializersimportCloudPickleSerializerfrompyspark.sql.utilsimport(get_active_spark_context,escape_meta_characters,StringConcat,)frompyspark.sql.variant_utilsimportVariantUtilsfrompyspark.errorsimport(PySparkNotImplementedError,PySparkTypeError,PySparkValueError,PySparkIndexError,PySparkRuntimeError,PySparkAttributeError,PySparkKeyError,)ifTYPE_CHECKING:importnumpyasnpfrompy4j.java_gatewayimportGatewayClient,JavaGateway,JavaClassT=TypeVar("T")U=TypeVar("U")__all__=["DataType","NullType","CharType","StringType","VarcharType","BinaryType","BooleanType","DateType","TimestampType","TimestampNTZType","DecimalType","DoubleType","FloatType","ByteType","IntegerType","LongType","DayTimeIntervalType","YearMonthIntervalType","CalendarIntervalType","Row","ShortType","ArrayType","MapType","StructField","StructType","VariantType","VariantVal",]
[docs]classDataType:"""Base class for data types."""def__repr__(self)->str:returnself.__class__.__name__+"()"def__hash__(self)->int:returnhash(str(self))def__eq__(self,other:Any)->bool:returnisinstance(other,self.__class__)andself.__dict__==other.__dict__def__ne__(self,other:Any)->bool:returnnotself.__eq__(other)
[docs]defneedConversion(self)->bool:""" Does this type needs conversion between Python object and internal SQL object. This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. """returnFalse
[docs]deftoInternal(self,obj:Any)->Any:""" Converts a Python object into an internal SQL object. """returnobj
[docs]deffromInternal(self,obj:Any)->Any:""" Converts an internal SQL object into a native Python object. """returnobj
def_as_nullable(self)->"DataType":returnself
[docs]@classmethoddeffromDDL(cls,ddl:str)->"DataType":""" Creates :class:`DataType` for a given DDL-formatted string. .. versionadded:: 4.0.0 Parameters ---------- ddl : str DDL-formatted string representation of types, e.g. :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can omit the ``struct<>`` for the compatibility reason with ``spark.createDataFrame`` and Python UDFs. Returns ------- :class:`DataType` Examples -------- Create a StructType by the corresponding DDL formatted string. >>> from pyspark.sql.types import DataType >>> DataType.fromDDL("b string, a int") StructType([StructField('b', StringType(), True), StructField('a', IntegerType(), True)]) Create a single DataType by the corresponding DDL formatted string. >>> DataType.fromDDL("decimal(10,10)") DecimalType(10,10) Create a StructType by the legacy string format. >>> DataType.fromDDL("b: string, a: int") StructType([StructField('b', StringType(), True), StructField('a', IntegerType(), True)]) """return_parse_datatype_string(ddl)
@classmethoddef_data_type_build_formatted_string(cls,dataType:"DataType",prefix:str,stringConcat:StringConcat,maxDepth:int,)->None:ifisinstance(dataType,(ArrayType,StructType,MapType)):dataType._build_formatted_string(prefix,stringConcat,maxDepth-1)# The method typeName() is not always the same as the Scala side.# Add this helper method to make TreeString() compatible with Scala side.@classmethoddef_get_jvm_type_name(cls,dataType:"DataType")->str:ifisinstance(dataType,(DecimalType,CharType,VarcharType,DayTimeIntervalType,YearMonthIntervalType,),):returndataType.simpleString()else:returndataType.typeName()
# This singleton pattern does not work with pickle, you will get# another object after pickle and unpickleclassDataTypeSingleton(type):"""Metaclass for DataType"""_instances:ClassVar[Dict[Type["DataTypeSingleton"],"DataTypeSingleton"]]={}def__call__(cls:Type[T])->T:ifclsnotincls._instances:# type: ignore[attr-defined]cls._instances[cls]=super(# type: ignore[misc, attr-defined]DataTypeSingleton,cls).__call__()returncls._instances[cls]# type: ignore[attr-defined]
[docs]classNullType(DataType,metaclass=DataTypeSingleton):"""Null type. The data type representing None, used for the types that cannot be inferred. """
classAtomicType(DataType):"""An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps."""classNumericType(AtomicType):"""Numeric data types."""classIntegralType(NumericType,metaclass=DataTypeSingleton):"""Integral data types."""passclassFractionalType(NumericType):"""Fractional data types."""
[docs]classStringType(AtomicType):"""String data type. Parameters ---------- collation : str name of the collation, default is UTF8_BINARY. """providerSpark="spark"providerICU="icu"providers=[providerSpark,providerICU]def__init__(self,collation:str="UTF8_BINARY"):self.collation=collation
[docs]@classmethoddefcollationProvider(cls,collationName:str)->str:# TODO: do this properly like on the scala sideifcollationName.startswith("UTF8"):returnStringType.providerSparkreturnStringType.providerICU
# For backwards compatibility and compatibility with other readers all string types# are serialized in json as regular strings and the collation info is written to# struct field metadata
[docs]classCharType(AtomicType):"""Char data type Parameters ---------- length : int the length limitation. """def__init__(self,length:int):self.length=length
[docs]classVarcharType(AtomicType):"""Varchar data type Parameters ---------- length : int the length limitation. """def__init__(self,length:int):self.length=length
[docs]deffromInternal(self,ts:int)->datetime.datetime:iftsisnotNone:# using int to avoid precision loss in floatreturndatetime.datetime.fromtimestamp(ts//1000000).replace(microsecond=ts%1000000)
[docs]classTimestampNTZType(AtomicType,metaclass=DataTypeSingleton):"""Timestamp (datetime.datetime) data type without timezone information."""
[docs]deffromInternal(self,ts:int)->datetime.datetime:iftsisnotNone:# using int to avoid precision loss in floatreturndatetime.datetime.fromtimestamp(ts//1000000,datetime.timezone.utc).replace(microsecond=ts%1000000,tzinfo=None)
[docs]classDecimalType(FractionalType):"""Decimal (decimal.Decimal) data type. The DecimalType must have fixed precision (the maximum total number of digits) and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. The precision can be up to 38, the scale must be less or equal to precision. When creating a DecimalType, the default precision and scale is (10, 0). When inferring schema from decimal.Decimal objects, it will be DecimalType(38, 18). Parameters ---------- precision : int, optional the maximum (i.e. total) number of digits (default: 10) scale : int, optional the number of digits on right side of dot. (default: 0) """def__init__(self,precision:int=10,scale:int=0):self.precision=precisionself.scale=scaleself.hasPrecisionInfo=True# this is a public API
[docs]classLongType(IntegralType):"""Long data type, representing signed 64-bit integers. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please use :class:`DecimalType`. """
[docs]classYearMonthIntervalType(AnsiIntervalType):"""YearMonthIntervalType, represents year-month intervals of the SQL standard Notes ----- This data type doesn't support collection: df.collect/take/head. """YEAR=0MONTH=1_fields={YEAR:"year",MONTH:"month",}_inverted_fields=dict(zip(_fields.values(),_fields.keys()))def__init__(self,startField:Optional[int]=None,endField:Optional[int]=None):ifstartFieldisNoneandendFieldisNone:# Default matched to scala side.startField=YearMonthIntervalType.YEARendField=YearMonthIntervalType.MONTHelifstartFieldisnotNoneandendFieldisNone:endField=startFieldfields=YearMonthIntervalType._fieldsifstartFieldnotinfields.keys()orendFieldnotinfields.keys():raisePySparkRuntimeError(errorClass="INVALID_INTERVAL_CASTING",messageParameters={"start_field":str(startField),"end_field":str(endField)},)self.startField=startFieldself.endField=endFielddef_str_repr(self)->str:fields=YearMonthIntervalType._fieldsstart_field_name=fields[self.startField]end_field_name=fields[self.endField]ifstart_field_name==end_field_name:return"interval %s"%start_field_nameelse:return"interval %s to %s"%(start_field_name,end_field_name)simpleString=_str_reprjsonValue=_str_repr
[docs]defneedConversion(self)->bool:# If PYSPARK_YM_INTERVAL_LEGACY is not set, needConversion is true,# 'df.collect' fails with PySparkNotImplementedError;# otherwise, no conversion is needed, and 'df.collect' returns the internal integers.returnnotos.environ.get("PYSPARK_YM_INTERVAL_LEGACY")=="1"
[docs]classCalendarIntervalType(DataType,metaclass=DataTypeSingleton):"""The data type representing calendar intervals. The calendar interval is stored internally in three components: - an integer value representing the number of `months` in this interval. - an integer value representing the number of `days` in this interval. - a long value representing the number of `microseconds` in this interval. """
[docs]classArrayType(DataType):"""Array data type. Parameters ---------- elementType : :class:`DataType` :class:`DataType` of each element in the array. containsNull : bool, optional whether the array can contain null (None) values. Examples -------- >>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType The below example demonstrates how to create class:`ArrayType`: >>> arr = ArrayType(StringType()) The array can contain null (None) values by default: >>> ArrayType(StringType()) == ArrayType(StringType(), True) True >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """def__init__(self,elementType:DataType,containsNull:bool=True):assertisinstance(elementType,DataType),"elementType %s should be an instance of %s"%(elementType,DataType,)self.elementType=elementTypeself.containsNull=containsNull
[docs]classMapType(DataType):"""Map data type. Parameters ---------- keyType : :class:`DataType` :class:`DataType` of the keys in the map. valueType : :class:`DataType` :class:`DataType` of the values in the map. valueContainsNull : bool, optional indicates whether values can contain null (None) values. Notes ----- Keys in a map data type are not allowed to be null (None). Examples -------- >>> from pyspark.sql.types import IntegerType, FloatType, MapType, StringType The below example demonstrates how to create class:`MapType`: >>> map_type = MapType(StringType(), IntegerType()) The values of the map can contain null (``None``) values by default: >>> (MapType(StringType(), IntegerType()) ... == MapType(StringType(), IntegerType(), True)) True >>> (MapType(StringType(), IntegerType(), False) ... == MapType(StringType(), FloatType())) False """def__init__(self,keyType:DataType,valueType:DataType,valueContainsNull:bool=True):assertisinstance(keyType,DataType),"keyType %s should be an instance of %s"%(keyType,DataType,)assertisinstance(valueType,DataType),"valueType %s should be an instance of %s"%(valueType,DataType,)self.keyType=keyTypeself.valueType=valueTypeself.valueContainsNull=valueContainsNull
[docs]classStructField(DataType):"""A field in :class:`StructType`. Parameters ---------- name : str name of the field. dataType : :class:`DataType` :class:`DataType` of the field. nullable : bool, optional whether the field can be null (None) or not. metadata : dict, optional a dict from string to simple type that can be toInternald to JSON automatically Examples -------- >>> from pyspark.sql.types import StringType, StructField >>> (StructField("f1", StringType(), True) ... == StructField("f1", StringType(), True)) True >>> (StructField("f1", StringType(), True) ... == StructField("f2", StringType(), True)) False """def__init__(self,name:str,dataType:DataType,nullable:bool=True,metadata:Optional[Dict[str,Any]]=None,):assertisinstance(dataType,DataType),"dataType %s should be an instance of %s"%(dataType,DataType,)assertisinstance(name,str),"field name %s should be a string"%(name)self.name=nameself.dataType=dataTypeself.nullable=nullableself.metadata=metadataor{}
[docs]classStructType(DataType):"""Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by its name or position. Examples -------- >>> from pyspark.sql.types import * >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField('f1', StringType(), True) >>> struct1[0] StructField('f1', StringType(), True) >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", CharType(10), True)]) >>> struct2 = StructType([StructField("f1", CharType(10), True)]) >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", VarcharType(10), True)]) >>> struct2 = StructType([StructField("f1", VarcharType(10), True)]) >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True), ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False The below example demonstrates how to create a DataFrame based on a struct created using class:`StructType` and class:`StructField`: >>> data = [("Alice", ["Java", "Scala"]), ("Bob", ["Python", "Scala"])] >>> schema = StructType([ ... StructField("name", StringType()), ... StructField("languagesSkills", ArrayType(StringType())), ... ]) >>> df = spark.createDataFrame(data=data, schema=schema) >>> df.printSchema() root |-- name: string (nullable = true) |-- languagesSkills: array (nullable = true) | |-- element: string (containsNull = true) >>> df.show() +-----+---------------+ | name|languagesSkills| +-----+---------------+ |Alice| [Java, Scala]| | Bob|[Python, Scala]| +-----+---------------+ """def__init__(self,fields:Optional[List[StructField]]=None):ifnotfields:self.fields=[]self.names=[]else:self.fields=fieldsself.names=[f.nameforfinfields]assertall(isinstance(f,StructField)forfinfields),"fields should be a list of StructField"# Precalculated list of fields that need conversion with fromInternal/toInternal functionsself._needConversion=[f.needConversion()forfinself]self._needSerializeAnyField=any(self._needConversion)@overloaddefadd(self,field:str,data_type:Union[str,DataType],nullable:bool=True,metadata:Optional[Dict[str,Any]]=None,)->"StructType":...@overloaddefadd(self,field:StructField)->"StructType":...
[docs]defadd(self,field:Union[str,StructField],data_type:Optional[Union[str,DataType]]=None,nullable:bool=True,metadata:Optional[Dict[str,Any]]=None,)->"StructType":""" Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. b) Between 2 and 4 parameters as (name, data_type, nullable (optional), metadata(optional). The data_type parameter may be either a String or a :class:`DataType` object. Parameters ---------- field : str or :class:`StructField` Either the name of the field or a :class:`StructField` object data_type : :class:`DataType`, optional If present, the DataType of the :class:`StructField` to create nullable : bool, optional Whether the field to add should be nullable (default True) metadata : dict, optional Any additional metadata (default None) Returns ------- :class:`StructType` Examples -------- >>> from pyspark.sql.types import IntegerType, StringType, StructField, StructType >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True), ... StructField("f2", StringType(), True, None)]) >>> struct1 == struct2 True >>> struct1 = StructType().add(StructField("f1", StringType(), True)) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True >>> struct1 = StructType().add("f1", "string", True) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True """ifisinstance(field,StructField):self.fields.append(field)self.names.append(field.name)else:ifisinstance(field,str)anddata_typeisNone:raisePySparkValueError(errorClass="ARGUMENT_REQUIRED",messageParameters={"arg_name":"data_type","condition":"passing name of struct_field to create",},)ifisinstance(data_type,str):data_type_f=_parse_datatype_json_value(data_type)else:data_type_f=data_typeself.fields.append(StructField(field,data_type_f,nullable,metadata))self.names.append(field)# Precalculated list of fields that need conversion with fromInternal/toInternal functionsself._needConversion=[f.needConversion()forfinself]self._needSerializeAnyField=any(self._needConversion)returnself
def__iter__(self)->Iterator[StructField]:"""Iterate the fields"""returniter(self.fields)def__len__(self)->int:"""Return the number of fields."""returnlen(self.fields)def__getitem__(self,key:Union[str,int])->StructField:"""Access fields by name or slice."""ifisinstance(key,str):forfieldinself:iffield.name==key:returnfieldraisePySparkKeyError(errorClass="KEY_NOT_EXISTS",messageParameters={"key":str(key)})elifisinstance(key,int):try:returnself.fields[key]exceptIndexError:raisePySparkIndexError(errorClass="INDEX_OUT_OF_RANGE",messageParameters={"arg_name":"StructType","index":str(key)},)elifisinstance(key,slice):returnStructType(self.fields[key])else:raisePySparkTypeError(errorClass="NOT_INT_OR_SLICE_OR_STR",messageParameters={"arg_name":"key","arg_type":type(key).__name__},)
[docs]@classmethoddeffromJson(cls,json:Dict[str,Any])->"StructType":""" Constructs :class:`StructType` from a schema defined in JSON format. Below is a JSON schema it must adhere to:: { "title":"StructType", "description":"Schema of StructType in json format", "type":"object", "properties":{ "fields":{ "description":"Array of struct fields", "type":"array", "items":{ "type":"object", "properties":{ "name":{ "description":"Name of the field", "type":"string" }, "type":{ "description": "Type of the field. Can either be another nested StructType or primitive type", "type":"object/string" }, "nullable":{ "description":"If nulls are allowed", "type":"boolean" }, "metadata":{ "description":"Additional metadata to supply", "type":"object" }, "required":[ "name", "type", "nullable", "metadata" ] } } } } } Parameters ---------- json : dict or a dict-like object e.g. JSON object This "dict" must have "fields" key that returns an array of fields each of which must have specific keys (name, type, nullable, metadata). Returns ------- :class:`StructType` Examples -------- >>> json_str = ''' ... { ... "fields": [ ... { ... "metadata": {}, ... "name": "Person", ... "nullable": true, ... "type": { ... "fields": [ ... { ... "metadata": {}, ... "name": "name", ... "nullable": false, ... "type": "string" ... }, ... { ... "metadata": {}, ... "name": "surname", ... "nullable": false, ... "type": "string" ... } ... ], ... "type": "struct" ... } ... } ... ], ... "type": "struct" ... } ... ''' >>> import json >>> scheme = StructType.fromJson(json.loads(json_str)) >>> scheme.simpleString() 'struct<Person:struct<name:string,surname:string>>' """returnStructType([StructField.fromJson(f)forfinjson["fields"]])
[docs]deffieldNames(self)->List[str]:""" Returns all field names in a list. Examples -------- >>> from pyspark.sql.types import StringType, StructField, StructType >>> struct = StructType([StructField("f1", StringType(), True)]) >>> struct.fieldNames() ['f1'] """returnlist(self.names)
[docs]defneedConversion(self)->bool:# We need convert Row()/namedtuple into tuple()returnTrue
[docs]deftoInternal(self,obj:Tuple)->Tuple:ifobjisNone:returnifisinstance(obj,VariantVal):raisePySparkValueError("Rows cannot be of type VariantVal")ifself._needSerializeAnyField:# Only calling toInternal function for fields that need conversionifisinstance(obj,dict):returntuple(f.toInternal(obj.get(n))ifcelseobj.get(n)forn,f,cinzip(self.names,self.fields,self._needConversion))elifisinstance(obj,(tuple,list)):returntuple(f.toInternal(v)ifcelsevforf,v,cinzip(self.fields,obj,self._needConversion))elifhasattr(obj,"__dict__"):d=obj.__dict__returntuple(f.toInternal(d.get(n))ifcelsed.get(n)forn,f,cinzip(self.names,self.fields,self._needConversion))else:raisePySparkValueError(errorClass="UNEXPECTED_TUPLE_WITH_STRUCT",messageParameters={"tuple":str(obj)},)else:ifisinstance(obj,dict):returntuple(obj.get(n)forninself.names)elifisinstance(obj,(list,tuple)):returntuple(obj)elifhasattr(obj,"__dict__"):d=obj.__dict__returntuple(d.get(n)forninself.names)else:raisePySparkValueError(errorClass="UNEXPECTED_TUPLE_WITH_STRUCT",messageParameters={"tuple":str(obj)},)
[docs]deffromInternal(self,obj:Tuple)->"Row":ifobjisNone:returnifisinstance(obj,Row):# it's already converted by picklerreturnobjvalues:Union[Tuple,List]ifself._needSerializeAnyField:# Only calling fromInternal function for fields that need conversionvalues=[f.fromInternal(v)ifcelsevforf,v,cinzip(self.fields,obj,self._needConversion)]else:values=objreturn_create_row(self.names,values)
classUserDefinedType(DataType):"""User-defined type (UDT). .. note:: WARN: Spark Internal Use Only """@classmethoddeftypeName(cls)->str:returncls.__name__.lower()@classmethoddefsqlType(cls)->DataType:""" Underlying SQL storage type for this UDT. """raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":"sqlType()"},)@classmethoddefmodule(cls)->str:""" The Python module of the UDT. """raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":"module()"},)@classmethoddefscalaUDT(cls)->str:""" The class name of the paired Scala UDT (could be '', if there is no corresponding one). """return""defneedConversion(self)->bool:returnTrue@classmethoddef_cachedSqlType(cls)->DataType:""" Cache the sqlType() into class, because it's heavily used in `toInternal`. """ifnothasattr(cls,"_cached_sql_type"):cls._cached_sql_type=cls.sqlType()# type: ignore[attr-defined]returncls._cached_sql_type# type: ignore[attr-defined]deftoInternal(self,obj:Any)->Any:ifobjisnotNone:returnself._cachedSqlType().toInternal(self.serialize(obj))deffromInternal(self,obj:Any)->Any:v=self._cachedSqlType().fromInternal(obj)ifvisnotNone:returnself.deserialize(v)defserialize(self,obj:Any)->Any:""" Converts a user-type object into a SQL datum. """raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":"toInternal()"},)defdeserialize(self,datum:Any)->Any:""" Converts a SQL datum into a user-type object. """raisePySparkNotImplementedError(errorClass="NOT_IMPLEMENTED",messageParameters={"feature":"fromInternal()"},)defsimpleString(self)->str:return"udt"defjson(self)->str:returnjson.dumps(self.jsonValue(),separators=(",",":"),sort_keys=True)defjsonValue(self)->Dict[str,Any]:ifself.scalaUDT():assertself.module()!="__main__","UDT in __main__ cannot work with ScalaUDT"schema={"type":"udt","class":self.scalaUDT(),"pyClass":"%s.%s"%(self.module(),type(self).__name__),"sqlType":self.sqlType().jsonValue(),}else:ser=CloudPickleSerializer()b=ser.dumps(type(self))schema={"type":"udt","pyClass":"%s.%s"%(self.module(),type(self).__name__),"serializedClass":base64.b64encode(b).decode("utf8"),"sqlType":self.sqlType().jsonValue(),}returnschema@classmethoddeffromJson(cls,json:Dict[str,Any])->"UserDefinedType":pyUDT=str(json["pyClass"])# convert unicode to strsplit=pyUDT.rfind(".")pyModule=pyUDT[:split]pyClass=pyUDT[split+1:]m=__import__(pyModule,globals(),locals(),[pyClass])ifnothasattr(m,pyClass):s=base64.b64decode(json["serializedClass"].encode("utf-8"))UDT=CloudPickleSerializer().loads(s)else:UDT=getattr(m,pyClass)returnUDT()def__eq__(self,other:Any)->bool:returntype(self)==type(other)
[docs]classVariantVal:""" A class to represent a Variant value in Python. .. versionadded:: 4.0.0 Parameters ---------- value : bytes The bytes representing the value component of the Variant. metadata : bytes The bytes representing the metadata component of the Variant. Methods ------- toPython() Convert the VariantVal to a Python data structure. Examples -------- >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) >>> v = df.select(sf.parse_json(df.json).alias("var")).head().var >>> v.toPython() {'a': 1} >>> v.toJson() '{"a":1}' """def__init__(self,value:bytes,metadata:bytes):self.value=valueself.metadata=metadatadef__str__(self)->str:returnVariantUtils.to_json(self.value,self.metadata)def__repr__(self)->str:return"VariantVal(%r, %r)"%(self.value,self.metadata)
[docs]deftoPython(self)->Any:""" Convert the VariantVal to a Python data structure. Returns ------- Any A Python object that represents the Variant. """returnVariantUtils.to_python(self.value,self.metadata)
[docs]deftoJson(self,zone_id:str="UTC")->str:""" Convert the VariantVal to a JSON string. The zone ID represents the time zone that the timestamp should be printed in. It is defaulted to UTC. The list of valid zone IDs can be found by importing the `zoneinfo` module and running :code:`zoneinfo.available_timezones()`. Returns ------- str A JSON string that represents the Variant. """returnVariantUtils.to_json(self.value,self.metadata,zone_id)
[docs]@classmethoddefparseJson(cls,json_str:str)->"VariantVal":""" Convert the VariantVal to a nested Python object of Python data types. :return: Python representation of the Variant nested structure """(value,metadata)=VariantUtils.parse_json(json_str)returnVariantVal(value,metadata)
_atomic_types:List[Type[DataType]]=[StringType,CharType,VarcharType,BinaryType,BooleanType,DecimalType,FloatType,DoubleType,ByteType,ShortType,IntegerType,LongType,DateType,TimestampType,TimestampNTZType,NullType,VariantType,YearMonthIntervalType,DayTimeIntervalType,]_complex_types:List[Type[Union[ArrayType,MapType,StructType]]]=[ArrayType,MapType,StructType,]_all_complex_types:Dict[str,Type[Union[ArrayType,MapType,StructType]]]={"array":ArrayType,"map":MapType,"struct":StructType,}# Datatypes that can be directly parsed by mapping a json string without regex.# This dict should be only used in json parsing.# Note that:# 1, CharType and VarcharType are not listed here, since they need regex;# 2, DecimalType can be parsed by both mapping ('decimal') and regex ('decimal(10, 2)');# 3, CalendarIntervalType is not an atomic type, but can be mapped by 'interval';_all_mappable_types:Dict[str,Type[DataType]]={"string":StringType,"binary":BinaryType,"boolean":BooleanType,"decimal":DecimalType,"float":FloatType,"double":DoubleType,"byte":ByteType,"short":ShortType,"integer":IntegerType,"long":LongType,"date":DateType,"timestamp":TimestampType,"timestamp_ntz":TimestampNTZType,"void":NullType,"variant":VariantType,"interval":CalendarIntervalType,}_LENGTH_CHAR=re.compile(r"char\(\s*(\d+)\s*\)")_LENGTH_VARCHAR=re.compile(r"varchar\(\s*(\d+)\s*\)")_FIXED_DECIMAL=re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")_INTERVAL_DAYTIME=re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")_INTERVAL_YEARMONTH=re.compile(r"interval (year|month)( to (year|month))?")_COLLATIONS_METADATA_KEY="__COLLATIONS"def_drop_metadata(d:Union[DataType,StructField])->Union[DataType,StructField]:assertisinstance(d,(DataType,StructField))ifisinstance(d,StructField):returnStructField(d.name,_drop_metadata(d.dataType),d.nullable,None)elifisinstance(d,StructType):returnStructType([cast(StructField,_drop_metadata(f))forfind.fields])elifisinstance(d,ArrayType):returnArrayType(_drop_metadata(d.elementType),d.containsNull)elifisinstance(d,MapType):returnMapType(_drop_metadata(d.keyType),_drop_metadata(d.valueType),d.valueContainsNull)returnddef_parse_datatype_string(s:str)->DataType:""" Parses the given data type string to a :class:`DataType`. The data type string format equals :class:`DataType.simpleString`, except that the top level struct type can omit the ``struct<>``. Since Spark 2.3, this also supports a schema in a DDL-formatted string and case-insensitive strings. Examples -------- >>> _parse_datatype_string("int ") IntegerType() >>> _parse_datatype_string("INT ") IntegerType() >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)]) >>> _parse_datatype_string("a DOUBLE, b STRING") StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)]) >>> _parse_datatype_string("a DOUBLE, b CHAR( 50 )") StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)]) >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )") StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)]) >>> _parse_datatype_string("a: array< short>") StructType([StructField('a', ArrayType(ShortType(), True), True)]) >>> _parse_datatype_string(" map<string , string > ") MapType(StringType(), StringType(), True) >>> # Error cases >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ParseException:... >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ParseException:... >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ParseException:... >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ParseException:... """frompyspark.sql.utilsimportis_remoteifis_remote():frompyspark.sql.connect.sessionimportSparkSessionreturnSparkSession.active()._parse_ddl(s)else:returnget_active_spark_context()._parse_ddl(s)def_parse_datatype_json_string(json_string:str)->DataType:"""Parses the given data type JSON string. Examples -------- >>> import pickle >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled ... scala_datatype = spark._jsparkSession.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype ... >>> for cls in _all_mappable_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) >>> check_datatype(simple_arraytype) >>> # Simple MapType. >>> simple_maptype = MapType(StringType(), LongType()) >>> check_datatype(simple_maptype) >>> # Simple StructType. >>> simple_structtype = StructType([ ... StructField("a", DecimalType(), False), ... StructField("b", BooleanType(), True), ... StructField("c", LongType(), True), ... StructField("d", BinaryType(), False)]) >>> check_datatype(simple_structtype) >>> # Complex StructType. >>> complex_structtype = StructType([ ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), ... StructField("boolean", BooleanType(), False), ... StructField("chars", CharType(10), False), ... StructField("words", VarcharType(10), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) >>> # Complex ArrayType. >>> complex_arraytype = ArrayType(complex_structtype, True) >>> check_datatype(complex_arraytype) >>> # Complex MapType. >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) """return_parse_datatype_json_value(json.loads(json_string))def_parse_datatype_json_value(json_value:Union[dict,str],fieldPath:str="",collationsMap:Optional[Dict[str,str]]=None,)->DataType:ifnotisinstance(json_value,dict):ifjson_valuein_all_mappable_types.keys():ifcollationsMapisnotNoneandfieldPathincollationsMap:_assert_valid_type_for_collation(fieldPath,json_value,collationsMap)collation_name=collationsMap[fieldPath]returnStringType(collation_name)return_all_mappable_types[json_value]()elif_FIXED_DECIMAL.match(json_value):m=_FIXED_DECIMAL.match(json_value)returnDecimalType(int(m.group(1)),int(m.group(2)))# type: ignore[union-attr]elif_INTERVAL_DAYTIME.match(json_value):m=_INTERVAL_DAYTIME.match(json_value)inverted_fields=DayTimeIntervalType._inverted_fieldsfirst_field=inverted_fields.get(m.group(1))# type: ignore[union-attr]second_field=inverted_fields.get(m.group(3))# type: ignore[union-attr]iffirst_fieldisnotNoneandsecond_fieldisNone:returnDayTimeIntervalType(first_field)returnDayTimeIntervalType(first_field,second_field)elif_INTERVAL_YEARMONTH.match(json_value):m=_INTERVAL_YEARMONTH.match(json_value)inverted_fields=YearMonthIntervalType._inverted_fieldsfirst_field=inverted_fields.get(m.group(1))# type: ignore[union-attr]second_field=inverted_fields.get(m.group(3))# type: ignore[union-attr]iffirst_fieldisnotNoneandsecond_fieldisNone:returnYearMonthIntervalType(first_field)returnYearMonthIntervalType(first_field,second_field)elif_LENGTH_CHAR.match(json_value):m=_LENGTH_CHAR.match(json_value)returnCharType(int(m.group(1)))# type: ignore[union-attr]elif_LENGTH_VARCHAR.match(json_value):m=_LENGTH_VARCHAR.match(json_value)returnVarcharType(int(m.group(1)))# type: ignore[union-attr]else:raisePySparkValueError(errorClass="CANNOT_PARSE_DATATYPE",messageParameters={"error":str(json_value)},)else:tpe=json_value["type"]iftpein_all_complex_types:ifcollationsMapisnotNoneandfieldPathincollationsMap:_assert_valid_type_for_collation(fieldPath,tpe,collationsMap)complex_type=_all_complex_types[tpe]ifcomplex_typeisArrayType:returnArrayType.fromJson(json_value,fieldPath,collationsMap)elifcomplex_typeisMapType:returnMapType.fromJson(json_value,fieldPath,collationsMap)returnStructType.fromJson(json_value)eliftpe=="udt":returnUserDefinedType.fromJson(json_value)else:raisePySparkValueError(errorClass="UNSUPPORTED_DATA_TYPE",messageParameters={"data_type":str(tpe)},)def_assert_valid_type_for_collation(fieldPath:str,fieldType:Any,collationMap:Dict[str,str])->None:iffieldPathincollationMapandfieldType!="string":raisePySparkTypeError(errorClass="INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",messageParameters={"jsonType":fieldType},)def_assert_valid_collation_provider(provider:str)->None:ifprovider.lower()notinStringType.providers:raisePySparkValueError(errorClass="COLLATION_INVALID_PROVIDER",messageParameters={"provider":provider,"supportedProviders":", ".join(StringType.providers),},)# Mapping Python types to Spark SQL DataType_type_mappings={type(None):NullType,bool:BooleanType,int:LongType,float:DoubleType,str:StringType,bytearray:BinaryType,decimal.Decimal:DecimalType,datetime.date:DateType,datetime.datetime:TimestampType,# can be TimestampNTZTypedatetime.time:TimestampType,# can be TimestampNTZTypedatetime.timedelta:DayTimeIntervalType,bytes:BinaryType,}# Mapping Python array types to Spark SQL DataType# We should be careful here. The size of these types in python depends on C# implementation. We need to make sure that this conversion does not lose any# precision. Also, JVM only support signed types, when converting unsigned types,# keep in mind that it require 1 more bit when stored as signed types.## Reference for C integer size, see:# ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types <limits.h>.# Reference for python array typecode, see:# https://docs.python.org/2/library/array.html# https://docs.python.org/3.6/library/array.html# Reference for JVM's supported integral types:# http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1_array_signed_int_typecode_ctype_mappings={"b":ctypes.c_byte,"h":ctypes.c_short,"i":ctypes.c_int,"l":ctypes.c_long,}_array_unsigned_int_typecode_ctype_mappings={"B":ctypes.c_ubyte,"H":ctypes.c_ushort,"I":ctypes.c_uint,"L":ctypes.c_ulong,}def_int_size_to_type(size:int,)->Optional[Union[Type[ByteType],Type[ShortType],Type[IntegerType],Type[LongType]]]:""" Return the Catalyst datatype from the size of integers. """ifsize<=8:returnByteTypeelifsize<=16:returnShortTypeelifsize<=32:returnIntegerTypeelifsize<=64:returnLongTypeelse:returnNone# The list of all supported array typecodes, is stored here_array_type_mappings:Dict[str,Type[DataType]]={# Warning: Actual properties for float and double in C is not specified in C.# On almost every system supported by both python and JVM, they are IEEE 754# single-precision binary floating-point format and IEEE 754 double-precision# binary floating-point format. And we do assume the same thing here for now."f":FloatType,"d":DoubleType,}# compute array typecode mappings for signed integer typesfor_typecodein_array_signed_int_typecode_ctype_mappings.keys():size=ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode])*8dt=_int_size_to_type(size)ifdtisnotNone:_array_type_mappings[_typecode]=dt# compute array typecode mappings for unsigned integer typesfor_typecodein_array_unsigned_int_typecode_ctype_mappings.keys():# JVM does not have unsigned types, so use signed types that is at least 1# bit larger to storesize=ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode])*8+1dt=_int_size_to_type(size)ifdtisnotNone:_array_type_mappings[_typecode]=dt# Type code 'u' in Python's array is deprecated since version 3.3, and will be# removed in version 4.0. See: https://docs.python.org/3/library/array.htmlifsys.version_info[0]<4:_array_type_mappings["u"]=StringTypedef_from_numpy_type(nt:"np.dtype")->Optional[DataType]:"""Convert NumPy type to Spark data type."""importnumpyasnpifnt==np.dtype("bool"):returnBooleanType()elifnt==np.dtype("int8"):returnByteType()elifnt==np.dtype("int16"):returnShortType()elifnt==np.dtype("int32"):returnIntegerType()elifnt==np.dtype("int64"):returnLongType()elifnt==np.dtype("float32"):returnFloatType()elifnt==np.dtype("float64"):returnDoubleType()elifnt.type==np.dtype("str"):returnStringType()returnNonedef_infer_type(obj:Any,infer_dict_as_struct:bool=False,infer_array_from_first_element:bool=False,infer_map_from_first_pair:bool=False,prefer_timestamp_ntz:bool=False,)->DataType:"""Infer the DataType from obj"""ifobjisNone:returnNullType()ifhasattr(obj,"__UDT__"):returnobj.__UDT__dataType=_type_mappings.get(type(obj))ifdataTypeisDecimalType:# the precision and scale of `obj` may be different from row to row.returnDecimalType(38,18)ifdataTypeisTimestampTypeandprefer_timestamp_ntzandobj.tzinfoisNone:returnTimestampNTZType()ifdataTypeisDayTimeIntervalType:returnDayTimeIntervalType()ifdataTypeisYearMonthIntervalType:returnYearMonthIntervalType()ifdataTypeisCalendarIntervalType:returnCalendarIntervalType()elifdataTypeisnotNone:returndataType()ifisinstance(obj,dict):ifinfer_dict_as_struct:struct=StructType()forkey,valueinobj.items():ifkeyisnotNoneandvalueisnotNone:struct.add(key,_infer_type(value,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),True,)returnstructelifinfer_map_from_first_pair:forkey,valueinobj.items():ifkeyisnotNoneandvalueisnotNone:returnMapType(_infer_type(key,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),_infer_type(value,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),True,)returnMapType(NullType(),NullType(),True)else:key_type:DataType=NullType()value_type:DataType=NullType()forkey,valueinobj.items():ifkeyisnotNone:key_type=_merge_type(key_type,_infer_type(key,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),)ifvalueisnotNone:value_type=_merge_type(value_type,_infer_type(value,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),)returnMapType(key_type,value_type,True)elifisinstance(obj,list):iflen(obj)>0:ifinfer_array_from_first_element:returnArrayType(_infer_type(obj[0],infer_dict_as_struct,infer_array_from_first_element,prefer_timestamp_ntz,),True,)else:returnArrayType(reduce(_merge_type,(_infer_type(v,infer_dict_as_struct,infer_array_from_first_element,prefer_timestamp_ntz,)forvinobj),),True,)returnArrayType(NullType(),True)elifisinstance(obj,array):ifobj.typecodein_array_type_mappings:returnArrayType(_array_type_mappings[obj.typecode](),False)else:raisePySparkTypeError(errorClass="UNSUPPORTED_DATA_TYPE",messageParameters={"data_type":f"array({obj.typecode})"},)else:try:return_infer_schema(obj,infer_dict_as_struct=infer_dict_as_struct,infer_array_from_first_element=infer_array_from_first_element,prefer_timestamp_ntz=prefer_timestamp_ntz,)exceptTypeError:raisePySparkTypeError(errorClass="UNSUPPORTED_DATA_TYPE",messageParameters={"data_type":type(obj).__name__},)def_infer_schema(row:Any,names:Optional[List[str]]=None,infer_dict_as_struct:bool=False,infer_array_from_first_element:bool=False,infer_map_from_first_pair:bool=False,prefer_timestamp_ntz:bool=False,)->StructType:"""Infer the schema from dict/namedtuple/object"""items:Iterable[Tuple[str,Any]]ifisinstance(row,dict):items=sorted(row.items())elifisinstance(row,(tuple,list)):ifhasattr(row,"__fields__"):# Rowitems=zip(row.__fields__,tuple(row))elifhasattr(row,"_fields"):# namedtupleitems=zip(row._fields,tuple(row))else:ifnamesisNone:names=["_%d"%iforiinrange(1,len(row)+1)]eliflen(names)<len(row):names.extend("_%d"%iforiinrange(len(names)+1,len(row)+1))items=zip(names,row)elifhasattr(row,"__dict__"):# objectitems=sorted(row.__dict__.items())else:raisePySparkTypeError(errorClass="CANNOT_INFER_SCHEMA_FOR_TYPE",messageParameters={"data_type":type(row).__name__},)fields=[]fork,vinitems:try:fields.append(StructField(k,_infer_type(v,infer_dict_as_struct,infer_array_from_first_element,infer_map_from_first_pair,prefer_timestamp_ntz,),True,))exceptTypeError:raisePySparkTypeError(errorClass="CANNOT_INFER_TYPE_FOR_FIELD",messageParameters={"field_name":k},)returnStructType(fields)def_has_nulltype(dt:DataType)->bool:"""Return whether there is a NullType in `dt` or not"""ifisinstance(dt,StructType):returnany(_has_nulltype(f.dataType)forfindt.fields)elifisinstance(dt,ArrayType):return_has_nulltype((dt.elementType))elifisinstance(dt,MapType):return_has_nulltype(dt.keyType)or_has_nulltype(dt.valueType)else:returnisinstance(dt,NullType)def_has_type(dt:DataType,dts:Union[type,Tuple[type,...]])->bool:"""Return whether there are specified types"""ifisinstance(dt,dts):returnTrueelifisinstance(dt,StructType):returnany(_has_type(f.dataType,dts)forfindt.fields)elifisinstance(dt,ArrayType):return_has_type(dt.elementType,dts)elifisinstance(dt,MapType):return_has_type(dt.keyType,dts)or_has_type(dt.valueType,dts)else:returnFalse@overloaddef_merge_type(a:StructType,b:StructType,name:Optional[str]=None)->StructType:...@overloaddef_merge_type(a:ArrayType,b:ArrayType,name:Optional[str]=None)->ArrayType:...@overloaddef_merge_type(a:MapType,b:MapType,name:Optional[str]=None)->MapType:...@overloaddef_merge_type(a:DataType,b:DataType,name:Optional[str]=None)->DataType:...def_merge_type(a:Union[StructType,ArrayType,MapType,DataType],b:Union[StructType,ArrayType,MapType,DataType],name:Optional[str]=None,)->Union[StructType,ArrayType,MapType,DataType]:ifnameisNone:defnew_msg(msg:str)->str:returnmsgdefnew_name(n:str)->str:return"field %s"%nelse:defnew_msg(msg:str)->str:return"%s: %s"%(name,msg)defnew_name(n:str)->str:return"field %s in %s"%(n,name)ifisinstance(a,NullType):returnbelifisinstance(b,NullType):returnaelifisinstance(a,TimestampType)andisinstance(b,TimestampNTZType):returnaelifisinstance(a,TimestampNTZType)andisinstance(b,TimestampType):returnbelifisinstance(a,AtomicType)andisinstance(b,StringType):returnbelifisinstance(a,StringType)andisinstance(b,AtomicType):returnaeliftype(a)isnottype(b):# TODO: type cast (such as int -> long)raisePySparkTypeError(errorClass="CANNOT_MERGE_TYPE",messageParameters={"data_type1":type(a).__name__,"data_type2":type(b).__name__},)# same typeifisinstance(a,StructType):nfs=dict((f.name,f.dataType)forfincast(StructType,b).fields)fields=[StructField(f.name,_merge_type(f.dataType,nfs.get(f.name,NullType()),name=new_name(f.name)))forfina.fields]names=set([f.nameforfinfields])forninnfs:ifnnotinnames:fields.append(StructField(n,nfs[n]))returnStructType(fields)elifisinstance(a,ArrayType):returnArrayType(_merge_type(a.elementType,cast(ArrayType,b).elementType,name="element in array %s"%name),True,)elifisinstance(a,MapType):returnMapType(_merge_type(a.keyType,cast(MapType,b).keyType,name="key of map %s"%name),_merge_type(a.valueType,cast(MapType,b).valueType,name="value of map %s"%name),True,)else:returnadef_need_converter(dataType:DataType)->bool:ifisinstance(dataType,StructType):returnTrueelifisinstance(dataType,ArrayType):return_need_converter(dataType.elementType)elifisinstance(dataType,MapType):return_need_converter(dataType.keyType)or_need_converter(dataType.valueType)elifisinstance(dataType,NullType):returnTrueelse:returnFalsedef_create_converter(dataType:DataType)->Callable:"""Create a converter to drop the names of fields in obj"""ifnot_need_converter(dataType):returnlambdax:xifisinstance(dataType,ArrayType):conv=_create_converter(dataType.elementType)returnlambdarow:[conv(v)forvinrow]elifisinstance(dataType,MapType):kconv=_create_converter(dataType.keyType)vconv=_create_converter(dataType.valueType)returnlambdarow:dict((kconv(k),vconv(v))fork,vinrow.items())elifisinstance(dataType,NullType):returnlambdax:Noneelifnotisinstance(dataType,StructType):returnlambdax:x# dataType must be StructTypenames=[f.nameforfindataType.fields]converters=[_create_converter(f.dataType)forfindataType.fields]convert_fields=any(_need_converter(f.dataType)forfindataType.fields)defconvert_struct(obj:Any)->Optional[Tuple]:ifobjisNone:returnNoneifisinstance(obj,(tuple,list)):ifconvert_fields:returntuple(conv(v)forv,convinzip(obj,converters))else:returntuple(obj)ifisinstance(obj,dict):d=objelifhasattr(obj,"__dict__"):# objectd=obj.__dict__else:raisePySparkTypeError(errorClass="UNSUPPORTED_DATA_TYPE",messageParameters={"data_type":type(obj).__name__},)ifconvert_fields:returntuple([conv(d.get(name))forname,convinzip(names,converters)])else:returntuple([d.get(name)fornameinnames])returnconvert_struct_acceptable_types={BooleanType:(bool,),ByteType:(int,),ShortType:(int,),IntegerType:(int,),LongType:(int,),FloatType:(float,),DoubleType:(float,),DecimalType:(decimal.Decimal,),StringType:(str,),CharType:(str,),VarcharType:(str,),BinaryType:(bytearray,bytes),DateType:(datetime.date,datetime.datetime),TimestampType:(datetime.datetime,),TimestampNTZType:(datetime.datetime,),DayTimeIntervalType:(datetime.timedelta,),ArrayType:(list,tuple,array),MapType:(dict,),StructType:(tuple,list,dict),VariantType:(bool,int,float,decimal.Decimal,str,bytearray,bytes,datetime.date,datetime.datetime,datetime.timedelta,tuple,list,dict,array,),}def_make_type_verifier(dataType:DataType,nullable:bool=True,name:Optional[str]=None,)->Callable:""" Make a verifier that checks the type of obj against dataType and raises a TypeError if they do not match. This verifier also checks the value of obj against datatype and raises a ValueError if it's not within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it will become infinity when cast to Java float, if it overflows. Examples -------- >>> _make_type_verifier(StructType([]))(None) >>> _make_type_verifier(StringType())("") >>> _make_type_verifier(LongType())(0) >>> _make_type_verifier(LongType())(1 << 64) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkTypeError:... >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) >>> _make_type_verifier(StructType([]))(()) >>> _make_type_verifier(StructType([]))([]) >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> # Check if numeric values are within the allowed range. >>> _make_type_verifier(ByteType())(12) >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> _make_type_verifier( ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> _make_type_verifier( # doctest: +IGNORE_EXCEPTION_DETAIL ... MapType(StringType(), IntegerType()) ... )({None: 1}) Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkValueError:... """ifnameisNone:defnew_msg(msg:str)->str:returnmsgdefnew_name(n:str)->str:return"field %s"%nelse:defnew_msg(msg:str)->str:return"%s: %s"%(name,msg)defnew_name(n:str)->str:return"field %s in %s"%(n,name)defverify_nullability(obj:Any)->bool:ifobjisNone:ifnullable:returnTrueelse:ifnameisnotNone:raisePySparkValueError(errorClass="FIELD_NOT_NULLABLE_WITH_NAME",messageParameters={"field_name":str(name),},)raisePySparkValueError(errorClass="FIELD_NOT_NULLABLE",messageParameters={},)else:returnFalse_type=type(dataType)defassert_acceptable_types(obj:Any)->None:assert_typein_acceptable_types,new_msg("unknown datatype: %s for object %r"%(dataType,obj))defverify_acceptable_types(obj:Any)->None:# subclass of them can not be fromInternal in JVMiftype(obj)notin_acceptable_types[_type]:ifnameisnotNone:raisePySparkTypeError(errorClass="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME",messageParameters={"field_name":str(name),"data_type":str(dataType),"obj":repr(obj),"obj_type":str(type(obj)),},)raisePySparkTypeError(errorClass="FIELD_DATA_TYPE_UNACCEPTABLE",messageParameters={"data_type":str(dataType),"obj":repr(obj),"obj_type":str(type(obj)),},)ifisinstance(dataType,(StringType,CharType,VarcharType)):# StringType, CharType and VarcharType can work with any typesdefverify_value(obj:Any)->None:passelifisinstance(dataType,UserDefinedType):verifier=_make_type_verifier(dataType.sqlType(),name=name)defverify_udf(obj:Any)->None:ifnot(hasattr(obj,"__UDT__")andobj.__UDT__==dataType):ifnameisnotNone:raisePySparkValueError(errorClass="FIELD_TYPE_MISMATCH_WITH_NAME",messageParameters={"field_name":str(name),"obj":str(obj),"data_type":str(dataType),},)raisePySparkValueError(errorClass="FIELD_TYPE_MISMATCH",messageParameters={"obj":str(obj),"data_type":str(dataType),},)verifier(dataType.toInternal(obj))verify_value=verify_udfelifisinstance(dataType,ByteType):defverify_byte(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)lower_bound=-128upper_bound=127ifobj<lower_boundorobj>upper_bound:raisePySparkValueError(errorClass="VALUE_OUT_OF_BOUNDS",messageParameters={"arg_name":"obj","lower_bound":str(lower_bound),"upper_bound":str(upper_bound),"actual":str(obj),},)verify_value=verify_byteelifisinstance(dataType,ShortType):defverify_short(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)lower_bound=-32768upper_bound=32767ifobj<lower_boundorobj>upper_bound:raisePySparkValueError(errorClass="VALUE_OUT_OF_BOUNDS",messageParameters={"arg_name":"obj","lower_bound":str(lower_bound),"upper_bound":str(upper_bound),"actual":str(obj),},)verify_value=verify_shortelifisinstance(dataType,IntegerType):defverify_integer(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)lower_bound=-2147483648upper_bound=2147483647ifobj<lower_boundorobj>upper_bound:raisePySparkValueError(errorClass="VALUE_OUT_OF_BOUNDS",messageParameters={"arg_name":"obj","lower_bound":str(lower_bound),"upper_bound":str(upper_bound),"actual":str(obj),},)verify_value=verify_integerelifisinstance(dataType,LongType):defverify_long(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)lower_bound=-9223372036854775808upper_bound=9223372036854775807ifobj<lower_boundorobj>upper_bound:raisePySparkValueError(errorClass="VALUE_OUT_OF_BOUNDS",messageParameters={"arg_name":"obj","lower_bound":str(lower_bound),"upper_bound":str(upper_bound),"actual":str(obj),},)verify_value=verify_longelifisinstance(dataType,ArrayType):element_verifier=_make_type_verifier(dataType.elementType,dataType.containsNull,name="element in array %s"%name)defverify_array(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)foriinobj:element_verifier(i)verify_value=verify_arrayelifisinstance(dataType,MapType):key_verifier=_make_type_verifier(dataType.keyType,False,name="key of map %s"%name)value_verifier=_make_type_verifier(dataType.valueType,dataType.valueContainsNull,name="value of map %s"%name)defverify_map(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)fork,vinobj.items():key_verifier(k)value_verifier(v)verify_value=verify_mapelifisinstance(dataType,StructType):verifiers=[]forfindataType.fields:verifier=_make_type_verifier(f.dataType,f.nullable,name=new_name(f.name))verifiers.append((f.name,verifier))defverify_struct(obj:Any)->None:assert_acceptable_types(obj)ifisinstance(obj,dict):forf,verifierinverifiers:verifier(obj.get(f))elifisinstance(obj,(tuple,list)):iflen(obj)!=len(verifiers):ifnameisnotNone:raisePySparkValueError(errorClass="FIELD_STRUCT_LENGTH_MISMATCH_WITH_NAME",messageParameters={"field_name":str(name),"object_length":str(len(obj)),"field_length":str(len(verifiers)),},)raisePySparkValueError(errorClass="FIELD_STRUCT_LENGTH_MISMATCH",messageParameters={"object_length":str(len(obj)),"field_length":str(len(verifiers)),},)forv,(_,verifier)inzip(obj,verifiers):verifier(v)elifhasattr(obj,"__dict__"):d=obj.__dict__forf,verifierinverifiers:verifier(d.get(f))else:ifnameisnotNone:raisePySparkTypeError(errorClass="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME",messageParameters={"field_name":str(name),"data_type":str(dataType),"obj":repr(obj),"obj_type":str(type(obj)),},)raisePySparkTypeError(errorClass="FIELD_DATA_TYPE_UNACCEPTABLE",messageParameters={"data_type":str(dataType),"obj":repr(obj),"obj_type":str(type(obj)),},)verify_value=verify_structelifisinstance(dataType,VariantType):defverify_variant(obj:Any)->None:# The variant data type can take in any type.passverify_value=verify_variantelse:defverify_default(obj:Any)->None:assert_acceptable_types(obj)verify_acceptable_types(obj)verify_value=verify_defaultdefverify(obj:Any)->None:ifnotverify_nullability(obj):verify_value(obj)returnverify# This is used to unpickle a Row from JVMdef_create_row_inbound_converter(dataType:DataType)->Callable:returnlambda*a:dataType.fromInternal(a)def_create_row(fields:Union["Row",List[str]],values:Union[Tuple[Any,...],List[Any]])->"Row":row=Row(*values)row.__fields__=fieldsreturnrow
[docs]classRow(tuple):""" A row in :class:`DataFrame`. The fields in it can be accessed: * like attributes (``row.key``) * like dictionary values (``row[key]``) ``key in row`` will search through row keys. Row can be used to create a row object by using named arguments. It is not allowed to omit a named argument to represent that the value is None or missing. This should be explicitly set to None in this case. .. versionchanged:: 3.0.0 Rows created from named arguments no longer have field names sorted alphabetically and will be ordered in the position as entered. Examples -------- >>> from pyspark.sql import Row >>> row = Row(name="Alice", age=11) >>> row Row(name='Alice', age=11) >>> row['name'], row['age'] ('Alice', 11) >>> row.name, row.age ('Alice', 11) >>> 'name' in row True >>> 'wrong_key' in row False Row also can be used to create another Row like class, then it could be used to create Row objects, such as >>> Person = Row("name", "age") >>> Person <Row('name', 'age')> >>> 'name' in Person True >>> 'wrong_key' in Person False >>> Person("Alice", 11) Row(name='Alice', age=11) This form can also be used to create rows as tuple values, i.e. with unnamed fields. >>> row1 = Row("Alice", 11) >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True """@overloaddef__new__(cls,*args:str)->"Row":...@overloaddef__new__(cls,**kwargs:Any)->"Row":...def__new__(cls,*args:Optional[str],**kwargs:Optional[Any])->"Row":ifargsandkwargs:raisePySparkValueError(errorClass="CANNOT_SET_TOGETHER",messageParameters={"arg_list":"args and kwargs"},)ifkwargs:# create row objectsrow=tuple.__new__(cls,list(kwargs.values()))row.__fields__=list(kwargs.keys())returnrowelse:# create row class or objectsreturntuple.__new__(cls,args)
[docs]defasDict(self,recursive:bool=False)->Dict[str,Any]:""" Return as a dict Parameters ---------- recursive : bool, optional turns the nested Rows to dict (default: False). Notes ----- If a row contains duplicate field names, e.g., the rows of a join between two :class:`DataFrame` that both have the fields of same names, one of the duplicate fields will be selected by ``asDict``. ``__getitem__`` will also return one of the duplicate fields, however returned value might be different to ``asDict``. Examples -------- >>> from pyspark.sql import Row >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} True >>> row = Row(key=1, value=Row(name='a', age=2)) >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} True >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True """ifnothasattr(self,"__fields__"):raisePySparkTypeError(errorClass="CANNOT_CONVERT_TYPE",messageParameters={"from_type":"Row","to_type":"dict",},)ifrecursive:defconv(obj:Any)->Any:ifisinstance(obj,Row):returnobj.asDict(True)elifisinstance(obj,list):return[conv(o)foroinobj]elifisinstance(obj,dict):returndict((k,conv(v))fork,vinobj.items())else:returnobjreturndict(zip(self.__fields__,(conv(o)foroinself)))else:returndict(zip(self.__fields__,self))
def__contains__(self,item:Any)->bool:ifhasattr(self,"__fields__"):returniteminself.__fields__else:returnsuper(Row,self).__contains__(item)# let object acts like classdef__call__(self,*args:Any)->"Row":"""create new Row object"""iflen(args)>len(self):raisePySparkValueError(errorClass="TOO_MANY_VALUES",messageParameters={"expected":str(len(self)),"item":"fields","actual":str(len(args)),},)return_create_row(self,args)def__getitem__(self,item:Any)->Any:ifisinstance(item,(int,slice)):returnsuper(Row,self).__getitem__(item)try:# it will be slow when it has many fields,# but this will not be used in normal casesidx=self.__fields__.index(item)returnsuper(Row,self).__getitem__(idx)exceptIndexError:raisePySparkKeyError(errorClass="KEY_NOT_EXISTS",messageParameters={"key":str(item)})exceptValueError:raisePySparkValueError(item)def__getattr__(self,item:str)->Any:ifitem.startswith("__"):raisePySparkAttributeError(errorClass="ATTRIBUTE_NOT_SUPPORTED",messageParameters={"attr_name":item})try:# it will be slow when it has many fields,# but this will not be used in normal casesidx=self.__fields__.index(item)returnself[idx]exceptIndexError:raisePySparkAttributeError(errorClass="ATTRIBUTE_NOT_SUPPORTED",messageParameters={"attr_name":item})exceptValueError:raisePySparkAttributeError(errorClass="ATTRIBUTE_NOT_SUPPORTED",messageParameters={"attr_name":item})def__setattr__(self,key:Any,value:Any)->None:ifkey!="__fields__":raisePySparkRuntimeError(errorClass="READ_ONLY",messageParameters={"object":"Row"},)self.__dict__[key]=valuedef__reduce__(self,)->Union[str,Tuple[Any,...]]:"""Returns a tuple so Python knows how to pickle Row."""ifhasattr(self,"__fields__"):return(_create_row,(self.__fields__,tuple(self)))else:returntuple.__reduce__(self)def__repr__(self)->str:"""Printable representation of Row used in Python REPL."""ifhasattr(self,"__fields__"):return"Row(%s)"%", ".join("%s=%r"%(k,v)fork,vinzip(self.__fields__,tuple(self)))else:return"<Row(%s)>"%", ".join(repr(field)forfieldinself)
classDateConverter:defcan_convert(self,obj:Any)->bool:returnisinstance(obj,datetime.date)defconvert(self,obj:datetime.date,gateway_client:"GatewayClient")->"JavaGateway":frompy4j.java_gatewayimportJavaClassDate=JavaClass("java.sql.Date",gateway_client)returnDate.valueOf(obj.strftime("%Y-%m-%d"))classDatetimeConverter:defcan_convert(self,obj:Any)->bool:returnisinstance(obj,datetime.datetime)defconvert(self,obj:datetime.datetime,gateway_client:"GatewayClient")->"JavaGateway":frompy4j.java_gatewayimportJavaClassTimestamp=JavaClass("java.sql.Timestamp",gateway_client)seconds=(calendar.timegm(obj.utctimetuple())ifobj.tzinfoelsetime.mktime(obj.timetuple()))t=Timestamp(int(seconds)*1000)t.setNanos(obj.microsecond*1000)returntclassDatetimeNTZConverter:defcan_convert(self,obj:Any)->bool:frompyspark.sql.utilsimportis_timestamp_ntz_preferredreturn(isinstance(obj,datetime.datetime)andobj.tzinfoisNoneandis_timestamp_ntz_preferred())defconvert(self,obj:datetime.datetime,gateway_client:"GatewayClient")->"JavaGateway":frompy4j.java_gatewayimportJavaClassseconds=calendar.timegm(obj.utctimetuple())DateTimeUtils=JavaClass("org.apache.spark.sql.catalyst.util.DateTimeUtils",gateway_client,)returnDateTimeUtils.microsToLocalDateTime(int(seconds)*1000000+obj.microsecond)classDayTimeIntervalTypeConverter:defcan_convert(self,obj:Any)->bool:returnisinstance(obj,datetime.timedelta)defconvert(self,obj:datetime.timedelta,gateway_client:"GatewayClient")->"JavaGateway":frompy4j.java_gatewayimportJavaClassIntervalUtils=JavaClass("org.apache.spark.sql.catalyst.util.IntervalUtils",gateway_client,)returnIntervalUtils.microsToDuration((math.floor(obj.total_seconds())*1000000)+obj.microseconds)classNumpyScalarConverter:defcan_convert(self,obj:Any)->bool:frompyspark.testing.utilsimporthave_numpyifhave_numpy:importnumpyasnpreturnisinstance(obj,np.generic)returnFalsedefconvert(self,obj:"np.generic",gateway_client:"GatewayClient")->Any:returnobj.item()classNumpyArrayConverter:def_from_numpy_type_to_java_type(self,nt:"np.dtype",gateway:"JavaGateway")->Optional["JavaClass"]:"""Convert NumPy type to Py4J Java type."""importnumpyasnpifntin[np.dtype("int8"),np.dtype("int16")]:# Mapping int8 to gateway.jvm.byte causes# TypeError: 'bytes' object does not support item assignmentreturngateway.jvm.shortelifnt==np.dtype("int32"):returngateway.jvm.intelifnt==np.dtype("int64"):returngateway.jvm.longelifnt==np.dtype("float32"):returngateway.jvm.floatelifnt==np.dtype("float64"):returngateway.jvm.doubleelifnt==np.dtype("bool"):returngateway.jvm.booleanelifnt.type==np.dtype("str"):returngateway.jvm.StringreturnNonedefcan_convert(self,obj:Any)->bool:frompyspark.testing.utilsimporthave_numpyifhave_numpy:importnumpyasnpreturnisinstance(obj,np.ndarray)andobj.ndim==1returnFalsedefconvert(self,obj:"np.ndarray",gateway_client:"GatewayClient")->"JavaGateway":frompysparkimportSparkContextgateway=SparkContext._gatewayassertgatewayisnotNoneplist=obj.tolist()jtpe=self._from_numpy_type_to_java_type(obj.dtype,gateway)ifjtpeisNone:raisePySparkTypeError(errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR",messageParameters={"dtype":str(obj.dtype)},)jarr=gateway.new_array(jtpe,len(obj))foriinrange(len(plist)):jarr[i]=plist[i]returnjarrifnotis_remote_only():frompy4j.protocolimportregister_input_converter# datetime is a subclass of date, we should register DatetimeConverter firstregister_input_converter(DatetimeNTZConverter())register_input_converter(DatetimeConverter())register_input_converter(DateConverter())register_input_converter(DayTimeIntervalTypeConverter())register_input_converter(NumpyScalarConverter())# NumPy array satisfies py4j.java_collections.ListConverter,# so prepend NumpyArrayConverterregister_input_converter(NumpyArrayConverter(),prepend=True)def_test()->None:importdoctestfrompyspark.sqlimportSparkSessionglobs=globals()globs["spark"]=SparkSession.builder.getOrCreate()(failure_count,test_count)=doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE)iffailure_count:sys.exit(-1)if__name__=="__main__":_test()