import numpy as np
import logging
from .spibitbang1 import spibitbang1
from queuetypes import *
from .hwdev import hwdev

def ApplyMask(value,width=8,bitoffset=0,previous=0):
    mask=(1<<width)-1
    if bitoffset>0:
      value<<=bitoffset;
      mask<<=bitoffset;
    return (value & mask) + (previous - (previous & mask));
def UnMask(value,width=8,bitoffset=0):
    mask=(1<<width)-1
    if bitoffset>0:
      value>>=bitoffset;
    value=value&mask;
    return value;

def int2bytes(i):
   b=[];
   while i>255:
        b=[i%256]+b;
        i>>=8;
   return [i]+b;

def GetField(D,name,dev_number,default=None):
    X=D.get(name,default)
    return X[dev_number] if isinstance(X,list) else X;
def Find(L,name,value):
  for x in L:
    if x[name]==value:
        return x;
  return False;

class AttrDict(dict):
  def __init__(self,*args,**kwargs):
    super(AttrDict,self).__init__(*args,**kwargs)
    self.__dict__=self

def DevRegList(D):
  #todo only count the drivers registers!!
  for i,dev in enumerate(D.drivers):
    dev['drv_id']=i;
  devreglist={}
  store=0;
  for dev in D.device_registers:
    N=dev.get('dim',1)
    name=dev['name']    
    for n in range(N):
      addr=GetField(dev,'address',n,0)
      devtype=GetField(dev,'driver',n)
#      print(addr,devtype)
      devid=0;
      if devtype:
        devid=Find(D.drivers,'name',devtype)['drv_id']
        devtype=Find(D.drivers,'name',devtype)['type']
      else: devtype=0;
      if N>1: name2=name+str(n+1)
      else:   name2=name;
      for reg in dev['registers']:
         regR=GetField(reg,'address',0,0)
         regW=GetField(reg,'address',1,0)
         if reg.get('store',False):
            store+=1;
            storex=store
         else:
            storex=0
#         hf.write("const devreg %s {%i,%i,%i,%i,%i};\n" % (name2+'_'+reg['name'],addr,regR,regW,storex,devtype) )
         devregname=name2+'.'+reg['name'];
         devreglist[devregname]=AttrDict({"Addr":addr,"Register_R":regR,"Register_W":regW,"store":storex,"devtype":devtype,"devid":devid});
#         print(devregname,devreglist[devregname]);
#         hf.write("inline const t_devreg %s {.address=%i,.register_R=%i,.register_W=%i,.store=%i,.driver=%i};\n" % (devregname,addr,regR,regW,storex,devtype) )
#hf.write("#define NumberStoreReg %i"%store)
#  print(devreglist)
  return devreglist,store

