1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 from __future__ import division
46 from array import array
47 import collections
48 from fractions import Fraction
49 import numbers
50 import re
51 import aloha
52
53 try:
54 import madgraph.various.misc as misc
55 except Exception:
56 import aloha.misc as misc
62
64 """ a class to encapsulate all computation. Limit side effect """
65
67 self.objs = []
68 self.use_tag = set()
69 self.id = -1
70 self.reduced_expr = {}
71 self.fct_expr = {}
72 self.reduced_expr2 = {}
73 self.inverted_fct = {}
74 self.has_pi = False
75 self.unknow_fct = []
76 dict.__init__(self)
77
81
82 - def add(self, name, obj):
83 self.id += 1
84 self.objs.append(obj)
85 self[name] = self.id
86 return self.id
87
88 - def get(self, name):
89 return self.objs[self[name]]
90
93
95 """return the list of identification number associate to the
96 given variables names. If a variable didn't exists, create it (in complex).
97 """
98 out = []
99 for var in variables:
100 try:
101 id = self[var]
102 except KeyError:
103 assert var not in ['M','W']
104 id = Variable(var).get_id()
105 out.append(id)
106 return out
107
108
110
111 str_expr = str(expression)
112 if str_expr in self.reduced_expr:
113 out, tag = self.reduced_expr[str_expr]
114 self.add_tag((tag,))
115 return out
116 if expression == 0:
117 return 0
118 new_2 = expression.simplify()
119 if new_2 == 0:
120 return 0
121
122 tag = 'TMP%s' % len(self.reduced_expr)
123 new = Variable(tag)
124 self.reduced_expr[str_expr] = [new, tag]
125 new_2 = new_2.factorize()
126 self.reduced_expr2[tag] = new_2
127 self.add_tag((tag,))
128
129
130 return new
131
132 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot',
133 'theta_function', 'exp']
135
136
137 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or
138 (fct_tag, len(args)) in self.unknow_fct):
139 self.unknow_fct.append( (fct_tag, len(args)) )
140
141 argument = []
142 for expression in args:
143 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)):
144 try:
145 expr = expression.expand().get_rep([0])
146 except KeyError, error:
147 if error.args != ((0,),):
148 raise
149 else:
150 raise aloha.ALOHAERROR, '''Error in input format.
151 Argument of function (or denominator) should be scalar.
152 We found %s''' % expression
153 new = expr.simplify()
154 if not isinstance(new, numbers.Number):
155 new = new.factorize()
156 argument.append(new)
157 else:
158 argument.append(expression)
159
160 for arg in argument:
161 val = re.findall(r'''\bFCT(\d*)\b''', str(arg))
162 for v in val:
163 self.add_tag(('FCT%s' % v,))
164
165
166 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \
167 all(isinstance(x, numbers.Number) for x in argument):
168 import cmath
169 if fct_tag.startswith('cmath.'):
170 module = ''
171 else:
172 module = 'cmath.'
173 try:
174 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(`x` for x in argument))))
175 except Exception, error:
176 print error
177 print "cmath.%s(%s)" % (fct_tag, ','.join(`x` for x in argument))
178 if str(fct_tag)+str(argument) in self.inverted_fct:
179 tag = self.inverted_fct[str(fct_tag)+str(argument)]
180 v = tag.split('(')[1][:-1]
181 self.add_tag(('FCT%s' % v,))
182 return tag
183 else:
184 id = len(self.fct_expr)
185 tag = 'FCT%s' % id
186 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id
187 self.fct_expr[tag] = (fct_tag, argument)
188 self.reduced_expr2[tag] = (fct_tag, argument)
189 self.add_tag((tag,))
190
191 return 'FCT(%s)' % id
192
193 KERNEL = Computation()
199 """ A list of Variable/ConstantObject/... This object represent the operation
200 between those object."""
201
202
203 vartype = 1
204
205 - def __init__(self, old_data=[], prefactor=1):
206 """ initialization of the object with default value """
207
208 self.prefactor = prefactor
209
210 list.__init__(self, old_data)
211
213 """ apply rule of simplification """
214
215
216 if len(self) == 1:
217 return self.prefactor * self[0].simplify()
218 constant = 0
219 items = {}
220 pos = -1
221 for term in self[:]:
222 pos += 1
223 if not hasattr(term, 'vartype'):
224 if isinstance(term, dict):
225
226 assert term.values() == [0]
227 term = term[(0,)]
228 constant += term
229 del self[pos]
230 pos -= 1
231 continue
232 tag = tuple(term.sort())
233 if tag in items:
234 orig_prefac = items[tag].prefactor
235 items[tag].prefactor += term.prefactor
236 if items[tag].prefactor and \
237 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8:
238 items[tag].prefactor = 0
239 del self[pos]
240 pos -=1
241 else:
242 items[tag] = term.__class__(term, term.prefactor)
243 self[pos] = items[tag]
244
245
246 countprefact = defaultdict(int)
247 nbplus, nbminus = 0,0
248 if constant not in [0, 1,-1]:
249 countprefact[constant] += 1
250 if constant.real + constant.imag > 0:
251 nbplus += 1
252 else:
253 nbminus += 1
254
255 for var in items.values():
256 if var.prefactor == 0:
257 self.remove(var)
258 else:
259 nb = var.prefactor
260 if nb in [1,-1]:
261 continue
262 countprefact[abs(nb)] +=1
263 if nb.real + nb.imag > 0:
264 nbplus += 1
265 else:
266 nbminus += 1
267 if countprefact and max(countprefact.values()) >1:
268 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0]
269 else:
270 fact_prefactor = 1
271 if nbplus < nbminus:
272 fact_prefactor *= -1
273 self.prefactor *= fact_prefactor
274
275 if fact_prefactor != 1:
276 for i,a in enumerate(self):
277 try:
278 a.prefactor /= fact_prefactor
279 except AttributeError:
280 self[i] /= fact_prefactor
281
282 if constant:
283 self.append(constant/ fact_prefactor )
284
285
286 varlen = len(self)
287 if varlen == 1:
288 if hasattr(self[0], 'vartype'):
289 return self.prefactor * self[0].simplify()
290 else:
291
292 return self.prefactor * self[0]
293 elif varlen == 0:
294 return 0
295 return self
296
297 - def split(self, variables_id):
298 """return a dict with the key being the power associated to each variables
299 and the value being the object remaining after the suppression of all
300 the variable"""
301
302 out = defaultdict(int)
303 for obj in self:
304 for key, value in obj.split(variables_id).items():
305 out[key] += self.prefactor * value
306 return out
307
309 """returns true if one of the variables is in the expression"""
310
311 return any((v in obj for obj in self for v in variables ))
312
313
315
316 out = []
317 for term in self:
318 if hasattr(term, 'get_all_var_names'):
319 out += term.get_all_var_names()
320 return out
321
322
323
324 - def replace(self, id, expression):
325 """replace one object (identify by his id) by a given expression.
326 Note that expression cann't be zero.
327 Note that this should be canonical form (this should contains ONLY
328 MULTVARIABLE) --so this should be called before a factorize.
329 """
330 new = self.__class__()
331
332 for obj in self:
333 assert isinstance(obj, MultVariable)
334 tmp = obj.replace(id, expression)
335 new += tmp
336 new.prefactor = self.prefactor
337 return new
338
339
341 """Pass from High level object to low level object"""
342
343 if not self:
344 return self
345 if self.prefactor == 1:
346 new = self[0].expand(veto)
347 else:
348 new = self.prefactor * self[0].expand(veto)
349
350 for item in self[1:]:
351 if self.prefactor == 1:
352 try:
353 new += item.expand(veto)
354 except AttributeError:
355 new = new + item
356
357 else:
358 new += (self.prefactor) * item.expand(veto)
359 return new
360
362 """define the multiplication of
363 - a AddVariable with a number
364 - a AddVariable with an AddVariable
365 other type of multiplication are define via the symmetric operation base
366 on the obj class."""
367
368
369 if not hasattr(obj, 'vartype'):
370 if not obj:
371 return 0
372 return self.__class__(self, self.prefactor*obj)
373 elif obj.vartype == 1:
374 new = self.__class__([],self.prefactor * obj.prefactor)
375 new[:] = [i*j for i in self for j in obj]
376 return new
377 else:
378
379 return NotImplemented
380
382 """define the multiplication of
383 - a AddVariable with a number
384 - a AddVariable with an AddVariable
385 other type of multiplication are define via the symmetric operation base
386 on the obj class."""
387
388 if not hasattr(obj, 'vartype'):
389 if not obj:
390 return 0
391 self.prefactor *= obj
392 return self
393 elif obj.vartype == 1:
394 new = self.__class__([], self.prefactor * obj.prefactor)
395 new[:] = [i*j for i in self for j in obj]
396 return new
397 else:
398
399 return NotImplemented
400
402 self.prefactor *= -1
403 return self
404
406 """Define all the different addition."""
407
408 if not hasattr(obj, 'vartype'):
409 if not obj:
410 return self
411 new = self.__class__(self, self.prefactor)
412 new.append(obj/self.prefactor)
413 return new
414 elif obj.vartype == 2:
415 new = AddVariable(self, self.prefactor)
416 if self.prefactor == 1:
417 new.append(obj)
418 else:
419 new.append((1/self.prefactor)*obj)
420 return new
421 elif obj.vartype == 1:
422 new = AddVariable(self, self.prefactor)
423 for item in obj:
424 new.append(obj.prefactor/self.prefactor * item)
425 return new
426 else:
427
428 return NotImplemented
429
431 """Define all the different addition."""
432
433 if not hasattr(obj, 'vartype'):
434 if not obj:
435 return self
436 self.append(obj/self.prefactor)
437 return self
438 elif obj.vartype == 2:
439 if self.prefactor == 1:
440 self.append(obj)
441 else:
442 self.append((1/self.prefactor)*obj)
443 return self
444 elif obj.vartype == 1:
445 for item in obj:
446 self.append(obj.prefactor/self.prefactor * item)
447 return self
448 else:
449
450 return NotImplemented
451
453 return self + (-1) * obj
454
456 return (-1) * self + obj
457
458 __radd__ = __add__
459 __rmul__ = __mul__
460
461
464
465 __truediv__ = __div__
466
468 return self.__rmult__(1/obj)
469
471 text = ''
472 if self.prefactor != 1:
473 text += str(self.prefactor) + ' * '
474 text += '( '
475 text += ' + '.join([str(item) for item in self])
476 text += ' )'
477 return text
478
480 text = ''
481 if self.prefactor != 1:
482 text += str(self.prefactor) + ' * '
483 text += super(AddVariable,self).__repr__()
484 return text
485
487
488
489 count = defaultdict(int)
490 correlation = defaultdict(defaultdict(int))
491 for i,term in enumerate(self):
492 try:
493 set_term = set(term)
494 except TypeError:
495
496 continue
497 for val1 in set_term:
498 count[val1] +=1
499
500 for val2 in set_term:
501 correlation[val1][val2] += 1
502
503 maxnb = max(count.values()) if count else 0
504 possibility = [v for v,val in count.items() if val == maxnb]
505 if maxnb == 1:
506 return 1, None
507 elif len(possibility) == 1:
508 return maxnb, possibility[0]
509
510
511
512
513 max_wgt, maxvar = 0, None
514 for var in possibility:
515 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var])
516 if wgt > max_wgt:
517 maxvar = var
518 max_wgt = wgt
519 str_maxvar = str(KERNEL.objs[var])
520 elif wgt == max_wgt:
521
522 new_str = str(KERNEL.objs[var])
523 if new_str < str_maxvar:
524 maxvar = var
525 str_maxvar = new_str
526 return maxnb, maxvar
527
529 """ try to factorize as much as possible the expression """
530
531 max, maxvar = self.count_term()
532 if max <= 1:
533
534 return self
535 else:
536
537 newadd = AddVariable()
538 constant = AddVariable()
539
540 for term in self:
541 try:
542 term.remove(maxvar)
543 except Exception:
544 constant.append(term)
545 else:
546 if len(term):
547 newadd.append(term)
548 else:
549 newadd.append(term.prefactor)
550 newadd = newadd.factorize()
551
552
553 if isinstance(newadd, AddVariable):
554 countprefact = defaultdict(int)
555 nbplus, nbminus = 0,0
556 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]:
557 countprefact[abs(nb)] +=1
558 if nb.real + nb.imag > 0:
559 nbplus += 1
560 else:
561 nbminus += 1
562
563 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0]
564 if nbplus < nbminus:
565 newadd.prefactor *= -1
566 if newadd.prefactor != 1:
567 for i,a in enumerate(newadd):
568 try:
569 a.prefactor /= newadd.prefactor
570 except AttributeError:
571 newadd[i] /= newadd.prefactor
572
573
574 if len(constant) > 1:
575 constant = constant.factorize()
576 elif constant:
577 constant = constant[0]
578 else:
579 out = MultContainer([KERNEL.objs[maxvar], newadd])
580 out.prefactor = self.prefactor
581 if newadd.prefactor != 1:
582 out.prefactor *= newadd.prefactor
583 newadd.prefactor = 1
584 return out
585 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant],
586 self.prefactor)
587 return out
588
590
591 vartype = 6
592
596
598 """ String representation """
599 if self.prefactor !=1:
600 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self]))
601 else:
602 text = '(%s)' % (' * '.join([str(t) for t in self]))
603 return text
604
606 self[:] = [term.factorize() for term in self]
607
610 """ A list of Variable with multiplication as operator between themselves.
611 Represented by array for speed optimization
612 """
613 vartype=2
614 addclass = AddVariable
615
616 - def __new__(cls, old=[], prefactor=1):
617 return array.__new__(cls, 'i', old)
618
619
620 - def __init__(self, old=[], prefactor=1):
621 """ initialization of the object with default value """
622
623 self.prefactor = prefactor
624 assert isinstance(self.prefactor, (float,int,long,complex))
625
627 assert len(self) == 1
628 return self[0]
629
631 a = list(self)
632 a.sort()
633 self[:] = array('i',a)
634 return self
635
637 """ simplify the product"""
638 if not len(self):
639 return self.prefactor
640 return self
641
642 - def split(self, variables_id):
643 """return a dict with the key being the power associated to each variables
644 and the value being the object remaining after the suppression of all
645 the variable"""
646
647 key = tuple([self.count(i) for i in variables_id])
648 arg = [id for id in self if id not in variables_id]
649 self[:] = array('i', arg)
650 return SplitCoefficient([(key,self)])
651
652 - def replace(self, id, expression):
653 """replace one object (identify by his id) by a given expression.
654 Note that expression cann't be zero.
655 """
656 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult'
657
658 if expression.vartype == 1:
659 nb = self.count(id)
660 if not nb:
661 return self
662 for i in range(nb):
663 self.remove(id)
664 new = self
665 for i in range(nb):
666 new *= expression
667 return new
668 elif expression.vartype == 2:
669
670 nb = self.count(id)
671 for i in range(nb):
672 self.remove(id)
673 self.__imul__(expression)
674 return self
675
676
677
678
679
680
681
682
683
684
685
686 else:
687 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
688
689
691 """return the list of variable used in this multiplication"""
692 return ['%s' % KERNEL.objs[n] for n in self]
693
694
695
696
698 """Define the multiplication with different object"""
699
700 if not hasattr(obj, 'vartype'):
701 if obj:
702 return self.__class__(self, obj*self.prefactor)
703 else:
704 return 0
705 elif obj.vartype == 1:
706 new = obj.__class__([], self.prefactor*obj.prefactor)
707 old, self.prefactor = self.prefactor, 1
708 new[:] = [self * term for term in obj]
709 self.prefactor = old
710 return new
711 elif obj.vartype == 4:
712 return NotImplemented
713
714 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
715
716 __rmul__ = __mul__
717
719 """Define the multiplication with different object"""
720
721 if not hasattr(obj, 'vartype'):
722 if obj:
723 self.prefactor *= obj
724 return self
725 else:
726 return 0
727 elif obj.vartype == 1:
728 new = obj.__class__([], self.prefactor * obj.prefactor)
729 self.prefactor = 1
730 new[:] = [self * term for term in obj]
731 return new
732 elif obj.vartype == 4:
733 return NotImplemented
734
735 self.prefactor *= obj.prefactor
736 return array.__iadd__(self, obj)
737
739 out = 1
740 for i in range(value):
741 out *= self
742 return out
743
744
746 """ define the adition with different object"""
747
748 if not obj:
749 return self
750 elif not hasattr(obj, 'vartype') or obj.vartype == 2:
751 new = self.addclass([self, obj])
752 return new
753 else:
754
755 return NotImplemented
756 __radd__ = __add__
757 __iadd__ = __add__
758
760 return self + (-1) * obj
761
763 self.prefactor *=-1
764 return self
765
767 return (-1) * self + obj
768
770 """ ONLY NUMBER DIVISION ALLOWED"""
771 assert not hasattr(obj, 'vartype')
772 self.prefactor /= obj
773 return self
774
775 __div__ = __idiv__
776 __truediv__ = __div__
777
778
780 """ String representation """
781 t = ['%s' % KERNEL.objs[n] for n in self]
782 if self.prefactor != 1:
783 text = '(%s * %s)' % (self.prefactor,' * '.join(t))
784 else:
785 text = '(%s)' % (' * '.join(t))
786 return text
787
788 __rep__ = __str__
789
792
800
804
808
811 """This is the standard object for all the variable linked to expression.
812 """
813 mult_class = MultVariable
814
815 - def __new__(cls, name, baseclass, *args):
816 """Factory class return a MultVariable."""
817
818 if name in KERNEL:
819 return cls.mult_class([KERNEL[name]])
820 else:
821 obj = baseclass(name, *args)
822 id = KERNEL.add(name, obj)
823 obj.id = id
824 return cls.mult_class([id])
825
830
843
844
845
846
847
848
849
850
851
852
853
854
855 -class MultLorentz(MultVariable):
856 """Specific class for LorentzObject Multiplication"""
857
858 add_class = AddVariable
859
861 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
862 the contraction in this Multiplication."""
863
864 out = {}
865 len_mult = len(self)
866
867 for i, fact in enumerate(self):
868
869 for j in range(len(fact.lorentz_ind)):
870
871 for k in range(i+1,len_mult):
872 fact2 = self[k]
873 try:
874 l = fact2.lorentz_ind.index(fact.lorentz_ind[j])
875 except Exception:
876 pass
877 else:
878 out[(i, j)] = (k, l)
879 out[(k, l)] = (i, j)
880 return out
881
883 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
884 the contraction in this Multiplication."""
885
886 out = {}
887 len_mult = len(self)
888
889 for i, fact in enumerate(self):
890
891 for j in range(len(fact.spin_ind)):
892
893 for k in range(i+1, len_mult):
894 fact2 = self[k]
895 try:
896 l = fact2.spin_ind.index(fact.spin_ind[j])
897 except Exception:
898 pass
899 else:
900 out[(i, j)] = (k, l)
901 out[(k, l)] = (i, j)
902
903 return out
904
906 """return one variable which are contracted with var and not yet expanded"""
907
908 for var in self.unused:
909 obj = KERNEL.objs[var]
910 if obj.has_component(home.lorentz_ind, home.spin_ind):
911 return obj
912 return None
913
914
915
916
918 """ expand each part of the product and combine them.
919 Try to use a smart order in order to minimize the number of uncontracted indices.
920 Veto forbids the use of sub-expression if it contains some of the variable in the
921 expression. Veto contains the id of the vetoed variables
922 """
923
924 self.unused = self[:]
925
926 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first]
927 product_term = []
928 current = None
929
930 while self.unused:
931
932 if not current:
933
934 try:
935
936 current = basic_end_point.pop()
937 except Exception:
938
939 current = self.unused.pop()
940 else:
941
942 if current not in self.unused:
943 current = None
944 continue
945
946 self.unused.remove(current)
947 cur_obj = KERNEL.objs[current]
948
949 product_term.append(cur_obj.expand())
950
951
952 var_obj = self.neighboor(product_term[-1])
953
954
955 if var_obj:
956 product_term[-1] *= var_obj.expand()
957 cur_obj = var_obj
958 self.unused.remove(cur_obj.id)
959 continue
960
961 current = None
962
963
964
965
966 out = self.prefactor
967 for fact in product_term[:]:
968 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []:
969 scalar = fact.get_rep([0])
970 if hasattr(scalar, 'vartype') and scalar.vartype == 1:
971 if not veto or not scalar.contains(veto):
972 scalar = scalar.simplify()
973 prefactor = 1
974
975 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]:
976 prefactor = scalar.prefactor
977 scalar.prefactor = 1
978 new = KERNEL.add_expression_contraction(scalar)
979 fact.set_rep([0], prefactor * new)
980 out *= fact
981 return out
982
984 """ create a shadow copy """
985 new = MultLorentz(self)
986 new.prefactor = self.prefactor
987 return new
988
993 """ A symbolic Object for All Helas object. All Helas Object Should
994 derivated from this class"""
995
996 contract_first = 0
997 mult_class = MultLorentz
998 add_class = AddVariable
999
1000 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
1001 """ initialization of the object with default value """
1002 assert isinstance(lor_ind, list)
1003 assert isinstance(spin_ind, list)
1004
1005 self.name = name
1006 self.lorentz_ind = lor_ind
1007 self.spin_ind = spin_ind
1008 KERNEL.add_tag(set(tags))
1009
1011 """Expand the content information into LorentzObjectRepresentation."""
1012
1013 try:
1014 return self.representation
1015 except Exception:
1016 self.create_representation()
1017 return self.representation
1018
1020 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1021
1023 """check if this Lorentz Object have some of those indices"""
1024
1025 if any([id in self.lorentz_ind for id in lor_list]) or \
1026 any([id in self.spin_ind for id in spin_list]):
1027 return True
1028
1029
1030
1032 return '%s' % self.name
1033
1035 """ A symbolic Object for All Helas object. All Helas Object Should
1036 derivated from this class"""
1037
1038 mult_class = MultLorentz
1039 object_class = LorentzObject
1040
1044
1045 @classmethod
1047 """default way to have a unique name"""
1048 return '_L_%(class)s_%(args)s' % \
1049 {'class':cls.__name__,
1050 'args': '_'.join(args)
1051 }
1052
1058 """A concrete representation of the LorentzObject."""
1059
1060 vartype = 4
1061
1063 """Specify error for LorentzObjectRepresentation"""
1064
1065 - def __init__(self, representation, lorentz_indices, spin_indices):
1066 """ initialize the lorentz object representation"""
1067
1068 self.lorentz_ind = lorentz_indices
1069 self.nb_lor = len(lorentz_indices)
1070 self.spin_ind = spin_indices
1071 self.nb_spin = len(spin_indices)
1072 self.nb_ind = self.nb_lor + self.nb_spin
1073
1074
1075
1076 if self.lorentz_ind or self.spin_ind:
1077 dict.__init__(self, representation)
1078 elif isinstance(representation,dict):
1079 if len(representation) == 0:
1080 self[(0,)] = 0
1081 elif len(representation) == 1 and (0,) in representation:
1082 self[(0,)] = representation[(0,)]
1083 else:
1084 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1085 else:
1086 if isinstance(representation,dict):
1087 try:
1088 self[(0,)] = representation[(0,)]
1089 except Exception:
1090 if representation:
1091 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1092 else:
1093 self[(0,)] = 0
1094 else:
1095 self[(0,)] = representation
1096
1098 """ string representation """
1099 text = 'lorentz index :' + str(self.lorentz_ind) + '\n'
1100 text += 'spin index :' + str(self.spin_ind) + '\n'
1101
1102 for ind in self.listindices():
1103 ind = tuple(ind)
1104 text += str(ind) + ' --> '
1105 text += str(self.get_rep(ind)) + '\n'
1106 return text
1107
1109 """return the value/Variable associate to the indices"""
1110 return self[tuple(indices)]
1111
1112 - def set_rep(self, indices, value):
1113 """assign 'value' at the indices position"""
1114
1115 self[tuple(indices)] = value
1116
1118 """Return an iterator in order to be able to loop easily on all the
1119 indices of the object."""
1120 return IndicesIterator(self.nb_ind)
1121
1122 @staticmethod
1124 shift = len(switch_order)
1125 for value in l1:
1126 try:
1127 index = l2.index(value)
1128 except Exception:
1129 raise LorentzObjectRepresentation.LorentzObjectRepresentationError(
1130 "Invalid addition. Object doen't have the same lorentz "+ \
1131 "indices : %s != %s" % (l1, l2))
1132 else:
1133 switch_order.append(shift + index)
1134 return switch_order
1135
1136
1138
1139 if not obj:
1140 return self
1141
1142 if not hasattr(obj, 'vartype'):
1143 assert self.lorentz_ind == []
1144 assert self.spin_ind == []
1145 new = self[(0,)] + obj * fact
1146 out = LorentzObjectRepresentation(new, [], [])
1147 return out
1148
1149 assert(obj.vartype == 4 == self.vartype)
1150
1151 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1152
1153 switch_order = []
1154 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order)
1155 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order)
1156 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1157 else:
1158
1159 switch = lambda ind : (ind)
1160
1161
1162 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind))
1163 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind))
1164
1165
1166 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind)
1167
1168
1169 if fact == 1:
1170 for ind in self.listindices():
1171 value = obj.get_rep(ind) + self.get_rep(switch(ind))
1172 new.set_rep(ind, value)
1173 else:
1174 for ind in self.listindices():
1175 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind)
1176 new.set_rep(ind, value)
1177
1178 return new
1179
1181
1182 if not obj:
1183 return self
1184
1185 assert(obj.vartype == 4 == self.vartype)
1186
1187 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1188
1189
1190 switch_order = []
1191 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order)
1192 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order)
1193 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1194 else:
1195
1196 switch = lambda ind : (ind)
1197
1198
1199 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind))
1200 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind))
1201
1202
1203 if fact == 1:
1204 for ind in self.listindices():
1205 self[tuple(ind)] += obj.get_rep(switch(ind))
1206 else:
1207 for ind in self.listindices():
1208 self[tuple(ind)] += fact * obj.get_rep(switch(ind))
1209 return self
1210
1212 return self.__add__(obj, fact= -1)
1213
1215 return obj.__add__(self, fact= -1)
1216
1218 return self.__add__(obj, fact= -1)
1219
1221 self *= -1
1222 return self
1223
1225 """multiplication performing directly the einstein/spin sommation.
1226 """
1227
1228 if not hasattr(obj, 'vartype'):
1229 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind)
1230 for ind in out.listindices():
1231 out.set_rep(ind, obj * self.get_rep(ind))
1232 return out
1233
1234
1235 assert(obj.__class__ == LorentzObjectRepresentation), \
1236 '%s is not valid class for this operation' %type(obj)
1237
1238
1239
1240 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \
1241 obj.lorentz_ind)
1242 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \
1243 obj.spin_ind)
1244 if not(sum_l_ind or sum_s_ind):
1245
1246 return self.tensor_product(obj)
1247
1248
1249
1250 new_object = LorentzObjectRepresentation({}, l_ind, s_ind)
1251
1252 for indices in new_object.listindices():
1253
1254 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind)
1255 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind)
1256
1257 new_object.set_rep(indices, \
1258 self.contraction(obj, sum_l_ind, sum_s_ind, \
1259 dict_l_ind, dict_s_ind))
1260
1261 return new_object
1262
1263 __rmul__ = __mul__
1264 __imul__ = __mul__
1265
1266 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1267 """ make the Lorentz/spin contraction of object self and obj.
1268 l_sum/s_sum are the position of the sum indices
1269 l_dict/s_dict are dict given the value of the fix indices (indices->value)
1270 """
1271 out = 0
1272 len_l = len(l_sum)
1273 len_s = len(s_sum)
1274
1275
1276
1277 for l_value in IndicesIterator(len_l):
1278 l_dict.update(self.pass_ind_in_dict(l_value, l_sum))
1279 for s_value in IndicesIterator(len_s):
1280
1281 s_dict.update(self.pass_ind_in_dict(s_value, s_sum))
1282
1283
1284 self_ind = self.combine_indices(l_dict, s_dict)
1285 obj_ind = obj.combine_indices(l_dict, s_dict)
1286
1287
1288 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind)
1289
1290 if factor:
1291
1292 try:
1293 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0))
1294 except Exception:
1295 factor *= (-1) ** (len(l_value) - l_value.count(0))
1296 out += factor
1297 return out
1298
1300 """ return the tensorial product of the object"""
1301 assert(obj.vartype == 4)
1302
1303 new_object = LorentzObjectRepresentation({}, \
1304 self.lorentz_ind + obj.lorentz_ind, \
1305 self.spin_ind + obj.spin_ind)
1306
1307
1308 lor1 = self.nb_lor
1309 lor2 = obj.nb_lor
1310 spin1 = self.nb_spin
1311 spin2 = obj.nb_spin
1312
1313
1314 if lor1 == 0 == spin1:
1315
1316 selfind = lambda indices: [0]
1317 else:
1318 selfind = lambda indices: indices[:lor1] + \
1319 indices[lor1 + lor2: lor1 + lor2 + spin1]
1320
1321
1322 if lor2 == 0 == spin2:
1323
1324 objind = lambda indices: [0]
1325 else:
1326 objind = lambda indices: indices[lor1: lor1 + lor2] + \
1327 indices[lor1 + lor2 + spin1:]
1328
1329
1330 for indices in new_object.listindices():
1331
1332 fac1 = self.get_rep(tuple(selfind(indices)))
1333 fac2 = obj.get_rep(tuple(objind(indices)))
1334 new_object.set_rep(indices, fac1 * fac2)
1335
1336 return new_object
1337
1339 """Try to factorize each component"""
1340 for ind, fact in self.items():
1341 if fact:
1342 self.set_rep(ind, fact.factorize())
1343
1344
1345 return self
1346
1348 """Check if we can simplify the object (check for non treated Sum)"""
1349
1350
1351 for ind, term in self.items():
1352 if hasattr(term, 'vartype'):
1353 self[ind] = term.simplify()
1354
1355 return self
1356
1357 @staticmethod
1359 """return two list, the first one contains the position of non summed
1360 index and the second one the position of summed index."""
1361
1362
1363
1364
1365
1366
1367 are_unique, are_sum = [], []
1368
1369
1370 for indice in list1:
1371 if indice in list2:
1372 are_sum.append(indice)
1373 else:
1374 are_unique.append(indice)
1375
1376
1377 for indice in list2:
1378 if indice not in are_sum:
1379 are_unique.append(indice)
1380
1381
1382 return are_unique, are_sum
1383
1384 @staticmethod
1386 """made a dictionary (pos -> index_value) for how call the object"""
1387 if not key:
1388 return {}
1389 out = {}
1390 for i, ind in enumerate(indices):
1391 out[key[i]] = ind
1392 return out
1393
1395 """return the indices in the correct order following the dicts rules"""
1396
1397 out = []
1398
1399 for value in self.lorentz_ind:
1400 out.append(l_dict[value])
1401
1402 for value in self.spin_ind:
1403 out.append(s_dict[value])
1404
1405 return out
1406
1407 - def split(self, variables_id):
1408 """return a dict with the key being the power associated to each variables
1409 and the value being the object remaining after the suppression of all
1410 the variable"""
1411
1412 out = SplitCoefficient()
1413 zero_rep = {}
1414 for ind in self.listindices():
1415 zero_rep[tuple(ind)] = 0
1416
1417 for ind in self.listindices():
1418
1419 if isinstance(self.get_rep(ind), numbers.Number):
1420 if tuple([0]*len(variables_id)) in out:
1421 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1422 else:
1423 out[tuple([0]*len(variables_id))] = \
1424 LorentzObjectRepresentation(dict(zero_rep),
1425 self.lorentz_ind, self.spin_ind)
1426 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1427 continue
1428
1429 for key, value in self.get_rep(ind).split(variables_id).items():
1430 if key in out:
1431 out[key][tuple(ind)] += value
1432 else:
1433 out[key] = LorentzObjectRepresentation(dict(zero_rep),
1434 self.lorentz_ind, self.spin_ind)
1435 out[key][tuple(ind)] += value
1436
1437 return out
1438
1446 """Class needed for the iterator"""
1447
1449 """ create an iterator looping over the indices of a list of len "len"
1450 with each value can take value between 0 and 3 """
1451
1452 self.len = len
1453 if len:
1454
1455
1456 self.data = [-1] + [0] * (len - 1)
1457 else:
1458
1459 self.data = 0
1460 self.next = self.nextscalar
1461
1464
1466 for i in range(self.len):
1467 if self.data[i] < 3:
1468 self.data[i] += 1
1469 return self.data
1470 else:
1471 self.data[i] = 0
1472 raise StopIteration
1473
1475 if self.data:
1476 raise StopIteration
1477 else:
1478 self.data = True
1479 return [0]
1480
1482
1486
1488 """return the highest rank of the coefficient"""
1489
1490 return max([max(arg[:4]) for arg in self])
1491
1492
1493 if '__main__' ==__name__:
1494
1495 import cProfile
1499
1500 cProfile.run('create()')
1501