Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from array import array 
  47  import collections 
  48  from fractions import Fraction 
  49  import numbers 
  50  import re 
  51  import aloha # define mode of writting 
  52   
  53  try: 
  54      import madgraph.various.misc as misc 
  55  except Exception: 
  56      import aloha.misc as misc 
57 58 -class defaultdict(collections.defaultdict):
59
60 - def __call__(self, *args):
61 return defaultdict(int)
62
63 -class Computation(dict):
64 """ a class to encapsulate all computation. Limit side effect """ 65
66 - def __init__(self):
67 self.objs = [] 68 self.use_tag = set() 69 self.id = -1 70 self.reduced_expr = {} 71 self.fct_expr = {} 72 self.reduced_expr2 = {} 73 self.inverted_fct = {} 74 self.has_pi = False # logical to check if pi is used in at least one fct 75 self.unknow_fct = [] 76 dict.__init__(self)
77
78 - def clean(self):
79 self.__init__() 80 self.clear()
81
82 - def add(self, name, obj):
83 self.id += 1 84 self.objs.append(obj) 85 self[name] = self.id 86 return self.id
87
88 - def get(self, name):
89 return self.objs[self[name]]
90
91 - def add_tag(self, tag):
92 self.use_tag.update(tag)
93
94 - def get_ids(self, variables):
95 """return the list of identification number associate to the 96 given variables names. If a variable didn't exists, create it (in complex). 97 """ 98 out = [] 99 for var in variables: 100 try: 101 id = self[var] 102 except KeyError: 103 assert var not in ['M','W'] 104 id = Variable(var).get_id() 105 out.append(id) 106 return out
107 108
109 - def add_expression_contraction(self, expression):
110 111 str_expr = str(expression) 112 if str_expr in self.reduced_expr: 113 out, tag = self.reduced_expr[str_expr] 114 self.add_tag((tag,)) 115 return out 116 if expression == 0: 117 return 0 118 new_2 = expression.simplify() 119 if new_2 == 0: 120 return 0 121 # Add a new variable 122 tag = 'TMP%s' % len(self.reduced_expr) 123 new = Variable(tag) 124 self.reduced_expr[str_expr] = [new, tag] 125 new_2 = new_2.factorize() 126 self.reduced_expr2[tag] = new_2 127 self.add_tag((tag,)) 128 #self.unknow_fct = [] 129 #return expression 130 return new
131 132 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 133 'theta_function', 'exp']
134 - def add_function_expression(self, fct_tag, *args):
135 136 137 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 138 (fct_tag, len(args)) in self.unknow_fct): 139 self.unknow_fct.append( (fct_tag, len(args)) ) 140 141 argument = [] 142 for expression in args: 143 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 144 try: 145 expr = expression.expand().get_rep([0]) 146 except KeyError, error: 147 if error.args != ((0,),): 148 raise 149 else: 150 raise aloha.ALOHAERROR, '''Error in input format. 151 Argument of function (or denominator) should be scalar. 152 We found %s''' % expression 153 new = expr.simplify() 154 if not isinstance(new, numbers.Number): 155 new = new.factorize() 156 argument.append(new) 157 else: 158 argument.append(expression) 159 160 for arg in argument: 161 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 162 for v in val: 163 self.add_tag(('FCT%s' % v,)) 164 165 # check if the function is a pure numerical function. 166 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \ 167 all(isinstance(x, numbers.Number) for x in argument): 168 import cmath 169 if fct_tag.startswith('cmath.'): 170 module = '' 171 else: 172 module = 'cmath.' 173 try: 174 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(`x` for x in argument)))) 175 except Exception, error: 176 print error 177 print "cmath.%s(%s)" % (fct_tag, ','.join(`x` for x in argument)) 178 if str(fct_tag)+str(argument) in self.inverted_fct: 179 tag = self.inverted_fct[str(fct_tag)+str(argument)] 180 v = tag.split('(')[1][:-1] 181 self.add_tag(('FCT%s' % v,)) 182 return tag 183 else: 184 id = len(self.fct_expr) 185 tag = 'FCT%s' % id 186 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 187 self.fct_expr[tag] = (fct_tag, argument) 188 self.reduced_expr2[tag] = (fct_tag, argument) 189 self.add_tag((tag,)) 190 191 return 'FCT(%s)' % id
192 193 KERNEL = Computation()
194 195 #=============================================================================== 196 # AddVariable 197 #=============================================================================== 198 -class AddVariable(list):
199 """ A list of Variable/ConstantObject/... This object represent the operation 200 between those object.""" 201 202 #variable to fastenize class recognition 203 vartype = 1 204
205 - def __init__(self, old_data=[], prefactor=1):
206 """ initialization of the object with default value """ 207 208 self.prefactor = prefactor 209 #self.tag = set() 210 list.__init__(self, old_data)
211
212 - def simplify(self):
213 """ apply rule of simplification """ 214 215 # deal with one length object 216 if len(self) == 1: 217 return self.prefactor * self[0].simplify() 218 constant = 0 219 items = {} 220 pos = -1 221 for term in self[:]: 222 pos += 1 # current position in the real self 223 if not hasattr(term, 'vartype'): 224 if isinstance(term, dict): 225 # allow term of type{(0,):x} 226 assert term.values() == [0] 227 term = term[(0,)] 228 constant += term 229 del self[pos] 230 pos -= 1 231 continue 232 tag = tuple(term.sort()) 233 if tag in items: 234 orig_prefac = items[tag].prefactor # to assume to zero 0.33333 -0.3333 235 items[tag].prefactor += term.prefactor 236 if items[tag].prefactor and \ 237 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8: 238 items[tag].prefactor = 0 239 del self[pos] 240 pos -=1 241 else: 242 items[tag] = term.__class__(term, term.prefactor) 243 self[pos] = items[tag] 244 245 # get the optimized prefactor 246 countprefact = defaultdict(int) 247 nbplus, nbminus = 0,0 248 if constant not in [0, 1,-1]: 249 countprefact[constant] += 1 250 if constant.real + constant.imag > 0: 251 nbplus += 1 252 else: 253 nbminus += 1 254 255 for var in items.values(): 256 if var.prefactor == 0: 257 self.remove(var) 258 else: 259 nb = var.prefactor 260 if nb in [1,-1]: 261 continue 262 countprefact[abs(nb)] +=1 263 if nb.real + nb.imag > 0: 264 nbplus += 1 265 else: 266 nbminus += 1 267 if countprefact and max(countprefact.values()) >1: 268 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 269 else: 270 fact_prefactor = 1 271 if nbplus < nbminus: 272 fact_prefactor *= -1 273 self.prefactor *= fact_prefactor 274 275 if fact_prefactor != 1: 276 for i,a in enumerate(self): 277 try: 278 a.prefactor /= fact_prefactor 279 except AttributeError: 280 self[i] /= fact_prefactor 281 282 if constant: 283 self.append(constant/ fact_prefactor ) 284 285 # deal with one/zero length object 286 varlen = len(self) 287 if varlen == 1: 288 if hasattr(self[0], 'vartype'): 289 return self.prefactor * self[0].simplify() 290 else: 291 #self[0] is a number 292 return self.prefactor * self[0] 293 elif varlen == 0: 294 return 0 #ConstantObject() 295 return self
296
297 - def split(self, variables_id):
298 """return a dict with the key being the power associated to each variables 299 and the value being the object remaining after the suppression of all 300 the variable""" 301 302 out = defaultdict(int) 303 for obj in self: 304 for key, value in obj.split(variables_id).items(): 305 out[key] += self.prefactor * value 306 return out
307
308 - def contains(self, variables):
309 """returns true if one of the variables is in the expression""" 310 311 return any((v in obj for obj in self for v in variables ))
312 313
314 - def get_all_var_names(self):
315 316 out = [] 317 for term in self: 318 if hasattr(term, 'get_all_var_names'): 319 out += term.get_all_var_names() 320 return out
321 322 323
324 - def replace(self, id, expression):
325 """replace one object (identify by his id) by a given expression. 326 Note that expression cann't be zero. 327 Note that this should be canonical form (this should contains ONLY 328 MULTVARIABLE) --so this should be called before a factorize. 329 """ 330 new = self.__class__() 331 332 for obj in self: 333 assert isinstance(obj, MultVariable) 334 tmp = obj.replace(id, expression) 335 new += tmp 336 new.prefactor = self.prefactor 337 return new
338 339
340 - def expand(self, veto=[]):
341 """Pass from High level object to low level object""" 342 343 if not self: 344 return self 345 if self.prefactor == 1: 346 new = self[0].expand(veto) 347 else: 348 new = self.prefactor * self[0].expand(veto) 349 350 for item in self[1:]: 351 if self.prefactor == 1: 352 try: 353 new += item.expand(veto) 354 except AttributeError: 355 new = new + item 356 357 else: 358 new += (self.prefactor) * item.expand(veto) 359 return new
360
361 - def __mul__(self, obj):
362 """define the multiplication of 363 - a AddVariable with a number 364 - a AddVariable with an AddVariable 365 other type of multiplication are define via the symmetric operation base 366 on the obj class.""" 367 368 369 if not hasattr(obj, 'vartype'): # obj is a number 370 if not obj: 371 return 0 372 return self.__class__(self, self.prefactor*obj) 373 elif obj.vartype == 1: # obj is an AddVariable 374 new = self.__class__([],self.prefactor * obj.prefactor) 375 new[:] = [i*j for i in self for j in obj] 376 return new 377 else: 378 #force the program to look at obj + self 379 return NotImplemented
380
381 - def __imul__(self, obj):
382 """define the multiplication of 383 - a AddVariable with a number 384 - a AddVariable with an AddVariable 385 other type of multiplication are define via the symmetric operation base 386 on the obj class.""" 387 388 if not hasattr(obj, 'vartype'): # obj is a number 389 if not obj: 390 return 0 391 self.prefactor *= obj 392 return self 393 elif obj.vartype == 1: # obj is an AddVariable 394 new = self.__class__([], self.prefactor * obj.prefactor) 395 new[:] = [i*j for i in self for j in obj] 396 return new 397 else: 398 #force the program to look at obj + self 399 return NotImplemented
400
401 - def __neg__(self):
402 self.prefactor *= -1 403 return self
404
405 - def __add__(self, obj):
406 """Define all the different addition.""" 407 408 if not hasattr(obj, 'vartype'): 409 if not obj: # obj is zero 410 return self 411 new = self.__class__(self, self.prefactor) 412 new.append(obj/self.prefactor) 413 return new 414 elif obj.vartype == 2: # obj is a MultVariable 415 new = AddVariable(self, self.prefactor) 416 if self.prefactor == 1: 417 new.append(obj) 418 else: 419 new.append((1/self.prefactor)*obj) 420 return new 421 elif obj.vartype == 1: # obj is a AddVariable 422 new = AddVariable(self, self.prefactor) 423 for item in obj: 424 new.append(obj.prefactor/self.prefactor * item) 425 return new 426 else: 427 #force to look at obj + self 428 return NotImplemented
429
430 - def __iadd__(self, obj):
431 """Define all the different addition.""" 432 433 if not hasattr(obj, 'vartype'): 434 if not obj: # obj is zero 435 return self 436 self.append(obj/self.prefactor) 437 return self 438 elif obj.vartype == 2: # obj is a MultVariable 439 if self.prefactor == 1: 440 self.append(obj) 441 else: 442 self.append((1/self.prefactor)*obj) 443 return self 444 elif obj.vartype == 1: # obj is a AddVariable 445 for item in obj: 446 self.append(obj.prefactor/self.prefactor * item) 447 return self 448 else: 449 #force to look at obj + self 450 return NotImplemented
451
452 - def __sub__(self, obj):
453 return self + (-1) * obj
454
455 - def __rsub__(self, obj):
456 return (-1) * self + obj
457 458 __radd__ = __add__ 459 __rmul__ = __mul__ 460 461
462 - def __div__(self, obj):
463 return self.__mul__(1/obj)
464 465 __truediv__ = __div__ 466
467 - def __rdiv__(self, obj):
468 return self.__rmult__(1/obj)
469
470 - def __str__(self):
471 text = '' 472 if self.prefactor != 1: 473 text += str(self.prefactor) + ' * ' 474 text += '( ' 475 text += ' + '.join([str(item) for item in self]) 476 text += ' )' 477 return text
478
479 - def __repr__(self):
480 text = '' 481 if self.prefactor != 1: 482 text += str(self.prefactor) + ' * ' 483 text += super(AddVariable,self).__repr__() 484 return text
485
486 - def count_term(self):
487 # Count the number of appearance of each variable and find the most 488 #present one in order to factorize her 489 count = defaultdict(int) 490 correlation = defaultdict(defaultdict(int)) 491 for i,term in enumerate(self): 492 try: 493 set_term = set(term) 494 except TypeError: 495 #constant term 496 continue 497 for val1 in set_term: 498 count[val1] +=1 499 # allow to find optimized factorization for identical count 500 for val2 in set_term: 501 correlation[val1][val2] += 1 502 503 maxnb = max(count.values()) if count else 0 504 possibility = [v for v,val in count.items() if val == maxnb] 505 if maxnb == 1: 506 return 1, None 507 elif len(possibility) == 1: 508 return maxnb, possibility[0] 509 #import random 510 #return maxnb, random.sample(possibility,1)[0] 511 512 #return maxnb, possibility[0] 513 max_wgt, maxvar = 0, None 514 for var in possibility: 515 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 516 if wgt > max_wgt: 517 maxvar = var 518 max_wgt = wgt 519 str_maxvar = str(KERNEL.objs[var]) 520 elif wgt == max_wgt: 521 # keep the one with the lowest string expr 522 new_str = str(KERNEL.objs[var]) 523 if new_str < str_maxvar: 524 maxvar = var 525 str_maxvar = new_str 526 return maxnb, maxvar
527
528 - def factorize(self):
529 """ try to factorize as much as possible the expression """ 530 531 max, maxvar = self.count_term() 532 if max <= 1: 533 #no factorization possible 534 return self 535 else: 536 # split in MAXVAR * NEWADD + CONSTANT 537 newadd = AddVariable() 538 constant = AddVariable() 539 #fill NEWADD and CONSTANT 540 for term in self: 541 try: 542 term.remove(maxvar) 543 except Exception: 544 constant.append(term) 545 else: 546 if len(term): 547 newadd.append(term) 548 else: 549 newadd.append(term.prefactor) 550 newadd = newadd.factorize() 551 552 # optimize the prefactor 553 if isinstance(newadd, AddVariable): 554 countprefact = defaultdict(int) 555 nbplus, nbminus = 0,0 556 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 557 countprefact[abs(nb)] +=1 558 if nb.real + nb.imag > 0: 559 nbplus += 1 560 else: 561 nbminus += 1 562 563 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 564 if nbplus < nbminus: 565 newadd.prefactor *= -1 566 if newadd.prefactor != 1: 567 for i,a in enumerate(newadd): 568 try: 569 a.prefactor /= newadd.prefactor 570 except AttributeError: 571 newadd[i] /= newadd.prefactor 572 573 574 if len(constant) > 1: 575 constant = constant.factorize() 576 elif constant: 577 constant = constant[0] 578 else: 579 out = MultContainer([KERNEL.objs[maxvar], newadd]) 580 out.prefactor = self.prefactor 581 if newadd.prefactor != 1: 582 out.prefactor *= newadd.prefactor 583 newadd.prefactor = 1 584 return out 585 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 586 self.prefactor) 587 return out
588
589 -class MultContainer(list):
590 591 vartype = 6 592
593 - def __init__(self,*args):
594 self.prefactor =1 595 list.__init__(self, *args)
596
597 - def __str__(self):
598 """ String representation """ 599 if self.prefactor !=1: 600 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 601 else: 602 text = '(%s)' % (' * '.join([str(t) for t in self])) 603 return text
604
605 - def factorize(self):
606 self[:] = [term.factorize() for term in self]
607
608 609 -class MultVariable(array):
610 """ A list of Variable with multiplication as operator between themselves. 611 Represented by array for speed optimization 612 """ 613 vartype=2 614 addclass = AddVariable 615
616 - def __new__(cls, old=[], prefactor=1):
617 return array.__new__(cls, 'i', old)
618 619
620 - def __init__(self, old=[], prefactor=1):
621 """ initialization of the object with default value """ 622 #array.__init__(self, 'i', old) <- done already in new !! 623 self.prefactor = prefactor 624 assert isinstance(self.prefactor, (float,int,long,complex))
625
626 - def get_id(self):
627 assert len(self) == 1 628 return self[0]
629
630 - def sort(self):
631 a = list(self) 632 a.sort() 633 self[:] = array('i',a) 634 return self
635
636 - def simplify(self):
637 """ simplify the product""" 638 if not len(self): 639 return self.prefactor 640 return self
641
642 - def split(self, variables_id):
643 """return a dict with the key being the power associated to each variables 644 and the value being the object remaining after the suppression of all 645 the variable""" 646 647 key = tuple([self.count(i) for i in variables_id]) 648 arg = [id for id in self if id not in variables_id] 649 self[:] = array('i', arg) 650 return SplitCoefficient([(key,self)])
651
652 - def replace(self, id, expression):
653 """replace one object (identify by his id) by a given expression. 654 Note that expression cann't be zero. 655 """ 656 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 657 658 if expression.vartype == 1: # AddVariable 659 nb = self.count(id) 660 if not nb: 661 return self 662 for i in range(nb): 663 self.remove(id) 664 new = self 665 for i in range(nb): 666 new *= expression 667 return new 668 elif expression.vartype == 2: # MultLorentz 669 # be carefull about A -> A * B 670 nb = self.count(id) 671 for i in range(nb): 672 self.remove(id) 673 self.__imul__(expression) 674 return self 675 # elif expression.vartype == 0: # Variable 676 # new_id = expression.id 677 # assert new_id != id 678 # while 1: 679 # try: 680 # self.remove(id) 681 # except ValueError: 682 # break 683 # else: 684 # self.append(new_id) 685 # return self 686 else: 687 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
688 689
690 - def get_all_var_names(self):
691 """return the list of variable used in this multiplication""" 692 return ['%s' % KERNEL.objs[n] for n in self]
693 694 695 696 #Defining rule of Multiplication
697 - def __mul__(self, obj):
698 """Define the multiplication with different object""" 699 700 if not hasattr(obj, 'vartype'): # should be a number 701 if obj: 702 return self.__class__(self, obj*self.prefactor) 703 else: 704 return 0 705 elif obj.vartype == 1: # obj is an AddVariable 706 new = obj.__class__([], self.prefactor*obj.prefactor) 707 old, self.prefactor = self.prefactor, 1 708 new[:] = [self * term for term in obj] 709 self.prefactor = old 710 return new 711 elif obj.vartype == 4: 712 return NotImplemented 713 714 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
715 716 __rmul__ = __mul__ 717
718 - def __imul__(self, obj):
719 """Define the multiplication with different object""" 720 721 if not hasattr(obj, 'vartype'): # should be a number 722 if obj: 723 self.prefactor *= obj 724 return self 725 else: 726 return 0 727 elif obj.vartype == 1: # obj is an AddVariable 728 new = obj.__class__([], self.prefactor * obj.prefactor) 729 self.prefactor = 1 730 new[:] = [self * term for term in obj] 731 return new 732 elif obj.vartype == 4: 733 return NotImplemented 734 735 self.prefactor *= obj.prefactor 736 return array.__iadd__(self, obj)
737
738 - def __pow__(self,value):
739 out = 1 740 for i in range(value): 741 out *= self 742 return out
743 744
745 - def __add__(self, obj):
746 """ define the adition with different object""" 747 748 if not obj: 749 return self 750 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 751 new = self.addclass([self, obj]) 752 return new 753 else: 754 #call the implementation of addition implemented in obj 755 return NotImplemented
756 __radd__ = __add__ 757 __iadd__ = __add__ 758
759 - def __sub__(self, obj):
760 return self + (-1) * obj
761
762 - def __neg__(self):
763 self.prefactor *=-1 764 return self
765
766 - def __rsub__(self, obj):
767 return (-1) * self + obj
768
769 - def __idiv__(self,obj):
770 """ ONLY NUMBER DIVISION ALLOWED""" 771 assert not hasattr(obj, 'vartype') 772 self.prefactor /= obj 773 return self
774 775 __div__ = __idiv__ 776 __truediv__ = __div__ 777 778
779 - def __str__(self):
780 """ String representation """ 781 t = ['%s' % KERNEL.objs[n] for n in self] 782 if self.prefactor != 1: 783 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 784 else: 785 text = '(%s)' % (' * '.join(t)) 786 return text
787 788 __rep__ = __str__ 789
790 - def factorize(self):
791 return self
792
793 794 #=============================================================================== 795 # FactoryVar 796 #=============================================================================== 797 -class C_Variable(str):
798 vartype=0 799 type = 'complex'
800
801 -class R_Variable(str):
802 vartype=0 803 type = 'double'
804
805 -class ExtVariable(str):
806 vartype=0 807 type = 'parameter'
808
809 810 -class FactoryVar(object):
811 """This is the standard object for all the variable linked to expression. 812 """ 813 mult_class = MultVariable # The class for the multiplication 814
815 - def __new__(cls, name, baseclass, *args):
816 """Factory class return a MultVariable.""" 817 818 if name in KERNEL: 819 return cls.mult_class([KERNEL[name]]) 820 else: 821 obj = baseclass(name, *args) 822 id = KERNEL.add(name, obj) 823 obj.id = id 824 return cls.mult_class([id])
825
826 -class Variable(FactoryVar):
827
828 - def __new__(self, name, type=C_Variable):
829 return FactoryVar(name, type)
830
831 -class DVariable(FactoryVar):
832
833 - def __new__(self, name):
834 835 if aloha.complex_mass: 836 #some parameter are pass to complex 837 if name[0] in ['M','W'] or name.startswith('OM'): 838 return FactoryVar(name, C_Variable) 839 if aloha.loop_mode and name.startswith('P'): 840 return FactoryVar(name, C_Variable) 841 #Normal case: 842 return FactoryVar(name, R_Variable)
843
844 845 846 847 #=============================================================================== 848 # Object for Analytical Representation of Lorentz object (not scalar one) 849 #=============================================================================== 850 851 852 #=============================================================================== 853 # MultLorentz 854 #=============================================================================== 855 -class MultLorentz(MultVariable):
856 """Specific class for LorentzObject Multiplication""" 857 858 add_class = AddVariable # Define which class describe the addition 859
860 - def find_lorentzcontraction(self):
861 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 862 the contraction in this Multiplication.""" 863 864 out = {} 865 len_mult = len(self) 866 # Loop over the element 867 for i, fact in enumerate(self): 868 # and over the indices of this element 869 for j in range(len(fact.lorentz_ind)): 870 # in order to compare with the other element of the multiplication 871 for k in range(i+1,len_mult): 872 fact2 = self[k] 873 try: 874 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 875 except Exception: 876 pass 877 else: 878 out[(i, j)] = (k, l) 879 out[(k, l)] = (i, j) 880 return out
881
882 - def find_spincontraction(self):
883 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 884 the contraction in this Multiplication.""" 885 886 out = {} 887 len_mult = len(self) 888 # Loop over the element 889 for i, fact in enumerate(self): 890 # and over the indices of this element 891 for j in range(len(fact.spin_ind)): 892 # in order to compare with the other element of the multiplication 893 for k in range(i+1, len_mult): 894 fact2 = self[k] 895 try: 896 l = fact2.spin_ind.index(fact.spin_ind[j]) 897 except Exception: 898 pass 899 else: 900 out[(i, j)] = (k, l) 901 out[(k, l)] = (i, j) 902 903 return out
904
905 - def neighboor(self, home):
906 """return one variable which are contracted with var and not yet expanded""" 907 908 for var in self.unused: 909 obj = KERNEL.objs[var] 910 if obj.has_component(home.lorentz_ind, home.spin_ind): 911 return obj 912 return None
913 914 915 916
917 - def expand(self, veto=[]):
918 """ expand each part of the product and combine them. 919 Try to use a smart order in order to minimize the number of uncontracted indices. 920 Veto forbids the use of sub-expression if it contains some of the variable in the 921 expression. Veto contains the id of the vetoed variables 922 """ 923 924 self.unused = self[:] # list of not expanded 925 # made in a list the interesting starting point for the computation 926 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 927 product_term = [] #store result of intermediate chains 928 current = None # current point in the working chain 929 930 while self.unused: 931 #Loop untill we have expand everything 932 if not current: 933 # First we need to have a starting point 934 try: 935 # look in priority in basic_end_point (P/S/fermion/...) 936 current = basic_end_point.pop() 937 except Exception: 938 #take one of the remaining 939 current = self.unused.pop() 940 else: 941 #check that this one is not already use 942 if current not in self.unused: 943 current = None 944 continue 945 #remove of the unuse (usualy done in the pop) 946 self.unused.remove(current) 947 cur_obj = KERNEL.objs[current] 948 # initialize the new chain 949 product_term.append(cur_obj.expand()) 950 951 # We have a point -> find the next one 952 var_obj = self.neighboor(product_term[-1]) 953 # provide one term which is contracted with current and which is not 954 #yet expanded. 955 if var_obj: 956 product_term[-1] *= var_obj.expand() 957 cur_obj = var_obj 958 self.unused.remove(cur_obj.id) 959 continue 960 961 current = None 962 963 964 # Multiply all those current 965 # For Fermion/Vector only one can carry index. 966 out = self.prefactor 967 for fact in product_term[:]: 968 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 969 scalar = fact.get_rep([0]) 970 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 971 if not veto or not scalar.contains(veto): 972 scalar = scalar.simplify() 973 prefactor = 1 974 975 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 976 prefactor = scalar.prefactor 977 scalar.prefactor = 1 978 new = KERNEL.add_expression_contraction(scalar) 979 fact.set_rep([0], prefactor * new) 980 out *= fact 981 return out
982
983 - def __copy__(self):
984 """ create a shadow copy """ 985 new = MultLorentz(self) 986 new.prefactor = self.prefactor 987 return new
988
989 #=============================================================================== 990 # LorentzObject 991 #=============================================================================== 992 -class LorentzObject(object):
993 """ A symbolic Object for All Helas object. All Helas Object Should 994 derivated from this class""" 995 996 contract_first = 0 997 mult_class = MultLorentz # The class for the multiplication 998 add_class = AddVariable # The class for the addition 999
1000 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
1001 """ initialization of the object with default value """ 1002 assert isinstance(lor_ind, list) 1003 assert isinstance(spin_ind, list) 1004 1005 self.name = name 1006 self.lorentz_ind = lor_ind 1007 self.spin_ind = spin_ind 1008 KERNEL.add_tag(set(tags))
1009
1010 - def expand(self):
1011 """Expand the content information into LorentzObjectRepresentation.""" 1012 1013 try: 1014 return self.representation 1015 except Exception: 1016 self.create_representation() 1017 return self.representation
1018
1019 - def create_representation(self):
1020 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1021
1022 - def has_component(self, lor_list, spin_list):
1023 """check if this Lorentz Object have some of those indices""" 1024 1025 if any([id in self.lorentz_ind for id in lor_list]) or \ 1026 any([id in self.spin_ind for id in spin_list]): 1027 return True
1028 1029 1030
1031 - def __str__(self):
1032 return '%s' % self.name
1033
1034 -class FactoryLorentz(FactoryVar):
1035 """ A symbolic Object for All Helas object. All Helas Object Should 1036 derivated from this class""" 1037 1038 mult_class = MultLorentz # The class for the multiplication 1039 object_class = LorentzObject # Define How to create the basic object. 1040
1041 - def __new__(cls, *args):
1042 name = cls.get_unique_name(*args) 1043 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1044 1045 @classmethod
1046 - def get_unique_name(cls, *args):
1047 """default way to have a unique name""" 1048 return '_L_%(class)s_%(args)s' % \ 1049 {'class':cls.__name__, 1050 'args': '_'.join(args) 1051 }
1052
1053 1054 #=============================================================================== 1055 # LorentzObjectRepresentation 1056 #=============================================================================== 1057 -class LorentzObjectRepresentation(dict):
1058 """A concrete representation of the LorentzObject.""" 1059 1060 vartype = 4 # Optimization for instance recognition 1061
1062 - class LorentzObjectRepresentationError(Exception):
1063 """Specify error for LorentzObjectRepresentation"""
1064
1065 - def __init__(self, representation, lorentz_indices, spin_indices):
1066 """ initialize the lorentz object representation""" 1067 1068 self.lorentz_ind = lorentz_indices #lorentz indices 1069 self.nb_lor = len(lorentz_indices) #their number 1070 self.spin_ind = spin_indices #spin indices 1071 self.nb_spin = len(spin_indices) #their number 1072 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1073 1074 1075 #store the representation 1076 if self.lorentz_ind or self.spin_ind: 1077 dict.__init__(self, representation) 1078 elif isinstance(representation,dict): 1079 if len(representation) == 0: 1080 self[(0,)] = 0 1081 elif len(representation) == 1 and (0,) in representation: 1082 self[(0,)] = representation[(0,)] 1083 else: 1084 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1085 else: 1086 if isinstance(representation,dict): 1087 try: 1088 self[(0,)] = representation[(0,)] 1089 except Exception: 1090 if representation: 1091 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1092 else: 1093 self[(0,)] = 0 1094 else: 1095 self[(0,)] = representation
1096
1097 - def __str__(self):
1098 """ string representation """ 1099 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1100 text += 'spin index :' + str(self.spin_ind) + '\n' 1101 #text += 'other info ' + str(self.tag) + '\n' 1102 for ind in self.listindices(): 1103 ind = tuple(ind) 1104 text += str(ind) + ' --> ' 1105 text += str(self.get_rep(ind)) + '\n' 1106 return text
1107
1108 - def get_rep(self, indices):
1109 """return the value/Variable associate to the indices""" 1110 return self[tuple(indices)]
1111
1112 - def set_rep(self, indices, value):
1113 """assign 'value' at the indices position""" 1114 1115 self[tuple(indices)] = value
1116
1117 - def listindices(self):
1118 """Return an iterator in order to be able to loop easily on all the 1119 indices of the object.""" 1120 return IndicesIterator(self.nb_ind)
1121 1122 @staticmethod
1123 - def get_mapping(l1,l2, switch_order=[]):
1124 shift = len(switch_order) 1125 for value in l1: 1126 try: 1127 index = l2.index(value) 1128 except Exception: 1129 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1130 "Invalid addition. Object doen't have the same lorentz "+ \ 1131 "indices : %s != %s" % (l1, l2)) 1132 else: 1133 switch_order.append(shift + index) 1134 return switch_order
1135 1136
1137 - def __add__(self, obj, fact=1):
1138 1139 if not obj: 1140 return self 1141 1142 if not hasattr(obj, 'vartype'): 1143 assert self.lorentz_ind == [] 1144 assert self.spin_ind == [] 1145 new = self[(0,)] + obj * fact 1146 out = LorentzObjectRepresentation(new, [], []) 1147 return out 1148 1149 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1150 1151 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1152 # if the order of indices are different compute a mapping 1153 switch_order = [] 1154 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1155 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1156 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1157 else: 1158 # no mapping needed (define switch as identity) 1159 switch = lambda ind : (ind) 1160 1161 # Some sanity check 1162 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1163 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1164 1165 # define an empty representation 1166 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1167 1168 # loop over all indices and fullfill the new object 1169 if fact == 1: 1170 for ind in self.listindices(): 1171 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1172 new.set_rep(ind, value) 1173 else: 1174 for ind in self.listindices(): 1175 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1176 new.set_rep(ind, value) 1177 1178 return new
1179
1180 - def __iadd__(self, obj, fact=1):
1181 1182 if not obj: 1183 return self 1184 1185 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1186 1187 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1188 1189 # if the order of indices are different compute a mapping 1190 switch_order = [] 1191 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1192 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1193 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1194 else: 1195 # no mapping needed (define switch as identity) 1196 switch = lambda ind : (ind) 1197 1198 # Some sanity check 1199 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1200 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1201 1202 # loop over all indices and fullfill the new object 1203 if fact == 1: 1204 for ind in self.listindices(): 1205 self[tuple(ind)] += obj.get_rep(switch(ind)) 1206 else: 1207 for ind in self.listindices(): 1208 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1209 return self
1210
1211 - def __sub__(self, obj):
1212 return self.__add__(obj, fact= -1)
1213
1214 - def __rsub__(self, obj):
1215 return obj.__add__(self, fact= -1)
1216
1217 - def __isub__(self, obj):
1218 return self.__add__(obj, fact= -1)
1219
1220 - def __neg__(self):
1221 self *= -1 1222 return self
1223
1224 - def __mul__(self, obj):
1225 """multiplication performing directly the einstein/spin sommation. 1226 """ 1227 1228 if not hasattr(obj, 'vartype'): 1229 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1230 for ind in out.listindices(): 1231 out.set_rep(ind, obj * self.get_rep(ind)) 1232 return out 1233 1234 # Sanity Check 1235 assert(obj.__class__ == LorentzObjectRepresentation), \ 1236 '%s is not valid class for this operation' %type(obj) 1237 1238 # compute information on the status of the index (which are contracted/ 1239 #not contracted 1240 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1241 obj.lorentz_ind) 1242 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1243 obj.spin_ind) 1244 if not(sum_l_ind or sum_s_ind): 1245 # No contraction made a tensor product 1246 return self.tensor_product(obj) 1247 1248 # elsewher made a spin contraction 1249 # create an empty representation but with correct indices 1250 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1251 #loop and fullfill the representation 1252 for indices in new_object.listindices(): 1253 #made a dictionary (pos -> index_value) for how call the object 1254 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1255 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1256 #add the new value 1257 new_object.set_rep(indices, \ 1258 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1259 dict_l_ind, dict_s_ind)) 1260 1261 return new_object
1262 1263 __rmul__ = __mul__ 1264 __imul__ = __mul__ 1265
1266 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1267 """ make the Lorentz/spin contraction of object self and obj. 1268 l_sum/s_sum are the position of the sum indices 1269 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1270 """ 1271 out = 0 # initial value for the output 1272 len_l = len(l_sum) #store len for optimization 1273 len_s = len(s_sum) # same 1274 1275 # loop over the possibility for the sum indices and update the dictionary 1276 # (indices->value) 1277 for l_value in IndicesIterator(len_l): 1278 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1279 for s_value in IndicesIterator(len_s): 1280 #s_dict_final = s_dict.copy() 1281 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1282 1283 #return the indices in the correct order 1284 self_ind = self.combine_indices(l_dict, s_dict) 1285 obj_ind = obj.combine_indices(l_dict, s_dict) 1286 1287 # call the object 1288 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1289 1290 if factor: 1291 #compute the prefactor due to the lorentz contraction 1292 try: 1293 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1294 except Exception: 1295 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1296 out += factor 1297 return out
1298
1299 - def tensor_product(self, obj):
1300 """ return the tensorial product of the object""" 1301 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1302 1303 new_object = LorentzObjectRepresentation({}, \ 1304 self.lorentz_ind + obj.lorentz_ind, \ 1305 self.spin_ind + obj.spin_ind) 1306 1307 #some shortcut 1308 lor1 = self.nb_lor 1309 lor2 = obj.nb_lor 1310 spin1 = self.nb_spin 1311 spin2 = obj.nb_spin 1312 1313 #define how to call build the indices first for the first object 1314 if lor1 == 0 == spin1: 1315 #special case for scalar 1316 selfind = lambda indices: [0] 1317 else: 1318 selfind = lambda indices: indices[:lor1] + \ 1319 indices[lor1 + lor2: lor1 + lor2 + spin1] 1320 1321 #then for the second 1322 if lor2 == 0 == spin2: 1323 #special case for scalar 1324 objind = lambda indices: [0] 1325 else: 1326 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1327 indices[lor1 + lor2 + spin1:] 1328 1329 # loop on the indices and assign the product 1330 for indices in new_object.listindices(): 1331 1332 fac1 = self.get_rep(tuple(selfind(indices))) 1333 fac2 = obj.get_rep(tuple(objind(indices))) 1334 new_object.set_rep(indices, fac1 * fac2) 1335 1336 return new_object
1337
1338 - def factorize(self):
1339 """Try to factorize each component""" 1340 for ind, fact in self.items(): 1341 if fact: 1342 self.set_rep(ind, fact.factorize()) 1343 1344 1345 return self
1346
1347 - def simplify(self):
1348 """Check if we can simplify the object (check for non treated Sum)""" 1349 1350 #Look for internal simplification 1351 for ind, term in self.items(): 1352 if hasattr(term, 'vartype'): 1353 self[ind] = term.simplify() 1354 #no additional simplification 1355 return self
1356 1357 @staticmethod
1358 - def compare_indices(list1, list2):
1359 """return two list, the first one contains the position of non summed 1360 index and the second one the position of summed index.""" 1361 #init object 1362 1363 # equivalent set call --slightly slower 1364 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1365 1366 1367 are_unique, are_sum = [], [] 1368 # loop over the first list and check if they are in the second list 1369 1370 for indice in list1: 1371 if indice in list2: 1372 are_sum.append(indice) 1373 else: 1374 are_unique.append(indice) 1375 # loop over the second list for additional unique item 1376 1377 for indice in list2: 1378 if indice not in are_sum: 1379 are_unique.append(indice) 1380 1381 # return value 1382 return are_unique, are_sum
1383 1384 @staticmethod
1385 - def pass_ind_in_dict(indices, key):
1386 """made a dictionary (pos -> index_value) for how call the object""" 1387 if not key: 1388 return {} 1389 out = {} 1390 for i, ind in enumerate(indices): 1391 out[key[i]] = ind 1392 return out
1393
1394 - def combine_indices(self, l_dict, s_dict):
1395 """return the indices in the correct order following the dicts rules""" 1396 1397 out = [] 1398 # First for the Lorentz indices 1399 for value in self.lorentz_ind: 1400 out.append(l_dict[value]) 1401 # Same for the spin 1402 for value in self.spin_ind: 1403 out.append(s_dict[value]) 1404 1405 return out
1406
1407 - def split(self, variables_id):
1408 """return a dict with the key being the power associated to each variables 1409 and the value being the object remaining after the suppression of all 1410 the variable""" 1411 1412 out = SplitCoefficient() 1413 zero_rep = {} 1414 for ind in self.listindices(): 1415 zero_rep[tuple(ind)] = 0 1416 1417 for ind in self.listindices(): 1418 # There is no function split if the element is just a simple number 1419 if isinstance(self.get_rep(ind), numbers.Number): 1420 if tuple([0]*len(variables_id)) in out: 1421 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1422 else: 1423 out[tuple([0]*len(variables_id))] = \ 1424 LorentzObjectRepresentation(dict(zero_rep), 1425 self.lorentz_ind, self.spin_ind) 1426 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1427 continue 1428 1429 for key, value in self.get_rep(ind).split(variables_id).items(): 1430 if key in out: 1431 out[key][tuple(ind)] += value 1432 else: 1433 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1434 self.lorentz_ind, self.spin_ind) 1435 out[key][tuple(ind)] += value 1436 1437 return out
1438
1439 1440 1441 1442 #=============================================================================== 1443 # IndicesIterator 1444 #=============================================================================== 1445 -class IndicesIterator:
1446 """Class needed for the iterator""" 1447
1448 - def __init__(self, len):
1449 """ create an iterator looping over the indices of a list of len "len" 1450 with each value can take value between 0 and 3 """ 1451 1452 self.len = len # number of indices 1453 if len: 1454 # initialize the position. The first position is -1 due to the method 1455 #in place which start by rising an index before returning smtg 1456 self.data = [-1] + [0] * (len - 1) 1457 else: 1458 # Special case for Scalar object 1459 self.data = 0 1460 self.next = self.nextscalar
1461
1462 - def __iter__(self):
1463 return self
1464
1465 - def next(self):
1466 for i in range(self.len): 1467 if self.data[i] < 3: 1468 self.data[i] += 1 1469 return self.data 1470 else: 1471 self.data[i] = 0 1472 raise StopIteration
1473
1474 - def nextscalar(self):
1475 if self.data: 1476 raise StopIteration 1477 else: 1478 self.data = True 1479 return [0]
1480
1481 -class SplitCoefficient(dict):
1482
1483 - def __init__(self, *args, **opt):
1484 dict.__init__(self, *args, **opt) 1485 self.tag=set()
1486
1487 - def get_max_rank(self):
1488 """return the highest rank of the coefficient""" 1489 1490 return max([max(arg[:4]) for arg in self])
1491 1492 1493 if '__main__' ==__name__: 1494 1495 import cProfile
1496 - def create():
1497 for i in range(10000): 1498 LorentzObjectRepresentation.compare_indices(range(i%10),[4,3,5])
1499 1500 cProfile.run('create()') 1501