#! /usr/bin/env python3

import yoda, random
from math import isclose

n_bins, x_min, x_max = 20, 0., 100.
linspace = yoda.linspace(n_bins, x_min, x_max)
e1 = yoda.BinnedEstimate1D(linspace, path="/bar", title="Linearly spaced Estimate1D")

named_bins = ["a", "b", "c"]
e2 = yoda.BinnedEstimate1D(named_bins, path="/baz", title="Named Estimate1D")

estimates = []
for i in range(e1.numBins()+1):
    val = -(i/2)**2/4
    errs = [-val**2.+1, val**2.]
    estimate = yoda.Estimate()
    estimate.set(val, errs)
    e1.set(i, estimate)
    assert isclose(e1.bin(i).val(), val)
    assert isclose(e1.bin(i).errDown(), errs[0])
    assert isclose(e1.bin(i).errUp(), errs[1])
mask = [1,8,11,13]
e1.maskBins(mask)
assert e1.numBins()==n_bins-len(mask)
e1.rebinXBy(2)
assert e1.numBins()==10

for i,_ in enumerate(e2.bins()):
    val, err = -i, 0.1
    estimate = yoda.Estimate()
    estimate.set(val, err)
    e2.set(i, estimate)
    assert isclose(e2.bin(i).val(), val)
    assert isclose(e2.bin(i).errDown(), err)
    assert isclose(e2.bin(i).errUp(), err)
e2.maskBinAt("b")
assert e2.numBins()==len(named_bins)-1

print(e1.vals())
for i, val in enumerate([-0.3125, -1.5625, -3.8125, -7.0625, -11.3125, -16.5625, -22.8125, -30.0625, -38.3125, -47.5625]):
    assert e1.vals()[i]==val
print(e1.xMin(), "-", e1.xMax())
assert isclose(e1.xMin(), x_min)
assert isclose(e1.xMax(), x_max)

yoda.write([e1,e2], "e1d.yoda")
aos = yoda.read("e1d.yoda")
for _, ao in aos.items():
    print(ao)

yoda.writeFLAT([e1,e2], "e1d.dat")
aos = yoda.read("e1d.dat")
for _, ao in aos.items():
    print(ao)

# Check that the bin scaling is done properly
s1 = e1.mkScatter()
if e1.numBins() != s1.numPoints():
    print(f"FAIL mkScatter() #bin={e1.numBins()} -> #point={s1.numPoints()}")
    exit(11)
if not isclose(e1.vals()[0], s1.point(0).y()):
    print(f"FAIL mkScatter() bin0 value={e1.vals()[0]} -> bin0 value={s1.point(0).y()}")
    exit(12)

# bin(0) is underflow for e1 so have to compare bin(1) to point(0)
eneg, epos = e1.bin(1).totalErr()
sneg, spos = s1.point(0).yErrs()
if not isclose(abs(eneg), sneg) or not isclose(epos, spos):
    print(f"FAIL mkScatter() bin1 err=({eneg},{epos}) -> point0 err=({sneg},{spos})")
    exit(13)

