##########################
import numpy as np
import scipy.ndimage as ndimage
import sys
import astropy.io.fits as pyfits
import matplotlib
matplotlib.use('Agg')
from matplotlib import colors
from astropy import units as u
from astropy.coordinates import SkyCoord
from matplotlib.pyplot import MultipleLocator
import matplotlib.pyplot as plt
import astropy.wcs as wcs
from matplotlib.gridspec import GridSpec
import yaml
class MyShape(object):
    import astropy.io.fits as pyfits
    import astropy.wcs as wcs
    @staticmethod
    def getDistance(lon1, lat1, lon2, lat2):  # {{{
        """ 
        Calculate the great circle distance between two points  
        on the earth (specified in decimal degrees) 
        """
        lon1 = np.radians(lon1)
        lat1 = np.radians(lat1)
        lon2 = np.radians(lon2)
        lat2 = np.radians(lat2)
        dlon = lon2 - lon1   
        dlat = lat2 - lat1   
        a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2  
        c = 2 * np.arcsin(np.sqrt(a))   
        k = np.degrees(c)
        return k # }}}
    def __init__(self,fitsfile):
        self.residHDU = pyfits.open(fitsfile)#{{{
        self.wcsobj = wcs.WCS(fitsfile)
        #### check the coor type of fits file.
        if self.residHDU[0].header['CTYPE1'].find('RA')!=-1:
            self.wcsobjcoortype="J2000"
        elif residHDU[0].header['CTYPE1'].find('GLON')!=-1:
            self.wcsobjcoortype='galactic'
        else:
            print('please Check the map, the coordinate system is uncertanty.')
            exit()# 
        self.ts_bin = np.abs(residHDU[0].header['CDELT1'])
        self.tsy_bin = np.abs(residHDU[0].header['CDELT2'])
        self.rectangle = {"para":[0,0,0,0],"coor_type":"J2000"}
        self.circle = {"para":[0,0,0],"coor_type":"J2000"}
        self.point = {"para":[0,0],"coor_type":"J2000"}
        self.ellipse = {"para":[0,0,0,0,0],"coor_type":"J2000"}
        self.rectangle2 = {"para":[0,0,0,0,0],"coor_type":"J2000"}
        self.arrow = {"para":[0,0,0,0],"coor_type":"J2000"}
        self.counter = {"file_path":fitsfile,"levels":["location","95"],"smoothness":False}
        self.is_label_plot = 0
        self.coorsys = "J2000_to_J2000"
        self.plotsetting = {}#}}}
    def get_coorsys(self,a,b): 
        self.coorsys = '%s_to_%s'%(a,b)
    def init_plotsetting(self):
        self.plotsetting = {"islabel":False,"color":"c",\
                "marker":"o","marker_size":30,"label_name":"?",\
                "rotation":0, "zorder":0,"line_width":1,"line_style":"solid",\
                "marker_full":True,"label":[],"label_color":"c",\
                "label_size":30,"label_font":"bold","label_line":False}
    def set_plot_para(self,plot_dict):
        self.init_plotsetting()
        self.plotsetting.update(plot_dict)
        if "label_color" not in plot_dict:
            self.plotsetting["label_color"]=self.plotsetting["color"]
    def set_shape_para(self,shape_type,shape_dict):
        if shape_type =="rectangle":# {{{
            self.rectangle.update(shape_dict)
        elif shape_type == "circle":
            self.circle.update(shape_dict)
        elif shape_type == "point":
            self.point.update(shape_dict)
        elif shape_type == "ellipse":
            self.ellipse.update(shape_dict)
        elif shape_type == "rectangle2":
            self.rectangle2.update(shape_dict)
        elif shape_type == "counter":
            self.counter.update(shape_dict)
        elif shape_type == "arrow":
            self.arrow.update(shape_dict)
        else:
            print("The shape %s is not defined!"%shape_type)
        if shape_type == "counter":
            #### check the coor type of fits file.
            temp_fits_hdu = pyfits.open(self.counter["file_path"])
            if temp_fits_hdu[0].header['CTYPE1'].find('RA')!=-1:
                shape_dict["coor_type"]="J2000"
            elif temp_fits_hdu[0].header['CTYPE1'].find('GLON')!=-1:
                shape_dict["coor_type"]="galactic"
            else:
                print('please Check the map, the coordinate system is uncertanty.')
                exit()# 
        
        self.get_coorsys(shape_dict["coor_type"],self.wcsobjcoortype)# }}}

    def plot_shape(self, shape_type):
        if shape_type =="rectangle":#{{{
            self.plot_rectangle()
        elif shape_type == "circle":
            self.plot_circle()
        elif shape_type == "point":
            self.plot_point()
        elif shape_type == "ellipse":
            self.plot_ellipse()
        elif shape_type == "rectangle2":
            self.plot_rectangle2()
        elif shape_type == "counter":
            self.plot_counter()
        elif shape_type == "arrow":
            self.plot_arrow()
        else:
            print("The shape %s is not defined!"%shape_type)#}}}
    def plot_point(self): 
        coor_x,coor_y = self.point["para"]#{{{
        if self.coorsys=='galactic_to_J2000': 
            glon,glat = coor_x,coor_y
            l_b = SkyCoord(glon*u.deg,glat*u.deg,frame='galactic')
            ra,dec = l_b.fk5.ra.value,l_b.fk5.dec.value
            coor_x,coor_y= ra,dec
        elif self.coorsys=='J2000_to_galactic': 
            ra,dec = coor_x,coor_y
            ra_dec = SkyCoord(coor_x*u.deg,coor_y*u.deg,frame='fk5')
            glon,glat = ra_dec.galactic.l.value,ra_dec.galactic.b.value
            coor_x,coor_y = glon,glat 
        else:
            pass
        #### plot position
        coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
        marker_size = self.plotsetting['marker_size']*20
        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            if self.plotsetting['marker_full']:
                plt.scatter(coor_x_pixel-0.5,coor_y_pixel-0.5,color=self.plotsetting['color']
                        ,marker=self.plotsetting['marker'],s=marker_size
                        ,label=self.plotsetting['label_name'],zorder=self.plotsetting['zorder'])
            else:
                plt.scatter(coor_x_pixel-0.5,coor_y_pixel-0.5,facecolors='none',edgecolors=self.plotsetting['color']
                        ,marker=self.plotsetting['marker'],s=marker_size,linewidth=0.8*self.plotsetting['line_width']
                        ,label=self.plotsetting['label_name'],zorder=self.plotsetting['zorder'])
        else:
            if self.plotsetting['marker_full']:
                plt.scatter(coor_x_pixel-0.5,coor_y_pixel-0.5,color=self.plotsetting['color']
                    ,marker=self.plotsetting['marker'],s=marker_size
                    ,zorder=self.plotsetting['zorder'])
            else:
                plt.scatter(coor_x_pixel-0.5,coor_y_pixel-0.5,facecolors='none',edgecolors=self.plotsetting['color']
                    ,marker=self.plotsetting['marker'],s=marker_size,linewidth=0.8*self.plotsetting['line_width']
                    ,zorder=self.plotsetting['zorder'])
            if len(self.plotsetting['label'])==2:
                label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                        ,self.plotsetting['label'][1],1)
                if self.plotsetting["rotation"]>0:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='bottom')
                else:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='top')
            if self.plotsetting["label_line"]:
                coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
                plt.plot([coor_x_pixel,label_x_pixel],[coor_y_pixel,label_y_pixel],color=self.plotsetting['label_color'],\
                        linewidth=0.8*self.plotsetting['line_width'],ls="-")#}}}

    def plot_ellipse(self): 
        ### plot the circle #{{{
        coor_x,coor_y = self.ellipse["para"][0],self.ellipse["para"][1]
        a_deg,b_deg,alpha_deg = self.ellipse["para"][2],self.ellipse["para"][3],self.ellipse["para"][4]
        myHDU = pyfits.PrimaryHDU()
        myHDU.header['BITPIX'] = -32
        myHDU.header['NAXIS'] = 2
        myHDU.header['NAXIS1']=21
        myHDU.header['NAXIS2']=21
        if self.ellipse["coor_type"] =='J2000':
            myHDU.header['CTYPE1']=('RA---ARC','Coordinate Type')
            myHDU.header['CTYPE2']=('DEC--ARC','Coordinate Type')
        else:
            myHDU.header['CTYPE1']=('GLON-ARC','Coordinate Type')
            myHDU.header['CTYPE2']=('GLAT-ARC','Coordinate Type')
        myHDU.header['CRVAL1']=coor_x
        myHDU.header['CRPIX1']=11
        myHDU.header['CDELT1']=-self.ts_bin
        myHDU.header['CRVAL2']= coor_y
        myHDU.header['CRPIX2']= 11
        myHDU.header['CDELT2']= self.ts_bin
        coor_obj = wcs.WCS(myHDU.header)
        a = a_deg/ts_bin
        b = b_deg/ts_bin
        alpha = alpha_deg/180.*np.pi
        coor_x_pixel,coor_y_pixel = coor_obj.wcs_world2pix(coor_x,coor_y,1)
        theta = np.arange(0, 2*np.pi, 0.02)  
        r_x = (coor_x_pixel)+(a*np.cos(theta)*np.cos(alpha)-b*np.sin(theta)*np.sin(alpha))
        r_y = (coor_y_pixel)-(b*np.sin(theta)*np.cos(alpha)+a*np.cos(theta)*np.sin(alpha))
        coor_x_set,coor_y_set = coor_obj.wcs_pix2world(r_x,r_y,1)

        if self.coorsys =='J2000_to_J2000' or self.coorsys=='galactic_to_galactic':
            coor_x_pixel_set,coor_y_pixel_set = self.wcsobj.wcs_world2pix(coor_x_set,coor_y_set,1)
            r_x,r_y = coor_x_pixel_set-0.5,coor_y_pixel_set-0.5   
        elif self.coorsys == 'J2000_to_galactic' :
            ra_dec_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='fk5')
            l_b_set=ra_dec_set.galactic 
            l_set,b_set = l_b_set.l.value,l_b_set.b.value
            l_pixel_set,b_pixel_set = self.wcsobj.wcs_world2pix(l_set,b_set,1)
            r_x,r_y = l_pixel_set-0.5,b_pixel_set-0.5
        elif self.coorsys == 'galactic_to_J2000' :
            l_b_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='galactic')
            ra_dec_set=l_b_set.fk5 
            ra_set,dec_set = ra_dec_set.ra.value,ra_dec_set.dec.value
            ra_pixel_set,dec_pixel_set = self.wcsobj.wcs_world2pix(ra_set,dec_set,1)
            r_x,r_y = ra_pixel_set-0.5,dec_pixel_set-0.5
        else:
            print('Please check the coordinate transfer parameter "coorsys" ')
        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width']
                    ,label=self.plotsetting['label_name'])
        else:
            if len(self.plotsetting['label'])==2:
                label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                        ,self.plotsetting['label'][1],1)
                if self.plotsetting["rotation"]>0:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='bottom')
                else:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='top')
            if self.plotsetting["label_line"]:
                coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
                plt.plot([coor_x_pixel,label_x_pixel],[coor_y_pixel,label_y_pixel],color=self.plotsetting['label_color'],\
                        linewidth=0.8*self.plotsetting['line_width'],ls="-")
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width'])#}}}

    def plot_rectangle2(self): 
        ### plot the circle #{{{
        coor_x,coor_y = self.rectangle2["para"][0],self.rectangle2["para"][1]
        a_deg,b_deg,alpha_deg = self.rectangle2["para"][2],self.rectangle2["para"][3],self.rectangle2["para"][4]
        myHDU = pyfits.PrimaryHDU()
        myHDU.header['BITPIX'] = -32
        myHDU.header['NAXIS'] = 2
        myHDU.header['NAXIS1']=21
        myHDU.header['NAXIS2']=21
        if self.rectangle2["coor_type"] =='J2000':
            myHDU.header['CTYPE1']=('RA---AIT','Coordinate Type')
            myHDU.header['CTYPE2']=('DEC--AIT','Coordinate Type')
        else:
            myHDU.header['CTYPE1']=('GLON-AIT','Coordinate Type')
            myHDU.header['CTYPE2']=('GLAT-AIT','Coordinate Type')
        myHDU.header['CRVAL1']=coor_x
        myHDU.header['CRPIX1']=11
        myHDU.header['CDELT1']=-self.ts_bin
        myHDU.header['CRVAL2']= coor_y
        myHDU.header['CRPIX2']= 11
        myHDU.header['CDELT2']= self.ts_bin
        coor_obj = wcs.WCS(myHDU.header)
        a = a_deg/ts_bin
        b = b_deg/ts_bin
        alpha = alpha_deg/180.*np.pi
        coor_x_pixel,coor_y_pixel = coor_obj.wcs_world2pix(coor_x,coor_y,1)
        theta = np.array([-1,1,1,-1,-1])  
        theta2 = np.array([-1,-1,1,1,-1])  
        r_x = (coor_x_pixel)+(a*theta*np.cos(alpha)-b*theta2*np.sin(alpha))
        r_y = (coor_y_pixel)-(b*theta2*np.cos(alpha)+a*theta*np.sin(alpha))
        coor_x_set,coor_y_set = coor_obj.wcs_pix2world(r_x,r_y,1)

        if self.coorsys =='J2000_to_J2000' or self.coorsys=='galactic_to_galactic':
            coor_x_pixel_set,coor_y_pixel_set = self.wcsobj.wcs_world2pix(coor_x_set,coor_y_set,1)
            r_x,r_y = coor_x_pixel_set-0.5,coor_y_pixel_set-0.5   
        elif self.coorsys == 'J2000_to_galactic' :
            ra_dec_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='fk5')
            l_b_set=ra_dec_set.galactic 
            l_set,b_set = l_b_set.l.value,l_b_set.b.value
            l_pixel_set,b_pixel_set = self.wcsobj.wcs_world2pix(l_set,b_set,1)
            r_x,r_y = l_pixel_set-0.5,b_pixel_set-0.5
        elif self.coorsys == 'galactic_to_J2000' :
            l_b_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='galactic')
            ra_dec_set=l_b_set.fk5 
            ra_set,dec_set = ra_dec_set.ra.value,ra_dec_set.dec.value
            ra_pixel_set,dec_pixel_set = self.wcsobj.wcs_world2pix(ra_set,dec_set,1)
            r_x,r_y = ra_pixel_set-0.5,dec_pixel_set-0.5
        else:
            print('Please check the coordinate transfer parameter "coorsys" ')
        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width']
                    ,label=self.plotsetting['label_name'])
        else:
            if len(self.plotsetting['label'])==2:
                label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                        ,self.plotsetting['label'][1],1)
                if self.plotsetting["rotation"]>0:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='bottom')
                else:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='top')
                
            if self.plotsetting["label_line"]:
                coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
                plt.plot([coor_x_pixel,label_x_pixel],[coor_y_pixel,label_y_pixel],color=self.plotsetting['label_color'],\
                        linewidth=0.8*self.plotsetting['line_width'],ls="-")
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width'])#}}}

    def plot_counter(self): 
    #### plot contour#{{{
        residHDU_con = pyfits.open(self.counter["file_path"])
        wcsobj_con = wcs.WCS(self.counter["file_path"])
        
        ### Decide coordinate range.
        right_con = residHDU_con[0].header['NAXIS1']
        top_con = residHDU_con[0].header['NAXIS2']
        x_set_con = np.arange(right_con)
        y_set_con = np.arange(top_con)
        X_con_array,Y_con_array = np.meshgrid(x_set_con,y_set_con)
        X_con = X_con_array+1
        Y_con = Y_con_array+1

        if self.coorsys=='J2000_to_galactic':
            ra_set_con,dec_set_con =  wcsobj_con.wcs_pix2world(X_con,Y_con,1)
            ra_dec_set = SkyCoord(ra_set_con*u.deg,dec_set_con*u.deg,frame='fk5')
            l_b_set=ra_dec_set.galactic 
            l_set,b_set = l_b_set.l.value,l_b_set.b.value
            x_set,y_set = self.wcsobj.wcs_world2pix(l_set,b_set,1)
        elif self.coorsys=='galactic_to_J2000':
            l_set_con,b_set_con = wcsobj_con.wcs_pix2world(X_con,Y_con,1)
            l_b_set = SkyCoord(l_set_con*u.deg,b_set_con*u.deg,frame='galactic')
            ra_dec_set = l_b_set.fk5
            ra_set,dec_set = ra_dec_set.ra.value,ra_dec_set.dec.value
            x_set,y_set = self.wcsobj.wcs_world2pix(ra_set,dec_set,1)
        else: 
            coor_x_set,coor_y_set =  wcsobj_con.wcs_pix2world(X_con,Y_con,1)
            x_set,y_set = self.wcsobj.wcs_world2pix(coor_x_set,coor_y_set,1)

        contour_data = residHDU_con[0].data
        if self.counter['smoothness']:
            contour_data= ndimage.gaussian_filter(contour_data,self.counter['smoothkernel'])# Smooth the data
        if self.counter["levels"][0]=='set_level_number':
            min_data =self.counter[1]
            max_data = self.counter[2]
            level_number = int(float(self.counter['levels'][3]))
            if min_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                min_data = np.min(contour_data[sub_temp])
            if max_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = np.max(contour_data[sub_temp])
            max_data = float(max_data)
            min_data = float(min_data)
            if max_data < min_data:
                max_data = 3*min_data
            config_plot_levels =  np.linspace(min_data,max_data,level_number).tolist()
        elif self.counter["levels"][0]=='set_level_time':
            min_data = self.counter['levels'][1]
            max_data = self.counter['levels'][2]
            level_number = int(float(self.counter['levels'][3]))
            if min_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                min_data = np.min(contour_data[sub_temp])
            if max_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = np.max(contour_data[sub_temp])
            max_data = float(max_data)
            min_data = max_data/level_number
            print("min_data is %s"%min_data)
            config_plot_levels =  (min_data*(np.arange(level_number)+1)).tolist()
        elif self.counter['levels'][0]=='set_equal_interval':
            min_data = self.counter['levels'][1]
            max_data = self.counter['levels'][2]
            delta_data = float(self.counter['levels'][3])
            if min_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                min_data = np.min(contour_data[sub_temp])
            if max_data=='default':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = np.max(contour_data[sub_temp])
            max_data = float(max_data)
            min_data = float(min_data)
            if max_data < min_data:
                max_data = 3*min_data
            config_plot_levels =  np.arange(min_data,max_data,delta_data).tolist()
        elif self.counter['levels'][0]=='location':
            if self.counter['levels'][1]=='68':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = float(np.max(contour_data[sub_temp]))
                config_plot_levels = [max_data-2.3,max_data] 
            if self.counter['levels'][1]=='95':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = float(np.max(contour_data[sub_temp]))
                config_plot_levels = [max_data-6,max_data] 
            if self.counter['levels'][1]=='99':
                sub_temp = np.where(~np.isnan(contour_data))
                max_data = float(np.max(contour_data[sub_temp]))
                config_plot_levels = [max_data-9.1,max_data] 
        else:
            config_plot_levels =  self.counter['levels']

        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            contour = plt.contour(x_set-0.5,y_set-0.5,contour_data,colors=self.plotsetting['color']
                ,levels=config_plot_levels,linestyles=self.plotsetting['line_style']
                ,linewidths=self.plotsetting['line_width'])
            plt.plot(0,0,color=self.plotsetting['color'],ls=self.plotsetting['line_style'],
                        label=self.plotsetting['label_name'])
        else:
            contour = plt.contour(x_set-0.5,y_set-0.5,contour_data,colors=self.plotsetting['color']
                ,levels=config_plot_levels,linestyles=self.plotsetting['line_style']
                ,linewidths=self.plotsetting['line_width'])
        #plt.clabel(contour,fontsize=5)
        print("name:%s; contours levels:%s"%(self.plotsetting['label_name'],self.counter['levels']))
        # #}}}

    def plot_rectangle(self):# 
        ### plot the rectangle#{{{
        coor_x,coor_y = self.rectangle["para"][0],self.rectangle["para"][1]
        a_size,b_size = self.rectangle["para"][2],self.rectangle["para"][3]
        myHDU = pyfits.PrimaryHDU()
        myHDU.header['BITPIX'] = -32
        myHDU.header['NAXIS'] = 2
        myHDU.header['NAXIS1']=21
        myHDU.header['NAXIS2']=21
        if self.rectangle["coor_type"] =='J2000':
            myHDU.header['CTYPE1']=('RA---AIT','Coordinate Type')
            myHDU.header['CTYPE2']=('DEC--AIT','Coordinate Type')
        else:
            myHDU.header['CTYPE1']=('GLON-AIT','Coordinate Type')
            myHDU.header['CTYPE2']=('GLAT-AIT','Coordinate Type')
        myHDU.header['CRVAL1']=coor_x
        myHDU.header['CRPIX1']=11
        myHDU.header['CDELT1']=-self.ts_bin
        myHDU.header['CRVAL2']= coor_y 
        myHDU.header['CRPIX2']= 11
        myHDU.header['CDELT2']= self.ts_bin
        coor_obj = wcs.WCS(myHDU.header)
        coor_x_pixel,coor_y_pixel = coor_obj.wcs_world2pix(coor_x,coor_y,1)
        r_x_min = coor_x_pixel - a_size * (1/ts_bin)
        r_x_max = coor_x_pixel + a_size * (1/ts_bin)
        r_y_min = coor_y_pixel - b_size * (1/ts_bin)
        r_y_max = coor_y_pixel + b_size * (1/ts_bin)
        r_x = [r_x_min,r_x_max,r_x_max,r_x_min,r_x_min]
        r_y = [r_y_max,r_y_max,r_y_min,r_y_min,r_y_max]
        coor_x_set,coor_y_set = coor_obj.wcs_pix2world(r_x,r_y,1)

        if self.coorsys =='J2000_to_J2000' or self.coorsys=='galactic_to_galactic':
            coor_x_pixel_set,coor_y_pixel_set = self.wcsobj.wcs_world2pix(coor_x_set,coor_y_set,1)
            r_x = coor_x_pixel_set-0.5 
            r_y = coor_y_pixel_set-0.5 
        elif self.coorsys == 'J2000_to_galactic' :
            ra_dec_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='fk5')
            l_b_set=ra_dec_set.galactic 
            l_set,b_set = l_b_set.l.value,l_b_set.b.value
            l_pixel_set,b_pixel_set = self.wcsobj.wcs_world2pix(l_set,b_set,1)
            r_x,r_y = l_pixel_set-0.5,b_pixel_set-0.5
        elif self.coorsys == 'galactic_to_J2000' :
            l_b_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='galactic')
            ra_dec_set=l_b_set.fk5 
            ra_set,dec_set = ra_dec_set.ra.value,ra_dec_set.dec.value
            ra_pixel_set,dec_pixel_set = self.wcsobj.wcs_world2pix(ra_set,dec_set,1)
            r_x,r_y = ra_pixel_set-0.5,dec_pixel_set-0.5
        else:
            print('Please check the coordinate transfer parameter "coorsys" ')
        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width']
                    ,label=self.plotsetting['label_name'])
        else:
            if len(self.plotsetting['label'])==2:
                label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                        ,self.plotsetting['label'][1],1)
                if self.plotsetting["rotation"]>0:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='bottom')
                else:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='top')
            if self.plotsetting["label_line"]:
                coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
                if self.plotsetting['label'][1]>=coor_y+size:
                    coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y+size,1)
                if self.plotsetting['label'][1]<=coor_y-size:
                    coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y-size,1)
                plt.plot([coor_x_pixel,label_x_pixel],[coor_y_pixel,label_y_pixel],color=self.plotsetting['label_color'],\
                        linewidth=0.8*self.plotsetting['line_width'],ls="-")
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width'])#}}}

    def plot_circle(self):
        ### plot the circle #{{{
        coor_x,coor_y,size =  self.circle["para"][0],self.circle["para"][1],self.circle["para"][2]
        myHDU = pyfits.PrimaryHDU()
        myHDU.header['BITPIX'] = -32
        myHDU.header['NAXIS'] = 2
        myHDU.header['NAXIS1']=21
        myHDU.header['NAXIS2']=21
        if self.circle["coor_type"] =='J2000':
            myHDU.header['CTYPE1']=('RA---ARC','Coordinate Type')
            myHDU.header['CTYPE2']=('DEC--ARC','Coordinate Type')
        else:
            myHDU.header['CTYPE1']=('GLON-ARC','Coordinate Type')
            myHDU.header['CTYPE2']=('GLAT-ARC','Coordinate Type')
        myHDU.header['CRVAL1']=coor_x
        myHDU.header['CRPIX1']=11
        myHDU.header['CDELT1']=-self.ts_bin
        myHDU.header['CRVAL2']= coor_y
        myHDU.header['CRPIX2']= 11
        myHDU.header['CDELT2']= self.ts_bin
        circle_pixel =0.02
        coor_obj = wcs.WCS(myHDU.header)
        coor_x_pixel,coor_y_pixel = coor_obj.wcs_world2pix(coor_x,coor_y,1)
        theta = np.arange(0, 2*np.pi, circle_pixel)  
        r_x = coor_x_pixel + size * (1/ts_bin) * np.cos(theta)
        r_y = coor_y_pixel + size * (1/ts_bin) * np.sin(theta)
        coor_x_set,coor_y_set = coor_obj.wcs_pix2world(r_x,r_y,1)

        if self.coorsys =='J2000_to_J2000' or self.coorsys=='galactic_to_galactic':
            coor_x_pixel_set,coor_y_pixel_set = self.wcsobj.wcs_world2pix(coor_x_set,coor_y_set,1)
            r_x = coor_x_pixel_set-0.5 
            r_y = coor_y_pixel_set-0.5 
        elif self.coorsys == 'J2000_to_galactic' :
            ra_dec_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='fk5')
            l_b_set=ra_dec_set.galactic 
            l_set,b_set = l_b_set.l.value,l_b_set.b.value
            l_pixel_set,b_pixel_set = self.wcsobj.wcs_world2pix(l_set,b_set,1)
            r_x,r_y = l_pixel_set-0.5,b_pixel_set-0.5
        elif self.coorsys == 'galactic_to_J2000' :
            l_b_set = SkyCoord(coor_x_set*u.deg,coor_y_set*u.deg,frame='galactic')
            ra_dec_set=l_b_set.fk5
            ra_set,dec_set = ra_dec_set.ra.value,ra_dec_set.dec.value
            ra_pixel_set,dec_pixel_set = self.wcsobj.wcs_world2pix(ra_set,dec_set,1)
            r_x,r_y = ra_pixel_set-0.5,dec_pixel_set-0.5
        else:
            print('Please check the coordinate transfer parameter "coorsys" ')
        if self.plotsetting['islabel']:
            self.is_label_plot +=1
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width']
                    ,label=self.plotsetting['label_name'])
        else:
            if len(self.plotsetting['label'])==2:
                label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                        ,self.plotsetting['label'][1],1)
                if self.plotsetting["rotation"]>0:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='bottom')
                else:
                    plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                        color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                        rotation=self.plotsetting["rotation"],verticalalignment='top')
            if self.plotsetting["label_line"]:
                coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y,1)
                if self.plotsetting['label'][1]>=coor_y+size:
                    coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y+size,1)
                if self.plotsetting['label'][1]<=coor_y-size:
                    coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(coor_x,coor_y-size,1)
                plt.plot([coor_x_pixel,label_x_pixel],[coor_y_pixel,label_y_pixel],color=self.plotsetting['label_color'],\
                        linewidth=0.8*self.plotsetting['line_width'],ls="-")
            plt.plot(r_x,r_y,color=self.plotsetting["color"],
                    ls=self.plotsetting['line_style'],linewidth=self.plotsetting['line_width'])#}}}

    #### plot arrow 
    def plot_arrow(self):
        coor_start_set=[]#{{{
        coor_end_set=[]
        if self.coorsys =='J2000_to_J2000' or self.coorsys=='galactic_to_galactic':
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(self.arrow["para"][0],
                    self.arrow["para"][1],1)
            coor_start_set.append([coor_x_pixel,coor_y_pixel])
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(self.arrow["para"][2],
                    self.arrow["para"][3],1)
            coor_end_set.append([coor_x_pixel,coor_y_pixel])
        elif coorsys == 'J2000_to_galactic' :
            ra_dec_set_start = SkyCoord(self.arrow["para"][0]*u.deg,self.arrow["para"][1]*u.deg,frame='fk5')
            l_b_set_start=ra_dec_set_start.galactic 
            l_set_start,b_set_start = l_b_set.l.value,l_b_set.b.value
            ra_dec_set_end = SkyCoord(self.arrow["para"][2]*u.deg,self.arrow["para"][3]*u.deg,frame='fk5')
            l_b_set_end=ra_dec_set_end.galactic 
            l_set_end,b_set_end = l_b_set.l.value,l_b_set.b.value
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(l_set_start,b_set_start,1)
            coor_start_set.append([coor_x_pixel,coor_y_pixel])
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(l_set_end,b_set_end,1)
            coor_end_set.append([coor_x_pixel,coor_y_pixel])
        elif coorsys == 'galactic_to_J2000' :
            l_b_set_start = SkyCoord(self.arrow["para"][0]*u.deg,self.arrow["para"][1]*u.deg,frame='galactic')
            ra_dec_set_start=l_b_set_start.fk5 
            ra_set_start,dec_set_start = ra_dec_set_start.ra.value,ra_dec_set_start.dec.value
            l_b_set_end = SkyCoord(self.arrow["para"][2]*u.deg,self.arrow["para"][3]*u.deg,frame='galactic')
            ra_dec_set_end=l_b_set_end.fk5 
            ra_set_end,dec_set_end = ra_dec_set_end.ra.value,ra_dec_set_end.dec.value
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(ra_set_start,dec_set_start,1)
            coor_start_set.append([coor_x_pixel,coor_y_pixel])
            coor_x_pixel,coor_y_pixel = self.wcsobj.wcs_world2pix(ra_set_end,dec_set_end,1)
            coor_end_set.append([coor_x_pixel,coor_y_pixel])
        else:
            print('Please check the coordinate transfer parameter "coorsys" ')

        ax1.annotate("",xy=(coor_end_set[0][0],coor_end_set[0][1]),
               xytext=(coor_start_set[0][0],coor_start_set[0][1]),
               size=self.plotsetting["line_width"],va="center", ha="center",
               arrowprops=dict(color=self.plotsetting["color"],arrowstyle="simple"))
        if len(self.plotsetting['label'])==2:
            label_x_pixel,label_y_pixel = self.wcsobj.wcs_world2pix(self.plotsetting['label'][0]\
                    ,self.plotsetting['label'][1],1)
            if self.plotsetting["rotation"]>0:
                plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                    color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                    rotation=self.plotsetting["rotation"],verticalalignment='bottom')
            else:
                plt.text(label_x_pixel,label_y_pixel,self.plotsetting['label_name'],
                    color=self.plotsetting['label_color'],fontsize=self.plotsetting['label_size'],fontweight=self.plotsetting['label_font'],
                    rotation=self.plotsetting["rotation"],verticalalignment='top')
                    #}}}



