# Int
>>> from pyspark.sql import functions as sf
>>> from pyspark.sql.types import LongType
>>> df = spark.createDataFrame([(1,), (2,), (3,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
|    6|
+-----+

>>> df = spark.createDataFrame([(None,), (2,), (None,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
|    2|
+-----+

>>> LONG_MAX = pow(2,63) -1
>>> df = spark.createDataFrame([(LONG_MAX,), (1,)], ["x"]).withColumn("x", sf.col("x").cast(LongType()))
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
| NULL|
+-----+

# Float
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([(1.5,), (2.5,), (3.0,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
|  7.0|
+-----+

>>> df = spark.createDataFrame([(1e308,), (1e308,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+--------+
|   sum_x|
+--------+
|Infinity|
+--------+

>>> df = spark.createDataFrame([(float('nan'),), (1.0,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
|  NaN|
+-----+

>>> df = spark.createDataFrame([(float('inf'),), (1.0,)], ["x"])
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+--------+
|   sum_x|
+--------+
|Infinity|
+--------+

# Decimal
>>> from decimal import Decimal
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([(Decimal("1.23"),), (Decimal("4.77"),)], "x DECIMAL(10,2)")
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
| 6.00|
+-----+
>>> df.select(sf.try_sum("x").alias("sum_x")).printSchema()
root
 |-- sum_x: decimal(20,2) (nullable = true)

>>> df = spark.createDataFrame([(Decimal("1.00"),), (None,), (Decimal("2.50"),)], "x DECIMAL(10,2)")
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
| 3.50|
+-----+

>>> df.select(sf.try_sum("x").alias("sum_x")).printSchema()
root
 |-- sum_x: decimal(20,2) (nullable = true)

>>> from pyspark.sql import functions as sf
>>> from decimal import Decimal
>>> df = spark.createDataFrame([(Decimal("90000"),), (Decimal("20000"),)], "x DECIMAL(5,0)")
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+------+
| sum_x|
+------+
|110000|
+------+

>>> df.select(sf.try_sum("x").alias("sum_x")).printSchema()
root
 |-- sum_x: decimal(15,0) (nullable = true)

# like test builtin
# spark.conf.set("spark.sql.ansi.enabled", "true")
>>> from pyspark.sql import functions as sf
>>> from decimal import Decimal
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
>>> spark.conf.set("spark.sql.ansi.enabled", "true")
>>> df = spark.createDataFrame([(Decimal("1" * 38),)] * 10, "x DECIMAL(38,0)")
>>> df.select(sf.try_sum("x").alias("sum_x")).show()
+-----+
|sum_x|
+-----+
| NULL|
+-----+

>>> df.select(sf.try_sum("x").alias("sum_x")).printSchema()
root
 |-- sum_x: decimal(38,0) (nullable = true)

# GroupBy
>>> from pyspark.sql import functions as sf, types as T
>>> LONG_MAX = 2**63 - 1
>>> data = [("bad",  LONG_MAX),("bad",  1),("ok",   10),("ok",   None),("ok",   5),]
>>> df = spark.createDataFrame(data, schema=T.StructType([T.StructField("g", T.StringType(), True),T.StructField("x", T.LongType(),   True),]))
>>> out = df.groupBy("g").agg(sf.try_sum("x").alias("sum_x")).orderBy("g")
>>> out.show()
+---+-----+
|  g|sum_x|
+---+-----+
|bad| NULL|
| ok|   15|
+---+-----+

>>> out.printSchema()
root
 |-- g: string (nullable = true)
 |-- sum_x: long (nullable = true)
