HloModule xla_computation_loss_fn, entry_computation_layout={(f32[], f32[], f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], /*index=5*/f32[], f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], f32[], /*index=10*/f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], f32[], f32[64,16]{1,0}, /*index=15*/f32[64,16]{1,0}, f32[64]{0}, f32[64]{0}, f32[64,64]{1,0}, f32[64,64]{1,0}, /*index=20*/f32[], f32[], f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], /*index=25*/f32[], f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], f32[], /*index=30*/f32[64,16]{1,0}, f32[64,16]{1,0}, f32[], f32[], f32[64,16]{1,0}, /*index=35*/f32[64,16]{1,0}, f32[64]{0}, f32[64]{0}, f32[64,64]{1,0}, f32[64,64]{1,0}, /*index=40*/f32[64,1]{1,0}, f32[128,64]{1,0}, f32[16,64]{1,0}, s32[541,3]{1,0}, s32[541]{0})->(f32[])}

region_0.96 {
  Arg_0.97 = f32[] parameter(0)
  Arg_1.98 = f32[] parameter(1)
  ROOT add.99 = f32[] add(Arg_0.97, Arg_1.98)
}

relu.108 {
  Arg_0.109 = f32[541,3,64]{2,1,0} parameter(0)
  constant.110 = f32[] constant(0)
  broadcast.111 = f32[541,3,64]{2,1,0} broadcast(constant.110), dimensions={}
  ROOT maximum.112 = f32[541,3,64]{2,1,0} maximum(Arg_0.109, broadcast.111)
}

region_1.142 {
  Arg_0.143 = f32[] parameter(0)
  Arg_1.144 = f32[] parameter(1)
  ROOT add.145 = f32[] add(Arg_0.143, Arg_1.144)
}

relu.154 {
  Arg_0.155 = f32[541,3,64]{2,1,0} parameter(0)
  constant.156 = f32[] constant(0)
  broadcast.157 = f32[541,3,64]{2,1,0} broadcast(constant.156), dimensions={}
  ROOT maximum.158 = f32[541,3,64]{2,1,0} maximum(Arg_0.155, broadcast.157)
}

region_2.180 {
  Arg_0.181 = f32[] parameter(0)
  Arg_1.182 = f32[] parameter(1)
  ROOT add.183 = f32[] add(Arg_0.181, Arg_1.182)
}