yaml_file = sys.argv[1]
config = yaml.safe_load(open(yaml_file))
residHDU = pyfits.open(config['Color_map']['map_path'])
wcsobj = wcs.WCS(config['Color_map']['map_path'])

#### check the coor type of fits file.{{{
if residHDU[0].header['CTYPE1'].find('RA')!=-1:
    config['Color_map']['coor_type']='J2000'
elif residHDU[0].header['CTYPE1'].find('GLON')!=-1:
    config['Color_map']['coor_type']='galactic'
else:
    print('please Check the map, the coordinate system is uncertanty.')
    exit()# }}}

### Decide coordinate range.
ts_bin = np.abs(residHDU[0].header['CDELT1'])
tsy_bin = np.abs(residHDU[0].header['CDELT2'])
left,right,bottom,top = 0,residHDU[0].header['NAXIS1'],0,residHDU[0].header['NAXIS2']

coor_x_ref,coor_y_ref = wcsobj.wcs_pix2world(0.5*right+0.5,0.5*top+0.5,1)
figure_x2y = float(right*ts_bin*np.cos(np.radians(coor_y_ref)))/float(top*tsy_bin)
if config['Color_map']['coor_type']=='galactic':
    if 'ROI_coor' in config['Output_map']:# {{{
        roi_ra = config['Output_map']['ROI_coor'][0]
        roi_dec = config['Output_map']['ROI_coor'][1]
    else:
        roi_b,roi_l = coor_x_ref,coor_y_ref 
        l_b = SkyCoord(roi_b*u.deg,roi_l*u.deg,frame='galactic')
        roi_ra,roi_dec=l_b.fk5.ra.value,l_b.fk5.dec.value
