diff --git a/study.py b/study.py index 20c1475..52433c5 100644 --- a/study.py +++ b/study.py @@ -10,7 +10,16 @@ def study(): print("You need to pass a configuration file in parameter") exit(1) - config = importlib.import_module(sys.argv[1]) + try: + config = importlib.import_module(sys.argv[1]) + except ModuleNotFoundError as e: + try: + config = importlib.import_module(str(sys.argv[1]).replace(".py", "")) + except ModuleNotFoundError as e: + print(e) + print("You need to pass a configuration file in parameter") + exit(1) + shape = Shape(0, 0, 0, 0, 0, 0) sim = Simulator(shape, config) sim.initLogger()