ENTRY main.187 {
  Arg_1.2 = f32[] parameter(1)
  Arg_2.3 = f32[64,16]{1,0} parameter(2)
  Arg_3.4 = f32[64,16]{1,0} parameter(3)
  Arg_5.6 = f32[] parameter(5)
  Arg_6.7 = f32[64,16]{1,0} parameter(6)
  Arg_7.8 = f32[64,16]{1,0} parameter(7)
  Arg_9.10 = f32[] parameter(9)
  Arg_10.11 = f32[64,16]{1,0} parameter(10)
  Arg_11.12 = f32[64,16]{1,0} parameter(11)
  Arg_13.14 = f32[] parameter(13)
  Arg_14.15 = f32[64,16]{1,0} parameter(14)
  Arg_15.16 = f32[64,16]{1,0} parameter(15)
  Arg_21.22 = f32[] parameter(21)
  Arg_22.23 = f32[64,16]{1,0} parameter(22)
  Arg_23.24 = f32[64,16]{1,0} parameter(23)
  Arg_25.26 = f32[] parameter(25)
  Arg_26.27 = f32[64,16]{1,0} parameter(26)
  Arg_27.28 = f32[64,16]{1,0} parameter(27)
  Arg_29.30 = f32[] parameter(29)
  Arg_30.31 = f32[64,16]{1,0} parameter(30)
  Arg_31.32 = f32[64,16]{1,0} parameter(31)
  Arg_33.34 = f32[] parameter(33)
  Arg_34.35 = f32[64,16]{1,0} parameter(34)
  Arg_35.36 = f32[64,16]{1,0} parameter(35)
  Arg_20.21 = f32[] parameter(20)
  broadcast.121 = f32[3,3]{1,0} broadcast(Arg_20.21), dimensions={}
  constant.58 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } })
  multiply.122 = f32[3,3]{1,0} multiply(broadcast.121, constant.58)
  Arg_0.1 = f32[] parameter(0)
  broadcast.75 = f32[3,3]{1,0} broadcast(Arg_0.1), dimensions={}
  multiply.76 = f32[3,3]{1,0} multiply(broadcast.75, constant.58)
  Arg_42.43 = f32[16,64]{1,0} parameter(42)
  Arg_43.44 = s32[541,3]{1,0} parameter(43)
  constant.54 = s32[] constant(0)
  broadcast.55 = s32[541,3]{1,0} broadcast(constant.54), dimensions={}
  compare.59 = pred[541,3]{1,0} compare(Arg_43.44, broadcast.55), direction=LT
  constant.52 = s32[] constant(16)
  broadcast.53 = s32[541,3]{1,0} broadcast(constant.52), dimensions={}
  add.60 = s32[541,3]{1,0} add(Arg_43.44, broadcast.53)
  select.61 = s32[541,3]{1,0} select(compare.59, add.60, Arg_43.44)
  reshape.62 = s32[541,3,1]{2,1,0} reshape(select.61)
  gather.63 = f32[541,3,64]{2,1,0} gather(Arg_42.43, reshape.62), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,64}
  Arg_41.42 = f32[128,64]{1,0} parameter(41)
  iota.64 = s32[3]{0} iota(), iota_dimension=0
  constant.50 = s32[] constant(0)
  broadcast.51 = s32[3]{0} broadcast(constant.50), dimensions={}
  compare.65 = pred[3]{0} compare(iota.64, broadcast.51), direction=LT
  constant.48 = s32[] constant(128)
  broadcast.49 = s32[3]{0} broadcast(constant.48), dimensions={}
  add.66 = s32[3]{0} add(iota.64, broadcast.49)
  select.67 = s32[3]{0} select(compare.65, add.66, iota.64)
  reshape.68 = s32[3,1]{1,0} reshape(select.67)
  gather.69 = f32[3,64]{1,0} gather(Arg_41.42, reshape.68), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,64}
  reshape.70 = f32[1,3,64]{2,1,0} reshape(gather.69)
  broadcast.71 = f32[1,3,64]{2,1,0} broadcast(reshape.70), dimensions={0,1,2}
  reshape.72 = f32[3,64]{1,0} reshape(broadcast.71)
  broadcast.73 = f32[541,3,64]{2,1,0} broadcast(reshape.72), dimensions={1,2}
  add.74 = f32[541,3,64]{2,1,0} add(gather.63, broadcast.73)
  dot.77 = f32[3,541,64]{2,1,0} dot(multiply.76, add.74), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.78 = f32[541,3,64]{2,0,1} transpose(dot.77), dimensions={1,0,2}
  reshape.91 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.78)
  Arg_4.5 = f32[] parameter(4)
  broadcast.79 = f32[3,3]{1,0} broadcast(Arg_4.5), dimensions={}
  multiply.80 = f32[3,3]{1,0} multiply(broadcast.79, constant.58)
  dot.81 = f32[3,541,64]{2,1,0} dot(multiply.80, add.74), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.82 = f32[541,3,64]{2,0,1} transpose(dot.81), dimensions={1,0,2}
  reshape.92 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.82)
  Arg_8.9 = f32[] parameter(8)
  broadcast.83 = f32[3,3]{1,0} broadcast(Arg_8.9), dimensions={}
  multiply.84 = f32[3,3]{1,0} multiply(broadcast.83, constant.58)
  dot.85 = f32[3,541,64]{2,1,0} dot(multiply.84, add.74), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.86 = f32[541,3,64]{2,0,1} transpose(dot.85), dimensions={1,0,2}
  reshape.93 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.86)
  Arg_12.13 = f32[] parameter(12)
  broadcast.87 = f32[3,3]{1,0} broadcast(Arg_12.13), dimensions={}
  multiply.88 = f32[3,3]{1,0} multiply(broadcast.87, constant.58)
  dot.89 = f32[3,541,64]{2,1,0} dot(multiply.88, add.74), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.90 = f32[541,3,64]{2,0,1} transpose(dot.89), dimensions={1,0,2}
  reshape.94 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.90)
  concatenate.95 = f32[541,3,64,4]{3,2,1,0} concatenate(reshape.91, reshape.92, reshape.93, reshape.94), dimensions={3}
  constant.57 = f32[] constant(0)
  reduce.100 = f32[541,3,64]{2,1,0} reduce(concatenate.95, constant.57), dimensions={3}, to_apply=region_0.96
  constant.46 = f32[] constant(4)
  broadcast.47 = f32[541,3,64]{2,1,0} broadcast(constant.46), dimensions={}
  divide.101 = f32[541,3,64]{2,1,0} divide(reduce.100, broadcast.47)
  Arg_18.19 = f32[64,64]{1,0} parameter(18)
  dot.102 = f32[541,3,64]{2,1,0} dot(add.74, Arg_18.19), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  Arg_16.17 = f32[64]{0} parameter(16)
  reshape.103 = f32[1,1,64]{2,1,0} reshape(Arg_16.17)
  broadcast.104 = f32[1,1,64]{2,1,0} broadcast(reshape.103), dimensions={0,1,2}
  reshape.105 = f32[64]{0} reshape(broadcast.104)
  broadcast.106 = f32[541,3,64]{2,1,0} broadcast(reshape.105), dimensions={2}
  add.107 = f32[541,3,64]{2,1,0} add(dot.102, broadcast.106)
  call.113 = f32[541,3,64]{2,1,0} call(add.107), to_apply=relu.108
  Arg_19.20 = f32[64,64]{1,0} parameter(19)
  dot.114 = f32[541,3,64]{2,1,0} dot(call.113, Arg_19.20), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  Arg_17.18 = f32[64]{0} parameter(17)
  reshape.115 = f32[1,1,64]{2,1,0} reshape(Arg_17.18)
  broadcast.116 = f32[1,1,64]{2,1,0} broadcast(reshape.115), dimensions={0,1,2}
  reshape.117 = f32[64]{0} reshape(broadcast.116)
  broadcast.118 = f32[541,3,64]{2,1,0} broadcast(reshape.117), dimensions={2}
  add.119 = f32[541,3,64]{2,1,0} add(dot.114, broadcast.118)
  add.120 = f32[541,3,64]{2,1,0} add(divide.101, add.119)
  dot.123 = f32[3,541,64]{2,1,0} dot(multiply.122, add.120), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.124 = f32[541,3,64]{2,0,1} transpose(dot.123), dimensions={1,0,2}
  reshape.137 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.124)
  Arg_24.25 = f32[] parameter(24)
  broadcast.125 = f32[3,3]{1,0} broadcast(Arg_24.25), dimensions={}
  multiply.126 = f32[3,3]{1,0} multiply(broadcast.125, constant.58)
  dot.127 = f32[3,541,64]{2,1,0} dot(multiply.126, add.120), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.128 = f32[541,3,64]{2,0,1} transpose(dot.127), dimensions={1,0,2}
  reshape.138 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.128)
  Arg_28.29 = f32[] parameter(28)
  broadcast.129 = f32[3,3]{1,0} broadcast(Arg_28.29), dimensions={}
  multiply.130 = f32[3,3]{1,0} multiply(broadcast.129, constant.58)
  dot.131 = f32[3,541,64]{2,1,0} dot(multiply.130, add.120), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.132 = f32[541,3,64]{2,0,1} transpose(dot.131), dimensions={1,0,2}
  reshape.139 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.132)
  Arg_32.33 = f32[] parameter(32)
  broadcast.133 = f32[3,3]{1,0} broadcast(Arg_32.33), dimensions={}
  multiply.134 = f32[3,3]{1,0} multiply(broadcast.133, constant.58)
  dot.135 = f32[3,541,64]{2,1,0} dot(multiply.134, add.120), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  transpose.136 = f32[541,3,64]{2,0,1} transpose(dot.135), dimensions={1,0,2}
  reshape.140 = f32[541,3,64,1]{3,2,1,0} reshape(transpose.136)
  concatenate.141 = f32[541,3,64,4]{3,2,1,0} concatenate(reshape.137, reshape.138, reshape.139, reshape.140), dimensions={3}
  reduce.146 = f32[541,3,64]{2,1,0} reduce(concatenate.141, constant.57), dimensions={3}, to_apply=region_1.142
  divide.147 = f32[541,3,64]{2,1,0} divide(reduce.146, broadcast.47)
  Arg_38.39 = f32[64,64]{1,0} parameter(38)
  dot.148 = f32[541,3,64]{2,1,0} dot(add.120, Arg_38.39), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  Arg_36.37 = f32[64]{0} parameter(36)
  reshape.149 = f32[1,1,64]{2,1,0} reshape(Arg_36.37)
  broadcast.150 = f32[1,1,64]{2,1,0} broadcast(reshape.149), dimensions={0,1,2}
  reshape.151 = f32[64]{0} reshape(broadcast.150)
  broadcast.152 = f32[541,3,64]{2,1,0} broadcast(reshape.151), dimensions={2}
  add.153 = f32[541,3,64]{2,1,0} add(dot.148, broadcast.152)
  call.159 = f32[541,3,64]{2,1,0} call(add.153), to_apply=relu.154
  Arg_39.40 = f32[64,64]{1,0} parameter(39)
  dot.160 = f32[541,3,64]{2,1,0} dot(call.159, Arg_39.40), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  Arg_37.38 = f32[64]{0} parameter(37)
  reshape.161 = f32[1,1,64]{2,1,0} reshape(Arg_37.38)
  broadcast.162 = f32[1,1,64]{2,1,0} broadcast(reshape.161), dimensions={0,1,2}
  reshape.163 = f32[64]{0} reshape(broadcast.162)
  broadcast.164 = f32[541,3,64]{2,1,0} broadcast(reshape.163), dimensions={2}
  add.165 = f32[541,3,64]{2,1,0} add(dot.160, broadcast.164)
  add.166 = f32[541,3,64]{2,1,0} add(divide.147, add.165)
  Arg_40.41 = f32[64,1]{1,0} parameter(40)
  dot.167 = f32[541,3,1]{2,1,0} dot(add.166, Arg_40.41), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  slice.168 = f32[541,1,1]{2,1,0} slice(dot.167), slice={[0:541], [2:3], [0:1]}
  reshape.169 = f32[541,1]{1,0} reshape(slice.168)
  broadcast.172 = f32[541,1]{1,0} broadcast(reshape.169), dimensions={0,1}
  reshape.173 = f32[541]{0} reshape(broadcast.172)
  broadcast.174 = f32[541,541]{1,0} broadcast(reshape.173), dimensions={0}
  Arg_44.45 = s32[541]{0} parameter(44)
  convert.170 = f32[541]{0} convert(Arg_44.45)
  reshape.171 = f32[1,541]{1,0} reshape(convert.170)
  broadcast.175 = f32[1,541]{1,0} broadcast(reshape.171), dimensions={0,1}
  reshape.176 = f32[541]{0} reshape(broadcast.175)
  broadcast.177 = f32[541,541]{1,0} broadcast(reshape.176), dimensions={1}
  subtract.178 = f32[541,541]{1,0} subtract(broadcast.174, broadcast.177)
  multiply.179 = f32[541,541]{1,0} multiply(subtract.178, subtract.178)
  reduce.184 = f32[] reduce(multiply.179, constant.57), dimensions={0,1}, to_apply=region_2.180
  constant.56 = f32[] constant(292681)
  divide.185 = f32[] divide(reduce.184, constant.56)
  ROOT tuple.186 = (f32[]) tuple(divide.185)
}