else:
    if 'ROI_coor' in config['Output_map']:
        roi_ra = config['Output_map']['ROI_coor'][0]
        roi_dec = config['Output_map']['ROI_coor'][1]
    else:
        roi_ra,roi_dec = coor_x_ref,coor_y_ref # }}}
word_size = 30 

### Plot Color map.
data_set = residHDU[0].data
if "color_bar_label" in config['Color_map']:
    if config['Color_map']['color_bar_label']=="$\\sqrt{\\rm TS}$":# {{{
        sub = np.where(data_set<0)
        data_set[sub] = -np.sqrt(-data_set[sub])
        sub = np.where(data_set>0)
        data_set[sub] = np.sqrt(data_set[sub])
    else:
        pass# }}}
if "Mask_map" in config: 
    mask_array = pyfits.open(config['Mask_map']['mask_map_path'])[0].data# {{{
    sub_mask = np.where(mask_array==0)
    if 'mask_value' in  config['Mask_map']:
        data_set[sub_mask] = config['Mask_map']['mask_value']
    else:
        data_set[sub_mask] = np.nan # }}}

gs = GridSpec(3,1,height_ratios=[1,10,9],wspace=0) 
fig = plt.figure(figsize=(figure_x2y*10,18),dpi=100)
resid_plot = pyfits.open(config['Color_map']['map_path'])
resid_plot[0].header["CRPIX1"]= residHDU[0].header["CRPIX1"]+0.5
resid_plot[0].header["CRPIX2"]= residHDU[0].header["CRPIX2"]+0.5
wcsobj_plot = wcs.WCS(resid_plot[0].header)
ax1=fig.add_subplot(gs[1,0],projection=wcsobj_plot)
#plt.xlabel('R.A.',fontsize=word_size)
plt.ylabel('Decl.',fontsize=word_size)

