>>> from pyspark.sql import functions as sf
>>> from pyspark.sql.types import LongType

>>> df = spark.createDataFrame([(1,), (2,), (3,)], ["x"])
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
|  2.0|
+-----+

>>> df.select(sf.try_avg("x").alias("avg_x")).printSchema()
root
 |-- avg_x: double (nullable = true)

>>> df = spark.createDataFrame([(None,), (2,), (None,)], ["x"])
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
|  2.0|
+-----+

>>> LONG_MAX = 2**63 - 1
>>> df = (
...     spark.createDataFrame([(LONG_MAX,), (1,)], ["x"])
...       .withColumn("x", sf.col("x").cast(LongType()))
... )
>>> df.select(sf.try_avg("x").alias("avg_x")).show(truncate=False)
+--------------------+
|avg_x               |
+--------------------+
|4.611686018427388e18|
+--------------------+

>>> df.select(sf.try_avg("x").alias("avg_x")).show(truncate=False)
+--------------------+
|avg_x               |
+--------------------+
|4.611686018427388e18|
+--------------------+

>>> df = spark.createDataFrame([(1.5,), (2.5,), (3.0,)], ["x"])
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+------------------+
|             avg_x|
+------------------+
|2.3333333333333335|
+------------------+


>>> df = spark.createDataFrame([(1e308,), (1e308,)], ["x"])
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+--------+
|   avg_x|
+--------+
|Infinity|
+--------+

>>> df = spark.createDataFrame([(float('nan'),), (1.0,)], ["x"])
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
|  NaN|
+-----+

>>> from decimal import Decimal
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame(
...     [(Decimal("1.23"),), (Decimal("4.77"),)],
...     "x DECIMAL(10,2)"
... )
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
| 3.00|
+-----+

# pyspark
# +--------+
# |   avg_x|
# +--------+
# |3.000000|
# +--------+

>>> df = spark.createDataFrame(
...     [(Decimal("1.00"),), (None,), (Decimal("2.50"),)],
...     "x DECIMAL(10,2)"
... )
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
| 1.75|
+-----+

# pyspark
# +--------+
# |   avg_x|
# +--------+
# |1.750000|
# +--------+

>>> from decimal import Decimal
>>> df = spark.createDataFrame(
...     [(Decimal("90000"),), (Decimal("20000"),)],
...     "x DECIMAL(5,0)"
... )
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
|55000|
+-----+

# pyspark
# +----------+
# |     avg_x|
# +----------+
# |55000.0000|
# +----------+

>>> from pyspark.sql import functions as sf
>>> from decimal import Decimal
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
>>> spark.conf.set("spark.sql.ansi.enabled", "true")

>>> df = spark.createDataFrame(
...     [(Decimal("1" * 38),)] * 10,
...     "x DECIMAL(38,0)"
... )
>>> df.select(sf.try_avg("x").alias("avg_x")).show()
+-----+
|avg_x|
+-----+
| NULL|
+-----+


>>> from pyspark.sql import functions as sf, types as T
>>> LONG_MAX = 2**63 - 1
>>> data = [
...     ("bad",  LONG_MAX),
...     ("bad",  1),
...     ("ok",   10),
...     ("ok",   None),
...     ("ok",   5),
... ]
>>> df = spark.createDataFrame(
...     data,
...     schema=T.StructType([
...         T.StructField("g", T.StringType(), True),
...         T.StructField("x", T.LongType(),   True),
...     ])
... )
>>> out = (
...     df.groupBy("g")
...       .agg(sf.try_avg("x").alias("avg_x"))
...       .orderBy("g")
... )
>>> out.show(truncate=False)
+---+--------------------+
|g  |avg_x               |
+---+--------------------+
|bad|4.611686018427388e18|
|ok |7.5                 |
+---+--------------------+

>>> out.printSchema()
root
 |-- g: string (nullable = true)
 |-- avg_x: double (nullable = true)

>>> from pyspark.sql import functions as sf, types as T

>>> LONG_MAX = 2**63 - 1
>>> data = [
...     ("bad",  LONG_MAX),
...     ("bad",  1),
...     ("ok",   10),
...     ("ok",   None),
...     ("ok",   5),
... ]
>>> df = spark.createDataFrame(
...     data,
...     schema=T.StructType([
...         T.StructField("g", T.StringType(), True),
...         T.StructField("x", T.LongType(),   True),
...     ])
... )
>>> out = (
...     df.groupBy("g")
...       .agg(sf.try_avg("x").alias("avg_x"))
...       .orderBy("g")
... )
>>> out.show()
+---+--------------------+
|  g|               avg_x|
+---+--------------------+
|bad|4.611686018427388e18|
| ok|                 7.5|
+---+--------------------+

>>> out.printSchema()
root
 |-- g: string (nullable = true)
 |-- avg_x: double (nullable = true)

# Interval
>>> spark.sql("SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col)").show()
+------------+
|try_avg(col)|
+------------+
|        NULL|
+------------+

>>> spark.sql("SELECT try_avg(col) FROM VALUES (interval '7 months'), (interval '1 months') AS tab(col)").show(truncate=False)
+----------------------------+
|try_avg(col)                |
+----------------------------+
|INTERVAL '0-4' YEAR TO MONTH|
+----------------------------+

>>> spark.sql("SELECT try_avg(col) FROM VALUES (interval '10 months'), null, (interval '5 months') AS tab(col)").show(truncate=False)
+----------------------------+
|try_avg(col)                |
+----------------------------+
|INTERVAL '0-8' YEAR TO MONTH|
+----------------------------+