def GetSteps(V1):
              if isinstance(V1['devreg'],list): Step=len(V1['devreg']) 
              else: Step=1; #V1.nVars
              Step2=V1['dim']*((V1.get('width',8)+7)//8)//Step #int(V1.size/V1.nVars)
              #print(Step,Step2,V1);
              return Step,Step2


class i2c_array(hwdev):
    def __init__(self,config):
        hwdev.__init__(self,config);
#        self.Qout=Qout;
#        self.Qin=Qin;
#        self.I2Ccallback=I2Ccallback
#        self.SWcallback=Switchcallback
#        self.previousHBA=np.zeros([number,3,32],dtype='int')
        pars=config['parameters'];
        self.RCU_Switch1=range(pars[0],pars[1]+1);
        self.N=len(self.RCU_Switch1);
#        self.devregs,RCU_storeReg=DevRegList(yaml)
#        print("Init",config['name'],'len=',len(self.RCU_Switch1),' stored reg=',RCU_storeReg)
#        self.previous   =np.zeros([self.N,RCU_storeReg],dtype='int')

    def load(self):
        print("TODO: load")
        #Inst1=Vars.Instr(Vars.DevType.Instr,Vars.RCU_init,0,[])  #Read the current status of GPIO IOs
        #self.SetVar(Inst1)
#Vars.RCU_mask.OPCW.get_data_value().Value.Value=[1,1,0,0]
#Vars.Ant_mask.OPCW.get_data_value().Value.Value=[1,1,0,0,0,0,0,0,0,0,1,0]

#Inst1=Vars.Instr(Vars.DevType.Instr,Vars.RCU_init,0,[])  #Read the current status of GPIO IOs
#RCU.SetVar(Inst1)

#Inst1=Vars.Instr(Vars.DevType.Var,Vars.RCU_att,12,[0,1,2,3,4,5,6,7,8,9,11])
#RCU.SetVar(Inst1)

#Inst1=Vars.Instr(Vars.DevType.Instr,Vars.RCU_off,0,[])
#RCU.SetVar(Inst1)

#Inst1=Vars.Instr(Vars.DevType.Instr,Vars.RCU_on,0,[])
#RCU.SetVar(Inst1)

#Inst1=Vars.Instr(Vars.DevType.Instr,Vars.ADC1_on,0,[])
#RCU.SetVar(Inst1)

#print(Vars.RCU)
    def OPCUASetVariable(self,varid,var1,data,mask):
       if var1['rw']=='variable': return;
       logging.info(str(("Set Var",var1['name'],data,mask)))
       data,mask2=self.SetGetVarValueMask(var1,data,mask);
#       if len(mask)==len(mask2): mask[:]=mask2[:];
#       elif len(mask)==0: (mask.append(x) for x in mask2);
       Data=OPCUAset(varid,InstType.varSet,data,mask2)
       return [Data]

    def OPCUAReadVariable(self,varid,var1,mask):
      logging.info(str(("Read Var",var1['name'],mask)))
      if len(mask)==0: mask=[True]*self.N; 
      #data=self.GetVarValueAll(var1)
      #else:             
      data,mask2=self.GetVarValueMask(var1,mask);
#      if len(mask)==len(mask2): mask[:]=mask2[:];
#      elif len(mask)==0: (mask.append(x) for x in mask2);
      Data=OPCUAset(varid,InstType.varSet,data,mask2)
      return [Data]

#    def OPCUAcallMethod(self,var1,data,mask):
#       print("Call Method",var1)
    
    def i2csetget(self,*args,**kwargs):
       self.conf['parentcls'].i2csetget(*args,**kwargs)
    
    def SetSwitch(self,RCUi):
        self.conf['parentcls'].SetSW1(self.RCU_Switch1[RCUi]);

    def SetSwitchMask(self,mask):
        m=0;
        for RCUi in range(self.N):
           if mask[RCUi]: m|=1<<self.RCU_Switch1[RCUi];
        self.conf['parentcls'].SetChannel(m);

    def SetGetVarValueMask(self,var1,data,mask):
        Step,Step2=GetSteps(var1);
        value1=[0]*Step*Step2;
        if len(data)==Step:
           data=data*self.N;
        if not(len(data)==Step*Step2):
            print("Check data length!");
            return;
        Step2//=self.N
        if (len(mask)==self.N):
          mask=[m for m in mask for x in range(Step)]
        if not(len(mask)==Step*self.N):
            print("Check mask length!");
            return;
#        if (len(value1)==V1.nVars) and (self.N>1):  value1=(value1*self.N);
#        logging.debug(str(("Step=",Step,"Mask=",mask)))
        i2c=self.conf['parentcls'];
        for RCUi in range(self.N):
            for Vari in range(Step):
                if not(mask[RCUi*Step+Vari]): continue
                i0=(RCUi*Step+    Vari)*Step2
                i1=(RCUi*Step+(Vari+1))*Step2
                devreg=var1['devreg'][Vari];
                width=var1.get('width',8)
                bitoffset=GetField(var1,'bitoffset',Vari,0)
                self.SetSwitch(RCUi);
                self.RCUi=RCUi;
                mask[RCUi*Step+Vari]=self.SetVarValue(devreg,width,bitoffset,data[i0:i1])
                if not(mask[RCUi*Step+Vari]): continue
                value2=value1[i0:i1]
                mask[RCUi*Step+Vari]=self.GetVarValue(devreg,width,bitoffset,value2)
                value1[i0:i1]=value2
        return value1,mask



    def GetVarValueMask(self,var1,mask):
        Step,Step2=GetSteps(var1);
        value1=[0]*Step*Step2;
        Step2//=self.N
        if (len(mask)==self.N):
          mask=[m for m in mask for x in range(Step)]
        if not(len(mask)==Step*self.N):
            print("Check mask length!");
            return;
#        if (len(value1)==V1.nVars) and (self.N>1):  value1=(value1*self.N);
        i2c=self.conf['parentcls'];
#        logging.debug(str(("Step=",Step,"Mask=",mask)))
        for RCUi in range(self.N):
            for Vari in range(Step):
                if not(mask[RCUi*Step+Vari]): continue
                i0=(RCUi*Step+    Vari)*Step2
                i1=(RCUi*Step+(Vari+1))*Step2
                devreg=var1['devreg'][Vari];
                width=var1.get('width',8)
                bitoffset=GetField(var1,'bitoffset',Vari,0)
                self.SetSwitch(RCUi);
                value2=value1[i0:i1]
                self.RCUi=RCUi;
                mask[RCUi*Step+Vari]=self.GetVarValue(devreg,width,bitoffset,value2)
                value1[i0:i1]=value2
        return value1,mask


    def getstorearray(self,devreg):
          storearray=devreg.get('storearray')
          if not(storearray):
                devreg['storearray']=[0]*self.N;
                storearray=devreg.get('storearray');
          return storearray;

    def Setdevreg(self,devreg,value,mask=[]):
#        if devreg.get('store'): logging.debug("Stored")
#        print(devreg['store'])
        if devreg.get('store'):
                storearray=self.getstorearray(devreg);
                for RCUi in range(self.N):
                  if mask[RCUi]:   
                      storearray[RCUi]=value[0]
                      self.RCUi=RCUi;
                logging.debug(str(("Stored values:",self.getstorearray(devreg))))
        self.SetSwitchMask(mask)
        self.SetVarValue(devreg,8,0,value)
        return True;
        

    def SetVarValue(self,devreg,width,bitoffset,value):
            if devreg['register_W']==-1: return True; #We can not set it, only read it e.g. temperature
            logging.debug(str(("RCU1 Set ",self.RCUi,devreg['addr'],value)))
            #self.parentcls.SetChannel(1<<RCU_Switch[RCUi]);
            if devreg['store']:
                storearray=self.getstorearray(devreg);
                previous=storearray[self.RCUi];
                value[0]=ApplyMask(value[0],width,bitoffset,previous);
                storearray[self.RCUi]=value[0]
                logging.debug("Stored value:"+str(value[0]))
            #  devreg['drivercls'].i2csetget
            return devreg['drivercls'].i2csetget(devreg['addr'],value,reg=devreg['register_W'])

    def GetVarValue(self,devreg,width,bitoffset,value):
        logging.debug(str(("RCU1 Get ",self.RCUi,devreg['addr'],value)))
#                self.GetI2C(RCUi,devreg,width,bitoffset,value)
#        if dev.store>0:
#            value[0]=self.previous[RCUi,dev.store-1]
#        return True
        callback=devreg['drivercls'].i2csetget;
        l1=int(np.floor((width+bitoffset+7)/8))
#        print(width,bitoffset,l1)
        value2=value
        reg=devreg['register_R']
        if reg>255: #This is for the monitor ADC
          if not(callback(devreg['addr'],int2bytes(reg),read=2)): return False;
          callback(0,[250],read=3)
          if not(callback(devreg['addr'],value2,read=1)): return False;
        else:
          if not(callback(devreg['addr'],value2,reg=reg,read=1)): return False;
        if value2[0] is None:  return False
        value[:]=value2[:];
        if devreg['store']:
             storearray=self.getstorearray(devreg);
             storearray[self.RCUi]=value[0]
             logging.debug(str(("Store buffer",self.RCUi,value[0])))
 #            print("Stored values:",self.getstorearray(devreg))
        if (width!=l1*8) or (bitoffset>0):
            if (width<8):
              for i in range(len(value)):
                value[i]=UnMask(value[i],width,bitoffset)
            else:
                value[0]=UnMask(value[0],width-(l1-1)*8,bitoffset)
        else: value[0]=value2[0]
        return True;