if config['Color_map']['smoothness']:
    sub = np.where(np.isnan(data_set))#{{{
    data_set[sub] = 0
    data_set = ndimage.gaussian_filter(data_set,config['Color_map']['smoothkernel'])# Smooth the data
    data_set[sub] = np.nan#}}}

mycolor = ['black','midnightblue','blue','firebrick','red',"orange",'gold','lightyellow',"whitesmoke"]
cmap_color_def = colors.LinearSegmentedColormap.from_list('my_list',mycolor)
color_bar_set = cmap_color_def


ax1_temp = ax1.imshow(data_set,aspect='auto',cmap=color_bar_set,extent=[left,right,bottom,top])

ax1.text(55,70, config["Output_map"]["title"][0],fontsize=float(config['Output_map']['title'][1]),color="w")
#ax1.text(70,70, config["Output_map"]["title"][0],fontsize=float(config['Output_map']['title'][1]),color="w")


plotshape_obj = MyShape(config['Color_map']['map_path'])
for key,value in config["source_dict"].items():
    plotshape_obj.set_shape_para(value["shape_type"],value["shape_para"])
    plotshape_obj.set_plot_para(value["plot_setting"])
    plotshape_obj.plot_shape(value["shape_type"])

### plot longitue
lon_coor = np.linspace(134,144,30)
lat_coor = -18.5 
l_b = SkyCoord(lon_coor*u.deg,lat_coor*u.deg,frame='galactic')
ra,dec = l_b.fk5.ra.value,l_b.fk5.dec.value
pixel_x,pixel_y = wcsobj.wcs_world2pix(ra,dec,1)
plt.plot(pixel_x,pixel_y,ls="--",color="darkgray",linewidth=3)
plt.text(pixel_x[4],pixel_y[4]+2.5,"b=-18.5$^\circ$",fontsize=20,color="darkgray",rotation=12)



ax1.tick_params(axis='both',which='major',labelsize=word_size)
#ax1.coords[0].set_major_formatter('d')
ax1.coords[0].set_ticklabel_visible(False)
ax1.coords[0].set_ticks_visible(False)

plt.xlim(left,right)
plt.ylim(bottom,top)
if plotshape_obj.is_label_plot != 0 :
    plt.legend(loc=config['Output_map']['label_control']['label_position']
            ,fontsize=config['Output_map']['label_control']['fontsize']
            ,ncol=config['Output_map']['label_control']['label_column'])

cax1 = fig.add_subplot(gs[0,0])
cbar=plt.colorbar(ax1_temp, cax=cax1, pad=0.01,orientation="horizontal",ticklocation="top")
cbar.ax.tick_params(labelsize=word_size)
cbar.set_label(config['Color_map']['color_bar_label'],fontsize=word_size)




ax2=fig.add_subplot(gs[2,0])
data2 = np.genfromtxt(config["Counts_hist"]["file_path"])
#data2 = np.genfromtxt("25_100TeV_hist_data.txt")
xx_mid = data2[:,0]
xx_err = data2[:,1]
yy_bin = data2[:,2]
yy_bin_err = data2[:,3]
plt.errorbar(xx_mid,yy_bin,xerr=xx_err,yerr=yy_bin_err,marker='o',capsize=3,color='k',ls='None')
label_each = ["J0216+4239","J0207+4300","rectangle"]
color_each = ["b","c","orange"]
for i in range(3):
    plt.bar(xx_mid,data2[:,4+i],width=0.4,color=color_each[i],label=label_each[i],alpha=0.5)
yy_bin_total =  data2[:,4]+data2[:,5]+data2[:,6]
plt.step(xx_mid+0.2,yy_bin_total,color="r",label="total")

plt.legend(loc=0,fontsize=config['Output_map']['label_control']['fontsize'],ncol=2)
plt.xlim(config["Counts_hist"]["x_lim"][0],config["Counts_hist"]["x_lim"][1])
#plt.ylim(0,40)
plt.ylim(config["Counts_hist"]["y_lim"][0],config["Counts_hist"]["y_lim"][1])
#plt.ylim(0,140)
x_set = [28,30,32,34,36]
x_set_label = ['28$^\circ$','30$^\circ$','32$^\circ$','34$^\circ$','36$^\circ$']
plt.xticks(x_set,x_set_label)
ax2.tick_params(axis='both',which='major',labelsize=word_size)
#ax2.coords[0].set_major_formatter('d')
plt.gca().invert_xaxis()

plt.ylabel('Counts',fontsize=word_size)
plt.xlabel('R.A.',fontsize=word_size)
plt.tight_layout()
#plt.subplots_adjust(wspace=0.05,hspace=0.005)
plt.savefig(config['Output_map']['name'],bbox_inches='tight')
plt.close()
print('ok')
