diff --git "a/MIXTRAL_DatosValidacion.txt" "b/MIXTRAL_DatosValidacion.txt"
new file mode 100644--- /dev/null
+++ "b/MIXTRAL_DatosValidacion.txt"
@@ -0,0 +1,18101 @@
+[INST] Calculate fasta statistics given the path to a fasta/multifasta.. [/INST] def swell_from_fasta(fasta_path):
+ if fasta_path == "-":
+ fastas = readfq.readfq(sys.stdin)
+ else:
+ fastas = readfq.readfq(open(fasta_path))
+ rows = []
+ for name, seq, qual in fastas:
+ rows.append([fasta_path, name] + calculate_fasta_stats(seq))
+ if fasta_path != "-":
+ fastas.close()
+ return ["fasta_path", "header", "num_seqs", "num_bases", "pc_acgt", "pc_masked", "pc_invalid", "pc_ambiguous", "longest_gap", "longest_ungap"], rows
+[INST] Calculate fasta statistics directly from a sequence.. [/INST] def swell_from_fasta_seq(seq, fasta_path="", header=""):
+ rows = [[fasta_path, header] + calculate_fasta_stats(seq)]
+ return ["fasta_path", "header", "num_seqs", "num_bases", "pc_acgt", "pc_masked", "pc_invalid", "pc_ambiguous", "longest_gap", "longest_ungap"], rows
+[INST] Sums all the numbers in the specified iterable. [/INST] def add(numbers):
+ the_sum = 0
+ for number in numbers:
+ the_sum += int(number)
+ return the_sum
+[INST] Divides the 1..Nth numbers from the 0th one. [/INST] def divide(numbers):
+ result = numbers[0]
+ for number in numbers[1:]:
+ result /= number
+ return result
+[INST] Raises the 0th number to the 1..Nth numbers as powers. [/INST] def exponent(numbers):
+ result = numbers[0]
+ for number in numbers[1:]:
+ result *= number
+ return result
+[INST] Takes numbers params and sums them together. [/INST] def add():
+ if not request.args.getlist('numbers'):
+ return "You need to give numbers to add", 400
+ numbers = request.args.get('numbers')
+ return str(calculator.add(numbers))
+[INST] Takes numbers params and subtracts them from the first. [/INST] def subtract():
+ if not request.args.get('numbers'):
+ return "You need to give numbers to subtract", 400
+ numbers = request.args.getlist('numbers')
+ return str(calculator.subtract(numbers))
+[INST] Takes numbers params and multiplies them together. [/INST] def multiply():
+ if not request.args.get('numbers'):
+ return "You need to give numbers to multiply", 400
+ numbers = request.args['numbers']
+ return str(calculator.multiply(numbers))
+[INST] Takes numbers params and divides them.. [/INST] def divide():
+ if not request.args.get('numbers'):
+ return "You need to give numbers to divide", 400
+ return "I'm a divider"
+[INST] Takes numbers params and creates the exponential of them aka x^y.. [/INST] def exponential():
+ if not request.args.get('numbers'):
+ return "You need to give numbers to do x^y", 400
+ numbers = request.args['numbers']
+ return str(calculator.exponent(numbers))
+[INST] Calculates the root of the inputs.
+Takes numbers params and finds the xth root of y where x is the first
+number param and y is the second. [/INST] def root():
+ if not request.args.get('numbers'):
+ return "You need to give numbers to find the root.", 400
+ numbers = request.args['numbers']
+ return str(calculator.root(numbers))
+[INST] Mean between the predictions for the five characters.. [/INST] def accuracy(output, target):
+ accuracy_result = 0
+ for y, t in zip(output, target):
+ _, predicted = torch.max(y.data, 1)
+ correct_predictions = (predicted == t).sum().item()
+ accuracy_result += correct_predictions / t.size(0)
+ return accuracy_result / len(target)
+[INST] Function for obtaining temperature from Yandex.
+Accepts the name of the city (city).
+The temperature is extracted directly from the HTML page.
+Return a list with two values.
+In case of success - value of temperature and URL, otherwise - error and URL.. [/INST] def yandex(city):
+ try:
+ myheader = {
+ 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 ' \
+ '(KHTML, like Gecko) Chrome/61.0.3163.79 Safari/537.36',
+ }
+ url = 'https://yandex.ru/pogoda/{}'.format(city)
+ req = request.Request(url, headers = myheader)
+ answer = BeautifulSoup(request.urlopen(req))
+ temp_value = str(answer.find('span', 'temp__value')).split('>')[1].split('<')[0]
+ if (temp_value[0] == '-') or (ord(temp_value[0]) == 8722):
+ temp = float(temp_value[1:]) * (-1)
+ else:
+ temp = float(temp_value)
+ return [temp, url]
+ except Exception as connection_error:
+ return [connection_error, url]
+[INST] Function for obtaining temperature from Yandex.
+Accepts the name of the city (city)
+The temperature is extracted from the JSON file obtained with OpenWeatherMap API.
+
+Return a list with two values.
+In case of success - value of temperature and URL, otherwise - error and URL.. [/INST] def open_weather_map(city):
+ try:
+ url = 'http://api.openweathermap.org/data/2.5/weather?q={}&appid=c7365fbce4cdaa0eed49c8adb6828336'.format(city)
+ req = requests.get(url)
+ temperature = float(req.json()['main']['temp']) - 273.15
+ return [round(temperature, 1), url]
+ except Exception as connection_error:
+ return [connection_error, url]
+[INST] Auto-update weather function
+The function takes a list of the cities to update.
+
+If the error connecting to sources - an error with
+a status of 500 and JSON with the cause of the error and URL.
+
+If the connection is successful, it enters the
+data into the database and returns an empty response with code 200.. [/INST] def auto_update_function(cities):
+ try:
+ connect = psycopg2.connect(database = 'django_test', user = 'roman',
+ host = 'localhost', password = 'admin')
+ cursor = connect.cursor()
+ cursor.execute(
+ 'SELECT city_name FROM frontend_city;'
+ )
+ utc_timezone = pytz.timezone('UTC')
+ cities_list = []
+ cities_cursor = cursor.fetchall()
+ for i in range(len(cities_cursor)):
+ cities_list.append(cities_cursor[i][0])
+ for i in range(len(cities)):
+ yandex_value = yandex(cities[i])
+ open_weather_value = open_weather_map(cities[i])
+ if type(yandex_value[0]) == error.HTTPError:
+ data = {
+ 'Error': 'Error in auto update function.',
+ 'Time': str(datetime.datetime.now(utc_timezone)),
+ 'Reason': '{}. Please, check url: {}'.format(yandex_value[0], yandex_value[1])
+ }
+ json_data_error = json.dumps(data)
+ response = HttpResponse(json_data_error, status=500, content_type='application/json', charset='utf-8')
+ return response
+ elif (type(open_weather_value[0]) == error.HTTPError):
+ data = {
+ 'Error': 'Error in auto update function.',
+ 'Time': datetime.datetime.now(utc_timezone),
+ 'Reason': '{}. Please, check url: {}'.format(open_weather_value[0], open_weather_value[1])
+ }
+ json_data_error = json.dumps(data)
+ response = HttpResponse(json_data_error, status=500, content_type='application/json', charset='utf-8')
+ return response
+ elif (cities[i] not in cities_list):
+ cursor.execute("INSERT INTO frontend_city (city_name) values ('{}');".format(cities[i]))
+ connect.commit()
+ data = {
+ 'Yandex': str(yandex_value[0]),
+ 'Open weather': str(open_weather_value[0])
+ }
+ cursor.execute("SELECT id FROM frontend_city WHERE city_name = '{}';".format(cities[i]))
+ city_id = cursor.fetchall()
+ city_id = city_id[0][0]
+ json_data = json.dumps(data)
+ cursor.execute(
+ "INSERT INTO frontend_history (city_id, temp_values, created) \
+ VALUES ({},'{}', '{}');".format(city_id, json_data,
+ datetime.datetime.now(utc_timezone)))
+ connect.commit()
+ connect.close()
+ response = HttpResponse(status=200, content_type='text/html', charset='utf-8')
+ return response
+ except Exception as connection_db_error:
+ data = {
+ 'Error': 'Error in auto update function.',
+ 'Time': str(datetime.datetime.now(utc_timezone)),
+ 'Reason': '{}'.format(connection_db_error)
+ }
+ json_data_error = json.dumps(data)
+ response = HttpResponse(json_data_error, status=500, content_type='application/json', charset='utf-8')
+ return response
+ connect.close()
+[INST] A script to retrieve data from the last update.
+First check the availability of the city in the database - if not,
+then the error 404 and JSON with error and reason.
+
+If the city is in the database - sort by the time of the addition and select the last entry.
+Return JSON with the results, code 200.
+
+If the error connecting to the database is a response
+with the code 500 and the cause of the error in JSON.. [/INST] def last_update_temperature(city):
+ try:
+ utc_timezone = pytz.timezone('UTC')
+ connect = psycopg2.connect(database='django_test', user='roman',
+ host='localhost', password='admin')
+ cursor = connect.cursor()
+ cursor.execute("SELECT id FROM frontend_city WHERE city_name = '{}';".format(city))
+ city_id = cursor.fetchall()
+ if len(city_id) == 0:
+ data = {
+ 'Error': 'Error in last update function.',
+ 'Time': str(datetime.datetime.now(utc_timezone)),
+ 'Reason': 'Please, check city name: {}'.format(city)
+ }
+ json_data_error = json.dumps(data)
+ response = HttpResponse(json_data_error, status=404, content_type='application/json', charset='utf-8')
+ return response
+ city_id = city_id[0][0]
+ cursor.execute(
+ "SELECT temp_values, created FROM frontend_history \
+ WHERE city_id = {} ORDER BY created;".format(city_id)
+ )
+ last_update = cursor.fetchall()[-1]
+ temperature = json.loads(last_update[0])
+ data = {
+ 'Time': str(last_update[1].astimezone(utc_timezone)),
+ 'Temperature': temperature,
+ }
+ json_data = json.dumps(data)
+ response = HttpResponse(json_data, status=200, content_type='application/json', charset='utf-8')
+ return response
+ except Exception as connection_db_error:
+ data = {
+ 'Error': 'Error in last update function.',
+ 'Time': str(datetime.datetime.now(utc_timezone)),
+ 'Reason': '{}'.format(connection_db_error)
+ }
+ json_data_error = json.dumps(data)
+ response = HttpResponse(json_data_error, status=500, content_type='application/json', charset='utf-8')
+ return response
+ connect.close()
+[INST] Manual copy of state dict.
+Why . Because when copying a state dict to another with load_state_dict, the values of weight are copied only
+when keys are the same in both state_dict, even if strict=False.. [/INST] def copy_state_dict(state_dict_1, state_dict_2):
+ state1_keys = list(state_dict_1.keys())
+ state2_keys = list(state_dict_2.keys())
+ for x in range(len(state1_keys)):
+ state_dict_2[state2_keys[x]] = state_dict_1[state1_keys[x]]
+ return state_dict_2
+[INST] Building the histogram of all possible orientation bins, given the number of bins per dimension and
+min/max limits on Z, Y and X axis (rotation).. [/INST] def build_histogram(n_bins_per_dim, min_lim, max_lim):
+ d = 3
+ n_bins = n_bins_per_dim ** d
+ bins_per_dim = torch.linspace(0.0, 1.0, n_bins_per_dim)
+ bins_all_dims = torch.cartesian_prod(bins_per_dim, bins_per_dim, bins_per_dim)
+ euler_bins = bins_all_dims * (max_lim - min_lim) + min_lim
+ quaternions_bins = torch.zeros((n_bins, 4), dtype=torch.float32)
+ for i in range(n_bins):
+ quaternions_bins[i, :] = se3lib.euler2quat(euler_bins[i, 0], euler_bins[i, 1], euler_bins[i, 2])
+ boundary_flags = torch.logical_or(euler_bins[:, 0] == max_lim[0], euler_bins[:, 2] == max_lim[2])
+ gymbal_flags = torch.logical_and(np.abs(euler_bins[:, 1]) == max_lim[1], euler_bins[:, 0] != min_lim[0])
+ redundant_flags = torch.logical_or(boundary_flags, gymbal_flags)
+ return quaternions_bins, redundant_flags
+[INST] Decode a batch of orientation (ori) using the pre-computed orientation decode variable (b) based on the histogram. [/INST] def decode_ori_batch(ori, b):
+ ori = ori.cpu()
+ batch_size = ori.size(0)
+ ori_avg = torch.zeros((batch_size, 4), dtype=torch.float32)
+ h_avg = torch.zeros((batch_size, 4, 4), dtype=torch.float32)
+ for i in range(batch_size):
+ ori_avg[i], h_avg[i] = decode_ori(ori[i], b)
+ return ori_avg, h_avg
+[INST] copy and resize Speed images to a new directory. The new (empty) folders must be created before calling
+this function. [/INST] def copy_speed_dataset_resize(old_path, new_path, new_size=(224, 224), split='train'):
+ if split not in {'train', 'test', 'real_test'}:
+ raise ValueError('Invalid split, has to be either \'train\', \'test\' or \'real_test\'')
+ with open(os.path.join(old_path, split + '.json'), 'r') as f:
+ target_list = json.load(f)
+ sample_ids = [label['filename'] for label in target_list]
+ image_root = os.path.join(old_path, 'images', split)
+ image_resize_root = os.path.join(new_path, 'images', split)
+ for i in range(len(sample_ids)):
+ img_name = os.path.join(image_root, sample_ids[i])
+ pil_img = Image.open(img_name).resize(new_size)
+ new_name = os.path.join(image_resize_root, sample_ids[i])
+ pil_img.save(new_name)
+[INST] Import the dataset. May take some seconds as we pre-compute the histogram to save time later. [/INST] def import_dataset(self):
+ print('Import dataset...')
+ if self.config.DATASET == 'SPEED':
+ dataloader = prepare_speed_dataset(self.config)
+ else:
+ raise ValueError('Dataset must be \'SPEED\' (URSO dataset not implemented)')
+ return dataloader
+[INST] Evaluation on test set for submission on ESA website. [/INST] def evaluate_submit(self, sub):
+ for phase in ['test', 'real_test']:
+ loop = tqdm(self.dataloader[phase], desc="Evaluation for submission",
+ bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', file=sys.stdout)
+ for inputs, filenames in loop:
+ inputs = inputs.to(self.config.DEVICE)
+ with torch.cuda.amp.autocast(enabled=self.config.USE_AMP):
+ ori, pos = self.predict(inputs, phase)
+ if self.config.ORI_TYPE == 'Classification':
+ ori, _ = decode_ori_batch(ori, self.config.B)
+ append = sub.append_real_test if phase == 'real_test' else sub.append_test
+ for filename, q, r in zip(filenames, ori.cpu().numpy(), pos.cpu().numpy()):
+ append(filename, q, r)
+[INST] Evaluation on validation set. Distance with the target spacecraft is also returned for each prediction. [/INST] def eval_error_distance(self):
+ phase = 'valid'
+ loop = tqdm(self.dataloader[phase], desc="Evaluation by distance", file=sys.stdout,
+ bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
+ ori_error = []
+ pos_error = []
+ distance = []
+ for inputs, targets in loop:
+ inputs = inputs.to(self.config.DEVICE)
+ targets['ori'], targets['pos'] = targets['ori'].to(self.config.DEVICE), \
+ targets['pos'].to(self.config.DEVICE)
+ with torch.set_grad_enabled(False):
+ ori, pos = self.predict(inputs, phase)
+ if self.config.ORI_TYPE == 'Classification':
+ ori, _ = decode_ori_batch(ori, self.config.B)
+ dist = torch.linalg.norm(targets['pos'], dim=1)
+ pos_err = torch.linalg.norm(targets['pos'] - pos, dim=1)
+ if self.config.ORI_TYPE == 'Regression':
+ inter_sum = torch.abs(torch.sum(ori * targets['ori'], dim=1))
+ else:
+ inter_sum = torch.abs(torch.sum(ori * targets['ori_original'], dim=1))
+ if True in inter_sum[inter_sum > 1.01]:
+ raise ValueError('Error computing orientation score intermediate sum')
+ inter_sum[inter_sum > 1] = 1
+ ori_err = 2 * torch.arccos(inter_sum) * 180 / np.pi
+ distance.extend(dist.tolist())
+ ori_error.extend(ori_err.tolist())
+ pos_error.extend(pos_err.tolist())
+ return ori_error, pos_error, distance
+[INST] This is an objective function for hyperparameter tuning with Optuna. [/INST] def objective(self, trial):
+ self.hparam_step += 1
+ self.config.WEIGHT_DECAY = trial.suggest_float("WEIGHT_DECAY", 0, 1e-2, step=1e-5)
+ self.model = self.import_model()
+ self.dataloader = self.import_dataset()
+ self.ori_criterion, self.pos_criterion = self.set_loss()
+ self.optimizer = self.set_optimizer()
+ self.scheduler = self.set_scheduler()
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.config.USE_AMP)
+ model, loss, score = self.train()
+ idx = score['valid']['esa'].index(min(score['valid']['esa']))
+ trial.report(score['real']['esa'][idx], step=0)
+ score_to_optimize = score['valid']['esa'][idx]
+ return score_to_optimize
+[INST] Page load definition for thermostat statistics.. [/INST] def smap_plot_thermostat(request, mac):
+ print "inside smap view method"
+ context = RequestContext(request)
+ if request.method == 'GET':
+ mac = '18b4302964f1'
+ device_metadata = [ob.device_control_page_info() for ob in DeviceMetadata.objects.filter(mac_address=mac)]
+ print device_metadata
+ device_id = device_metadata[0]['device_id']
+ device_type_id = device_metadata[0]['device_model_id']
+ mac_address = device_metadata[0]['mac_address']
+ device_type_id = device_type_id.device_model_id
+ print device_type_id
+ device_status = [ob.data_as_json() for ob in Thermostat.objects.filter(thermostat_id=device_id)]
+ device_zone = device_status[0]['zone']['id']
+ device_nickname = device_status[0]['nickname']
+ zone_nickname = device_status[0]['zone']['zone_nickname']
+ device_info = str(device_zone) + '/thermostat/' + device_id
+ device_info = device_info.encode('ascii', 'ignore')
+ device_smap_tag = '/bemoss/' + str(device_zone) + '/thermostat/' + device_id
+ device_smap_tag = device_smap_tag.encode('ascii', 'ignore')
+ temperature = device_smap_tag + '/temperature'
+ heat_setpoint = device_smap_tag + '/heat_setpoint'
+ cool_setpoint = device_smap_tag + '/cool_setpoint'
+ print temperature
+ _uuid_temperature = get_uuid_for_data_point(temperature)
+ _uuid_heat_setpoint = get_uuid_for_data_point(heat_setpoint)
+ _uuid_cool_setpoint = get_uuid_for_data_point(cool_setpoint)
+ zones = [ob.as_json() for ob in Building_Zone.objects.all()]
+ thermostats_sn = [ob.data_side_nav() for ob in Thermostat.objects.filter(network_status='ONLINE', thermostat_id__bemoss=True)]
+ vav_sn = [ob.data_side_nav() for ob in VAV.objects.filter(network_status='ONLINE', vav_id__bemoss=True)]
+ rtu_sn = [ob.data_side_nav() for ob in RTU.objects.filter(network_status='ONLINE', rtu_id__bemoss=True)]
+ lighting_sn = [ob.data_side_nav() for ob in Lighting.objects.filter(network_status='ONLINE', lighting_id__bemoss=True)]
+ plugload_sn = [ob.data_side_nav() for ob in Plugload.objects.filter(network_status='ONLINE', plugload_id__bemoss=True)]
+ occ_sensors_sn = [ob.data_side_nav() for ob in OccupancySensor.objects.filter(network_status='ONLINE', occupancy_sensor_id__bemoss=True)]
+ lt_sensors_sn = [ob.data_side_nav() for ob in AmbientLightSensor.objects.filter(network_status='ONLINE', ambient_light_sensor_id__bemoss=True)]
+ mtn_sensors_sn = [ob.data_side_nav() for ob in MotionSensor.objects.filter(network_status='ONLINE', motion_sensor_id__bemoss=True)]
+ powermeters_sn = [ob.data_side_nav() for ob in PowerMeter.objects.filter(network_status='ONLINE', power_meter_id__bemoss=True)]
+ active_al = get_notifications()
+ context.update({'active_al':active_al})
+ context.update({
+ 'zones': zones, 'thermostat_sn': thermostats_sn,
+ 'lighting_sn': lighting_sn, 'plugload_sn': plugload_sn, 'occ_sensors_sn': occ_sensors_sn,
+ 'lt_sensors_sn': lt_sensors_sn, 'mtn_sensors_sn': mtn_sensors_sn, 'powermeters_sn': powermeters_sn,
+ 'vav_sn': vav_sn, 'rtu_sn': rtu_sn
+ })
+ rs_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_temperature)
+ rs_heat_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_heat_setpoint)
+ rs_cool_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_cool_setpoint)
+ return render_to_response(
+ 'energygame/4_players.html',
+ {'temperature': rs_temperature, 'heat_setpoint': rs_heat_setpoint, 'cool_setpoint': rs_cool_setpoint,
+ 'device_info': device_info, 'mac': mac_address,
+ 'nickname': device_nickname,
+ 'zone_nickname': zone_nickname},
+ context)
+[INST] Page load definition for VAV statistics.. [/INST] def smap_plot_vav(request, mac):
+ print "inside smap view method"
+ context = RequestContext(request)
+ if request.method == 'GET':
+ device_metadata = [ob.device_control_page_info() for ob in DeviceMetadata.objects.filter(mac_address=mac)]
+ print device_metadata
+ device_id = device_metadata[0]['device_id']
+ device_type_id = device_metadata[0]['device_model_id']
+ device_type = device_metadata[0]['device_type']
+ device_type_id = device_type_id.device_model_id
+ print device_type_id
+ device_status = [ob.as_json() for ob in VAV.objects.filter(vav_id=device_id)]
+ device_zone = device_status[0]['zone']['id']
+ device_nickname = device_status[0]['nickname']
+ zone_nickname = device_status[0]['zone']['zone_nickname']
+ device_info = str(device_zone) + '/' + device_type + '/' + device_id
+ device_info = device_info.encode('ascii', 'ignore')
+ device_smap_tag = '/bemoss/' + str(device_zone) + '/' + device_type + '/' + device_id
+ device_smap_tag = device_smap_tag.encode('ascii', 'ignore')
+ temperature = device_smap_tag + '/temperature'
+ supply_temperature = device_smap_tag + '/supply_temperature'
+ heat_setpoint = device_smap_tag + '/heat_setpoint'
+ cool_setpoint = device_smap_tag + '/cool_setpoint'
+ flap_position = device_smap_tag + '/flap_position'
+ print temperature
+ _uuid_temperature = get_uuid_for_data_point(temperature)
+ _uuid_supply_temperature = get_uuid_for_data_point(supply_temperature)
+ _uuid_heat_setpoint = get_uuid_for_data_point(heat_setpoint)
+ _uuid_cool_setpoint = get_uuid_for_data_point(cool_setpoint)
+ _uuid_flap_position = get_uuid_for_data_point(flap_position)
+ zones = [ob.as_json() for ob in Building_Zone.objects.all()]
+ thermostats_sn = [ob.data_side_nav() for ob in Thermostat.objects.filter(network_status='ONLINE', thermostat_id__bemoss=True)]
+ vav_sn = [ob.data_side_nav() for ob in VAV.objects.filter(network_status='ONLINE', vav_id__bemoss=True)]
+ rtu_sn = [ob.data_side_nav() for ob in RTU.objects.filter(network_status='ONLINE', rtu_id__bemoss=True)]
+ lighting_sn = [ob.data_side_nav() for ob in Lighting.objects.filter(network_status='ONLINE', lighting_id__bemoss=True)]
+ plugload_sn = [ob.data_side_nav() for ob in Plugload.objects.filter(network_status='ONLINE', plugload_id__bemoss=True)]
+ occ_sensors_sn = [ob.data_side_nav() for ob in OccupancySensor.objects.filter(network_status='ONLINE', occupancy_sensor_id__bemoss=True)]
+ lt_sensors_sn = [ob.data_side_nav() for ob in AmbientLightSensor.objects.filter(network_status='ONLINE', ambient_light_sensor_id__bemoss=True)]
+ mtn_sensors_sn = [ob.data_side_nav() for ob in MotionSensor.objects.filter(network_status='ONLINE', motion_sensor_id__bemoss=True)]
+ powermeters_sn = [ob.data_side_nav() for ob in PowerMeter.objects.filter(network_status='ONLINE', power_meter_id__bemoss=True)]
+ active_al = get_notifications()
+ context.update({'active_al':active_al})
+ context.update({
+ 'zones': zones, 'thermostat_sn': thermostats_sn,
+ 'lighting_sn': lighting_sn, 'plugload_sn': plugload_sn, 'occ_sensors_sn': occ_sensors_sn,
+ 'lt_sensors_sn': lt_sensors_sn, 'mtn_sensors_sn': mtn_sensors_sn, 'powermeters_sn': powermeters_sn,
+ 'vav_sn': vav_sn, 'rtu_sn': rtu_sn
+ })
+ rs_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_temperature)
+ rs_supply_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_supply_temperature)
+ rs_heat_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_heat_setpoint)
+ rs_cool_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_cool_setpoint)
+ rs_flap_position = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_flap_position)
+ return render_to_response(
+ 'statistics/statistics_vav.html',
+ {'temperature': rs_temperature, 'supply_temperature': rs_supply_temperature,
+ 'flap_position': rs_flap_position, 'heat_setpoint': rs_heat_setpoint, 'cool_setpoint': rs_cool_setpoint,
+ 'zones': zones, 'mac': mac, 'device_info': device_info,
+ 'nickname': device_nickname,
+ 'zone_nickname': zone_nickname},
+ context)
+[INST] Page load definition for RTU statistics.. [/INST] def smap_plot_rtu(request, mac):
+ print "inside smap view method"
+ context = RequestContext(request)
+ if request.method == 'GET':
+ device_metadata = [ob.device_control_page_info() for ob in DeviceMetadata.objects.filter(mac_address=mac)]
+ print device_metadata
+ device_id = device_metadata[0]['device_id']
+ device_type_id = device_metadata[0]['device_model_id']
+ device_type = device_metadata[0]['device_type']
+ device_type_id = device_type_id.device_model_id
+ print device_type_id
+ device_status = [ob.as_json() for ob in RTU.objects.filter(rtu_id=device_id)]
+ device_zone = device_status[0]['zone']['id']
+ device_nickname = device_status[0]['nickname']
+ zone_nickname = device_status[0]['zone']['zone_nickname']
+ device_info = str(device_zone) + '/' + device_type + '/' + device_id
+ device_info = device_info.encode('ascii', 'ignore')
+ device_smap_tag = '/bemoss/' + str(device_zone) + '/' + device_type + '/' + device_id
+ device_smap_tag = device_smap_tag.encode('ascii', 'ignore')
+ outside_temperature = device_smap_tag + '/outside_temperature'
+ return_temperature = device_smap_tag + '/return_temperature'
+ supply_temperature = device_smap_tag + '/supply_temperature'
+ heat_setpoint = device_smap_tag + '/heat_setpoint'
+ cool_setpoint = device_smap_tag + '/cool_setpoint'
+ cooling_mode = device_smap_tag + '/cooling_mode'
+ heating = device_smap_tag + '/heating'
+ outside_damper_position = device_smap_tag + '/outside_damper_position'
+ bypass_damper_position = device_smap_tag + '/bypass_damper_position'
+ _uuid_outside_temperature = get_uuid_for_data_point(outside_temperature)
+ _uuid_return_temperature = get_uuid_for_data_point(return_temperature)
+ _uuid_supply_temperature = get_uuid_for_data_point(supply_temperature)
+ _uuid_heat_setpoint = get_uuid_for_data_point(heat_setpoint)
+ _uuid_cool_setpoint = get_uuid_for_data_point(cool_setpoint)
+ _uuid_cooling_mode = get_uuid_for_data_point(cooling_mode)
+ _uuid_heating = get_uuid_for_data_point(heating)
+ _uuid_outside_damper_position = get_uuid_for_data_point(outside_damper_position)
+ _uuid_bypass_damper_position = get_uuid_for_data_point(bypass_damper_position)
+ zones = [ob.as_json() for ob in Building_Zone.objects.all()]
+ thermostats_sn = [ob.data_side_nav() for ob in Thermostat.objects.filter(network_status='ONLINE', thermostat_id__bemoss=True)]
+ vav_sn = [ob.data_side_nav() for ob in VAV.objects.filter(network_status='ONLINE', vav_id__bemoss=True)]
+ rtu_sn = [ob.data_side_nav() for ob in RTU.objects.filter(network_status='ONLINE', rtu_id__bemoss=True)]
+ lighting_sn = [ob.data_side_nav() for ob in Lighting.objects.filter(network_status='ONLINE', lighting_id__bemoss=True)]
+ plugload_sn = [ob.data_side_nav() for ob in Plugload.objects.filter(network_status='ONLINE', plugload_id__bemoss=True)]
+ occ_sensors_sn = [ob.data_side_nav() for ob in OccupancySensor.objects.filter(network_status='ONLINE', occupancy_sensor_id__bemoss=True)]
+ lt_sensors_sn = [ob.data_side_nav() for ob in AmbientLightSensor.objects.filter(network_status='ONLINE', ambient_light_sensor_id__bemoss=True)]
+ mtn_sensors_sn = [ob.data_side_nav() for ob in MotionSensor.objects.filter(network_status='ONLINE', motion_sensor_id__bemoss=True)]
+ powermeters_sn = [ob.data_side_nav() for ob in PowerMeter.objects.filter(network_status='ONLINE', power_meter_id__bemoss=True)]
+ active_al = get_notifications()
+ context.update({'active_al':active_al})
+ context.update({
+ 'zones': zones, 'thermostat_sn': thermostats_sn,
+ 'lighting_sn': lighting_sn, 'plugload_sn': plugload_sn, 'occ_sensors_sn': occ_sensors_sn,
+ 'lt_sensors_sn': lt_sensors_sn, 'mtn_sensors_sn': mtn_sensors_sn, 'powermeters_sn': powermeters_sn,
+ 'vav_sn': vav_sn, 'rtu_sn': rtu_sn
+ })
+ rs_outside_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_outside_temperature)
+ rs_return_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_return_temperature)
+ rs_supply_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_supply_temperature)
+ rs_heat_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_heat_setpoint)
+ rs_cool_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_cool_setpoint)
+ rs_cooling_mode = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_cooling_mode)
+ rs_heating = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_heating)
+ rs_outside_damper_position = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_outside_damper_position)
+ rs_bypass_damper_position = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_bypass_damper_position)
+ return render_to_response(
+ 'statistics/statistics_rtu.html',
+ {'outside_temperature': rs_outside_temperature, 'supply_temperature': rs_supply_temperature,
+ 'return_temperature': rs_return_temperature, 'heating': rs_heating,
+ 'outside_damper_position': rs_outside_damper_position,
+ 'bypass_damper_position': rs_bypass_damper_position, 'cooling_mode': rs_cooling_mode,
+ 'heat_setpoint': rs_heat_setpoint, 'cool_setpoint': rs_cool_setpoint,
+ 'zones': zones, 'mac': mac, 'device_info': device_info,
+ 'nickname': device_nickname,
+ 'zone_nickname': zone_nickname},
+ context)
+[INST] Page load definition for thermostat statistics.. [/INST] def smap_plot_thermostat(request, mac):
+ print "inside smap view method"
+ context = RequestContext(request)
+ if request.method == 'GET':
+ device_metadata = [ob.device_control_page_info() for ob in DeviceMetadata.objects.filter(mac_address=mac)]
+ print device_metadata
+ device_id = device_metadata[0]['device_id']
+ device_type_id = device_metadata[0]['device_model_id']
+ mac_address = device_metadata[0]['mac_address']
+ device_type_id = device_type_id.device_model_id
+ print device_type_id
+ device_status = [ob.data_as_json() for ob in Thermostat.objects.filter(thermostat_id=device_id)]
+ device_zone = device_status[0]['zone']['id']
+ device_nickname = device_status[0]['nickname']
+ zone_nickname = device_status[0]['zone']['zone_nickname']
+ device_info = str(device_zone) + '/thermostat/' + device_id
+ device_info = device_info.encode('ascii', 'ignore')
+ device_smap_tag = '/bemoss/' + str(device_zone) + '/thermostat/' + device_id
+ device_smap_tag = device_smap_tag.encode('ascii', 'ignore')
+ temperature = device_smap_tag + '/temperature'
+ heat_setpoint = device_smap_tag + '/heat_setpoint'
+ cool_setpoint = device_smap_tag + '/cool_setpoint'
+ print temperature
+ _uuid_temperature = get_uuid_for_data_point(temperature)
+ _uuid_heat_setpoint = get_uuid_for_data_point(heat_setpoint)
+ _uuid_cool_setpoint = get_uuid_for_data_point(cool_setpoint)
+ zones = [ob.as_json() for ob in Building_Zone.objects.all()]
+ thermostats_sn = [ob.data_side_nav() for ob in Thermostat.objects.filter(network_status='ONLINE', thermostat_id__bemoss=True)]
+ vav_sn = [ob.data_side_nav() for ob in VAV.objects.filter(network_status='ONLINE', vav_id__bemoss=True)]
+ rtu_sn = [ob.data_side_nav() for ob in RTU.objects.filter(network_status='ONLINE', rtu_id__bemoss=True)]
+ lighting_sn = [ob.data_side_nav() for ob in Lighting.objects.filter(network_status='ONLINE', lighting_id__bemoss=True)]
+ plugload_sn = [ob.data_side_nav() for ob in Plugload.objects.filter(network_status='ONLINE', plugload_id__bemoss=True)]
+ occ_sensors_sn = [ob.data_side_nav() for ob in OccupancySensor.objects.filter(network_status='ONLINE', occupancy_sensor_id__bemoss=True)]
+ lt_sensors_sn = [ob.data_side_nav() for ob in AmbientLightSensor.objects.filter(network_status='ONLINE', ambient_light_sensor_id__bemoss=True)]
+ mtn_sensors_sn = [ob.data_side_nav() for ob in MotionSensor.objects.filter(network_status='ONLINE', motion_sensor_id__bemoss=True)]
+ powermeters_sn = [ob.data_side_nav() for ob in PowerMeter.objects.filter(network_status='ONLINE', power_meter_id__bemoss=True)]
+ active_al = get_notifications()
+ context.update({'active_al':active_al})
+ context.update({
+ 'zones': zones, 'thermostat_sn': thermostats_sn,
+ 'lighting_sn': lighting_sn, 'plugload_sn': plugload_sn, 'occ_sensors_sn': occ_sensors_sn,
+ 'lt_sensors_sn': lt_sensors_sn, 'mtn_sensors_sn': mtn_sensors_sn, 'powermeters_sn': powermeters_sn,
+ 'vav_sn': vav_sn, 'rtu_sn': rtu_sn
+ })
+ rs_temperature = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_temperature)
+ rs_heat_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_heat_setpoint)
+ rs_cool_setpoint = get_data_from_smap("http://localhost/backend/api/data/uuid/" + _uuid_cool_setpoint)
+ return render_to_response(
+ 'statistics/statistics_thermostat.html',
+ {'temperature': rs_temperature, 'heat_setpoint': rs_heat_setpoint, 'cool_setpoint': rs_cool_setpoint,
+ 'device_info': device_info, 'mac': mac_address,
+ 'nickname': device_nickname,
+ 'zone_nickname': zone_nickname},
+ context)
+[INST] Handle message and send to browser.. [/INST] def on_match_device_status_update(self, topic, headers, message, match):
+ print os.path.basename(__file__)+"@on_match_device_status_update"
+ print "message:"+str(message)
+ device_info = topic.split('/')
+ device_id = device_info[6]
+ device_type = device_info[5]
+ page_load_helper.page_load(device_id, device_type, message[0])
+ _log.debug("Topic: {topic}, Headers: {headers}, "
+ "Message: {message}".format(
+ topic=topic, headers=headers, message=message))
+ topic = topic.split('/')
+ topic_to_tcp = '/ui/web/'+topic[5]+'/'+topic[4]+'/'+topic[6]+'/'+topic[7]+'/'+topic[8]
+ print topic_to_tcp
+ self.publish(str(topic_to_tcp), headers, str(message))
+[INST] Handle message and send to browser.. [/INST] def on_match_device_status_update_rtu(self, topic, headers, message, match):
+ print os.path.basename(__file__)+"@on_match_device_status_update"
+ print "message:"+str(message)
+ device_info = topic.split('/')
+ device_id = device_info[6]
+ device_type = device_info[5]
+ page_load_helper.page_load(device_id, device_type, message[0])
+ _log.debug("Topic: {topic}, Headers: {headers}, "
+ "Message: {message}".format(
+ topic=topic, headers=headers, message=message))
+ topic = topic.split('/')
+ topic_to_tcp = '/ui/web/'+topic[5]+'/'+topic[4]+'/'+topic[6]+'/'+topic[7]+'/'+topic[8]
+ print topic_to_tcp
+ self.publish(str(topic_to_tcp), headers, str(message))
+[INST] Handle message and send to browser.. [/INST] def on_match_device_status_update_vav(self, topic, headers, message, match):
+ print os.path.basename(__file__)+"@on_match_device_status_update"
+ print "message:"+str(message)
+ device_info = topic.split('/')
+ device_id = device_info[6]
+ device_type = device_info[5]
+ page_load_helper.page_load(device_id, device_type, message[0])
+ _log.debug("Topic: {topic}, Headers: {headers}, "
+ "Message: {message}".format(
+ topic=topic, headers=headers, message=message))
+ topic = topic.split('/')
+ topic_to_tcp = '/ui/web/'+topic[5]+'/'+topic[4]+'/'+topic[6]+'/'+topic[7]+'/'+topic[8]
+ print topic_to_tcp
+ self.publish(str(topic_to_tcp), headers, str(message))
+[INST] Main method called by the eggsecutable.. [/INST] def main(argv=sys.argv):
+ try:
+ utils.default_main(ListenerAgent,
+ description='Example VOLTTRON heartbeat agent',
+ argv=argv)
+ except Exception as e:
+ _log.exception('unhandled exception')
+[INST] The main function which creates the pipeline and runs it.. [/INST] def run(argv=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input',
+ dest='input',
+ required=False,
+ help='Input file to read. This can be a local file or '
+ 'a file in a Google Storage Bucket.',
+ default='gs://jk-bq-dev-datasets/test.tfrecords')
+ parser.add_argument('--output',
+ dest='output',
+ required=False,
+ help='Output BQ table to write results to.',
+ default='sandbox.person')
+ known_args, pipeline_args = parser.parse_known_args(argv)
+ table_spec = known_args.output
+ with beam.Pipeline(options=PipelineOptions(pipeline_args)) as p:
+ (p
+ | 'Read from a TFRecords' >> beam.io.ReadFromTFRecord(known_args.input)
+ | 'Transform to Dict' >> beam.ParDo(ProtoToDict())
+ | 'Write to BigQuery' >> beam.io.WriteToBigQuery(
+ table_spec,
+ write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
+ create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER))
+[INST] Try to find matches between registered filter patterns with a message.. [/INST] def match_filter_patterns(self, content: str, guild_id: int) -> Optional[typing.Tuple[re.Match, str]]:
+ for pattern, pattern_identifier in self._filter_cache[guild_id].items():
+ if search := pattern.search(content):
+ return search, pattern_identifier
+[INST] Notify moderators when a message filter is triggered.. [/INST] async def notify_mods(
+ self,
+ event: GuildMessageCreateEvent,
+ matching_content: Optional[str],
+ type_of_filter: str,
+ footer: Optional[str] = None
+ ) -> None:
+ message_channel = self.bot.cache.get_guild_channel(event.message.channel_id)
+ user = event.message.author
+ channel = self.bot.cache.get_guild_channel(Channels.mod_alerts)
+ description = (
+ f"**Sent by {user.mention} in {message_channel.mention}**\n"
+ f"{'**Filter matching content**:' if matching_content else ''}{matching_content}\n"
+ f"**Full message:**\n{event.message.content}"
+ )
+ embed = hikari.Embed(
+ title=f"**Filter triggered:** {type_of_filter}",
+ description=description,
+ colour=hikari.Color.of((255, 255, 255))
+ )
+ mods_role = self.bot.cache.get_role(Roles.mods)
+ embed.set_footer(text=footer)
+ await channel.send(content=mods_role.mention, embed=embed)
+ await event.message.delete()
+[INST] Return True for server moderator.. [/INST] async def plugin_check(self, ctx: Union[lightbulb.Context, GuildMessageCreateEvent]):
+ if Roles.mods in [role.id for role in ctx.member.get_roles()]:
+ return True
+ return False
+[INST] Connect to database and create initial tables.. [/INST] async def init_db():
+ conn: asyncpg.Connection = await asyncpg.connect(Client.database_url)
+ tables = Path("postgres", "tables")
+ logger.info("Creating tables.")
+ for table_file in tables.iterdir():
+ await conn.execute(table_file.read_text())
+ return conn
+[INST] Load extensions when bot is starting.. [/INST] async def on_starting(self, _event: hikari.StartingEvent) -> None:
+ logging.info("Connecting to database...")
+ self.db_conn = await init_db()
+ logging.info("Loading extensions...")
+ for ext in Client.extensions:
+ ext = str(ext).replace(os.sep, ".")[:-3]
+ self.load_extension(ext)
+ logging.info(f"{ext}: Loaded")
+ logging.info("Done loading extensions.")
+[INST] Increment the value assigned to a slice. [/INST] def inc(self, start: typing.Optional[int], stop: typing.Optional[int], value: V) -> None:
+ if start is None:
+ if stop is None:
+ self.value_offset += value
+ else:
+ self.dual[stop - 1] += value
+ elif stop is None:
+ self.dual[start - 1] -= value
+ self.value_offset += value
+ elif start < stop:
+ self.dual[start - 1] -= value
+ self.dual[stop - 1] += value
+ elif start > stop:
+ raise ValueError("start > stop")
+[INST] Returns an object that allows assigning values to indices and aggregating them by slices. [/INST] def ixs_by_slices(*, zero_factory: ZF = None, zero_test: ZT = None) -> by_slices.Aggregator[V]:
+ return by_slices.UnboundedAggregator(
+ negative=by_slices.VariableSizeLeftBoundedAggregator(zero_factory=zero_factory,
+ zero_test=zero_test),
+ nonnegative=by_slices.VariableSizeLeftBoundedAggregator(zero_factory=zero_factory,
+ zero_test=zero_test),
+ )
+[INST] Returns an object that allows assigning values to slices and aggregating them by indices. [/INST] def slices_by_ixs(*, zero_factory: ZF = None, zero_test: ZT = None) -> by_ixs.Aggregator[V]:
+ return by_ixs.Aggregator(
+ dual=ixs_by_slices(zero_factory=zero_factory, zero_test=zero_test),
+ zero_factory=zero_factory,
+ )
+[INST] Extract connection info and result variable from SQL
+Please don't add any more syntax requiring
+special parsing.
+Instead, add @arguments to SqlMagic.execute.
+
+We're grandfathering the
+connection string and `<<` operator in.. [/INST] def parse(cell, config):
+ result = {"connection": "", "sql": "", "result_var": None}
+ pieces = cell.split(None, 3)
+ if not pieces:
+ return result
+ result["connection"] = _connection_string(pieces[0], config)
+ if result["connection"]:
+ pieces.pop(0)
+ if len(pieces) > 1 and pieces[1] == "<<":
+ result["result_var"] = pieces.pop(0)
+ pieces.pop(0)
+ result["sql"] = (" ".join(pieces)).strip()
+ return result
+[INST] - comment from a line
+The argparser unfortunately expects -- to precede an option,
+but in SQL that delineates a comment. So this removes comments
+so a line can safely be fed to the argparser.. [/INST] def without_sql_comment(parser, line):
+ args = _option_strings_from_parser(parser)
+ result = itertools.takewhile(
+ lambda word: (not word.startswith("--")) or (word in args),
+ shlex.split(line, posix=False),
+ )
+ return " ".join(result)
+[INST] Implements PERSIST, which writes a DataFrame to the RDBMS. [/INST] def _persist_dataframe(self, raw, conn, user_ns, append=False):
+ if not DataFrame:
+ raise ImportError("Must `pip install pandas` to use DataFrames")
+ frame_name = raw.strip(";")
+ if not frame_name:
+ raise SyntaxError("Syntax: %sql --persist ")
+ try:
+ frame = eval(frame_name, user_ns)
+ except SyntaxError:
+ raise SyntaxError("Syntax: %sql --persist ")
+ if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
+ raise TypeError("%s is not a Pandas DataFrame or Series" % frame_name)
+ table_name = frame_name.lower()
+ table_name = self.legal_sql_identifier.search(table_name).group(0)
+ if_exists = "append" if append else "fail"
+ frame.to_sql(table_name, conn.session.engine, if_exists=if_exists)
+ return "Persisted %s" % table_name
+[INST] Provides an IPython session in which tables have been created. [/INST] def ip():
+ ip_session = get_ipython()
+ runsql(
+ ip_session,
+ [
+ "CREATE TABLE test (n INT, name TEXT)",
+ "INSERT INTO test VALUES (1, 'foo')",
+ "INSERT INTO test VALUES (2, 'bar')",
+ "CREATE TABLE author (first_name, last_name, year_of_death)",
+ "INSERT INTO author VALUES ('William', 'Shakespeare', 1616)",
+ "INSERT INTO author VALUES ('Bertold', 'Brecht', 1956)",
+ ],
+ )
+ yield ip_session
+ runsql(ip_session, "DROP TABLE test")
+ runsql(ip_session, "DROP TABLE author")
+[INST] Capture the next fame, if any. [/INST] def enter_frame(self):
+ assert not self._is_entered_frame, \
+ 'previous enter_frame() had no matching exit_frame()'
+ if self._capture is not None:
+ self._is_entered_frame = self._capture.grab()
+[INST] Authenticate to ownCloud.
+This will create a session on the server.. [/INST] def login(self, user_id, password):
+ self.__session = requests.session()
+ self.__session.verify = self.__verify_certs
+ self.__session.auth = (user_id, password)
+ res = self.__session.get(self.url)
+ if res.status_code == 200:
+ if self.__single_session:
+ self.__session.auth = None
+ return
+ self.__session.close()
+ self.__session = None
+ raise ResponseError(res)
+[INST] Log out the authenticated user and close the session.. [/INST] def logout(self):
+ self.__session.close()
+ return True
+[INST] Returns the file info for the given remote file. [/INST] def file_info(self, path):
+ res = self.__make_dav_request('PROPFIND', path)
+ if res:
+ return res[0]
+ return None
+[INST] Returns the listing/contents of the given remote directory. [/INST] def list(self, path):
+ if not path[-1] == '/':
+ path += '/'
+ res = self.__make_dav_request('PROPFIND', path)
+ if res:
+ return res[1:]
+ return None
+[INST] Write data into a remote file. [/INST] def put_file_contents(self, remote_path, data):
+ return self.__make_dav_request('PUT', remote_path, data = data)
+[INST] Upload a directory with all its contents. [/INST] def put_directory(self, target_path, local_directory, **kwargs):
+ target_path = self.__normalize_path(target_path)
+ if not target_path[-1] == '/':
+ target_path += '/'
+ gathered_files = []
+ if not local_directory[-1] == '/':
+ local_directory += '/'
+ basedir = os.path.basename(local_directory[0: -1]) + '/'
+ for path, _, files in os.walk(local_directory):
+ gathered_files.append(
+ (path, basedir + path[len(local_directory):], files)
+ )
+ for path, remote_path, files in gathered_files:
+ self.mkdir(target_path + remote_path + '/')
+ for name in files:
+ if not self.put_file(
+ target_path + remote_path + '/',
+ path + '/' + name,
+ **kwargs
+ ):
+ return False
+ return True
+[INST] Uploads a file using chunks. If the file is smaller than
+``chunk_size`` it will be uploaded directly.. [/INST] def __put_file_chunked(self, remote_path, local_source_file, **kwargs):
+ chunk_size = kwargs.get('chunk_size', 10 * 1024 * 1024)
+ result = True
+ transfer_id = int(time.time())
+ remote_path = self.__normalize_path(remote_path)
+ if remote_path[-1] == '/':
+ remote_path += os.path.basename(local_source_file)
+ stat_result = os.stat(local_source_file)
+ file_handle = open(local_source_file, 'r', 8192)
+ file_handle.seek(0, os.SEEK_END)
+ size = file_handle.tell()
+ file_handle.seek(0)
+ headers = {}
+ if kwargs.get('keep_mtime', True):
+ headers['X-OC-MTIME'] = stat_result.st_mtime
+ if size == 0:
+ return self.__make_dav_request(
+ 'PUT',
+ remote_path,
+ data = '',
+ headers = headers
+ )
+ chunk_count = size / chunk_size
+ if chunk_count > 1:
+ headers['OC-CHUNKED'] = 1
+ if size % chunk_size > 0:
+ chunk_count += 1
+ for chunk_index in range(0, chunk_count):
+ data = file_handle.read(chunk_size)
+ if chunk_count > 1:
+ chunk_name = '%s-chunking-%s-%i-%i' % \
+ (remote_path, transfer_id, chunk_count, chunk_index)
+ else:
+ chunk_name = remote_path
+ if not self.__make_dav_request(
+ 'PUT',
+ chunk_name,
+ data = data,
+ headers = headers
+ ):
+ result = False
+ break
+ file_handle.close()
+ return result
+[INST] Deletes a remote file or directory. [/INST] def delete(self, path):
+ return self.__make_dav_request('DELETE', path)
+[INST] Shares a remote file with link. [/INST] def share_file_with_link(self, path):
+ path = self.__normalize_path(path)
+ post_data = {'shareType': self.OCS_SHARE_TYPE_LINK, 'path': path}
+ res = self.__make_ocs_request(
+ 'POST',
+ self.OCS_SERVICE_SHARE,
+ 'shares',
+ data = post_data
+ )
+ if res.status_code == 200:
+ tree = ET.fromstring(res.text)
+ self.__check_ocs_status(tree)
+ data_el = tree.find('data')
+ return PublicShare(
+ int(data_el.find('id').text),
+ path,
+ data_el.find('url').text,
+ data_el.find('token').text
+ )
+ raise ResponseError(res)
+[INST] Checks whether a path is already shared. [/INST] def is_shared(self, path):
+ self.file_info(path)
+ try:
+ result = self.get_shares(path)
+ if result:
+ return (len(result) > 0)
+ except ResponseError as e:
+ if e.status_code != 404:
+ raise e
+ return False
+ return False
+[INST] Shares a remote file with specified user. [/INST] def share_file_with_user(self, path, user, **kwargs):
+ perms = kwargs.get('perms', self.OCS_PERMISSION_READ)
+ if (((not isinstance(perms, int)) or (perms > self.OCS_PERMISSION_ALL))
+ or ((not isinstance(user, basestring)) or (user == ''))):
+ return False
+ path = self.__normalize_path(path)
+ post_data = {'shareType': self.OCS_SHARE_TYPE_USER, 'shareWith': user, 'path': path, 'permissions': perms}
+ res = self.__make_ocs_request(
+ 'POST',
+ self.OCS_SERVICE_SHARE,
+ 'shares',
+ data = post_data
+ )
+ if res.status_code == 200:
+ tree = ET.fromstring(res.text)
+ self.__check_ocs_status(tree)
+ data_el = tree.find('data')
+ return UserShare(
+ int(data_el.find('id').text),
+ path,
+ perms
+ )
+ raise ResponseError(res)
+[INST] Checks the status code of an OCS request. [/INST] def __check_ocs_status(tree):
+ code_el = tree.find('meta/statuscode')
+ if code_el is not None and code_el.text != '100':
+ raise ResponseError(int(code_el.text))
+[INST] Removes the leading "remote.php/webdav" path from the given path. [/INST] def __strip_dav_path(self, path):
+ if (path.startswith(self.__davpath)):
+ return path[len(self.__davpath):]
+ return path
+[INST] This function patches tornado request handlers init method
+and then patches method that handle HTTP requests inside.
+This happens dynamically every time a request handler instace
+is created. This is needed as HTTP handler method do not exists
+on request handlers by default and are supposed to be added
+by users.
+
+check if patching __new__ would be better in any way.. [/INST] def _patch_handler_init(init, handler, args, kwargs):
+ init(*args, **kwargs)
+ tracing = handler.settings.get("opentracing_tracing")
+ if not tracing:
+ return
+ if not tracing._trace_all:
+ return
+ for method in handler.SUPPORTED_METHODS:
+ handlers.wrap_method(handler, method.lower())
+[INST] Function decorator that traces functions
+NOTE: Must be placed before the Tornado decorators
+@param attributes any number of request attributes
+(strings) to be set as tags on the created span. [/INST] def trace(self, *attributes):
+ @wrapt.decorator
+ def wrapper(wrapped, instance, args, kwargs):
+ if self._trace_all:
+ return wrapped(*args, **kwargs)
+ handler = instance
+ with tornado_context():
+ try:
+ self._apply_tracing(handler, list(attributes))
+ result = wrapped(*args, **kwargs)
+ self._handle_wrapped_result(handler, result)
+ except Exception as exc:
+ self._finish_tracing(handler, error=exc, tb=sys.exc_info()[2])
+ raise
+ return result
+ return wrapper
+[INST] Helper function to avoid rewriting for middleware and decorator.
+Returns a new span from the request with logged attributes and
+correct operation name from the func.. [/INST] def _apply_tracing(self, handler, attributes):
+ operation_name = self._get_operation_name(handler)
+ headers = handler.request.headers
+ request = handler.request
+ try:
+ span_ctx = self._tracer.extract(opentracing.Format.HTTP_HEADERS, headers)
+ scope = self._tracer.start_active_span(operation_name, child_of=span_ctx)
+ except (
+ opentracing.InvalidCarrierException,
+ opentracing.SpanContextCorruptedException,
+ ):
+ scope = self._tracer.start_active_span(operation_name)
+ setattr(request, SCOPE_ATTR, scope)
+ scope.span.set_tag(tags.COMPONENT, "tornado")
+ scope.span.set_tag(tags.SPAN_KIND, tags.SPAN_KIND_RPC_SERVER)
+ scope.span.set_tag(tags.HTTP_METHOD, request.method)
+ scope.span.set_tag(tags.HTTP_URL, request.uri)
+ for attr in attributes:
+ if hasattr(request, attr):
+ payload = str(getattr(request, attr))
+ if payload:
+ scope.span.set_tag(attr, payload)
+ self._call_start_span_cb(scope.span, request)
+ if self._trace_response_headers_enabled:
+ trace_id = getattr(scope.span.context, "trace_id", 0)
+ span_id = getattr(scope.span.context, "span_id", 0)
+ if trace_id and span_id:
+ handler.add_header("Access-Control-Expose-Headers", "Server-Timing")
+ handler.add_header(
+ "Server-Timing",
+ 'traceparent;desc="00-{trace_id}-{span_id}-01"'.format(
+ trace_id="{:016x}".format(trace_id),
+ span_id="{:016x}".format(span_id),
+ ),
+ )
+ return scope
+[INST] Calculates xyzt homophily from a batch of graphs.
+Homophily is a graph scalar quantity that measures the likeness of variables
+in nodes. Notice that this calculator assumes a special order of input
+features in x.. [/INST] def calculate_xyzt_homophily(x, edge_index, batch):
+ hx = homophily(edge_index, x[:, 0], batch).reshape(-1, 1)
+ hy = homophily(edge_index, x[:, 1], batch).reshape(-1, 1)
+ hz = homophily(edge_index, x[:, 2], batch).reshape(-1, 1)
+ ht = homophily(edge_index, x[:, 3], batch).reshape(-1, 1)
+ return hx, hy, hz, ht
+[INST] Calculate the matrix of pairwise distances between pulses in (x,y,z)-coordinates.. [/INST] def calculate_distance_matrix(xyz_coords: Tensor) -> Tensor:
+ diff = xyz_coords.unsqueeze(dim=2) - xyz_coords.T.unsqueeze(dim=0)
+ return torch.sqrt(torch.sum(diff**2, dim=1))
+[INST] Ingests data, builds graph (connectivity/adjacency), and preprocesses features.. [/INST] def _forward(self, data: Data) -> Data:
+ self._validate_features(data)
+ data.x[:, 0] /= 100.0
+ data.x[:, 1] /= 100.0
+ data.x[:, 2] += 350.0
+ data.x[:, 2] /= 100.0
+ data.x[:, 3] /= 1.05e04
+ data.x[:, 3] -= 1.0
+ data.x[:, 3] *= 20.0
+ data.x[:, 4] /= 1.0
+ data.x[:, 5] -= 1.25
+ data.x[:, 5] /= 0.25
+ data.x[:, 6] /= 0.05
+ return data
+[INST] Ingests data, builds graph (connectivity/adjacency), and preprocesses features.. [/INST] def _forward(self, data: Data) -> Data:
+ self._validate_features(data)
+ data.x[:, 0] /= 500.0
+ data.x[:, 1] /= 500.0
+ data.x[:, 2] /= 500.0
+ data.x[:, 3] /= 2e04
+ data.x[:, 3] -= 1.0
+ data.x[:, 4] = torch.log10(data.x[:, 4]) / 2.0
+ data.x[:, 6] /= 0.05
+ data.x[:, 7] -= 50.0
+ data.x[:, 7] /= 50.0
+ data.x[:, 8] /= 20.0
+ data.x[:, 9] -= 60.0
+ data.x[:, 9] /= 60.0
+ data.x[:, 13] /= 130.0
+ return data
+[INST] Ingests data, builds graph (connectivity/adjacency), and preprocesses features.. [/INST] def _forward(self, data: Data) -> Data:
+ self._validate_features(data)
+ data = group_pulses_to_dom(data)
+ data = group_pulses_to_pmt(data)
+ xyz = torch.stack((data["dom_x"], data["dom_y"], data["dom_z"]), dim=1)
+ pmt_dir = torch.stack(
+ (data["pmt_dir_x"], data["pmt_dir_x"], data["pmt_dir_x"]), dim=1
+ )
+ charge = data["charge"].unsqueeze(dim=1)
+ center_of_gravity = sum_pool_and_distribute(
+ xyz * charge, data.batch
+ ) / sum_pool_and_distribute(charge, data.batch)
+ vector_to_center_of_gravity = center_of_gravity - xyz
+ distance_to_center_of_gravity = torch.norm(
+ vector_to_center_of_gravity, p=2, dim=1
+ )
+ unit_vector_to_center_of_gravity = vector_to_center_of_gravity / (
+ distance_to_center_of_gravity.unsqueeze(dim=1) + 1e-3
+ )
+ cos_angle_wrt_center_of_gravity = (
+ pmt_dir * unit_vector_to_center_of_gravity
+ ).sum(dim=1)
+ photoelectrons_on_pmt = (
+ sum_pool_and_distribute(data["charge"], data.pmt_index, data.batch)
+ .floor()
+ .clip(1, None)
+ )
+ data.x = torch.cat(
+ (
+ data.x,
+ photoelectrons_on_pmt.unsqueeze(dim=1),
+ distance_to_center_of_gravity.unsqueeze(dim=1),
+ cos_angle_wrt_center_of_gravity.unsqueeze(dim=1),
+ ),
+ dim=1,
+ )
+ data.x[:, 0] /= 500.0
+ data.x[:, 1] /= 500.0
+ data.x[:, 2] /= 500.0
+ data.x[:, 3] /= 2e04
+ data.x[:, 3] -= 1.0
+ data.x[:, 4] = torch.log10(data.x[:, 4]) / 2.0
+ data.x[:, 6] /= 0.05
+ data.x[:, 7] -= 50.0
+ data.x[:, 7] /= 50.0
+ data.x[:, 8] /= 20.0
+ data.x[:, 9] -= 60.0
+ data.x[:, 9] /= 60.0
+ data.x[:, 13] /= 130.0
+ data.x[:, 14] = (
+ torch.log10(data.x[:, 14]) / 2.0
+ )
+ data.x[:, 15] = (
+ torch.log10(1e-03 + data.x[:, 15]) / 2.0
+ )
+ return data
+[INST] Additional poolings of feature tensor `x` on `data`.
+By default the nominal `pooling_method` is used for features as well.
+This method can be overwritten for bespoke coarsening operations.. [/INST] def _additional_features(self, cluster: LongTensor, data: Data) -> Tensor:
+ return None
+[INST] Perform clustering of nodes in `data` by assigning unique cluster indices to each.. [/INST] def _perform_clustering(self, data: Data) -> LongTensor:
+ dom_index = group_by(
+ data, ["dom_x", "dom_y", "dom_z", "rde", "pmt_area"]
+ )
+ return dom_index
+[INST] Plots 2D contours from GraphNeT PISA fits.. [/INST] def plot_2D_contour(
+ contour_data,
+ xlim=(0.4, 0.6),
+ ylim=(2.38 * 1e-3, 2.55 * 1e-3),
+ chi2_critical_value=4.605,
+ width=3.176,
+ height=2.388,
+):
+ fig, ax = plt.subplots(figsize=(width, height), constrained_layout=True)
+ proxy = []
+ labels = []
+ for entry in contour_data:
+ entry_data, model_name, label, ls, color = read_entry(entry)
+ model_idx = entry_data["model"] == model_name
+ model_data = entry_data.loc[model_idx]
+ x = pd.unique(model_data.sort_values("theta23_fixed")["theta23_fixed"])
+ y = pd.unique(model_data.sort_values("dm31_fixed")["dm31_fixed"])
+ z = np.zeros((len(y), len(x)))
+ for i in range(len(x)):
+ for k in range(len(y)):
+ idx = (model_data["theta23_fixed"] == x[i]) & (
+ model_data["dm31_fixed"] == y[k]
+ )
+ match = model_data["mod_chi2"][idx]
+ if len(match) > 0:
+ if model_data["converged"][idx].values is True:
+ match = float(match)
+ else:
+ match = 10000
+ else:
+ match = 10000
+ z[k, i] = match
+ CS = ax.contour(
+ np.sin(np.deg2rad(x)) ** 2,
+ y,
+ z,
+ levels=[chi2_critical_value],
+ colors=color,
+ label=label,
+ linestyles=ls,
+ linewidths=2,
+ )
+ proxy.extend(
+ [plt.Rectangle((0, 0), 1, 1, fc=color) for pc in CS.collections]
+ )
+ if chi2_critical_value == 4.605:
+ label = label + " 90 $\\%$ CL"
+ labels.append(label)
+ plt.legend(proxy, labels, frameon=False, loc="upper right")
+ plt.xlim(xlim[0], xlim[1])
+ plt.ylim(ylim[0], ylim[1])
+ plt.xlabel("$\\sin^2(\\theta_{23})$", fontsize=12)
+ plt.ylabel("$\\Delta m_{31}^2 [eV^2]$", fontsize=12)
+ plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
+ plt.title("Sensitivity (Simplified Analysis)")
+ return fig
+[INST] Plots 1D contours from GraphNeT PISA fits.. [/INST] def plot_1D_contour(
+ contour_data, chi2_critical_value=2.706, width=2 * 3.176, height=2.388
+):
+ variables = ["theta23_fixed", "dm31_fixed"]
+ fig, ax = plt.subplots(
+ 1, 2, figsize=(width, height), constrained_layout=True
+ )
+ ls = 0
+ for entry in contour_data:
+ entry_data, model_name, label, ls, color = read_entry(entry)
+ for variable in variables:
+ model_idx = entry_data["model"] == model_name
+ padding_idx = entry_data[variable] != -1
+ chi2_idx = entry_data["mod_chi2"] < chi2_critical_value
+ model_data = entry_data.loc[
+ (model_idx) & (padding_idx) & (chi2_idx), :
+ ]
+ x = model_data.sort_values(variable)
+ if variable == "theta23_fixed":
+ ax[0].set_ylabel("$\\chi^2$", fontsize=12)
+ ax[0].plot(
+ np.sin(np.deg2rad(x[variable])) ** 2,
+ x["mod_chi2"],
+ color=color,
+ label=label,
+ ls=ls,
+ )
+ ax[0].set_xlabel("$\\sin(\\theta_{23})^2$", fontsize=12)
+ elif variable == "dm31_fixed":
+ ax[1].plot(
+ x[variable], x["mod_chi2"], color=color, label=label, ls=ls
+ )
+ ax[1].ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
+ ax[1].set_xlabel("$\\Delta m_{31}^2 [eV^2]$", fontsize=12)
+ h = [item.get_text() for item in ax[1].get_yticklabels()]
+ empty_string_labels = [""] * len(h)
+ ax[1].set_yticklabels(empty_string_labels)
+ ax[0].set_ylim(0, chi2_critical_value)
+ ax[1].set_ylim(0, chi2_critical_value)
+ ax[0].legend()
+ return fig
+[INST] Run minimal icetray chain with GraphNeT module.. [/INST] def main(input_files, output_file, key, pulsemaps, gcd_file, events_max):
+ makedirs(dirname(output_file), exist_ok=True)
+ if gcd_file is None:
+ gcd_candidates = [p for p in input_files if is_gcd_file(p)]
+ assert (
+ len(gcd_candidates) == 1
+ ), f"Did not get exactly one GCD-file candidate in `{dirname(input_files[0])}: {gcd_candidates}"
+ gcd_file = gcd_candidates[0]
+ input_files = [p for p in input_files if not is_gcd_file(p)]
+ min_hit_oms = 10
+ tray = I3Tray()
+ tray.Add("I3Reader", filenamelist=input_files)
+ tray.Add(
+ lambda frame: len(
+ dataclasses.I3RecoPulseSeriesMap.from_frame(frame, pulsemaps[0])
+ )
+ >= min_hit_oms,
+ )
+ tray.Add(
+ GraphNeTModuleIceCubeUpgrade,
+ keys=key,
+ model=MODEL_PATH,
+ pulsemaps=pulsemaps,
+ gcd_file=gcd_file,
+ )
+ tray.Add("I3Writer", filename=output_file)
+ if events_max > 0:
+ tray.Execute(events_max)
+ else:
+ tray.Execute()
+[INST] Adds particle and interaction ID from database `db` to dataframe `df`.. [/INST] def add_pid_and_interaction(db, df):
+ events = df["event_no"]
+ with sqlite3.connect(db) as con:
+ query = (
+ "select event_no, pid, interaction_type from truth where event_no in %s"
+ % str(tuple(events))
+ )
+ data = (
+ pd.read_sql(query, con)
+ .sort_values("event_no")
+ .reset_index(drop=True)
+ )
+ df = df.sort_values("event_no").reset_index(drop=True)
+ df["interaction_type"] = data["interaction_type"]
+ df["pid"] = data["pid"]
+ for variable in df.columns:
+ if variable == "energy":
+ df[variable] = np.log10(df[variable])
+ if variable + "_pred" in df.keys():
+ df[variable + "_pred"] = np.log10(df[variable + "_pred"])
+ if variable + "_retro" in df.keys():
+ df[variable + "_retro"] = np.log10(df[variable + "_retro"])
+ return df
+[INST] Calculate the uncertainty on the estimated width from the 68-interpercentile range.. [/INST] def calculate_width_error(diff):
+ N = len(diff)
+ x_16 = abs(
+ diff - np.percentile(diff, 16, interpolation="nearest")
+ ).argmin()
+ x_84 = abs(
+ diff - np.percentile(diff, 84, interpolation="nearest")
+ ).argmin()
+ if len(diff) > 0:
+ error_width = np.sqrt(
+ (1 / empirical_pdf(x_84, diff) ** 2) * (0.84 * (1 - 0.84) / N)
+ + (1 / empirical_pdf(x_16, diff) ** 2) * (0.16 * (1 - 0.16) / N)
+ ) * (1 / 2)
+ else:
+ error_width = np.nan
+ return error_width
+[INST] Check whether `data` contains a column with a name containing "retro".. [/INST] def check_for_retro(data: pd.DataFrame) -> bool:
+ columns = data.columns
+ is_retro = False
+ for column in columns:
+ if "retro" in column:
+ is_retro = True
+ break
+ return is_retro
+[INST] Plot reconstruction resoltion (width) for DynEdge vs.. [/INST] def PlotWidth(key_limits, biases):
+ key_limits = key_limits["width"]
+ if "retro" in biases.keys():
+ contains_retro = True
+ else:
+ contains_retro = False
+ for key in biases["dynedge"].keys():
+ fig, ax = plt.subplots(2, 3, figsize=(11.69, 8.27))
+ fig.suptitle("dynedge: %s" % key, size=30)
+ pid_count = 0
+ for pid in biases["dynedge"][key].keys():
+ interaction_count = 0
+ for interaction_type in biases["dynedge"][key][pid]:
+ if interaction_type != str(0.0):
+ plot_data = biases["dynedge"][key][pid][interaction_type]
+ if contains_retro:
+ plot_data_retro = biases["retro"][key][pid][
+ interaction_type
+ ]
+ if len(plot_data["mean"]) != 0:
+ ax2 = ax[interaction_count, pid_count].twinx()
+ ax2.bar(
+ x=(plot_data["mean"]),
+ height=plot_data["count"],
+ alpha=0.3,
+ color="grey",
+ align="edge",
+ width=0.25,
+ )
+ ax[interaction_count, pid_count].errorbar(
+ plot_data["mean"],
+ plot_data["width"],
+ plot_data["width_error"],
+ linestyle="dotted",
+ fmt="o",
+ capsize=10,
+ label="dynedge",
+ )
+ if contains_retro:
+ ax[interaction_count, pid_count].errorbar(
+ plot_data_retro["mean"],
+ plot_data_retro["width"],
+ plot_data_retro["width_error"],
+ linestyle="dotted",
+ fmt="o",
+ capsize=10,
+ label="RetroReco",
+ )
+ if pid == str(12.0):
+ pid_tag = "e"
+ if pid == str(14.0):
+ pid_tag = "u"
+ if pid == str(16.0):
+ pid_tag = "T"
+ if interaction_type == str(1.0):
+ interaction_tag = "cc"
+ if interaction_type == str(2.0):
+ interaction_tag = "nc"
+ if interaction_type == str(0.0):
+ interaction_tag = "unknown"
+ plt.title(
+ "$\\nu_%s$ %s" % (pid_tag, interaction_tag),
+ size=20,
+ )
+ ax[interaction_count, pid_count].tick_params(
+ axis="x", labelsize=10
+ )
+ ax[interaction_count, pid_count].tick_params(
+ axis="y", labelsize=10
+ )
+ ax[interaction_count, pid_count].set_xlim(
+ key_limits[key]["x"]
+ )
+ ax[interaction_count, pid_count].set_ylim(
+ key_limits[key]["y"]
+ )
+ ax[interaction_count, pid_count].legend()
+ plt.tick_params(right=False, labelright=False)
+ if (interaction_count == 0) & (pid_count == 0) or (
+ interaction_count == 1
+ ) & (pid_count == 0):
+ ax[interaction_count, pid_count].set_ylabel(
+ "W($log_{10}$($\\frac{pred}{truth}$)) [GeV]",
+ size=20,
+ )
+ if interaction_count != 0:
+ ax[interaction_count, pid_count].set_xlabel(
+ "$energy_{log10}$ GeV", size=25
+ )
+ interaction_count += 1
+ pid_count += 1
+ return fig
+[INST] Plot relative improvement of DynEdge vs.. [/INST] def PlotRelativeImprovement(key_limits, biases):
+ key_limits = key_limits["rel_imp"]
+ for key in biases["dynedge"].keys():
+ fig, ax = plt.subplots(2, 3, figsize=(11.69, 8.27))
+ fig.suptitle("dynedge: %s" % key, size=30)
+ pid_count = 0
+ for pid in biases["dynedge"][key].keys():
+ interaction_count = 0
+ for interaction_type in biases["dynedge"][key][pid]:
+ if interaction_type != str(0.0):
+ plot_data = biases["dynedge"][key][pid][interaction_type]
+ plot_data_retro = biases["retro"][key][pid][
+ interaction_type
+ ]
+ if len(plot_data["mean"]) != 0:
+ ax2 = ax[interaction_count, pid_count].twinx()
+ ax2.bar(
+ x=(plot_data["mean"]),
+ height=plot_data["count"],
+ alpha=0.3,
+ color="grey",
+ align="edge",
+ width=0.25,
+ )
+ ax[interaction_count, pid_count].plot(
+ plot_data["mean"],
+ np.repeat(0, len(plot_data["mean"])),
+ color="black",
+ lw=4,
+ )
+ ax[interaction_count, pid_count].errorbar(
+ plot_data["mean"],
+ 1
+ - np.array(plot_data["width"])
+ / np.array(plot_data_retro["width"]),
+ marker="o",
+ markeredgecolor="black",
+ )
+ if pid == str(12.0):
+ pid_tag = "e"
+ if pid == str(14.0):
+ pid_tag = "u"
+ if pid == str(16.0):
+ pid_tag = "T"
+ if interaction_type == str(1.0):
+ interaction_tag = "cc"
+ if interaction_type == str(2.0):
+ interaction_tag = "nc"
+ if interaction_type == str(0.0):
+ interaction_tag = "unknown"
+ plt.title(
+ "$\\nu_%s$ %s" % (pid_tag, interaction_tag),
+ size=20,
+ )
+ ax[interaction_count, pid_count].tick_params(
+ axis="x", labelsize=10
+ )
+ ax[interaction_count, pid_count].tick_params(
+ axis="y", labelsize=10
+ )
+ ax[interaction_count, pid_count].set_xlim(
+ key_limits[key]["x"]
+ )
+ ax[interaction_count, pid_count].set_ylim(
+ key_limits[key]["y"]
+ )
+ ax[interaction_count, pid_count].legend()
+ plt.tick_params(right=False, labelright=False)
+ if (interaction_count == 0) & (pid_count == 0) or (
+ interaction_count == 1
+ ) & (pid_count == 0):
+ ax[interaction_count, pid_count].set_ylabel(
+ "Relative Improvement", size=20
+ )
+ if interaction_count != 0:
+ ax[interaction_count, pid_count].set_xlabel(
+ "$energy_{log10}$ GeV", size=25
+ )
+ interaction_count += 1
+ pid_count += 1
+ return fig
+[INST] Assert that a valid combination of transformation arguments are passed and update the corresponding functions. [/INST] def _validate_and_set_transforms(
+ self,
+ transform_prediction_and_target: Union[Callable, None],
+ transform_target: Union[Callable, None],
+ transform_inference: Union[Callable, None],
+ transform_support: Union[Callable, None],
+ ):
+ assert not (
+ (transform_prediction_and_target is not None)
+ and (transform_target is not None)
+ ), "Please specify at most one of `transform_prediction_and_target` and `transform_target`"
+ assert (transform_target is not None) == (
+ transform_inference is not None
+ ), "Please specify both `transform_inference` and `transform_target`"
+ if transform_target is not None:
+ if transform_support is not None:
+ assert (
+ len(transform_support) == 2
+ ), "Please specify min and max for transformation support."
+ x_test = torch.from_numpy(
+ np.linspace(transform_support[0], transform_support[1], 10)
+ )
+ else:
+ x_test = np.logspace(-6, 6, 12 + 1)
+ x_test = torch.from_numpy(
+ np.concatenate([-x_test[::-1], [0], x_test])
+ )
+ t_test = torch.unsqueeze(transform_target(x_test), -1)
+ t_test = torch.squeeze(transform_inference(t_test), -1)
+ valid = torch.isfinite(t_test)
+ assert torch.allclose(
+ t_test[valid], x_test[valid]
+ ), "The provided transforms for targets during training and predictions during inference are not inverse. Please adjust transformation functions or support."
+ del x_test, t_test, valid
+ if transform_prediction_and_target is not None:
+ self._transform_prediction_training = (
+ transform_prediction_and_target
+ )
+ self._transform_target = transform_prediction_and_target
+ elif transform_target is not None:
+ self._transform_prediction_inference = transform_inference
+ self._transform_target = transform_target
+[INST] Calculates where a simulated muon stops and if this is inside the detectors fiducial volume.
+IMPORTANT: The final position of the muon is saved in truth extractor/databases as position_x,position_y and position_z.
+This is analogoues to the neutrinos whose interaction vertex is saved under the same name.. [/INST] def muon_stopped(truth, borders, horizontal_pad=100.0, vertical_pad=100.0):
+ border = mpath.Path(borders[0])
+ start_pos = np.array(
+ [truth["position_x"], truth["position_y"], truth["position_z"]]
+ )
+ travel_vec = -1 * np.array(
+ [
+ truth["track_length"]
+ * np.cos(truth["azimuth"])
+ * np.sin(truth["zenith"]),
+ truth["track_length"]
+ * np.sin(truth["azimuth"])
+ * np.sin(truth["zenith"]),
+ truth["track_length"] * np.cos(truth["zenith"]),
+ ]
+ )
+ end_pos = start_pos + travel_vec
+ stopped_xy = border.contains_point(
+ (end_pos[0], end_pos[1]), radius=-horizontal_pad
+ )
+ stopped_z = (end_pos[2] > borders[1][0] + vertical_pad) * (
+ end_pos[2] < borders[1][1] - vertical_pad
+ )
+ return {
+ "x": end_pos[0],
+ "y": end_pos[1],
+ "z": end_pos[2],
+ "stopped": (stopped_xy * stopped_z),
+ }
+[INST] Sum-pool values across the cluster, and distribute the individual nodes.. [/INST] def sum_pool_and_distribute(
+ tensor: Tensor,
+ cluster_index: LongTensor,
+ batch: Optional[LongTensor] = None,
+) -> Tensor:
+ if batch is None:
+ batch = torch.zeros(tensor.size(dim=0)).long()
+ tensor_pooled, _ = sum_pool_x(cluster_index, tensor, batch)
+ inv, _ = consecutive_cluster(cluster_index)
+ tensor_unpooled = tensor_pooled[inv]
+ return tensor_unpooled
+[INST] Group rows in `tensor` that are identical. [/INST] def _group_identical(
+ tensor: Tensor, batch: Optional[LongTensor] = None
+) -> LongTensor:
+ if batch is not None:
+ tensor = tensor.cat((tensor, batch.unsqueeze(dim=1)), dim=1)
+ return torch.unique(tensor, return_inverse=True, dim=0)[1]
+[INST] Group nodes in `data` that have identical values of `keys`.
+This grouping is done with in each event in case of batching. This allows
+for, e.g., assigning the same index to all pulses on the same PMT or DOM in
+the same event. This can be used for coarsening graphs, e.g., from pulse
+level to DOM-level by aggregating feature across each group returned by this
+method.. [/INST] def group_by(data: Data, keys: List[str]) -> LongTensor:
+ features = [getattr(data, key) for key in keys]
+ tensor = torch.stack(features).T
+ batch = getattr(tensor, "batch", None)
+ index = _group_identical(tensor, batch)
+ return index
+[INST] r"""Pools and coarsens a graph given by the
+:class:`torch_geometric.data.Data` object according to the clustering
+defined in :attr:`cluster`.
+All nodes within the same cluster will be represented as one node.
+Final node features are defined by the *sum* of features of all nodes
+within the same cluster, node positions are averaged and edge indices are
+defined to be the union of the edge indices of all nodes within the same
+cluster.. [/INST] def sum_pool(cluster, data, transform=None):
+ cluster, perm = consecutive_cluster(cluster)
+ x = None if data.x is None else _sum_pool_x(cluster, data.x)
+ index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
+ batch = None if data.batch is None else pool_batch(perm, data.batch)
+ pos = None if data.pos is None else pool_pos(cluster, data.pos)
+ data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)
+ if transform is not None:
+ data = transform(data)
+ return data
+[INST] r"""Pools and coarsens a graph given by the
+:class:`torch_geometric.data.Data` object according to the clustering
+defined in :attr:`cluster`.
+All nodes within the same cluster will be represented as one node.
+Final node features are defined by the *std* of features of all nodes
+within the same cluster, node positions are averaged and edge indices are
+defined to be the union of the edge indices of all nodes within the same
+cluster.. [/INST] def std_pool(cluster, data, transform=None):
+ cluster, perm = consecutive_cluster(cluster)
+ x = None if data.x is None else _std_pool_x(cluster, data.x)
+ index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
+ batch = None if data.batch is None else pool_batch(perm, data.batch)
+ pos = None if data.pos is None else pool_pos(cluster, data.pos)
+ data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)
+ if transform is not None:
+ data = transform(data)
+ return data
+[INST] Decorator for only exposing function if `icecube` module is present.. [/INST] def requires_icecube(test_function):
+ def wrapper(*args, **kwargs):
+ if has_icecube_package():
+ return test_function(*args, **kwargs)
+ else:
+ logger.info(
+ f"Function `{test_function.__name__}` not used since `icecube` isn't available."
+ )
+ return
+ wrapper.__name__ = test_function.__name__
+ return wrapper
+[INST] Creates a MixSlice from plaintext data.. [/INST] def encrypt(data, path, key, iv, threads=None, padder=None):
+ padder = padder or _Padder(blocksize=MixSlice.MACRO_SIZE)
+ padded_data = padder.pad(data)
+ fragments = _mix_and_slice(data=padded_data, key=key,
+ iv=iv, threads=threads)
+ fragments = [_BytesIO(f) for f in fragments]
+ if not _os.path.exists(path):
+ _os.makedirs(path)
+ name = "frag_%%0%dd.dat" % len(str(len(fragments)))
+ for fragid in range(len(fragments)):
+ fragment = fragments[fragid]
+ assert isinstance(fragment, _BytesIO)
+ fragment.seek(0)
+ destination = _os.path.join(path, name % fragid)
+ with open(destination, "wb") as fp:
+ _shutil.copyfileobj(fragment, fp)
+ fragment.close()
+[INST] adds a salt to the hash function (OPTIONAL)
+should be called AFTER Init, and BEFORE update
+salt: a bytestring, length determined by hashbitlen.
+if hashbitlen=224 or 256, then salt will be 16 bytes
+if hashbitlen=384 or 512, then salt will be 32 bytes. [/INST] def addsalt(self, salt):
+ if self.init != 1:
+ raise Exception('addsalt() not called after init() and before update()')
+ saltsize = 16 if self.hashbitlen in [224, 256] else 32
+ if len(salt) != saltsize:
+ raise Exception('incorrect salt length')
+ ret = LIB.AddSalt(self.state, salt)
+ if ret:
+ raise Exception('AddSalt() ret = %d', ret)
+[INST] update the state with new data, storing excess data
+as necessary. may be called multiple times and if a
+call sends less than a full block in size, the leftover
+is cached and will be consumed in the next call
+data: data to be hashed (bytestring). [/INST] def update(self, data):
+ self.init = 2
+ datalen = len(data ) *8
+ if not datalen: return
+ ret = LIB.Update(self.state, data, datalen)
+ if ret:
+ raise Exception('Update() ret = %d', ret)
+[INST] finalize the hash -- pad and hash remaining data
+returns hashval, the digest. [/INST] def final(self, data=b''):
+ if data:
+ self.update(data)
+ hashval = c_buffer(int(self. hashbitlen /8))
+ ret = LIB.Final(self.state, hashval)
+ if ret:
+ raise Exception('Final() ret = %d', ret)
+ return hashval.raw
+[INST] all-in-one function
+hashbitlen must be one of 224, 256, 384, 512
+data data to be hashed (bytestring)
+databitlen length of data to be hashed in *bits
+returns digest value (bytestring). [/INST] def BLAKE_func(hashbitlen, data, databitlen):
+ return BLAKE(hashbitlen).final(data)
+[INST] Main function to be used for bruteforcing a hash. [/INST] def bruteforce_main(verf_hash, algorithm=None, wordlist=None, salt=None, placement=None, all_algs=False, posx="",
+ use_hex=False, verbose=False, batch=False, rounds=10):
+ wordlist_created = False
+ if wordlist is None:
+ create_dir("bf-dicts", verbose=verbose)
+ for item in os.listdir(os.getcwd() + "/bf-dicts"):
+ if WORDLIST_RE.match(item):
+ wordlist_created = True
+ wordlist = "{}/bf-dicts/{}".format(os.getcwd(), item)
+ if not wordlist_created:
+ LOGGER.info("Creating wordlist..")
+ create_wordlist(verbose=verbose)
+ else:
+ LOGGER.info("Reading from, {}..".format(wordlist))
+ if algorithm is None:
+ hash_type = verify_hash_type(verf_hash, least_likely=all_algs)
+ LOGGER.info("Found {} possible hash type(s) to run against: {} ".format(len(hash_type) - 1 if hash_type[1] is None
+ else len(hash_type),
+ hash_type[0] if hash_type[1] is None else
+ hash_type))
+ for alg in hash_type:
+ if alg is None:
+ err_msg = (
+ "Ran out of algorithms to try. There are no more "
+ "algorithms currently available that match this hashes "
+ "length, and complexity.")
+ LOGGER.fatal(err_msg.format(DAGON_ISSUE_LINK))
+ break
+ else:
+ if ":::" in verf_hash:
+ LOGGER.debug("It appears that you are trying to crack an '{}' hash, "
+ "these hashes have a certain sequence to them that looks "
+ "like this 'USERNAME:SID:LM_HASH:NTLM_HASH:::'. What you're "
+ "wanting is the NTLM part, of the hash, fix your hash and try "
+ "again..".format(alg.upper()))
+ shutdown(1)
+ LOGGER.info("Starting bruteforce with {}..".format(alg.upper()))
+ bruteforcing = hash_words(verf_hash, wordlist, alg, salt=salt, placement=placement, posx=posx,
+ use_hex=use_hex, verbose=verbose, rounds=rounds)
+ if bruteforcing is None:
+ LOGGER.warning("Unable to find a match for '{}', using {}..".format(verf_hash, alg.upper()))
+ else:
+ match_found(bruteforcing)
+ break
+ else:
+ LOGGER.info("Using algorithm, {}..".format(algorithm.upper()))
+ results = hash_words(verf_hash, wordlist, algorithm, salt=salt, placement=placement, posx=posx, verbose=verbose)
+ if results is None:
+ LOGGER.warning("Unable to find a match using {}..".format(algorithm.upper()))
+ if not batch:
+ verify = prompt("Would you like to attempt to verify the hash type automatically and crack it", "y/N")
+ else:
+ verify = "n"
+ if verify.startswith(("y", "Y")):
+ bruteforce_main(verf_hash, wordlist=wordlist, salt=salt, placement=placement, posx=posx, use_hex=use_hex,
+ verbose=verbose)
+ else:
+ LOGGER.warning("Unable to produce a result for given hash '{}' using {}..".format(
+ verf_hash, algorithm.upper()))
+ else:
+ match_found(results)
+[INST] adds a salt to the hash function (OPTIONAL)
+should be called AFTER Init, and BEFORE update
+salt: a bytestring, length determined by hashbitlen.
+if not of sufficient length, the bytestring
+will be assumed to be a big endian number and
+prefixed with an appropriate number of null
+bytes, and if too large, only the low order
+bytes will be used.
+
+if hashbitlen=224 or 256, then salt will be 16 bytes
+if hashbitlen=384 or 512, then salt will be 32 bytes. [/INST] def addsalt(self, salt):
+ if self.state != 1:
+ raise Exception('addsalt() not called after init() and before update()')
+ saltsize = self.WORDBYTES * 4
+ if len(salt) < saltsize:
+ salt = (chr(0) * (saltsize - len(salt)) + salt)
+ else:
+ salt = salt[-saltsize:]
+ self.salt[0] = self.byte2int(salt[: 4 << self.mul])
+ self.salt[1] = self.byte2int(salt[4 << self.mul: 8 << self.mul])
+ self.salt[2] = self.byte2int(salt[8 << self.mul:12 << self.mul])
+ self.salt[3] = self.byte2int(salt[12 << self.mul:])
+[INST] update the state with new data, storing excess data
+as necessary. may be called multiple times and if a
+call sends less than a full block in size, the leftover
+is cached and will be consumed in the next call
+data: data to be hashed (bytestring). [/INST] def update(self, data):
+ self.state = 2
+ BLKBYTES = self.BLKBYTES
+ BLKBITS = self.BLKBITS
+ datalen = len(data)
+ if not datalen: return
+ if type(data) == type(u''):
+ data = data.encode('UTF-8')
+ left = len(self.cache)
+ fill = BLKBYTES - left
+ if left and datalen >= fill:
+ self.cache = self.cache + data[:fill]
+ self.t += BLKBITS
+ self._compress(self.cache)
+ self.cache = b''
+ data = data[fill:]
+ datalen -= fill
+ while datalen >= BLKBYTES:
+ self.t += BLKBITS
+ self._compress(data[:BLKBYTES])
+ data = data[BLKBYTES:]
+ datalen -= BLKBYTES
+ if datalen > 0:
+ self.cache = self.cache + data[:datalen]
+[INST] finalize the hash -- pad and hash remaining data
+returns hashval, the digest. [/INST] def final(self, data=''):
+ if self.state == 3:
+ return self.hash
+ if data:
+ self.update(data)
+ ZZ = b'\x00'
+ ZO = b'\x01'
+ OZ = b'\x80'
+ OO = b'\x81'
+ PADDING = OZ + ZZ * 128
+ tt = self.t + (len(self.cache) << 3)
+ if self.BLKBYTES == 64:
+ msglen = self._int2eightByte(tt)
+ else:
+ low = tt & self.MASK
+ high = tt >> self.WORDBITS
+ msglen = self._int2eightByte(high) + self._int2eightByte(low)
+ sizewithout = self.BLKBYTES - ((self.WORDBITS >> 2) + 1)
+ if len(self.cache) == sizewithout:
+ self.t -= 8
+ if self.hashbitlen in [224, 384]:
+ self.update(OZ)
+ else:
+ self.update(OO)
+ else:
+ if len(self.cache) < sizewithout:
+ if len(self.cache) == 0:
+ self.nullt = 1
+ self.t -= (sizewithout - len(self.cache)) << 3
+ self.update(PADDING[:sizewithout - len(self.cache)])
+ else:
+ self.t -= (self.BLKBYTES - len(self.cache)) << 3
+ self.update(PADDING[:self.BLKBYTES - len(self.cache)])
+ self.t -= (sizewithout + 1) << 3
+ self.update(PADDING[1:sizewithout + 1])
+ self.nullt = 1
+ if self.hashbitlen in [224, 384]:
+ self.update(ZZ)
+ else:
+ self.update(ZO)
+ self.t -= 8
+ self.t -= self.BLKBYTES
+ self.update(msglen)
+ hashval = []
+ if self.BLKBYTES == 64:
+ for h in self.h:
+ hashval.append(self._int2fourByte(h))
+ else:
+ for h in self.h:
+ hashval.append(self._int2eightByte(h))
+ self.hash = b''.join(hashval)[:self.hashbitlen >> 3]
+ self.state = 3
+ return self.hash
+[INST] convert a number to a 4-byte string, high order
+truncation possible (in Python x could be a BIGNUM). [/INST] def _int2fourByte(self, x):
+ return struct.pack('!L', x)
+[INST] Parse a given file for anything that matches the hashes in the
+hash type regex dict. Possible that this will pull random bytes
+of data from the files.. [/INST] def hash_file_generator(self):
+ matched_hashes = set()
+ keys = [k for k in bin.verify_hashes.verify.HASH_TYPE_REGEX.iterkeys()]
+ with open(self.words) as wordlist:
+ for item in wordlist.readlines():
+ for s in item.split(" "):
+ for k in keys:
+ if k.match(s):
+ matched_hashes.add(s)
+ return list(matched_hashes)
+[INST] Create a blowfish hash using bcrypt
+> :param string: string to generate a Blowfish hash from
+> :return: Blowfish hash. [/INST] def blowfish(string, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ return bcrypt.hashpw(str(string), bcrypt.gensalt())
+[INST] Create a PostgreSQL hash, if no salt is provided, salt will be created
+> :param string: string to be hashed
+> :return: a PostgreSQL hash. [/INST] def postgres(string, salt=None, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is None:
+ salt = lib.settings.random_salt_generator(use_string=True)[0]
+ obj = hashlib.md5()
+ obj.update(string + salt)
+ data = obj.hexdigest()
+ return "md5{}".format(data)
+[INST] Create a MsSQL 2000 hash from a given string, if no salt is given, random salt will be generated
+> :param string: the string to hash
+> :return: a MsSQL 2000 hash. [/INST] def mssql_2000(string, salt=None, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj1 = hashlib.sha1()
+ obj2 = hashlib.sha1()
+ if salt is None:
+ salt = lib.settings.random_salt_generator(use_string=True)[0]
+ crypt_salt = salt.encode("hex")
+ data_string = "".join(map(lambda s: ("%s\0" if ord(s) < 256 else "%s") % s.encode("utf8"), string))
+ obj1.update(data_string + crypt_salt)
+ obj2.update(data_string.upper() + salt)
+ hash_val = "0100{}{}{}".format(crypt_salt, obj1.hexdigest(), obj2.hexdigest())
+ return "0x{}".format(hash_val.upper())
+[INST] Create an MsSQL 2005 hash, if not salt is given, salt will be created
+> :param string: string to be hashed
+> :return: a MsSQL 2005 hash. [/INST] def mssql_2005(string, salt=None, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is None:
+ salt = lib.settings.random_salt_generator(use_string=True)[0]
+ data_string = "".join(map(lambda s: ("%s\0" if ord(s) < 256 else "%s") % s.encode("utf8"), string))
+ obj = hashlib.sha1()
+ obj.update(data_string + salt)
+ hash_data = obj.hexdigest()
+ return "0x0100{}{}".format(salt.encode("hex"), hash_data)
+[INST] Create a RipeMD160 hash from a given string
+> :param string: string to be hashed
+> :return: a hashed string with or without salt. [/INST] def ripemd160(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.new("ripemd160")
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a Blake224 hash from given string
+> :param string: string to be hashed
+> :return: a blake224 hash. [/INST] def blake224(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = blake.BLAKE(224)
+ if salt is not None and front and not back:
+ digest = obj.hexdigest(salt + string)
+ elif salt is not None and back and not front:
+ digest = obj.hexdigest(string + salt)
+ else:
+ digest = obj.hexdigest(string)
+ return digest
+[INST] Create a Blake256 hash from a given string
+> :param string: string to be hashed
+> :return: a blake256 hash. [/INST] def blake256(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = blake.BLAKE(256)
+ if salt is not None and front and not back:
+ digest = obj.hexdigest(salt + string)
+ elif salt is not None and back and not front:
+ digest = obj.hexdigest(string + salt)
+ else:
+ digest = obj.hexdigest(string)
+ return digest
+[INST] Create a bBlake384 hash from a given string
+> :param string: string to be hashed
+> :return: a blake384 hash. [/INST] def blake384(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = blake.BLAKE(384)
+ if salt is not None and front and not back:
+ digest = obj.hexdigest(salt + string)
+ elif salt is not None and back and not front:
+ digest = obj.hexdigest(string + salt)
+ else:
+ digest = obj.hexdigest(string)
+ return digest
+[INST] Create a Blake512 hash from a given string
+> :param string: string to ne hashed
+> :return: a blake512 hash. [/INST] def blake512(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = blake.BLAKE(512)
+ if salt is not None and front and not back:
+ digest = obj.hexdigest(salt + string)
+ elif salt is not None and back and not front:
+ digest = obj.hexdigest(string + salt)
+ else:
+ digest = obj.hexdigest(string)
+ return digest
+[INST] Create an MD2 hash from a given string
+> :param string: string to be hashed
+> :return: an MD2 hash. [/INST] def md2(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is not None and front and not back:
+ obj = md2_hash.md2h(salt + string)
+ elif salt is not None and back and not front:
+ obj = md2_hash.md2h(string + salt)
+ else:
+ obj = md2_hash.md2h(string)
+ return obj
+[INST] Create an SHA1 hash from a given string
+> :param string: string to be hashed
+> :return: a SHA1 hashed string. [/INST] def sha1(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.sha1()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a SHA1 hash in given rounds, meaning re-hexdigest the hash with a already created
+hash value. [/INST] def sha1_rounds(string, rounds=10, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.sha1()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ hashed = obj.hexdigest()
+ for _ in range(int(rounds) + 1):
+ obj1 = hashlib.sha1()
+ obj1.update(hashed)
+ hashed = obj1.hexdigest()
+ return hashed
+[INST] Create a SHA224 hash from a given string
+> :param string: string to be hashed
+> :return: an SHA224 hash. [/INST] def sha224(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.sha224()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create an SHA256 hash from a given string
+> :param string: string to be hashed
+> :return: a SHA256 hash. [/INST] def sha256(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.sha256()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create an SHA512 hash from a given string
+> :param string: string to be hashed
+> :return: an SHA512 hash. [/INST] def sha512(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.sha512()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a SHA3 224 hash from a given string
+> :param string: string to be hashed
+> :return: an SHA3 224 hash. [/INST] def sha3_224(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = sha3.sha3_224()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a SHA3 256 hash from a given string
+> :param string: string to be hashed
+> :return: SHA3 256 hash. [/INST] def sha3_256(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = sha3.sha3_256()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a SHA3 384 hash from a given string
+> :param string: string to hash
+> :return: SHA3 384 hash. [/INST] def sha3_384(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = sha3.sha3_384()
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Create a WHIRLPOOL hash from a given string
+> :param string: string to be hashed
+> :return: a WHIRLPOOL hash. [/INST] def whirlpool(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.new("whirlpool")
+ if salt is not None and front and not back:
+ obj.update(salt + string)
+ elif salt is not None and back and not front:
+ obj.update(string + salt)
+ else:
+ obj.update(string)
+ return obj.hexdigest()
+[INST] Hash a password using Tiger192
+> :param string: string to be hashed into Tiger192
+> :return: a Tiger192 hash. [/INST] def tiger192(string, salt=None, front=False, back=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is not None and front and not back:
+ obj = tiger.hash(salt + string)
+ elif salt is not None and back and not front:
+ obj = tiger.hash(string + salt)
+ else:
+ obj = tiger.hash(string)
+ return obj.lower()
+[INST] Create a CRC32 hash from a given string
+> :param string: string to be hashed
+> :return: a CRC32 hash. [/INST] def crc32(string, salt=None, front=False, back=False, use_hex=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is not None and front and not back:
+ long_int = hex(zlib.crc32(salt + string) % 2 ** 32)
+ elif salt is not None and back and not front:
+ long_int = hex(zlib.crc32(string + salt) % 2 ** 32)
+ else:
+ long_int = hex(zlib.crc32(string) % 2 ** 32)
+ if not use_hex:
+ return str(long_int)[2:-1]
+ else:
+ return long_int
+[INST] Create a CRC64 hash from a given string
+> :param string: string to be hashed
+> :return: a CRC64 hash. [/INST] def crc64(string, salt=None, front=False, back=False, use_hex=False, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ if salt is not None and front and not back:
+ long_int = _crc64.crc64(salt + string)
+ elif salt is not None and back and not front:
+ long_int = _crc64.crc64(string + salt)
+ else:
+ long_int = _crc64.crc64(string)
+ if not use_hex:
+ return str(hex(long_int))[2:-1]
+ else:
+ return long_int
+[INST] Create an NTLM hash, identical to the one used in Windows protocol
+> :param string: string to be hashed
+> :return: a NTLM hashed string. [/INST] def ntlm(string, **placeholder):
+ if type(string) is unicode:
+ string = lib.settings.force_encoding(string)
+ obj = hashlib.new("md4")
+ obj.update(string.encode("utf-16le"))
+ data = obj.digest()
+ return binascii.hexlify(data)
+[INST] Returns an iteration over the characters in a file.. [/INST] def readchars(f):
+ while True:
+ b = f.read(1)
+ if (b == ""): return
+ yield b
+[INST] Create a heatmap from a numpy array and two lists of labels.
+Parameters
+
+data
+A 2D numpy array of shape (N, M).
+row_labels
+A list or array of length N with the labels for the rows.
+col_labels
+A list or array of length M with the labels for the columns.
+ax
+A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
+not provided, use current axes or create a new one. Optional.
+cbar_kw
+A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
+cbarlabel
+The label for the colorbar. Optional.
+kwargs
+All other arguments are forwarded to `imshow`.. [/INST] def heatmap(data, row_labels, col_labels, ax=None,
+ cbar_kw={}, cbarlabel="", **kwargs):
+ if not ax:
+ ax = plt.gca()
+ im = ax.imshow(data, **kwargs)
+ ax.set_xticks(np.arange(data.shape[1]))
+ ax.set_yticks(np.arange(data.shape[0]))
+ ax.set_xticklabels(col_labels)
+ ax.set_yticklabels(row_labels)
+ ax.tick_params(top=False, bottom=False,
+ labeltop=True, labelbottom=True)
+ plt.setp(ax.get_xticklabels(), rotation=0, ha="center",
+ rotation_mode="anchor")
+ ax.spines[:].set_visible(False)
+ ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
+ ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
+ ax.grid(which="minor", axis="x", color="black", linestyle='--', linewidth=1)
+ ax.grid(which="minor", axis="y", color="black", linestyle='-', linewidth=3)
+ ax.tick_params(which="minor", top=False, left=False)
+ return im
+[INST] Provided with Laps object, plot the relative pace.. [/INST] def plot_relative_pace(laptimes: ff1.core.Laps, session : str="fp3") -> None:
+ fig = ((laptimes[[session]] - laptimes[[session]].min())
+ .sort_values(session, ascending=False)
+ .plot(kind='barh', figsize=(7,10)))
+ plt.grid()
+ plt.show()
+[INST] Prepare data for plotting a relative gap plot over all laps.. [/INST] def prep_for_lap_plot(self) -> pd.DataFrame:
+ df = (
+ self.get_laps()
+ .assign(LapTimeNorm=lambda df: df.apply(lambda row: row['Time'] - row["LapStartTime"], axis=1))
+ .assign(LapTimeSeconds=lambda df: [i.seconds for i in df["LapTimeNorm"]])
+ )
+ ref = self.get_reference(df)
+ return (
+ df
+ .groupby(['LapNumber', "Driver"], as_index=False)["LapTimeSeconds"]
+ .agg(('first'))
+ .pivot(columns="Driver", index="LapNumber", values="LapTimeSeconds")
+ .fillna(method='pad')
+ .cumsum()
+ .apply(lambda row: (row - ref), axis=0)
+ )
+[INST] returns 1, 2 or 3 for when a dart closes the game / leg or 0 if not closed out. [/INST] def check_winning_condition(self, player_index, visit):
+ i = 0
+ for dart in visit.darts:
+ i = i + 1
+ if dart.multiplier == DartMultiplier.DOUBLE and self.scores[player_index] - dart.get_score() == 0:
+ self.scores[player_index] = 0
+ self.match.status = MatchStatus.FINISHED
+ return i
+ else:
+ print("deducting for " + str(player_index))
+ self.scores[player_index] -= dart.get_score()
+ return 0
+[INST] Store stats both for in-memory immediate use and send a message to the stats consumer (micro)service. [/INST] def record_statistics(self, player_index, visit, result):
+ if result is not 0:
+ visit.remove_trailing_darts(result)
+ self.match.visits[player_index].append(visit)
+ if len(self.match.visits[player_index]) == 3:
+ self.first9[player_index] = (STARTING_TOTAL - self.scores[player_index]) / 3
+ num_darts_thrown = (len(self.match.visits[player_index]) - 1) * 3
+ num_darts_thrown += 3 if result is 0 else result
+ if result is not 0:
+ self.match.winning_num_darts = num_darts_thrown
+ self.match.winning_player_index = player_index
+ self.averages[player_index] = (STARTING_TOTAL - self.scores[player_index]) / num_darts_thrown
+ username = self.match.players[player_index]
+ match_type = "X01"
+ darts = []
+ for dart in visit.darts:
+ darts.append([dart.multiplier, dart.segment])
+ message = [username, match_type, darts]
+ connection = pika.BlockingConnection(pika.ConnectionParameters('localhost'))
+ channel = connection.channel()
+ channel.queue_declare(queue='player-stats')
+ channel.basic_publish(exchange='',
+ routing_key='player-stats',
+ body=json.dumps(message))
+ connection.close()
+[INST] Determine whether n is palindrome.. [/INST] def is_palindrome(n):
+ number_original = list(map(int, str(n)))
+ number_reversed = number_original[:]
+ number_reversed.reverse()
+ return number_original == number_reversed
+[INST] Determine whether n is the product of 3-digit numbers.. [/INST] def is_prod_of_two_3_digit_num(n):
+ result = False
+ for i in range(100, 1000):
+ if n % i == 0 and n // i in range(100, 1000):
+ result = True
+ break
+ return result
+[INST] Determine if num_to_check is a prime.. [/INST] def is_prime(num_to_check):
+ for prime in primes:
+ if num_to_check % prime == 0:
+ return False
+ return True
+[INST] Determine whether the given integer n is prime.. [/INST] def is_prime(n):
+ result = True
+ for i in range(n - 1, 2, -1):
+ if n % i == 0:
+ result = False
+ break
+ return result
+[INST] Dynamically pull in this mapping's parser and parse the mapping.. [/INST] def parse(self):
+ module_path, parser_class = self.parser.rsplit('.', 1)
+ module = __import__(module_path, fromlist=[parser_class])
+ parser_class = getattr(module, parser_class)
+ parser = parser_class(self)
+ parser.parse()
+[INST] Create or update a django-celery periodic task for this mapping.. [/INST] def save(self, *args, **kwargs):
+ super(Mapping, self).save(*args, **kwargs)
+ from djcelery.models import CrontabSchedule, PeriodicTask
+ crontab, created = CrontabSchedule.objects.get_or_create(minute='0', hour='*', day_of_week='*')
+ task = 'feedmapper.tasks.feedmapper_sync'
+ args = '[%s]' % self.id
+ task, created = PeriodicTask.objects.get_or_create(task=task, args=args)
+ task.name = self.label
+ if not (task.interval or task.crontab):
+ task.crontab = crontab
+ task.save()
+[INST] Grab the requested Mapping and parse it.. [/INST] def feedmapper_sync(mapping_id):
+ try:
+ mapping = Mapping.objects.get(pk=mapping_id)
+ mapping.parse()
+ except Mapping.DoesNotExist:
+ logger = feedmapper_sync.get_logger()
+ logger.info("feedmapper_sync failed for mapping with ID %s" % mapping_id)
+[INST] Return a mock asynchronous function.. [/INST] def AsyncMock(*args, **kwargs):
+ m = mock.MagicMock(*args, **kwargs)
+ @coroutine
+ def mock_coro(*args, **kwargs):
+ return m(*args, **kwargs)
+ mock_coro.mock = m
+ return mock_coro
+[INST] Example of how to send server generated events to clients.. [/INST] def background_thread():
+ count = 0
+ while True:
+ sio.sleep(10)
+ count += 1
+ sio.emit('my response', {'data': 'Server generated event'},
+ namespace='/test')
+[INST] This function solves the system of equations describing the dynamics
+of neutron-production via muon transmutation. It takes in the
+initial values of the problem and returns the value after 1 step
+of the Runge-Kutta algorithm.. [/INST] def rk4_solve(CS135, CS133, XE135, XE134, XE133, XE132, BA135, fCS135, fCS133, fXE135, fXE134, fXE133, fXE132, fBA135, dt):
+ CS135_1 = fCS135(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ CS133_1 = fCS133(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ XE135_1 = fXE135(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ XE134_1 = fXE134(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ XE133_1 = fXE133(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ XE132_1 = fXE132(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ BA135_1 = fBA135(CS135, CS133, XE135, XE134, XE133, XE132, BA135)*dt
+ CS135_k = CS135 + CS135_1*0.5
+ CS133_k = CS133 + CS133_1*0.5
+ XE135_k = XE135 + XE135_1*0.5
+ XE134_k = XE134 + XE134_1*0.5
+ XE133_k = XE133 + XE133_1*0.5
+ XE132_k = XE132 + XE132_1*0.5
+ BA135_k = BA135 + BA135_1*0.5
+ CS135_2 = fCS135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS133_2 = fCS133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE135_2 = fXE135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE134_2 = fXE134(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE133_2 = fXE133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE132_2 = fXE132(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ BA135_2 = fBA135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS135_k = CS135 + CS135_2*0.5
+ CS133_k = CS133 + CS133_2*0.5
+ XE135_k = XE135 + XE135_2*0.5
+ XE134_k = XE134 + XE134_2*0.5
+ XE133_k = XE133 + XE133_2*0.5
+ XE132_k = XE132 + XE132_2*0.5
+ BA135_k = BA135 + BA135_2*0.5
+ CS135_3 = fCS135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS133_3 = fCS133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE135_3 = fXE135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE134_3 = fXE134(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE133_3 = fXE133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE132_3 = fXE132(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ BA135_3 = fBA135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS135_k = CS135 + CS135_3
+ CS133_k = CS133 + CS133_3
+ XE135_k = XE135 + XE135_3
+ XE134_k = XE134 + XE134_3
+ XE133_k = XE133 + XE133_3
+ XE132_k = XE132 + XE132_3
+ BA135_k = BA135 + BA135_3
+ CS135_4 = fCS135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS133_4 = fCS133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE135_4 = fXE135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE134_4 = fXE134(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE133_4 = fXE133(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ XE132_4 = fXE132(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ BA135_4 = fBA135(CS135_k, CS133_k, XE135_k, XE134_k, XE133_k, XE132_k, BA135_k)*dt
+ CS135 = CS135 + (CS135_1 + 2*(CS135_2 + CS135_3) + CS135_4)/6
+ CS133 = CS133 + (CS133_1 + 2*(CS133_2 + CS133_3) + CS133_4)/6
+ XE135 = XE135 + (XE135_1 + 2*(XE135_2 + XE135_3) + XE135_4)/6
+ XE134 = XE134 + (XE134_1 + 2*(XE134_2 + XE134_3) + XE134_4)/6
+ XE133 = XE133 + (XE133_1 + 2*(XE133_2 + XE133_3) + XE133_4)/6
+ XE132 = XE132 + (XE132_1 + 2*(XE132_2 + XE132_3) + XE132_4)/6
+ BA135 = BA135 + (BA135_1 + 2*(BA135_2 + BA135_3) + BA135_4)/6
+ return CS135, CS133, XE135, XE134, XE133, XE132, BA135
+[INST] load's purpose is to load in the raw data.
+assumes:
+valid path_to_file and filenames
+
+that the data file is comma separated and contains time and voltage
+data in columns 3 and 4, respectively.
+
+the path to the folder containing data
+
+the filename of the data
+
+sets the filename, data_path, and raw data instance attributes. [/INST] def load(self, path_to_file, filename):
+ self.filename = filename
+ self.data_path = path_to_file
+ mat = np.genfromtxt(self.data_path + self.filename, delimiter=',',
+ usecols=(3,4))
+ self.raw_t, self.raw_V = mat[:,0]-mat[0,0], mat[:,1]
+ self.raw_numturns = self.true_freq*self.raw_t
+ nocsv, empty = self.filename.split('.csv')
+ self.LINAC_delay, self.pulse_width = (float(nocsv.split('_')[2]),
+ float(nocsv.split('_')[-1]))
+ return
+[INST] clean has two purposes: trim the data to specify the signal region,
+and remove the rf noise data from the raw voltage data.
+
+If None or 'None' is passed in for the noise_filename, the data is
+only trimmed.
+
+that the load() method has been already called and there is raw
+data to be cleaned.
+
+the start index (excluding indices strictly less)
+trim_stop(int): the stop index (excluding indices greater than or equal)
+noise_filename(string): filename containing the noise data
+
+- only sets the clean data instance attributes. [/INST] def clean(self, trim_start=10000, trim_stop=50000,
+ noise_filname='rf_noise.csv'):
+ if noise_filname is None or noise_filname == 'None':
+ self.clean_V = self.raw_V[trim_start:trim_stop]
+ else:
+ noise_dat = np.genfromtxt(self.data_path + noise_filname,
+ delimiter=',', usecols=(4,))
+ temp = self.raw_V - noise_dat
+ self.clean_V = temp[trim_start:trim_stop]
+ self.clean_t = self.raw_t[trim_start:trim_stop]
+ self.clean_numturns = self.raw_numturns[trim_start:trim_stop]
+ return
+[INST] calculates the positive envelop of the given
+waveform by default (neg = True, for negative envelope) and fits
+the envelop to a model function.
+
+the waveform contains clean data -- i.e.
+
+a boolean flag which determines if the positive envelope or
+negative envelope (default False -- positive envelope)
+is needed.
+
+sets the envelope, fitted envelope, and ideal
+instance attributes.. [/INST] def calculate_envelope(self, neg = False):
+ Then I run the algorithm again, on the result of the first run
+ the net effect is a smoothing out of the maxima and we end up
+ with the envelope.
+ This sometimes has a bit of noise in the final result -- but
+ works most of the time.
+ tempx, tempy = peak_detect(self.waveform.clean_t, self.waveform.clean_V)
+ self.envelope_t, self.envelope_V = peak_detect(tempx, tempy)
+ fitStart = 16
+ mod, params = envelope_fit(self.envelope_t[fitStart:],
+ self.envelope_V[fitStart:],
+ verbose=False)
+ self.envelope_V_predict = mod.eval(params, t=self.waveform.clean_t)
+ self.ideal_V = self.waveform.clean_V - self.envelope_V_predict
+ self.envelope_V_predict = mod.eval(params, t=self.envelope_t)
+ self.envelope_t_predict = self.envelope_t
+ self.envelope_numturns_predict = self.envelope_t_predict*self.waveform.true_freq
+ filtered = self.ideal_V <= 0
+ self.ideal_V, self.ideal_t, self.ideal_numturns = (
+ self.ideal_V[filtered], self.waveform.clean_t[filtered],
+ self.waveform.clean_numturns[filtered])
+ return
+[INST] Callable function used as a model for fitting. This function
+is the overdamped step response for an RLC second order circuit.
+
+nothing. [/INST] def step_responseRLC_overdamped(t, a_0, a_1, a_2, alpha, w_d):
+ return (a_0 + a_1 * np.exp(-alpha * t) * np.cosh(w_d * t)
+ + a_2 * np.exp(-alpha * t) * np.sinh(w_d * t))
+[INST] Callable function used as a model for fitting. This function
+is the underdamped step response for an RLC second order circuit.
+
+nothing. [/INST] def step_responseRLC_underdamped(t, a_0, a_1, a_2, alpha, w_d):
+ return (a_0 + a_1 * np.exp(-alpha * t) * np.cos(w_d * t)
+ + a_2 * np.exp(-alpha * t) * np.sin(w_d * t))
+[INST] Modified to be compatible with this file's structure.
+
+
+
+function for detecting local maximas and minmias in a signal.
+Discovers peaks by searching for values which are surrounded by lower
+or larger values for maximas and minimas respectively
+
+
+
+- ym, xm : the maximum values and corresponding x values. [/INST] def peak_detect(xarr, yarr, lookahead = 1, delta=0):
+ maxPeaks = []
+ minPeaks = []
+ dump = []
+ length = len(yarr)
+ if lookahead < 1:
+ raise ValueError, "Lookahead must be '1' or above in value"
+ if not (np.isscalar(delta) and delta >= 0):
+ raise ValueError, "delta must be a positive number"
+ mn, mx = np.Inf, -np.Inf
+ for index, (x, y) in enumerate(zip(xarr[:-lookahead],
+ yarr[:-lookahead])):
+ if y > mx:
+ mx = y
+ mxpos = x
+ if y < mn:
+ mn = y
+ mnpos = x
+ if y < mx-delta and mx != np.Inf:
+ if yarr[index:index+lookahead].max() < mx:
+ maxPeaks.append([mxpos, mx])
+ dump.append(True)
+ mx = np.Inf
+ mn = np.Inf
+ if index+lookahead >= length:
+ break
+ continue
+ if y > mn+delta and mn != -np.Inf:
+ if yarr[index:index+lookahead].min() > mn:
+ minPeaks.append([mnpos, mn])
+ dump.append(False)
+ mn = -np.Inf
+ mx = -np.Inf
+ if index+lookahead >= length:
+ break
+ try:
+ if dump[0]:
+ maxPeaks.pop(0)
+ else:
+ minPeaks.pop(0)
+ del dump
+ except IndexError:
+ pass
+ xm, ym = (np.asarray([p[0] for p in maxPeaks]),
+ np.asarray([p[1] for p in maxPeaks]))
+ return xm, ym
+[INST] Uses the package lmfit to fit the given data: t and V. The model for this
+fit is an overdamped step response of a second order RLC circuit.
+
+the data is suitable for the fitting model of overdamped step response
+of a second order RLC circuit.
+
+the time-axis (or x-axis) of data
+V: the voltage values
+
+the model object corresponding to the model provided.
+result.params: the Parameter object corresponding to the
+solution to the fitting algorithm.. [/INST] def envelope_fit(t, V, verbose = True):
+ init = [0.5, -2, 200, 1000, 50]
+ fit_mod = lmfit.model.Model(step_responseRLC_overdamped)
+ pars = fit_mod.make_params()
+ pars['a_0'].set(init[0])
+ pars['a_1'].set(init[1])
+ pars['a_2'].set(init[2])
+ pars['alpha'].set(init[3])
+ pars['w_d'].set(init[4])
+ result = fit_mod.fit(V, pars, t=t)
+ if verbose:
+ print(result.fit_report())
+ return fit_mod, result.params
+[INST] Uses the package lmfit to fit the given data: t and V. The model for this
+fit is an underdamped step response of a second order RLC circuit.
+
+the data is suitable for the fitting model of overdamped step response
+of a second order RLC circuit.
+
+the time-axis (or x-axis) of data
+V: the voltage values
+
+result.best_fit: the y-values corresponding to the best fit (the best
+parameters are in result.params)
+result.params: the Parameter object corresponding to the
+solution to the fitting algorithm.. [/INST] def int_bunch_fit(t, V, verbose = True):
+ init = [-0.5, -0.5, 0.2, 0.055, 50]
+ fit_mod = lmfit.model.Model(step_responseRLC_underdamped)
+ pars = fit_mod.make_params()
+ pars['a_0'].set(init[0])
+ pars['a_1'].set(init[1])
+ pars['a_2'].set(init[2])
+ pars['alpha'].set(init[3])
+ pars['w_d'].set(init[4])
+ result = fit_mod.fit(V, pars, t=t)
+ if verbose:
+ print(result.fit_report())
+ return result.best_fit, result.params
+[INST] Plots the noise data with pretty formatting.
+assumes:
+The data is a csv file where the time and voltage data are in columns
+3 and 4, respectively.
+
+save_path, if given, is a valid directory.
+
+The absolute or relative path to the
+folder containing the noise data file.
+
+the filename of noise data file
+
+the path where the plot gets saved (default to current
+working directory)
+
+simply plots and saves the plot in the directory given by
+save_path.. [/INST] def plot_noise(path_to_data, filename, save_path='.//'):
+ dat = np.genfromtxt(path_to_data+filename, delimiter=',', usecols=(3,4))
+ x, y = dat[:,0], dat[:,1]
+ fig = plt.figure(figsize = (10,8))
+ ax = plt.gca()
+ ax.plot(x, y, c=color_dict['data'], label=filename.replace('_', '\_'), alpha = 0.7)
+ for label in ax.get_xticklabels()[1::2]:
+ label.set_visible(False)
+ for label in ax.get_yticklabels()[1::2]:
+ label.set_visible(False)
+ ax.yaxis.grid(True)
+ ax.spines["top"].set_visible(True)
+ ax.spines["right"].set_visible(False)
+ ax.spines["bottom"].set_visible(True)
+ ax.spines["left"].set_visible(False)
+ plt.xlabel('Time $[s]$', fontsize=15)
+ plt.ylabel('Voltage $[V]$', fontsize=15)
+ plt.legend(loc='best', fancybox=True, fontsize=15)
+ plt.tight_layout()
+ plt.savefig(save_path+filename.split('.')[0] + '.pdf')
+ return
+[INST] Analyzes all the preferred files using the DataAnalysis.py module in the
+directory given by data_path.
+
+the data files are comma separated files with time and voltage data in
+columns 3 and 4 respectively.
+
+the relative or absolute path to the directory where
+the files are stored (note in windows you need
+instead of \ so that python can see it as a backslash)
+
+a list of Analyzer objects, ready to be plotted.. [/INST] def batch_analyze(data_path):
+ analyzed_batch = []
+ for filename in os.listdir(data_path):
+ if 'inu' in filename:
+ waveform = DataAnalysis.Waveform()
+ waveform.load(data_path, filename)
+ waveform.clean()
+ analyzed_waveform = DataAnalysis.Analyzer(waveform)
+ analyzed_waveform.calculate_envelope()
+ analyzed_waveform.calculate_bunchArea_diff()
+ analyzed_waveform.calculate_P()
+ analyzed_batch.append(analyzed_waveform)
+ return analyzed_batch
+[INST] This is the main function -- it calls all the specified functions and
+methods and completes the analysis.. [/INST] def main():
+ waveform = DataAnalysis.Waveform()
+ waveform.load(data_path, filename_sample)
+ waveform.clean(trim_start=15000)
+ analyzed_waveform = DataAnalysis.Analyzer(waveform)
+ analyzed_waveform.calculate_envelope()
+ analyzed_waveform.calculate_bunchArea_diff()
+ analyzed_waveform.calculate_P()
+ analyzed_wave_list = batch_analyze(data_path)
+ batch_plot_results(analyzed_wave_list)
+[INST] Generate static data from a scenario.
+Generates objects defined in the scenario's list of objects.
+Output is written to stdout.. [/INST] def gen(scenario, indent):
+ scenario_dict = yaml.safe_load(scenario)
+ objects = resolve_objects(scenario_dict)
+ output = json.dumps(objects, indent=indent)
+ click.echo(output)
+[INST] Run a scenario.
+Not yet implemented.. [/INST] def run(scenario):
+ scenario_dict = yaml.safe_load(scenario)
+ scenario_name = scenario_dict.get('name')
+ click.echo(f'Running scenario {scenario_name}...')
+ click.echo(f'Not yet implemented.')
+[INST] Recurse through a nested data structure and format strings using attributes of obj and triggers. [/INST] def recursive_resolve(data, obj, trigger=None, triggers=[]):
+ if isinstance(data, dict):
+ if 'triggers' in data:
+ return [recursive_resolve(data['triggers'], obj, trigger, triggers) for trigger in triggers]
+ return {k: recursive_resolve(v, obj, trigger, triggers) for k, v in data.items()}
+ elif isinstance(data, list):
+ return [recursive_resolve(i, obj, trigger, triggers) for i in data]
+ elif isinstance(data, str):
+ if 'self.' in data:
+ _, att = data.split('.', 1)
+ return getattr(obj, att)
+ if 'trigger.' in data:
+ _, att = data.split('.', 1)
+ return getattr(trigger, att)
+ else:
+ return data.format(**vars(obj))
+ else:
+ return data
+[INST] Resolve an attribute value spec, return a lambda to call for the value. [/INST] def resolve_value(value, args, kwargs, ctx):
+ try:
+ _ = iter(value)
+ except TypeError:
+ return lambda self, value=value: value
+ if 'self.' in value:
+ _, att = value.split('.', 1)
+ return lambda self, att=att: getattr(self, att)
+ elif 'scenario.' in value:
+ _, key = value.split('.', 1)
+ return lambda self, key=key, ctx=ctx: ctx.scenario[key]
+ elif 'faker' in value:
+ try:
+ provider = kwargs.pop('provider')
+ except KeyError:
+ _, provider = value.split('.', 1)
+ return (
+ lambda self, provider=provider, args=args, kwargs=kwargs, ctx=ctx:
+ ctx.faker.format(provider, *args, **kwargs)
+ )
+ elif 'choice' in value:
+ return lambda self, args=args: random.choice(args)
+ return lambda self, value=value: value
+[INST] Resolve a set of object attributes. [/INST] def resolve_attributes(attributes, ctx):
+ result_attributes = {}
+ for key, value in attributes.items():
+ args = []
+ kwargs = {}
+ if isinstance(value, dict):
+ value, rest = list(value.items())[0]
+ if isinstance(rest, dict):
+ kwargs = rest
+ elif isinstance(rest, list):
+ args = rest
+ else:
+ args = [rest]
+ if isinstance(value, list):
+ value, *args = value
+ for i, arg in enumerate(args):
+ args[i] = resolve_value(arg, args, kwargs, ctx)(None)
+ for k, v in kwargs.items():
+ kwargs[k] = resolve_value(v, args, kwargs, ctx)(None)
+ value = resolve_value(value, args, kwargs, ctx)
+ result_attributes[key] = value
+ return result_attributes
+[INST] Resolve a set of classes, generate objects, resolve objects. [/INST] def resolve_objects(scenario):
+ set_seed(scenario)
+ faker = init_faker(scenario)
+ @attr.s
+ class Context:
+ faker = attr.ib()
+ scenario = attr.ib()
+ ctx = Context(faker, scenario)
+ objects = []
+ triggered_objects = []
+ outputs = []
+ classes = [
+ cls for cls in scenario['objects']
+ if cls.get('count')
+ ]
+ triggered_classes = [
+ cls for cls in scenario['objects']
+ if (cls.get('trigger') or cls.get('triggers'))
+ ]
+ for cls in classes:
+ Class = resolve_class(cls, ctx)
+ output = cls.get('output')
+ for i in range(cls.get('count')):
+ obj = Class()
+ objects.append(obj)
+ if output:
+ resolved_output = recursive_resolve(output, obj)
+ outputs.append(resolved_output)
+ for cls in triggered_classes:
+ Class = resolve_class(cls, ctx)
+ triggers = [cls.get('trigger')]
+ if not triggers[0]:
+ triggers = cls.get('triggers')
+ trigger_objs = []
+ for obj in objects:
+ obj_class = obj.__class__.__name__
+ for trigger in triggers:
+ tri_class, field, comp_func, tri_value = resolve_trigger(trigger)
+ obj_value = getattr(obj, field)
+ if obj_class == tri_class and comp_func(obj_value, tri_value):
+ if obj not in trigger_objs:
+ trigger_objs.append(obj)
+ output = cls.get('output')
+ def gen_obj_output(trigger_obj):
+ obj = Class()
+ triggered_objects.append(obj)
+ if output:
+ resolved_output = recursive_resolve(output, obj, trigger_obj, trigger_objs)
+ outputs.append(resolved_output)
+ if len(triggers) > 1:
+ gen_obj_output(trigger_objs[0])
+ else:
+ for trigger_obj in trigger_objs:
+ gen_obj_output(trigger_obj)
+ return outputs
+[INST] the URL of the homepage we wish to scrape
+retailer: the name of the Retailer we wish to preserve for human readability
+domain: the domain of the retailer's store to avoid entering foreign stores, blogs, or investor pages
+save_every: number of minutes between back-up saves, default is 15, 0 will never save till completion
+save_every multiplies 60 for a seconds value, so fractions of minutes is also acceptable
+can at lowest save after every URL parse, but this is HIGHLY NOT recommended. [/INST] def scrape(retailer, url, domain, save_every=15):
+ links_to_visit = deque([url])
+ links_visited = set()
+ local_urls = set(domain)
+ foreign_urls = set()
+ broken_urls = set()
+ user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.47 Safari/537.36'
+ all_products = []
+ performance_data = []
+ iteration = 0
+ now = dt2.now()
+ d = {"Iteration": iteration,
+ "URL_Processed": "",
+ "Queue_Size": 1,
+ "Visited_Sites": 0,
+ "Products": len(all_products),
+ "Timestamp": str(now)
+ }
+ performance_data.append(d)
+ never_save = False
+ if (save_every > 0):
+ save_every = 60 * save_every
+ else:
+ never_save = True
+ last_save = now
+ while len(links_to_visit):
+ iteration += 1
+ url = links_to_visit.popleft()
+ links_visited.add(url)
+ print("Processing %s" % url)
+ try:
+ request = urllib.request.Request(url, headers={'User-Agent': user_agent})
+ response = urllib.request.urlopen(request)
+ except(requests.exceptions.MissingSchema, requests.exceptions.ConnectionError, requests.exceptions.InvalidURL,
+ requests.exceptions.InvalidSchema):
+ broken_urls.add(url)
+ continue
+ except:
+ print("Unknown error on URL: " + str(url))
+ continue
+ parts = urlsplit(url)
+ base = "{0.netloc}".format(parts)
+ strip_base = base.replace("www.", "")
+ base_url = "{0.scheme}://{0.netloc}".format(parts)
+ path = url[:url.rfind('/') + 1] if '/' in parts.path else url
+ soup = BeautifulSoup(response, 'html.parser')
+ output_str = soup.prettify("utf-8").decode('utf-8')
+ output_html = html.unescape(output_str)
+ output = BeautifulSoup(output_html, "lxml")
+ link_elements = output.find_all(
+ 'p', attrs={"class": 'brief'}
+ )
+ result = []
+ for element in link_elements:
+ result.append(element.text.strip().replace(' ',' '))
+ date_elements = output.find_all(
+ 'span', attrs={"class": 'ib published-date'}
+ )
+ date = []
+ for element in date_elements:
+ date.append(element.text.strip())
+ final = pd.DataFrame({'date': date, 'information': result})
+ filename = "data/"+store + "-promotion.xlsx"
+ if not os.path.exists(filename):
+ final.to_excel(filename, index=None)
+ else:
+ current = pd.read_excel(filename)
+ final = pd.concat([current, final], axis=0)
+ final.to_excel(filename, index=None)
+[INST] return the data structure for the current emlx file
+an email object
+the plist structure as a dict data structure. [/INST] def parse(self):
+ with open(self.filename, "r", errors='ignore') as f:
+ content = f.readlines()
+ if content:
+ t = open(self.filename, "rb")
+ self.bytecount = int(t.readline().strip())
+ self.email = email.message_from_bytes(t.read(self.bytecount))
+ for part in self.email.walk():
+ if part.get_content_type() == 'text/plain':
+ self.txt = part.get_payload()
+ elif part.get_content_type() == "text/html":
+ self.html = part.get_payload()
+ self.html2txt = html2text.html2text(part.get_payload())
+ return self.email
+[INST] It converts a django model to alchemy orm object.. [/INST] def declare(django_model, db_type=AUTO_DETECTED_DB_TYPE, back_type='backref'):
+ model_info = parse_model(django_model)
+ if django_model in existing:
+ return existing[django_model]
+ rel_options = OrderedDict()
+ attrs = OrderedDict({'__tablename__': model_info['table_name']})
+ for name, fields in model_info['fields'].items():
+ rel_option = fields.get('__rel_kwargs__', {})
+ if rel_option:
+ rel_options[name] = rel_option
+ col_types = {}
+ col_type_options = {}
+ for _db_type in DB_TYPES:
+ col_types[_db_type] = fields.get('__{}_type__'.format(_db_type), None)
+ col_type_options[_db_type] = fields.get('__{}_type_kwargs__'.format(_db_type), {})
+ type_key = 'default' if col_types.get(db_type) is None else db_type
+ if col_types[type_key]:
+ col_args = [col_types[type_key](**col_type_options[type_key])]
+ if '__fk_kwargs__' in fields:
+ col_args.append(ForeignKey(**_extract_kwargs(fields['__fk_kwargs__'])))
+ column = attrs[name] = Column(*col_args, **_extract_kwargs(fields))
+ rel_option['foreign_keys'] = [column]
+ for logical_name, rel_option in rel_options.items():
+ if '__secondary_model__' in rel_option:
+ secondary = rel_option['secondary'] = declare(rel_option['__secondary_model__'], db_type=db_type, back_type=back_type).__table__
+ if '__logical_name__' in rel_option:
+ logical_name = rel_option['__logical_name__']
+ back = rel_option.get('__back__', None)
+ if back and back_type:
+ rel_option[back_type] = back.rstrip('+').lower()
+ attrs[logical_name] = relationship(rel_option['__target__'], **_extract_kwargs(rel_option))
+ cls = existing[django_model] = type(model_info['table_name'], (Base,), attrs)
+ return cls
+[INST] It makes sqlalchemy model objects from django models.. [/INST] def transfer(models, exports, db_type=AUTO_DETECTED_DB_TYPE, back_type='backref', as_table=False, name_formatter=NAME_FORMATTER):
+ for model in parse_models(models).values():
+ declare(model, db_type=db_type, back_type=back_type)
+ for django_model, alchemy_model in existing.items():
+ if models.__name__ == django_model.__module__:
+ key = name_formatter(django_model._meta.object_name)
+ exports[key] = alchemy_model.__table__ if as_table else alchemy_model
+[INST] It loads all models automatically.. [/INST] def autoload(config=D2A_CONFIG.get('AUTOLOAD', {})):
+ module = config.get('module', 'models_sqla')
+ option = config.get('option', {})
+ for app in settings.INSTALLED_APPS:
+ mods = app.split('.')
+ for i in range(1, len(mods) + 1):
+ mod = '.'.join(mods[:i])
+ d = '{mod}.models'.format(mod=mod)
+ a = '{mod}.{module}'.format(mod=mod, module=module)
+ if importlib.util.find_spec(d) is None:
+ continue
+ try:
+ importlib.import_module(a)
+ except ImportError:
+ sys.modules[a] = types.ModuleType(a)
+ transfer(importlib.import_module(d), sys.modules[a].__dict__, **option)
+[INST] This is an XML parser for the Pascal VOC datasets. It might be applicable to other datasets with minor changes to
+the code, but in its current form it expects the data format and XML tags of the Pascal VOC datasets.. [/INST] def parse_xml(self,
+ images_dirs,
+ image_set_filenames,
+ annotations_dirs=[],
+ classes=['background',
+ 'aeroplane', 'bicycle', 'bird', 'boat',
+ 'bottle', 'bus', 'car', 'cat',
+ 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant',
+ 'sheep', 'sofa', 'train', 'tvmonitor'],
+ include_classes = 'all',
+ exclude_truncated=False,
+ exclude_difficult=False,
+ ret=False):
+ self.images_dirs = images_dirs
+ self.annotations_dirs = annotations_dirs
+ self.image_set_filenames = image_set_filenames
+ self.classes = classes
+ self.include_classes = include_classes
+ self.filenames = []
+ self.image_ids = []
+ self.labels = []
+ if not annotations_dirs:
+ self.labels = None
+ annotations_dirs = [None] * len(images_dirs)
+ for images_dir, image_set_filename, annotations_dir in zip(images_dirs, image_set_filenames, annotations_dirs):
+ with open(image_set_filename) as f:
+ image_ids = [line.strip() for line in f]
+ self.image_ids += image_ids
+ for image_id in tqdm(image_ids, desc=os.path.basename(image_set_filename)):
+ filename = '{}'.format(image_id) + '.jpg'
+ self.filenames.append(os.path.join(images_dir, filename))
+ if not annotations_dir is None:
+ with open(os.path.join(annotations_dir, image_id + '.xml')) as f:
+ soup = BeautifulSoup(f, 'xml')
+ folder = soup.folder.text
+ boxes = []
+ objects = soup.find_all('object')
+ for obj in objects:
+ class_name = obj.find('name').text
+ class_id = self.classes.index(class_name)
+ if (not self.include_classes == 'all') and (not class_id in self.include_classes): continue
+ pose = obj.pose.text
+ truncated = int(obj.truncated.text)
+ if exclude_truncated and (truncated == 1): continue
+ difficult = int(obj.difficult.text)
+ if exclude_difficult and (difficult == 1): continue
+ xmin = int(obj.bndbox.xmin.text)
+ ymin = int(obj.bndbox.ymin.text)
+ xmax = int(obj.bndbox.xmax.text)
+ ymax = int(obj.bndbox.ymax.text)
+ item_dict = {'folder': folder,
+ 'image_name': filename,
+ 'image_id': image_id,
+ 'class_name': class_name,
+ 'class_id': class_id,
+ 'pose': pose,
+ 'truncated': truncated,
+ 'difficult': difficult,
+ 'xmin': xmin,
+ 'ymin': ymin,
+ 'xmax': xmax,
+ 'ymax': ymax}
+ box = []
+ for item in self.labels_output_format:
+ box.append(item_dict[item])
+ boxes.append(box)
+ self.labels.append(boxes)
+ if ret:
+ return self.filenames, self.labels, self.image_ids
+[INST] This is an JSON parser for the MS COCO datasets. It might be applicable to other datasets with minor changes to
+the code, but in its current form it expects the JSON format of the MS COCO datasets.. [/INST] def parse_json(self,
+ images_dirs,
+ annotations_filenames,
+ ground_truth_available=False,
+ include_classes = 'all',
+ ret=False):
+ self.images_dirs = images_dirs
+ self.annotations_filenames = annotations_filenames
+ self.include_classes = include_classes
+ self.filenames = []
+ self.image_ids = []
+ self.labels = []
+ if not ground_truth_available:
+ self.labels = None
+ with open(annotations_filenames[0], 'r') as f:
+ annotations = json.load(f)
+ self.cats_to_names = {}
+ self.classes_to_names = []
+ self.classes_to_names.append('background')
+ self.cats_to_classes = {}
+ self.classes_to_cats = {}
+ for i, cat in enumerate(annotations['categories']):
+ self.cats_to_names[cat['id']] = cat['name']
+ self.classes_to_names.append(cat['name'])
+ self.cats_to_classes[cat['id']] = i + 1
+ self.classes_to_cats[i + 1] = cat['id']
+ for images_dir, annotations_filename in zip(self.images_dirs, self.annotations_filenames):
+ with open(annotations_filename, 'r') as f:
+ annotations = json.load(f)
+ if ground_truth_available:
+ image_ids_to_annotations = defaultdict(list)
+ for annotation in annotations['annotations']:
+ image_ids_to_annotations[annotation['image_id']].append(annotation)
+ for img in annotations['images']:
+ self.filenames.append(os.path.join(images_dir, img['file_name']))
+ self.image_ids.append(img['id'])
+ if ground_truth_available:
+ annotations = image_ids_to_annotations[img['id']]
+ boxes = []
+ for annotation in annotations:
+ cat_id = annotation['category_id']
+ if (not self.include_classes == 'all') and (not cat_id in self.include_classes): continue
+ class_id = self.cats_to_classes[cat_id]
+ xmin = annotation['bbox'][0]
+ ymin = annotation['bbox'][1]
+ width = annotation['bbox'][2]
+ height = annotation['bbox'][3]
+ xmax = xmin + width
+ ymax = ymin + height
+ item_dict = {'image_name': img['file_name'],
+ 'image_id': img['id'],
+ 'class_id': class_id,
+ 'xmin': xmin,
+ 'ymin': ymin,
+ 'xmax': xmax,
+ 'ymax': ymax}
+ box = []
+ for item in self.labels_output_format:
+ box.append(item_dict[item])
+ boxes.append(box)
+ self.labels.append(boxes)
+ if ret:
+ return self.filenames, self.labels, self.image_ids
+[INST] Decorator for a matplotlib plot function.. [/INST] def frame(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ buffer = io.BytesIO()
+ func(*args, **kwargs)
+ plt.savefig(buffer, format="png")
+ buffer.seek(0)
+ image = Image.open(buffer)
+ plt.close()
+ return image
+ return wrapper
+[INST] Save decorated frames to an animated gif.
+frames (list): collection of frames built with the frame decorator
+path (str): filename with relative or absolute path
+duration (int): milliseconds between frames. [/INST] def save(frames, path, duration=100):
+ frames[0].save(
+ path,
+ save_all=True,
+ append_images=frames[1:],
+ optimize=True,
+ duration=duration,
+ loop=0,
+ )
+[INST] Change presence and report ready.. [/INST] async def on_ready():
+ activity = discord.Game(name=f"@Vibrant for help ({len(bot.guilds)})")
+ await bot.change_presence(activity=activity)
+ print("Ready Player One.")
+[INST] Change the bots prefix for the guild.. [/INST] async def change_server_prefix(self, ctx, *, new_prefix):
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$set": {"prefix": new_prefix}}
+ )
+ await ctx.send(f"Prefix changed to `{new_prefix}` {check_emoji()}")
+[INST] Sends the message containing the catalog.. [/INST] async def send(self, channel):
+ if self.type == "Image":
+ if len(self.content) == 1:
+ imagefile = image_to_file(self.content[0])
+ return await channel.send(file=imagefile)
+ await self.delete_message()
+ imagefile = image_to_file(self.content[self.page])
+ self.message = await channel.send(file=imagefile)
+ await self.add_reactions(self.image_reactions)
+ elif self.type == "Embed":
+ if self.message:
+ await self.message.edit(embed=self.generate_embed())
+ else:
+ self.message = await channel.send(embed=self.generate_embed())
+ await self.add_reactions(self.embed_reactions)
+ cfg.catalogs[self.channel.id] = self
+[INST] Attempts to delete the message stored.. [/INST] async def delete_message(self):
+ if self.message:
+ try:
+ await self.message.delete()
+ except discord.errors.NotFound:
+ pass
+[INST] Add new guild to database.. [/INST] async def on_guild_join(self, guild):
+ db.guilds.insert_one(
+ {
+ "_id": guild.id,
+ "prefix": "$",
+ "wc": None,
+ "colors": [],
+ "themes": []
+ }
+ )
+[INST] Update database if welcome channel is deleted.. [/INST] async def on_guild_channel_delete(self, channel):
+ db.guilds.update_one(
+ {"_id": channel.guild.id, "wc": channel.id},
+ {"$set": {"wc": None}}
+ )
+[INST] Change the active colors to a theme.. [/INST] async def load_theme(self, ctx, *, themename):
+ themes = db.get(ctx.guild.id, "themes")
+ theme = theme_lookup(themename, themes)
+ output_suppressed = ctx.guild.id in cfg.suppress_output
+ if not theme:
+ raise UserInputError("Could not find that theme")
+ if not output_suppressed:
+ embed = discord.Embed(
+ title=f"Clearing Colors {loading_emoji()}")
+ msg = await ctx.send(embed=embed)
+ cfg.heavy_command_active.add(ctx.guild.id)
+ cfg.suppress_output.add(ctx.guild.id)
+ await ctx.invoke(self.bot.get_command("clear_colors"))
+ if not output_suppressed:
+ cfg.suppress_output.discard(ctx.guild.id)
+ embed = discord.Embed(title=f"Creating Roles {loading_emoji()}")
+ await msg.edit(embed=embed)
+ owned_colors = [color for color in theme["colors"]
+ if color["members"]
+ and all([ctx.guild.get_member(member_id) for member_id in color["members"]])]
+ colors = deepcopy(theme["colors"])
+ for color in colors:
+ color["role"] = None
+ color["members"] = []
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$set": {"colors": colors}}
+ )
+ cm_pairs = []
+ for color in owned_colors:
+ role = await ColorAssignment.create_role(ctx.guild, color)
+ color["role"] = role.id
+ for member_id in color["members"]:
+ cm_pairs.append((color, ctx.guild.get_member(member_id)))
+ if not output_suppressed:
+ embed = discord.Embed(title=f"Applying Color {loading_emoji()}")
+ await msg.edit(embed=embed)
+ for color, member in cm_pairs:
+ await ColorAssignment.color(member, color)
+ await asyncio.sleep(1)
+ cfg.heavy_command_active.discard(ctx.guild.id)
+ if ctx.guild.id not in cfg.suppress_output:
+ success_embed = discord.Embed(
+ title=f"Loaded {theme['name']} {check_emoji()}",
+ color=discord.Color.green()
+ )
+ await msg.edit(embed=success_embed)
+ await ctx.invoke(self.bot.get_command("colors"))
+[INST] Disable all commands in cog if heavy command is running.. [/INST] def cog_check(self, ctx):
+ if ctx.guild.id in cfg.heavy_command_active:
+ raise CommandError("A command is running, please wait")
+ return True
+[INST] Removes the color role if a user has one.. [/INST] async def uncolor(member, color):
+ if not color["role"]:
+ return
+ role = member.guild.get_role(color["role"])
+ await member.remove_roles(role)
+[INST] Create a role and update database.. [/INST] async def create_role(guild, color):
+ role = await guild.create_role(
+ name=color["name"],
+ color=utils.discord_color(color)
+ )
+ db.guilds.update_one(
+ {"_id": guild.id, "colors.name": color["name"]},
+ {"$set": {"colors.$.role": role.id}}
+ )
+ return role
+[INST] Color a specified user a specified color.. [/INST] async def color_user(self, ctx, member: discord.Member, *, cstring=""):
+ colors = db.get(ctx.guild.id, "colors")
+ ucolor = utils.find_user_color(member, colors)
+ if not colors:
+ raise CommandError("There are no active colors")
+ if len(colors) > 1 and not cstring:
+ exclusive_colors = [color for color in colors if color != ucolor]
+ color = utils.color_lookup(cstring, exclusive_colors)
+ else:
+ color = utils.color_lookup(cstring, colors)
+ if not color:
+ raise CommandError("Color Not Found")
+ if color == ucolor:
+ raise CommandError(f"{member.name} already has that color")
+ if ucolor:
+ await self.uncolor(member, ucolor)
+ await self.color(member, color)
+ embed = discord.Embed(
+ title=f"{member.name} is {color['name']} {check_emoji()}",
+ color=utils.discord_color(color)
+ )
+ await ctx.send(embed=embed)
+[INST] Display an image of equipped colors.. [/INST] async def colorme(self, ctx, *, cstring=""):
+ await ctx.invoke(
+ self.bot.get_command("color"),
+ member=ctx.author,
+ cstring=cstring
+ )
+[INST] Display an image of equipped colors.. [/INST] async def uncolorme(self, ctx):
+ colors = db.get(ctx.guild.id, "colors")
+ ucolor = utils.find_user_color(ctx.author, colors)
+ if ucolor:
+ await self.uncolor(ctx.author, ucolor)
+ response = f"You have been uncolored {check_emoji()}"
+ else:
+ response = "You don't have a color"
+ await ctx.send(embed=discord.Embed(title=response))
+[INST] Gather all of the uncolored users and assigns them a color. [/INST] async def color_server(self, ctx, *, color: utils.ColorConverter = None):
+ output_suppressed = ctx.guild.id in cfg.suppress_output
+ colors = db.get(ctx.guild.id, "colors")
+ if not colors:
+ raise CommandError("There are no active colors")
+ cfg.heavy_command_active.add(ctx.guild.id)
+ uncolored = [member for member in ctx.guild.members
+ if not utils.find_user_color(member, colors)]
+ if not output_suppressed:
+ embed = discord.Embed(title=f"Creating Roles {loading_emoji()}")
+ msg = await ctx.send(embed=embed)
+ colors = [color] if color else colors
+ index = random.randrange(len(colors))
+ sliced_colors = colors[index:] + colors[:index]
+ for color in sliced_colors[:len(uncolored)]:
+ if not color["role"]:
+ role = await self.create_role(ctx.guild, color)
+ color["role"] = role.id
+ if not output_suppressed:
+ embed = discord.Embed(
+ title=f"Coloring {len(uncolored)} People {loading_emoji()}",
+ description=f"This will take around {len(uncolored)} seconds"
+ )
+ await msg.edit(embed=embed)
+ for color, member in zip(cycle(sliced_colors), uncolored):
+ await self.color(member, color)
+ await asyncio.sleep(1)
+ cfg.heavy_command_active.discard(ctx.guild.id)
+ if not output_suppressed:
+ embed = discord.Embed(
+ title=f"Colored {len(uncolored)} members {check_emoji()}",
+ color=discord.Color.green()
+ )
+ await msg.edit(embed=embed)
+[INST] Remove all colors but not delete them.. [/INST] async def uncolor_server(self, ctx):
+ colors = db.get(ctx.guild.id, "colors")
+ cfg.heavy_command_active.add(ctx.guild.id)
+ for color in colors:
+ if color["role"]:
+ role = ctx.guild.get_role(color["role"])
+ await role.delete()
+ cfg.heavy_command_active.discard(ctx.guild.id)
+ if ctx.guild.id not in cfg.suppress_output:
+ embed = discord.Embed(title=f"Everyone has been uncolored {check_emoji()}",
+ color=discord.Color.green())
+ await ctx.send(embed=embed)
+[INST] Welcome and attempt to randomly color a new user.. [/INST] async def on_member_join(self, member):
+ wc, colors = db.get_many(member.guild.id, "wc", "colors")
+ if not wc:
+ return
+ welcome_channel = member.guild.get_channel(wc)
+ if colors:
+ color = random.choice(colors)
+ await self.color(member, color)
+ accent = utils.discord_color(color)
+ else:
+ accent = discord.Colour.blurple()
+ embed = discord.Embed(
+ title=f"{member.name} has joined the server!",
+ description=f"Please give {member.mention} a warm welcome!",
+ color=accent)
+ embed.set_thumbnail(url=member.avatar_url)
+ await welcome_channel.send(embed=embed)
+[INST] Sends a goodbye message when a member leaves the server.. [/INST] async def on_member_remove(self, member):
+ wc, colors = db.get_many(member.guild.id, "wc", "colors")
+ ucolor = utils.find_user_color(member, colors)
+ if not wc:
+ return
+ welcome_channel = member.guild.get_channel(wc)
+ embed = discord.Embed(title=f"{member.name} has left the server!",
+ description="They won't be missed",
+ color=discord.Color.red())
+ embed.set_thumbnail(url=member.avatar_url)
+ await welcome_channel.send(embed=embed)
+ if ucolor:
+ members = ucolor["members"]
+ members.remove(member.id)
+ db.guilds.update_one(
+ {"_id": member.guild.id, "colors.name": ucolor["name"]},
+ {"$set": {"colors.$.members": members}}
+ )
+[INST] Updates color info if a users color role is manually removed. [/INST] async def on_member_update(self, before, after):
+ if before.roles == after.roles:
+ return
+ roles_before = set(before.roles)
+ roles_after = set(after.roles)
+ removed_roles = roles_before - roles_after
+ added_roles = roles_after - roles_before
+ for role in removed_roles:
+ response = db.guilds.update_one(
+ {"_id": after.guild.id, "colors.role": role.id},
+ {"$pull": {"colors.$.members": after.id}}
+ )
+ if not role.members and response.matched_count == 1:
+ try:
+ await role.delete()
+ except discord.errors.NotFound:
+ pass
+ except discord.Forbidden:
+ pass
+ for role in added_roles:
+ db.guilds.update_one(
+ {"_id": after.guild.id, "colors.role": role.id},
+ {"$push": {"colors.$.members": after.id}}
+ )
+[INST] Ensure user has permissions and heavy command is not active for cog commands.. [/INST] def cog_check(self, ctx):
+ if not ctx.author.guild_permissions.manage_roles:
+ raise CommandError(f"You need Manage Roles permission")
+ if ctx.guild.id in cfg.heavy_command_active:
+ raise CommandError("Please wait for the current command to finish")
+ return True
+[INST] Add a color to the database colors.. [/INST] async def add_color(self, ctx, hexcode, *, name=""):
+ colors = db.get(ctx.guild.id, "colors")
+ if len(colors) >= cfg.color_limit:
+ raise CommandError(f"Color Limit Reached ({len(colors)}/50)")
+ if not utils.validate_hex(hexcode):
+ raise UserInputError(
+ f"Invalid hexcode. Get help at https://htmlcolorcodes.com/")
+ if "|" in name or len(name) > 100:
+ raise UserInputError(
+ "Color names must be shorter than 100 characters and cannot include `|`")
+ if name in [color["name"] for color in colors]:
+ raise UserInputError("Cannot have duplicate color names")
+ if not name:
+ name = f"Color {len(colors) + 1}"
+ if hexcode in {"#000000", "#000"}:
+ hexcode = "#000001"
+ new_color = {
+ "name": name,
+ "hexcode": hexcode,
+ "role": None,
+ "members": []
+ }
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$push": {"colors": new_color}}
+ )
+ await ctx.invoke(bot.get_command("colors"))
+[INST] Remove a color from a guild's colors.. [/INST] async def remove_color(self, ctx, *, color: utils.ColorConverter):
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$pull": {"colors": color}}
+ )
+ if color["role"]:
+ role = ctx.guild.get_role(color["role"])
+ await role.delete()
+ await ctx.invoke(bot.get_command("colors"))
+[INST] Change the way a color looks.. [/INST] async def recolor(self, ctx, *, query):
+ colors = db.get(ctx.guild.id, "colors")
+ try:
+ before, after = query.split("|")
+ except ValueError:
+ raise UserInputError(
+ "Command should be formatted as: $rename | ")
+ color = utils.color_lookup(before.strip(), colors)
+ if not color:
+ raise UserInputError("Couldn't find that color")
+ after = after.strip()
+ if not utils.validate_hex(after):
+ raise UserInputError(
+ "Invalid hexcode. Get help at https://htmlcolorcodes.com/")
+ if color["role"]:
+ role = ctx.guild.get_role(color["role"])
+ new_color = utils.to_rgb(after)
+ await role.edit(color=discord.Color.from_rgb(*new_color))
+ db.guilds.update_one(
+ {"_id": ctx.guild.id, "colors": color},
+ {"$set": {"colors.$.hexcode": after}}
+ )
+ await ctx.invoke(bot.get_command("colors"))
+[INST] Removes a role from a color if user deletes it. [/INST] async def on_guild_role_delete(self, role):
+ db.guilds.update_one(
+ {"_id": role.guild.id, "colors.role": role.id},
+ {"$set": {"colors.$.role": None, "colors.$.members": []}}
+ )
+[INST] Updated a color if the associated role is updated. [/INST] async def on_guild_role_update(self, before, after):
+ if before.name != after.name:
+ db.guilds.update_one(
+ {"_id": before.guild.id, "colors.role": before.id},
+ {"$set": {"colors.$.name": after.name}}
+ )
+ if before.color != after.color:
+ db.guilds.update_one(
+ {"_id": before.guild.id, "colors.role": before.id},
+ {"$set": {"colors.$.hexcode": str(after.color)}}
+ )
+[INST] Creates an image displaying themes.. [/INST] def draw_themes(themes):
+ canvas_width = 900
+ canvas_height = 600
+ slice_length = 5
+ slice_amt = math.ceil(len(themes)/slice_length)
+ cont_height = canvas_height // slice_length
+ text_top = int(cont_height * .1)
+ text_height = int(cont_height * .30)
+ color_top = int(cont_height * .5)
+ color_height = int(cont_height * .45)
+ fnt = ImageFont.truetype(
+ font=f"assets{sep}fonts{sep}Roboto.ttf",
+ size=text_height
+ )
+ theme_slices = [themes[i:i+slice_length]
+ for i in range(0, len(themes), slice_length)]
+ theme_images = []
+ for i, theme_slice in enumerate(theme_slices):
+ img = Image.new(
+ mode='RGBA',
+ size=(canvas_width, cont_height * len(theme_slice)),
+ color=(0, 0, 0, 0)
+ )
+ draw = ImageDraw.Draw(img)
+ for j, theme in enumerate(theme_slice):
+ index = i * slice_length + j + 1
+ text = f"{index}. {theme['name']}"
+ text_width, text_height = draw.textsize(text, fnt)
+ x = (canvas_width/2)-(text_width/2)
+ y = j * cont_height + text_top
+ draw.text((x, y), text, font=fnt, fill=(240, 240, 240))
+ color_width = canvas_width/len(theme["colors"])
+ for k, color in enumerate(theme["colors"]):
+ x0 = k * color_width
+ y0 = j * cont_height + color_top
+ x1 = x0 + color_width - 5
+ y1 = y0 + color_height
+ draw.rectangle(
+ xy=[(x0, y0), (x1, y1)],
+ fill=utils.to_rgb(color["hexcode"])
+ )
+ theme_images.append(img)
+ return theme_images
+[INST] Draw the guild's themes and send in channel.. [/INST] async def show_themes(self, ctx):
+ themes = db.get(ctx.guild.id, "themes")
+ if not themes:
+ return await ctx.send(embed=Embed(title="You have no themes"))
+ theme_files = self.draw_themes(themes)
+ catalog = Catalog(theme_files)
+ await catalog.send(ctx.channel)
+[INST] Shows a detailed, textbased view of your themes.. [/INST] async def show_themes_in_detail(self, ctx):
+ themes = db.get(ctx.guild.id, "themes")
+ themes_embed = Embed(title="Themes in Detail", description="")
+ for theme in themes:
+ colors = [f"**{color['name']}**: {color['hexcode']}"
+ for color in theme["colors"]]
+ themes_embed.add_field(name=theme["name"], value="\n".join(colors))
+ await ctx.send(embed=themes_embed)
+[INST] Draw and send an image of all presets. [/INST] async def show_imports(self, ctx):
+ themes = []
+ for fname in os.listdir(f"assets{sep}presets{sep}"):
+ with open(f"assets{sep}presets{sep}{fname}", "r") as preset:
+ themes.append(json.load(preset))
+ catalog = Catalog(self.draw_themes(themes))
+ await catalog.send(ctx.channel)
+[INST] purge a set amount of messages!. [/INST] async def delete_messages(self, ctx, amount):
+ if ctx.author.id in admin_ids:
+ await ctx.channel.purge(limit=int(amount)+1)
+ await ctx.send(f"purged {amount} messages", delete_after=2)
+[INST] Ensure user has permissions and heavy command is not active for cog commands.. [/INST] def cog_check(self, ctx):
+ if not ctx.author.guild_permissions.manage_roles:
+ raise CommandError(f"You need Manage Roles permission")
+ return True
+[INST] Save the current state of the guilds colors. [/INST] async def save_theme(self, ctx, *, name=""):
+ colors, themes = db.get_many(ctx.guild.id, "colors", "themes")
+ if not colors:
+ raise CommandError("There must be colors to save a theme")
+ if len(themes) >= cfg.theme_limit:
+ raise CommandError(f"Theme Limit Reached ({len(themes)}/10)")
+ if "|" in name or len(name) >= 100:
+ raise UserInputError(
+ "Color names must be shorter than 100 characters and cannot include `|`")
+ if not name:
+ name = f"Theme {len(themes) + 1}"
+ for color in colors:
+ del color["role"]
+ new_theme = {
+ "name": name,
+ "colors": colors,
+ }
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$push": {"themes": new_theme}}
+ )
+ await ctx.invoke(self.bot.get_command("themes"))
+[INST] Overwrite one of the Guild's themes with another.. [/INST] async def overwrite_theme(self, ctx, *, themename):
+ colors, themes = db.get_many(ctx.guild.id, "colors", "themes")
+ old_theme = theme_lookup(themename, themes)
+ if not old_theme:
+ raise UserInputError("Could not find that theme")
+ for color in colors:
+ del color["role"]
+ new_theme = {
+ "name": old_theme["name"],
+ "colors": colors,
+ }
+ db.guilds.update_one(
+ {"_id": ctx.guild.id, "themes": old_theme},
+ {"$set": {"themes.$": new_theme}}
+ )
+ if ctx.guild.id not in cfg.suppress_output:
+ await ctx.invoke(self.bot.get_command("themes"))
+[INST] Rename a theme in the guild.. [/INST] async def rename_theme(self, ctx, *, query):
+ themes = db.get(ctx.guild.id, "themes")
+ try:
+ before, after = query.split("|")
+ except ValueError:
+ raise UserInputError(
+ "Command should be formatted as: $t.rn | ")
+ theme = theme_lookup(before.strip(), themes)
+ if not theme:
+ raise UserInputError("Couldn't find that theme")
+ after = after.strip()
+ if not after:
+ raise UserInputError(
+ "Command should be formatted as: $rename | ")
+ db.guilds.update_one(
+ {"_id": ctx.guild.id, "themes": theme},
+ {"$set": {"themes.$.name": after}}
+ )
+ await ctx.invoke(self.bot.get_command("themes"))
+[INST] Save a preset as a theme.. [/INST] async def import_colors(self, ctx, *, name):
+ themes = db.get(ctx.guild.id, "themes")
+ if len(themes) >= cfg.theme_limit:
+ raise CommandError(f"Theme Limit Reached ({len(themes)}/10)")
+ name = name.replace(" ", "")
+ try:
+ with open(f"assets{sep}presets{sep}{name.lower()}.json") as preset_data:
+ preset = json.load(preset_data)
+ except FileNotFoundError:
+ raise UserInputError("Could not find that preset")
+ db.guilds.update_one(
+ {"_id": ctx.guild.id},
+ {"$push": {"themes": preset}}
+ )
+ if ctx.guild.id not in cfg.suppress_output:
+ await ctx.invoke(self.bot.get_command("themes"))
+[INST] Find a color in a list of colors based on a query. [/INST] async def convert(self, ctx, arg):
+ colors = db.get(ctx.guild.id, "colors")
+ if not colors:
+ raise CommandError("You have no active colors")
+ color = color_lookup(arg, colors)
+ if not color:
+ raise CommandError("Color Not Found")
+ return color
+[INST] Find a theme based on a string.. [/INST] def theme_lookup(arg, themes):
+ if arg.isdigit():
+ try:
+ return themes[int(arg)-1]
+ except IndexError:
+ pass
+ for theme in themes:
+ if theme["name"].lower() == arg.lower():
+ return theme
+[INST] Find a color based on a string.. [/INST] def color_lookup(arg, colors):
+ if arg == "":
+ return random.choice(colors)
+ if arg.isdigit():
+ try:
+ return colors[int(arg)-1]
+ except IndexError:
+ pass
+ for color in colors:
+ if color["name"].lower() == arg.lower():
+ return color
+[INST] Extract the color from a list that contains the member id.. [/INST] def find_user_color(user, colors):
+ for color in colors:
+ if user.id in color["members"]:
+ return color
+[INST] The event triggered when an error is raised while invoking a command.. [/INST] async def on_command_error(self, ctx, error):
+ gumbachi = cfg.bot.get_user(128595549975871488)
+ if hasattr(ctx.command, 'on_error'):
+ return
+ error = getattr(error, 'original', error)
+ if isinstance(error, commands.CommandNotFound):
+ return
+ cfg.heavy_command_active.discard(ctx.guild.id)
+ cfg.suppress_output.discard(ctx.guild.id)
+ if isinstance(error, commands.CommandError):
+ embed = discord.Embed(title=str(error), color=discord.Color.red())
+ return await ctx.send(embed=embed)
+ if isinstance(error, discord.Forbidden):
+ if error.code == 50001:
+ embed = discord.Embed(
+ title="Missing Access",
+ description="The bot is not able to access the channel\nPlease ensure the bot has the correct permissions",
+ color=discord.Color.red()
+ )
+ elif error.code == 50013:
+ embed = discord.Embed(
+ title="Bot Missing Permissions",
+ description="The bot is missing permissions\nPlease ensure the bot has its required permissions",
+ color=discord.Color.red()
+ )
+ return await ctx.send(embed=embed)
+ if isinstance(error, discord.errors.NotFound):
+ if error.code == 10008:
+ embed = discord.Embed(
+ title="Missing Message",
+ description="Kindly don't delete the progress message",
+ color=discord.Color.red()
+ )
+ elif error.code == 50011:
+ embed = discord.Embed(
+ title="Bot couldn't find role",
+ description="Bot couldn't find a role. Try resetting your colors if this persists",
+ color=discord.Color.red()
+ )
+ return await ctx.send(embed=embed)
+ error_embed = discord.Embed(
+ title=f'Your command: {ctx.message.content}',
+ description=f"{ctx.guild.id}: {error}",
+ color=discord.Colour.red()
+ )
+ await ctx.send(embed=error_embed)
+ await gumbachi.send(embed=error_embed)
+ print(f'Ignoring exception in command {ctx.command}:', file=sys.stderr)
+ traceback.print_exception(
+ type(error), error, error.__traceback__, file=sys.stderr)
+[INST] Draw the colors in the current set. [/INST] def draw_colors(colors):
+ rows = math.ceil(len(colors) / 3)
+ row_height = 50
+ column_width = 300
+ columns = 3
+ img = Image.new(mode='RGBA',
+ size=(columns * column_width, rows * row_height),
+ color=(0, 0, 0, 0))
+ draw = ImageDraw.Draw(img)
+ fnt = ImageFont.truetype(
+ font=f".{sep}assets{sep}fonts{sep}Roboto.ttf", size=30)
+ for i, color in enumerate(colors):
+ rgb = utils.to_rgb(color["hexcode"])
+ rem = i % 3
+ div = i//3
+ x1 = column_width * rem
+ x2 = row_height * div
+ y1 = column_width * (rem+1)
+ y2 = row_height * (div+1)
+ draw.rectangle([x1, x2, y1, y2], fill=rgb,
+ outline=(0, 0, 0, 0), width=2)
+ W, H = column_width*rem, row_height*div+10
+ msg = f"{i + 1}. {color['name']}"
+ w, _ = draw.textsize(msg, fnt)
+ text_size_limit = column_width - (column_width/10)
+ while w > text_size_limit:
+ msg = msg[:-2]
+ w, _ = draw.textsize(msg, fnt)
+ if w <= text_size_limit:
+ msg = msg + "..."
+ r, g, b = rgb
+ luminance = (0.299 * r + 0.587 * g + 0.114 * b)/255
+ text_color = (0, 0, 0) if luminance > 0.5 else (240, 240, 240)
+ x = (column_width-w)/2 + W
+ y = H
+ draw.text((x, y), msg, font=fnt, fill=text_color)
+ byte_arr = io.BytesIO()
+ img.save(byte_arr, format='PNG')
+ byte_arr = byte_arr.getvalue()
+ im = io.BytesIO(byte_arr)
+ return File(im, filename="colors.png")
+[INST] Display an image of equipped colors.. [/INST] async def show_colors(self, ctx):
+ colors = db.get(ctx.guild.id, "colors")
+ if not colors:
+ return await ctx.send(embed=Embed(title="You have no colors"))
+ await ctx.send(file=self.draw_colors(colors))
+[INST] Show what the database thinks colors are (For testing/support).. [/INST] async def show_colors_in_detail(self, ctx):
+ colors = db.get(ctx.guild.id, "colors")
+ cinfo = Embed(title="Detailed Color Info", description="")
+ for color in colors:
+ members = [bot.get_user(id).name for id in color["members"]]
+ cinfo.add_field(
+ name=color["name"],
+ value=f"**ROLE:** {color['role']}\n**MEMBERS({len(members)}):** {', '.join(members)}")
+ await ctx.send(embed=cinfo)
+[INST] Predict/infer text in input image (which can be a file path).. [/INST] def predict(self, image: Union[str, Path, Image.Image]) -> str:
+ image_pil = image
+ if not isinstance(image, Image.Image):
+ image_pil = util.read_image_pil(image, grayscale=True)
+ image_pil = resize_image(image_pil, IMAGE_SCALE_FACTOR)
+ image_tensor = self.transform(image_pil)
+ y_pred = self.scripted_model(image_tensor.unsqueeze(axis=0))[0]
+ pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens)
+ return pred_str
+[INST] Returns a list of nodes in a shortest path between source and target
+using the A* ("A-star") algorithm.
+
+Heurestic function changed to include the dictionairy with the coordinates
+
+Weights used in the function include distant and local scores of the edges. [/INST] def astar_path_Kasia(G, source, target, scenario, coordDict):
+ def heuristic_Kasia(theNode, theTarget, coordDict):
+ nodeX = coordDict[theNode][0]
+ nodeY = coordDict[theNode][1]
+ targetX = coordDict[theTarget][0]
+ targetY = coordDict[theTarget][1]
+ distanceToTarget = math.sqrt(math.pow((targetX-nodeX),2)+math.pow((targetY-nodeY),2))
+ return distanceToTarget
+ if source not in G or target not in G:
+ msg = 'Either source {} or target {} is not in G'
+ raise nx.NodeNotFound(msg.format(source, target))
+ push = heappush
+ pop = heappop
+ distanceToTarget = heuristic_Kasia(source, target, coordDict)
+ print('Max heuristic value (eucidlian distance to destination from the starting node): ', distanceToTarget)
+ print('Max Total Score of the edge: 1')
+ if scenario == 1:
+ weightDS = 0.5
+ weightLS = 0.5
+ elif scenario == 2:
+ weightDS = w1
+ weightLS = (1-w1)
+ heuristicCostRatio = 1000
+ c = count()
+ queue = [(0, next(c), source, 0, None)]
+ enqueued = {}
+ explored = {}
+ while queue:
+ _, __, curnode, dist, parent = pop(queue)
+ if curnode == target:
+ path = [curnode]
+ node = parent
+ while node is not None:
+ path.append(node)
+ node = explored[node]
+ path.reverse()
+ return path
+ if curnode in explored:
+ continue
+ explored[curnode] = parent
+ for neighbor, w in G[curnode].items():
+ if neighbor in explored:
+ continue
+ if scenario==2:
+ weightDS = (2*w1-1)/distanceToTarget * heuristic_Kasia(neighbor, target, coordDict) + (1-w1)
+ weightLS = (-1)*(2*w1-1)/distanceToTarget * heuristic_Kasia(neighbor, target, coordDict) + w1
+ ncost = dist + TotalScoreToHeuresticScaleFactor * heuristicCostRatio * (weightDS * (1-w.get('DS', 1)) + weightLS * (1-w.get('LS', 1)))
+ if neighbor in enqueued:
+ qcost, h = enqueued[neighbor]
+ if qcost <= ncost:
+ continue
+ else:
+ h = heuristic_Kasia(neighbor, target, coordDict)
+ enqueued[neighbor] = ncost, h
+ push(queue, (ncost + h, next(c), neighbor, ncost, curnode))
+ raise nx.NetworkXNoPath("Node %s not reachable from %s" % (target, source))
+[INST] Run as few or as many tests as needed here.. [/INST] def runTest(self):
+ self.setUp()
+ self.test_RawImageGuess1()
+[INST] Copy a user editable config file to the target directory.. [/INST] def copy_config(self, overwrite=False,
+ new_ext='.new'):
+ if self.directory_map is None:
+ error_str = "No source directory available."
+ raise ValueError(error_str)
+ self.directory_map.copy_file(self.config_file_name,
+ overwrite=overwrite,
+ new_ext=new_ext)
+ return
+[INST] Make header and footer strings consisting of a bar of characters of
+fixed width, with a title embeded in the header bar.. [/INST] def make_head_foot_bar(cls, header_title, bar_width, bar_char='*'):
+ title_space = ' {} '.format(header_title)
+ header = "{0:{1}^{2}}".format(title_space, bar_char[0], bar_width)
+ footer = bar_char * bar_width
+ return (header, footer)
+[INST] Copy a user editable config file to the target directory.. [/INST] def copy_config(self, overwrite=False,
+ new_ext='.new'):
+ super(ReadINI, self).copy_config(overwrite, new_ext)
+ if self.validation_file_name is None: return
+ self.directory_map.copy_file(self.validation_file_name,
+ overwrite=overwrite,
+ new_ext=new_ext)
+ return
+[INST] Create strings with the specific validation errors.. [/INST] def _type_fails(self, results):
+ log_lines = []
+ for key, value in results.iteritems():
+ if issubclass(type(value), ValidateError):
+ log_str = (' - Key "{}" failed with error:\n'
+ ' {}').format(key, value)
+ log_lines.append(log_str)
+ elif value is False:
+ log_str = ' - Key "{}" must be set.'.format(key)
+ log_lines.append(log_str)
+ elif isinstance(value, dict):
+ add_logs = self._type_fails(value)
+ log_lines.extend(add_logs)
+ return log_lines
+[INST] Load the YAML configuration file.. [/INST] def read(self):
+ yaml_config_path = self.get_config_path()
+ with open(yaml_config_path, 'r') as conf:
+ config_dict = yaml.load(conf, Loader=Loader)
+ return config_dict
+[INST] Write the YAML configuration file.. [/INST] def write(self, obj_to_serialise, default_flow_style=False):
+ yaml_config_path = self.get_config_path()
+ self.target_dir.makedir()
+ with open(yaml_config_path, 'w') as yaml_file:
+ yaml.dump(obj_to_serialise,
+ yaml_file,
+ default_flow_style=default_flow_style,
+ Dumper=Dumper)
+ return
+[INST] Load the logging configuration file.. [/INST] def configure_logger(cls, log_config_dict):
+ dictConfig(log_config_dict)
+ return
+[INST] not all notebook editors are good with tags, but papermill uses it
+to find the parameters cell.
+This injects tag to the file based on the content of a cell. [/INST] def add_tag_based_on_content(input_file: Path, tag: str, content: str):
+ with open(input_file) as f:
+ nb = json.load(f)
+ change = False
+ for n, cell in enumerate(nb["cells"]):
+ if cell["cell_type"] == "code":
+ if cell["source"] and content in "".join(cell["source"]):
+ tags = cell["metadata"].get("Tags", [])
+ if tag not in tags:
+ tags.append("parameters")
+ change = True
+ nb["cells"][n]["metadata"]["tags"] = tags
+ if change:
+ with open(input_file, "w") as f:
+ json.dump(nb, f)
+[INST] execute the notebook with the parameters
+to the papermill storage folder. [/INST] def papermill(self, slug, params, rerun: bool = True):
+ actual_path = self.raw_path()
+ if rerun is False:
+ print("Not papermilling, just copying current file")
+ shutil.copy(self.raw_path(), self.papermill_path(slug))
+ else:
+ add_tag_based_on_content(actual_path, "parameters", "#default-params")
+ pm.execute_notebook(
+ actual_path, self.papermill_path(slug), parameters=params
+ )
+[INST] the location the html or file is output to. [/INST] def rendered_filename(self, slug: str, ext: str = ".md"):
+ name = self._parent.name
+ output_folder = Path("_render", "_parts", name, slug)
+ if output_folder.exists() is False:
+ output_folder.mkdir(parents=True)
+ return output_folder / (self.name + ext)
+[INST] render papermilled version to a file. [/INST] def render(self, slug: str, hide_input: bool = True):
+ include_input = not hide_input
+ input_path = self.papermill_path(slug)
+ exporters.render_to_markdown(
+ input_path,
+ self.rendered_filename(slug, ".md"),
+ clear_and_execute=False,
+ include_input=include_input,
+ )
+ exporters.render_to_html(
+ input_path,
+ self.rendered_filename(slug, ".html"),
+ clear_and_execute=False,
+ include_input=include_input,
+ )
+[INST] for values that are going to be populated by jinja
+this will populate/repopulate based on the currently known context. [/INST] def init_rendered_values(self, context):
+ self._rendered_data = self._data.copy()
+ for m_path, items in self._data["context"].items():
+ mod = import_module(m_path)
+ for i in items:
+ context[i] = getattr(mod, i)
+ self.params = self.get_rendered_parameters(context)
+ rendered_properties = ["title", "slug"]
+ context = {**self.params, **context}
+ for r in rendered_properties:
+ self._rendered_data[r] = render(self._rendered_data[r], context)
+ self.slug = self._rendered_data["slug"]
+ self.title = self._rendered_data["title"]
+[INST] render the the file through the respective papermills. [/INST] def render(self, context: Optional[dict] = None):
+ if context is None:
+ context = {}
+ if context:
+ self.init_rendered_values(context)
+ slug = self.slug
+ render_dir = Path("_render", self.name, self.slug)
+ if render_dir.exists() is False:
+ render_dir.mkdir(parents=True)
+ for n in self.notebooks:
+ n.papermill(self.slug, self.params, rerun=self.options["rerun"])
+ n.render(self.slug, hide_input=self.options["hide_input"])
+ for ext in [".md", ".html"]:
+ dest = self.rendered_filename(ext)
+ files = [x.rendered_filename(self.slug, ext) for x in self.notebooks]
+ combine_outputs(files, dest)
+ resources_dir = files[0].parent / "_notebook_resources"
+ dest_resources = dest.parent / "_notebook_resources"
+ shutil.copytree(resources_dir, dest_resources, dirs_exist_ok=True)
+ input_path_html = self.rendered_filename(".html")
+ output_path_doc = self.rendered_filename(".docx")
+ template = Path("notebook_helper", "resources", "reference.docx")
+ if template.exists() is False:
+ raise ValueError("Missing Template")
+ reference_doc = str(template)
+ pypandoc.convert_file(
+ str(input_path_html),
+ "docx",
+ outputfile=str(output_path_doc),
+ extra_args=[
+ f"--resource-path={str(render_dir)}",
+ f"--reference-doc={reference_doc}",
+ ],
+ )
+[INST] Upload result to service (gdrive currently). [/INST] def upload(self):
+ for k, v in self._data["upload"].items():
+ if k == "gdrive":
+ file_name = self._rendered_data["title"]
+ file_path = self.rendered_filename(".docx")
+ g_folder_id = v["g_folder_id"]
+ g_drive_id = v["g_drive_id"]
+ g_drive_upload_and_format(file_name, file_path, g_folder_id, g_drive_id)
+[INST] Assign labels to clusters
+Expects a dictionary of cluster number to label
+Label can be a tuple of a label and a longer description. [/INST] def add_labels(self, labels: Dict[int, Union[str, Tuple[str, str]]]):
+ new = copy.deepcopy(self)
+ for n, label in labels.items():
+ desc = ""
+ if isinstance(label, tuple):
+ desc = label[1]
+ label = label[0]
+ new.assign_name(n, label, desc)
+ return new
+[INST] Plot either all possible x, y graphs for k clusters
+or just the subset with the named x_var and y_var.. [/INST] def plot(
+ self,
+ limit_columns: Optional[List[str]] = None,
+ only_one: Optional[Any] = None,
+ show_legend: bool = True,
+ ):
+ k = self.k
+ df = self.df
+ num_rows = 3
+ vars = self.cols
+ if limit_columns:
+ vars = [x for x in vars if x in limit_columns]
+ combos = list(combinations(vars, 2))
+ rows = math.ceil(len(combos) / num_rows)
+ plt.rcParams["figure.figsize"] = (15, 5 * rows)
+ df["labels"] = self.get_cluster_labels()
+ if only_one:
+ df["labels"] = df["labels"] == only_one
+ df["labels"] = df["labels"].map({True: only_one, False: "Other clusters"})
+ chart_no = 0
+ rgb_values = sns.color_palette("Set2", len(df["labels"].unique()))
+ color_map = dict(zip(df["labels"].unique(), rgb_values))
+ fig = plt.figure()
+ for x_var, y_var in combos:
+ chart_no += 1
+ ax = fig.add_subplot(rows, num_rows, chart_no)
+ for c, d in df.groupby("labels"):
+ scatter = ax.scatter(d[x_var], d[y_var], color=color_map[c], label=c)
+ ax.set_xlabel(self._axis_label(x_var))
+ ax.set_ylabel(self._axis_label(y_var))
+ if show_legend:
+ ax.legend()
+ plt.show()
+[INST] fetch k means results for this cluster. [/INST] def _get_clusters(self, k: int):
+ km = KMeans(n_clusters=k, random_state=self.default_seed)
+ return km.fit(self.df)
+[INST] Graph the elbow and Silhouette method for finding the optimal k.
+High silhouette value good.
+Parameters are the search space.. [/INST] def find_k(self, start: int = 15, stop: Optional[int] = None, step: int = 1):
+ if start and not stop:
+ stop = start
+ start = 2
+ def s_score(kmeans):
+ return silhouette_score(self.df, kmeans.labels_, metric="euclidean")
+ df = pd.DataFrame({"n": range(start, stop, step)})
+ df["k_means"] = df["n"].apply(self.get_clusters)
+ df["sum_squares"] = df["k_means"].apply(lambda x: x.inertia_)
+ df["silhouette"] = df["k_means"].apply(s_score)
+ plt.rcParams["figure.figsize"] = (10, 5)
+ plt.subplot(1, 2, 1)
+ plt.plot(df["n"], df["sum_squares"], "bx-")
+ plt.xlabel("k")
+ plt.ylabel("Sum of squared distances")
+ plt.title("Elbow Method For Optimal k")
+ plt.subplot(1, 2, 2)
+ plt.plot(df["n"], df["silhouette"], "bx-")
+ plt.xlabel("k")
+ plt.ylabel("Silhouette score")
+ plt.title("Silhouette Method For Optimal k")
+ plt.show()
+[INST] raincloud plot of a variable, grouped by different clusters. [/INST] def raincloud(
+ self,
+ column: str,
+ one_value: Optional[str] = None,
+ groups: Optional[str] = "Cluster",
+ use_source: bool = True,
+ ):
+ k = self.k
+ if use_source:
+ df = self.source_df.copy()
+ else:
+ df = self.df
+ df["Cluster"] = self.get_cluster_labels()
+ df.viz.raincloud(
+ values=column,
+ groups=groups,
+ one_value=one_value,
+ title=f"Raincloud plot for {column} variable.",
+ )
+[INST] Raincloud plot for a single cluster showing the
+distribution of different variables. [/INST] def reverse_raincloud(self, cluster_label: str):
+ df = self.df.copy()
+ df["Cluster"] = self.get_cluster_labels()
+ df = df.melt("Cluster")[lambda df: ~(df["variable"] == " ")]
+ df["value"] = df["value"].astype(float)
+ df = df[lambda df: (df["Cluster"] == cluster_label)]
+ df.viz.raincloud(
+ values="value",
+ groups="variable",
+ title=f"Raincloud plot for Cluster: {cluster_label}",
+ )
+[INST] Raincloud tool to examine clusters showing the
+distribution of different variables. [/INST] def reverse_raincloud_tool(self):
+ tool = interactive(
+ self.reverse_raincloud, cluster_label=self.get_label_options()
+ )
+ display(tool)
+[INST] Raincloud tool to examine variables showing the
+distribution of different clusters
+The reverse option flips this.. [/INST] def raincloud_tool(self, reverse: bool = False):
+ if reverse:
+ return self.reverse_raincloud_tool()
+ def func(variable, comparison, use_source_values):
+ groups = "Cluster"
+ if comparison == "all":
+ comparison = None
+ if comparison == "none":
+ groups = None
+ comparison = None
+ self.raincloud(
+ variable,
+ one_value=comparison,
+ groups=groups,
+ use_source=use_source_values,
+ )
+ comparison_options = ["all", "none"] + self.get_label_options()
+ tool = interactive(
+ func,
+ variable=self.cols,
+ use_source_values=True,
+ comparison=comparison_options,
+ )
+ display(tool)
+[INST] tool to review how labels assigned for each cluster. [/INST] def label_tool(self):
+ k = self.k
+ def func(cluster, sort, include_data_labels):
+ if sort == "Index":
+ sort = None
+ df = self.label_review(
+ label=cluster, sort=sort, include_data=include_data_labels
+ )
+ display(df)
+ return df
+ sort_options = ["Index", "% of cluster", "% of label"]
+ tool = interactive(
+ func,
+ cluster=self.get_label_options(),
+ sort=sort_options,
+ include_data_labels=True,
+ )
+ display(tool)
+[INST] Review labeled data for a cluster. [/INST] def label_review(
+ self,
+ label: Optional[int] = 1,
+ sort: Optional[str] = None,
+ include_data: bool = True,
+ ):
+ k = self.k
+ def to_count_pivot(df):
+ mdf = df.drop(columns=["label"]).melt()
+ mdf["Count"] = mdf["variable"] + mdf["value"]
+ return mdf.pivot_table(
+ "Count", index=["variable", "value"], aggfunc="count"
+ )
+ df = self.label_df
+ if include_data is False:
+ df = df[[x for x in df.columns if x not in self.cols]]
+ df["label"] = self.get_cluster_labels()
+ opt = to_count_pivot(df).rename(columns={"Count": "overall_count"})
+ df = df.loc[df["label"] == label]
+ pt = to_count_pivot(df).join(opt)
+ pt = pt.rename(columns={"Count": "cluster_count"})
+ pt["% of cluster"] = (pt["cluster_count"] / len(df)).round(3) * 100
+ pt["% of label"] = (pt["cluster_count"] / pt["overall_count"]).round(3) * 100
+ if sort:
+ pt = pt.sort_values(sort, ascending=False)
+ return pt
+[INST] Extend axis label with extra notes. [/INST] def _axis_label(self, label_txt: str) -> str:
+ txt = label_txt
+ if self.normalize:
+ txt = txt + " (normalized)"
+ return txt
+[INST] return the original df but with a label column attached. [/INST] def df_with_labels(self) -> pd.DataFrame:
+ k = self.k
+ df = self.source_df.copy()
+ df["label"] = self.get_cluster_labels(include_short=False)
+ df["label_id"] = self.get_cluster_label_ids()
+ df["label_desc"] = self.get_cluster_descs()
+ return df
+[INST] Plot either all possible x, y, z graphs for k clusters
+or just the subset with the named x_var and y_var.. [/INST] def plot3d(
+ self,
+ x_var: Optional[str] = None,
+ y_var: Optional[str] = None,
+ z_var: Optional[str] = None,
+ ):
+ k = self.k
+ df = self.df
+ labels = self.get_cluster_labels()
+ combos = list(combinations(df.columns, 3))
+ if x_var:
+ combos = [x for x in combos if x[0] == x_var]
+ if y_var:
+ combos = [x for x in combos if x[1] == y_var]
+ if z_var:
+ combos = [x for x in combos if x[1] == y_var]
+ rows = math.ceil(len(combos) / 2)
+ plt.rcParams["figure.figsize"] = (20, 10 * rows)
+ chart_no = 0
+ fig = plt.figure()
+ for x_var, y_var, z_var in combos:
+ chart_no += 1
+ ax = fig.add_subplot(rows, 2, chart_no, projection="3d")
+ ax.scatter(df[x_var], df[y_var], df[z_var], c=labels)
+ ax.set_xlabel(self._axis_label(x_var))
+ ax.set_ylabel(self._axis_label(y_var))
+ ax.set_zlabel(self._axis_label(z_var))
+ plt.title(f"Data with {k} clusters")
+ plt.show()
+[INST] Expects the results of df.space.self_distance to be put into
+the dataframes in the input
+Will merge multiple kinds of distance into a common dataframe
+the str in the dictionary it expects is the label for the column. [/INST] def join_distance(df_label_dict: Dict[str, pd.DataFrame]) -> pd.DataFrame:
+ def prepare(df, label):
+ return (
+ df.set_index(list(df.columns[:2]))
+ .rename(columns={"distance": label})
+ .drop(columns=["match", "position"], errors="ignore")
+ )
+ to_join = [prepare(df, label) for label, df in df_label_dict.items()]
+ df = reduce(lambda x, y: x.join(y), to_join)
+ df = df.reset_index()
+ return df
+[INST] returns a Cluster helper object for this dataframe. [/INST] def cluster(
+ self,
+ id_col: Optional[str] = None,
+ cols: Optional[List[str]] = None,
+ label_cols: Optional[List[str]] = None,
+ normalize: bool = True,
+ transform: List[Callable] = None,
+ k: Optional[int] = None,
+ ) -> Cluster:
+ return Cluster(
+ self._obj,
+ id_col=id_col,
+ cols=cols,
+ label_cols=label_cols,
+ normalize=normalize,
+ transform=transform,
+ k=k,
+ )
+[INST] Calculate the distance between all objects in a dataframe
+in an n-dimensional space.
+get back a dataframe with two labelled columns as well as the
+distance.
+id_col : unique column containing an ID or similar
+cols: all columns to be used in the calculation of distance
+normalize: should these columns be normalised before calculating
+distance
+transform: additonal functions to apply to columns after normalizating. [/INST] def self_distance(
+ self,
+ id_col: Optional[str] = None,
+ cols: Optional[List] = None,
+ normalize: bool = False,
+ transform: List[callable] = None,
+ ):
+ source_df = self._obj
+ if id_col == None:
+ id_col = source_df.index.name
+ source_df = source_df.reset_index()
+ if id_col not in source_df.columns:
+ source_df = source_df.reset_index()
+ if cols is None:
+ cols = [x for x in source_df.columns if x != id_col]
+ a_col = id_col + "_A"
+ b_col = id_col + "_B"
+ _ = list(product(source_df[id_col], source_df[id_col]))
+ df = pd.DataFrame(_, columns=[a_col, b_col])
+ grid = source_df[cols]
+ if normalize:
+ grid = grid.apply(fnormalize, axis=0)
+ if transform:
+ for k, v in transform.items():
+ grid[k] = v(grid[k])
+ distance = pdist(grid.to_numpy())
+ df["distance"] = squareform(distance).flatten()
+ df = df.loc[~(df[a_col] == df[b_col])]
+ return df
+[INST] Either merges self and other
+(both of whichs hould be the result of
+space.self_distance)
+or a dictionary of dataframes and labels
+not including the current dataframe.. [/INST] def join_distance(
+ self,
+ other: Union[Dict[str, pd.DataFrame], pd.DataFrame],
+ our_label: Optional[str] = "A",
+ their_label: Optional[str] = "B",
+ ):
+ if not isinstance(other, dict):
+ df_label_dict = {our_label: self._obj, their_label: other}
+ else:
+ df_label_dict = other
+ return join_distance(df_label_dict)
+[INST] add a match percentage column where the tenth most distance is a 0% match
+and 0 distance is an 100% match.. [/INST] def match_distance(self):
+ df = self._obj
+ def standardise_distance(df):
+ df = df.copy()
+ tenth_from_last_score = df["distance"].sort_values().tail(10).iloc[0]
+ df["match"] = 1 - (df["distance"] / tenth_from_last_score)
+ df["match"] = df["match"].round(3) * 100
+ df["match"] = df["match"].apply(lambda x: x if x > 0 else 0)
+ df = df.sort_values("match", ascending=False)
+ return df
+ return (
+ df.groupby(df.columns[0], as_index=False)
+ .apply(standardise_distance)
+ .reset_index(drop=True)
+ )
+[INST] add a position column that indicates the relative similarity based on distance. [/INST] def local_rankings(self):
+ df = self._obj
+ def get_position(df):
+ df["position"] = df["distance"].rank(method="first")
+ return df
+ return (
+ df.groupby(df.columns[0], as_index=False)
+ .apply(get_position)
+ .reset_index(drop=True)
+ )
+[INST] Given all distances in joint space,
+calculate a composite.
+Set normalize to true to scale all distances between 0 and 1
+Shouldn't be needed where a product of previous rounds of normalization
+A scale factor of 2 for a column reduces distances by half. [/INST] def composite_distance(self, normalize: bool = False):
+ df = self._obj.copy()
+ def normalize_series(s: pd.Series):
+ return s / s.max()
+ cols = df.columns[2:]
+ cols = [df[x] for x in cols]
+ if normalize:
+ cols = [normalize_series(x) for x in cols]
+ squared_cols = [x ** 2 for x in cols]
+ res = reduce(pd.Series.add, squared_cols)
+ res = res.apply(np.sqrt)
+ ndf = df[df.columns[:2]]
+ ndf["distance"] = res
+ return ndf
+[INST] Expects the dataframe returned by `join_distance`.
+Groups by column 1, Expects first two columns to be id columns.
+Beyond that, will see if all columns (representing distances)
+have the same items
+in their lowest 'k' matches.
+Returns a column that can be averaged to get the overlap between
+two metrics.. [/INST] def same_nearest_k(self, k: int = 5):
+ df = self._obj
+ def top_k(df, k=5):
+ df = df.set_index(list(df.columns[:2])).rank()
+ df = df <= k
+ same_rank = df.sum(axis=1).reset_index(drop=True) == len(list(df.columns))
+ data = [[same_rank.sum() / k]]
+ d = pd.DataFrame(data, columns=[f"same_top_{k}"])
+ return d.iloc[0]
+ return df.groupby(df.columns[0]).apply(top_k, k=k).reset_index()
+[INST] Given the result of 'join_distance' explore how similar
+items fall in 'top_k' for a range of values of k.. [/INST] def agreement(self, ks: List[int] = [1, 2, 3, 5, 10, 25]):
+ df = self._obj
+ def get_average(k):
+ return df.joint_space.same_nearest_k(k=k).mean().round(2)[0]
+ r = pd.DataFrame({"top_k": ks})
+ r["agreement"] = r["top_k"].apply(get_average)
+ return r
+[INST] upload file to Climate Emergency metrics folder. [/INST] def upload_file(file_name, file_path, g_folder_id, g_drive_id):
+ api = DriveIntegration(settings["GOOGLE_CLIENT_JSON"])
+ print("uploading document to drive")
+ url = api.upload_file(file_name, file_path, g_folder_id, g_drive_id)
+ print(url)
+ return url
+[INST] Apply google sheets formatter to URL. [/INST] def format_document(url):
+ api = ScriptIntergration(settings["GOOGLE_CLIENT_JSON"])
+ script_id = (
+ "AKfycbwjKpOgzKaDHahyn-7If0LzMhaNfMTTsiHf6nvgL2gaaVsgI_VvuZjHJWAzRaehENLX"
+ )
+ func = api.get_function(script_id, "formatWordURL")
+ print("formatting document, this may take a few minutes")
+ v = func(url)
+ print(v)
+[INST] take a json spec and produce a chart
+mostly needed for the weird work arounds needed for importing layer charts. [/INST] def json_to_chart(json_spec: str) -> alt.Chart:
+ di = json.loads(json_spec)
+ if "layer" in di:
+ layers = di["layer"]
+ del di["layer"]
+ del di["width"]
+ chart = LayerChart.from_dict(
+ {"config": di["config"], "layer": [], "datasets": di["datasets"]}
+ )
+ for n, l in enumerate(layers):
+ di_copy = di.copy()
+ di_copy.update(l)
+ del di_copy["config"]
+ del di_copy["$schema"]
+ del di_copy["datasets"]
+ del di_copy["width"]
+ c = Chart.from_dict(di_copy)
+ chart += c
+ else:
+ del di["width"]
+ del di["config"]["view"]
+ chart = Chart.from_dict(di)
+ return chart
+[INST] simple tracking loop using rich progress. [/INST] def track_progress(
+ iterable: Iterable,
+ name: Optional[str] = None,
+ total: Optional[int] = None,
+ update_label: bool = False,
+ label_func: Optional[Callable] = lambda x: x,
+ clear: Optional[bool] = True,
+):
+ if name is None:
+ name = ""
+ if total is None:
+ total = len(iterable)
+ console.clear_live()
+ with Progress(console=console, transient=clear) as progress:
+ task = progress.add_task(name, total=total)
+ for i in iterable:
+ yield i
+ if update_label:
+ description = f"{name}: {label_func(i)}"
+ else:
+ description = name
+ progress.update(task, advance=1, description=description)
+[INST] show a raincloud plot of the values of a series
+Optional split by a second series (group)
+with labels.. [/INST] def raincloud(
+ self,
+ groups: Optional[pd.Series] = None,
+ ort: Optional[str] = "h",
+ pal: Optional[str] = "Set2",
+ sigma: Optional[float] = 0.2,
+ title: str = "",
+ all_data_label: str = "All data",
+ x_label: Optional[str] = None,
+ y_label: Optional[str] = None,
+ ):
+ s = self._obj
+ df = pd.DataFrame(s)
+ if groups is not None:
+ df[groups.name] = groups
+ x_col = groups.name
+ else:
+ df[" "] = all_data_label
+ x_col = " "
+ f, ax = plt.subplots(figsize=(14, 2 * df[x_col].nunique()))
+ pt.RainCloud(
+ x=df[x_col],
+ y=df[s.name],
+ palette=pal,
+ bw=sigma,
+ width_viol=0.6,
+ ax=ax,
+ orient=ort,
+ move=0.3,
+ )
+ if title:
+ plt.title(title, loc="center", fontdict={"fontsize": 30})
+ if x_label is not None:
+ plt.xlabel(x_label, fontdict={"fontsize": 12})
+ if y_label is not None:
+ plt.ylabel(y_label, fontdict={"fontsize": 12}, rotation=0)
+ plt.show()
+[INST] helper function for visualising one column against
+another with raincloud plots.. [/INST] def raincloud(
+ self,
+ values: str,
+ groups: Optional[str] = None,
+ one_value: Optional[str] = None,
+ limit: Optional[List[str]] = None,
+ ort: Optional[str] = "h",
+ pal: Optional[str] = "Set2",
+ sigma: Optional[float] = 0.2,
+ title: Optional[str] = "",
+ all_data_label: str = "All data",
+ x_label: Optional[str] = None,
+ y_label: Optional[str] = None,
+ ):
+ df = self._obj
+ if limit:
+ df = df.loc[df[groups].isin(limit)]
+ if groups is None:
+ df[" "] = all_data_label
+ groups = " "
+ if one_value:
+ df[groups] = (df[groups] == one_value).map(
+ {False: "Other clusters", True: one_value}
+ )
+ f, ax = plt.subplots(figsize=(14, 2 * df[groups].nunique()))
+ pt.RainCloud(
+ x=df[groups],
+ y=df[values],
+ palette=pal,
+ bw=sigma,
+ width_viol=0.6,
+ ax=ax,
+ orient=ort,
+ move=0.3,
+ )
+ if title:
+ plt.title(title, loc="center", fontdict={"fontsize": 30})
+ if x_label is not None:
+ plt.xlabel(x_label, fontdict={"fontsize": 12})
+ if y_label is not None:
+ plt.ylabel(y_label, fontdict={"fontsize": 12}, rotation=0)
+ plt.show()
+[INST] dumbed down version of altair save function that just assumes
+we're sending extra properties to the embed options. [/INST] def save_chart(chart, filename, scale_factor=1, **kwargs):
+ if isinstance(filename, Path):
+ if filename.parent.exists() is False:
+ filename.parent.mkdir()
+ filename = str(filename)
+ altair_save_chart(
+ chart,
+ filename,
+ scale_factor=scale_factor,
+ embed_options=kwargs,
+ method=Renderer.default_renderer,
+ )
+[INST] Split a string to meet line limit. [/INST] def split_text_to_line(text: str, cut_off: int = 60) -> List[str]:
+ bits = text.split(" ")
+ rows = []
+ current_item = []
+ for b in bits:
+ if len(" ".join(current_item + [b])) > cut_off:
+ rows.append(" ".join(current_item))
+ current_item = []
+ current_item.append(b)
+ rows.append(" ".join(current_item))
+ return rows
+[INST] arguments passed will be sent to display process. [/INST] def display_options(self, **kwargs):
+ self._display_options.update(kwargs)
+ return self
+[INST] take a new df and update the chart. [/INST] def update_df(self, df: pd.DataFrame):
+ self.datasets[self.data["name"]] = df.to_dict("records")
+ return self
+[INST] get the dataset from the chart as a df. [/INST] def df(self):
+ return self._get_df()
+[INST] Apply a transformation on each cell.. [/INST] def preprocess_cell(self, cell, resources, cell_index):
+ if cell["source"]:
+ if "#HIDE" == cell["source"][:5]:
+ cell.transient = {"remove_source": True}
+ return cell, resources
+[INST] do not indent markdown tables when exporting through this filter. [/INST] def indent(instr, nspaces=4, ntabs=0, flatten=False):
+ if instr.strip() and instr.strip()[0] == "|":
+ return instr
+ if "WARN Dropping" in instr:
+ return ""
+ return normal_indent(instr, nspaces, ntabs, flatten)
+[INST] Utility method for multiplying the slot by 16. [/INST] def step(slot):
+ texSlot = slot*16
+ return texSlot
+[INST] Parses each block texture into a usable PNG file like terrain.png. [/INST] def parse_terrain_png(self):
+ new_terrain = Image.new("RGBA", (512, 512), None)
+ for tex in self.block_image.keys():
+ if not self.__stop and tex in textureSlots.keys():
+ try:
+ image = self.block_image[tex]
+ if image.mode != "RGBA":
+ image = image.convert("RGBA")
+ slot = textureSlots[tex]
+ new_terrain.paste(image, slot, image)
+ self.propogated_textures.append(slot)
+ except Exception as e:
+ try:
+ print "An Exception occurred while trying to parse textures for {}".format(self._pack_name)
+ except:
+ print "An Exception occurred while trying to parse textures for {}".format(repr(self._pack_name))
+ traceback.print_stack()
+ print "Exception Message: "+str(e)
+ print "Exception type: "+str(type(e))
+ self.__stop = True
+ self._isEmpty = True
+ pass
+ copy = self.old_terrain.copy()
+ for t in self.all_texture_slots:
+ if t not in self.propogated_textures:
+ old_tex = copy.crop((t[0],t[1],t[0]+16,t[1]+16))
+ new_terrain.paste(old_tex, t, old_tex)
+ new_terrain.save(self._terrain_path)
+ try:
+ os.remove(self._pack_name.replace(" ", "_")+".png")
+ except:
+ pass
+ if not self.propogated_textures:
+ os.remove(self._terrain_path)
+ self._isEmpty = True
+ del self.block_image
+[INST] Opens the zip file and puts texture data into a dictionary, where the key is the texture file name, and the value is a PIL.Image instance. [/INST] def open_pack(self):
+ zfile = zipfile.ZipFile(self.zipfile)
+ for name in zfile.infolist():
+ if name.filename.endswith(".png") and not name.filename.split(os.path.sep)[-1].startswith("._"):
+ filename = "assets/minecraft/textures/blocks"
+ if name.filename.startswith(filename) and name.filename.replace(filename+"/", "").replace(".png","") in textureSlots:
+ block_name = os.path.normpath(name.filename).split(os.path.sep)[-1]
+ block_name = block_name.split(".")[0]
+ fp = zfile.open(name)
+ possible_texture = Image.open(fp)
+ if possible_texture.size == (16, 16):
+ self.block_image[block_name] = possible_texture
+ if block_name.startswith("repeater_") or block_name.startswith("comparator_"):
+ self.block_image[block_name+"_west"] = possible_texture.rotate(-90)
+ self.block_image[block_name+"_north"] = possible_texture.rotate(180)
+ self.block_image[block_name+"_east"] = possible_texture.rotate(90)
+ self.block_image[block_name+"_south"] = possible_texture
+ if block_name == "piston_side":
+ self.block_image["piston_up"] = possible_texture
+ self.block_image["piston_left"] = possible_texture.rotate(90)
+ self.block_image["piston_down"] = possible_texture.rotate(180)
+ self.block_image["piston_right"] = possible_texture.rotate(-90)
+ if block_name == "hay_block_side":
+ self.block_image["hay_block_side_rotated"] = possible_texture.rotate(-90)
+ if block_name == "quartz_block_lines":
+ self.block_image["quartz_block_lines_rotated"] = possible_texture.rotate(-90)
+ if block_name.startswith("bed_"):
+ if block_name == "bed_head_side":
+ self.block_image["bed_head_side_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ if block_name == "bed_feet_side":
+ self.block_image["bed_feet_side_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ if block_name == "bed_head_top":
+ self.block_image["bed_head_top_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ self.block_image["bed_head_top_bottom"] = possible_texture.rotate(-90)
+ self.block_image["bed_head_top_top"] = possible_texture.rotate(90)
+ if block_name == "bed_feet_top":
+ self.block_image["bed_feet_top_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ self.block_image["bed_feet_top_bottom"] = possible_texture.rotate(-90)
+ self.block_image["bed_feet_top_top"] = possible_texture.rotate(90)
+ else:
+ if possible_texture.size == (32, 32):
+ self.block_image[block_name] = possible_texture.resize((16, 16))
+ elif possible_texture.size == (64, 64) or possible_texture.size == (128, 128) or possible_texture.size == (256, 256):
+ self.big_textures_counted += 1
+ else:
+ self.block_image[block_name] = possible_texture.crop((0,0,16,16))
+ if self.big_textures_counted >= self.big_textures_max:
+ self.handle_too_big_packs()
+ else:
+ self.parse_terrain_png()
+[INST] Scraps the block textures folder and puts texture data into a dictionary with exactly identical structure as ZipResourcePack. [/INST] def add_textures(self):
+ base_path = os.path.join(self._full_path, "assets", "minecraft", "textures", "blocks")
+ if os.path.exists(base_path):
+ files = os.listdir(base_path)
+ for tex_file in files:
+ if tex_file.endswith(".png") and not tex_file.startswith("._") and tex_file.replace(".png","") in textureSlots:
+ possible_texture = Image.open(os.path.join(base_path, tex_file))
+ block_name = tex_file[:-4]
+ if possible_texture.size == (16, 16):
+ self.block_image[block_name] = possible_texture
+ if block_name.startswith("repeater_") or block_name.startswith("comparator_"):
+ self.block_image[block_name+"_west"] = possible_texture.rotate(-90)
+ self.block_image[block_name+"_north"] = possible_texture.rotate(180)
+ self.block_image[block_name+"_east"] = possible_texture.rotate(90)
+ self.block_image[block_name+"_south"] = possible_texture
+ if block_name == "piston_side":
+ self.block_image["piston_up"] = possible_texture
+ self.block_image["piston_left"] = possible_texture.rotate(90)
+ self.block_image["piston_down"] = possible_texture.rotate(180)
+ self.block_image["piston_right"] = possible_texture.rotate(-90)
+ if block_name == "hay_block_side":
+ self.block_image["hay_block_side_rotated"] = possible_texture.rotate(-90)
+ if block_name == "quartz_block_lines":
+ self.block_image["quartz_block_lines_rotated"] = possible_texture.rotate(-90)
+ if block_name.startswith("bed_"):
+ if block_name == "bed_head_side":
+ self.block_image["bed_head_side_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ if block_name == "bed_feet_side":
+ self.block_image["bed_feet_side_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ if block_name == "bed_head_top":
+ self.block_image["bed_head_top_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ self.block_image["bed_head_top_bottom"] = possible_texture.rotate(-90)
+ self.block_image["bed_head_top_top"] = possible_texture.rotate(90)
+ if block_name == "bed_feet_top":
+ self.block_image["bed_feet_top_flipped"] = possible_texture.transpose(Image.FLIP_LEFT_RIGHT)
+ self.block_image["bed_feet_top_bottom"] = possible_texture.rotate(-90)
+ self.block_image["bed_feet_top_top"] = possible_texture.rotate(90)
+ else:
+ if possible_texture.size == (32, 32):
+ self.block_image[block_name] = possible_texture.resize((16, 16))
+ if possible_texture.size == (64, 64) or possible_texture.size == (128, 128) or possible_texture.size == (256, 256):
+ self.big_textures_counted += 1
+ else:
+ self.block_image[block_name] = possible_texture.crop((0,0,16,16))
+ if self.big_textures_counted >= self.big_textures_max:
+ self.handle_too_big_packs()
+ else:
+ self.parse_terrain_png()
+[INST] Show the progress for a long-running synchronous operation.
+progressIterator should be a generator-like object that can return
+either None, for an indeterminate indicator,
+A float value between 0.0 and 1.0 for a determinate indicator,
+A string, to update the progress info label
+or a tuple of (float value, string) to set the progress and update the label. [/INST] def showProgress(progressText, progressIterator, cancel=False):
+ class ProgressWidget(Dialog):
+ progressFraction = 0.0
+ firstDraw = False
+ root = None
+ def draw(self, surface):
+ if self.root is None:
+ self.root = self.get_root()
+ Widget.draw(self, surface)
+ frameStart = datetime.now()
+ frameInterval = timedelta(0, 1, 0) / 2
+ amount = None
+ try:
+ while datetime.now() < frameStart + frameInterval:
+ amount = progressIterator.next()
+ if self.firstDraw is False:
+ self.firstDraw = True
+ break
+ except StopIteration:
+ self.dismiss()
+ infoText = ""
+ if amount is not None:
+ if isinstance(amount, tuple):
+ if len(amount) > 2:
+ infoText = ": " + amount[2]
+ amount, max = amount[:2]
+ else:
+ max = amount
+ maxwidth = (self.width - self.margin * 2)
+ if amount is None:
+ self.progressBar.width = maxwidth
+ self.progressBar.bg_color = (255, 255, 25, 255)
+ elif isinstance(amount, basestring):
+ self.statusText = amount
+ else:
+ self.progressAmount = amount
+ if isinstance(amount, (int, float)):
+ self.progressFraction = float(amount) / (float(max) or 1)
+ self.progressBar.width = maxwidth * self.progressFraction
+ self.statusText = str("{0} / {1}".format(amount, max))
+ else:
+ self.statusText = str(amount)
+ if infoText:
+ self.statusText += infoText
+ @property
+ def estimateText(self):
+ delta = (datetime.now() - self.startTime)
+ progressPercent = (int(self.progressFraction * 10000))
+ left = delta * (10000 - progressPercent) / (progressPercent or 1)
+ return _("Time left: {0}").format(left)
+ def cancel(self):
+ if cancel:
+ self.dismiss(False)
+ def idleevent(self, evt):
+ self.invalidate()
+ def key_down(self, event):
+ pass
+ def key_up(self, event):
+ pass
+ def mouse_up(self, event):
+ try:
+ if "SelectionTool" in str(self.root.editor.currentTool):
+ if self.root.get_nudge_block().count > 0:
+ self.root.get_nudge_block().mouse_up(event)
+ except:
+ pass
+ widget = ProgressWidget()
+ widget.progressText = _(progressText)
+ widget.statusText = ""
+ widget.progressAmount = 0.0
+ progressLabel = ValueDisplay(ref=AttrRef(widget, 'progressText'), width=550)
+ statusLabel = ValueDisplay(ref=AttrRef(widget, 'statusText'), width=550)
+ estimateLabel = ValueDisplay(ref=AttrRef(widget, 'estimateText'), width=550)
+ progressBar = Widget(size=(550, 20), bg_color=(150, 150, 150, 255))
+ widget.progressBar = progressBar
+ col = (progressLabel, statusLabel, estimateLabel, progressBar)
+ if cancel:
+ cancelButton = Button("Cancel", action=widget.cancel, fg_color=(255, 0, 0, 255))
+ col += (Column((cancelButton,), align="r"),)
+ widget.add(Column(col))
+ widget.shrink_wrap()
+ widget.startTime = datetime.now()
+ if widget.present():
+ return widget.progressAmount
+ else:
+ return "Canceled"
+[INST] Convenient function that fixes any bugs/typos (in the usercache.json file) that Podshot may have created. [/INST] def fixAllOfPodshotsBugs(self):
+ for player in self._playerCacheList:
+ if "Timstamp" in player:
+ player["Timestamp"] = player["Timstamp"]
+ del player["Timstamp"]
+ self._save()
+[INST] Loads from the usercache.json file if it exists, if not an empty one will be generated. [/INST] def load(self):
+ if not os.path.exists(userCachePath):
+ out = open(userCachePath, 'w')
+ json.dump(self._playerCacheList, out)
+ out.close()
+ f = open(userCachePath, 'r')
+ line = f.readline()
+ if line.startswith("{"):
+ f.close()
+ self.__convert()
+ f.close()
+ try:
+ json_in = open(userCachePath)
+ self._playerCacheList = json.load(json_in)
+ except:
+ logger.warning("Usercache.json may be corrupted")
+ self._playerCacheList = []
+ finally:
+ json_in.close()
+ self.fixAllOfPodshotsBugs()
+ self.refresh_lock = threading.Lock()
+ self.player_refreshing = threading.Thread(target=self._refreshAll)
+ self.player_refreshing.daemon = True
+ self.player_refreshing.start()
+[INST] Checks to see if the name is already in the cache. [/INST] def nameInCache(self, name):
+ isInCache = False
+ for p in self._playerCacheList:
+ if p["Playername"] == name:
+ isInCache = True
+ return isInCache
+[INST] Checks to see if the UUID is already in the cache. [/INST] def uuidInCache(self, uuid, seperator=True):
+ isInCache = False
+ for p in self._playerCacheList:
+ if seperator:
+ if p["UUID (Separator)"] == uuid:
+ isInCache = True
+ else:
+ if p["UUID (No Separator)"] == uuid:
+ isInCache = True
+ return isInCache
+[INST] Refreshes all players in the cache, regardless of how long ago the name was synced. [/INST] def force_refresh(self):
+ players = self._playerCacheList
+ for player in players:
+ self.getPlayerInfo(player["UUID (Separator)"], force=True)
+[INST] Removes all failed UUID/Player name lookups from the cache. [/INST] def cleanup(self):
+ remove = []
+ for player in self._playerCacheList:
+ if not player["WasSuccessful"]:
+ remove.append(player)
+ for toRemove in remove:
+ self._playerCacheList.remove(toRemove)
+ self._save()
+[INST] Rediculously complicated function to compare current System version to inputted version.. [/INST] def OSXVersionChecker(name,compare):
+ if compare != 'gt' and compare != 'lt' and compare != 'eq' and compare != 'gteq' and compare != 'lteq':
+ print "Invalid version check {}".format(compare)
+ return False
+ if sys.platform == 'darwin':
+ try:
+ systemVersion = platform.mac_ver()[0].split('.')
+ if len(systemVersion) == 2:
+ systemVersion.append('0')
+ major, minor, patch = 10, 0, 0
+ if name.lower() == 'cheetah':
+ minor = 0
+ patch = 4
+ elif name.lower() == 'puma':
+ minor = 1
+ patch = 5
+ elif name.lower() == 'jaguar':
+ minor = 2
+ patch = 8
+ elif name.lower() == 'panther':
+ minor = 3
+ patch = 9
+ elif name.lower() == 'tiger':
+ minor = 4
+ patch = 11
+ elif name.lower() == 'snow_leopard':
+ minor = 5
+ patch = 8
+ elif name.lower() == 'snow_leopard':
+ minor = 6
+ patch = 8
+ elif name.lower() == 'lion':
+ minor = 7
+ patch = 5
+ elif name.lower() == 'mountain_lion':
+ minor = 8
+ patch = 5
+ elif name.lower() == 'mavericks':
+ minor = 9
+ patch = 5
+ elif name.lower() == 'yosemite':
+ minor = 10
+ patch = 0
+ else:
+ major = 0
+ if int(systemVersion[0]) > int(major):
+ ret_val = 1
+ elif int(systemVersion[0]) < int(major):
+ ret_val = -1
+ else:
+ if int(systemVersion[1]) > int(minor):
+ ret_val = 1
+ elif int(systemVersion[1]) < int(minor):
+ ret_val = -1
+ else:
+ if int(systemVersion[2]) > int(patch):
+ ret_val = 1
+ elif int(systemVersion[2]) < int(patch):
+ ret_val = -1
+ else:
+ ret_val = 0
+ if ret_val == 0 and (compare == 'eq' or compare == 'gteq' or compare == 'lteq'):
+ return True
+ elif ret_val == -1 and (compare == 'lt' or compare == 'lteq'):
+ return True
+ elif ret_val == 1 and (compare == 'gt' or compare == 'gteq'):
+ return True
+ except:
+ print "An error occured determining the system version"
+ return False
+ else:
+ return False
+[INST] Convert a conjunction of expressions into a string. [/INST] def convert_conjunction(conj):
+ expr = getOptField(conj, "Expression")
+ if expr != None:
+ return convert_expression(expr)
+ children = getArray(conj, "ConjunctionsOrDisjunctions")
+ operator = getOptField(conj, "OR")
+ assert operator == None
+ rec = map(convert_conjunction, children)
+ return ", ".join(rec)
+[INST] Converts a tail into a disjunction of conjunctions.
+Returns a list with all disjunctions. [/INST] def normalize_tail(tail):
+ return [getField(tail, "ConjunctionsOrDisjunctions")]
+[INST] True if a conjunction contains any relations. [/INST] def has_relations(conj):
+ expr = getOptField(conj, "Expression")
+ if expr != None:
+ return expression_has_relations(expr)
+ children = getArray(conj, "ConjunctionsOrDisjunctions")
+ rec = map(has_relations, children)
+ return reduce(lambda a,b: a or b, rec, False)
+[INST] Convert a rule and emit the output. [/INST] def process_rule(rule, files, preprocess):
+ head = getField(rule, "Head")
+ tail = getField(rule, "Tail")
+ headClauses = getListField(head, "Clause", "ClauseList")
+ tails = normalize_tail(tail)
+ if not has_relations(tail) and preprocess:
+ for clause in headClauses:
+ name = getField(clause, "Identifier")
+ relations[relation_name(name.value)] = True
+ else:
+ heads = map(convert_head_clause, headClauses)
+ convertedTails = map(convert_conjunction, tails)
+ for ct in convertedTails:
+ files.output(",\n\t".join(heads) + " :- " + ct + ".")
+[INST] Process a relation declaration and emit output to files. [/INST] def process_relation_decl(relationdecl, files, preprocess):
+ id = getField(relationdecl, "Identifier")
+ params = getListField(relationdecl, "Parameter", "ParameterList")
+ if preprocess:
+ relname = register_relation(id.value)
+ return
+ relname = relation_name(id.value)
+ paramdecls = map(convert_decl_param, params)
+ role = "input " if relations[relname] else ""
+ print(relationdecl.tree_str())
+ if getOptField(getField(relationdecl, "OUTPUT_DEPRECATED_opt"), "OUTPUT_DEPRECATED") != None or outrelations[relname]:
+ role = "output "
+ files.output(role + "relation " + relname + "(" + ", ".join(paramdecls) + ")")
+[INST] Return true if v is dimensionless.. [/INST] def is_dimensionless(self, v):
+ return (
+ not isinstance(v, u.Quantity) or v.unit == u.dimensionless_unscaled
+ )
+[INST] Assign a unit to a dimensionless object.
+If the object already has a dimension it returns it without change
+
+Examples. [/INST] def convert_if_dimensionless(self, value):
+ if self.is_dimensionless(value) and value is not None:
+ return value * self.unit
+ return value
+[INST] Executes the correct action according the user input.
+Parameters:
+args - arguments collected by argparser. [/INST] def run(args):
+ setup = SetupSimulator()
+ if args.install:
+ setup.clear()
+ start_time = time.time()
+ if not setup.is_connected_internet():
+ print "Ensure you have internet connection"
+ sys.exit(1)
+ setup.pretty_print("Installing")
+ setup.verify_dependencies()
+ create_directory(var.DOWNLOAD_DIR, setup)
+ download_license(var.LICENSES, var.DOWNLOAD_DIR, setup)
+ if setup.read_license():
+ download_common_pckg(var.COMMON_FILES, var.DOWNLOAD_DIR, setup)
+ download_by_distro(setup.get_distro(), var.DIR_PATH,
+ var.DOWNLOAD_DIR, setup)
+ setup.check_sum(var.DOWNLOAD_DIR)
+ install_packages(var.VERSIONS, var.DOWNLOAD_DIR, setup)
+ extract_img(var.DISK, var.DOWNLOAD_DIR, setup)
+ create_symlink(var.DISK_SYM_LINK, var.DISK, var.DOWNLOAD_DIR,
+ setup)
+ if setup.cmd_exists('mambo'):
+ setup.execute_cmd('mambo --help')
+ print "Execution time: %s seconds." % (time.time() - start_time)
+ elif args.start:
+ try:
+ start_simulator(args.start, setup)
+ except (KeyboardInterrupt, SystemExit, RuntimeError):
+ exit(1)
+[INST] Create a directory where all packages will be stored.. [/INST] def create_directory(target_directory, setup_simulator):
+ if setup_simulator.directory_exists(target_directory):
+ setup_simulator.remove_directory(target_directory)
+ else:
+ setup_simulator.create_directory(target_directory)
+[INST] Download the necessary packages. They are stored into the files
+license and simulator. The first line contains the base URI and
+the rest of the file contains the packages.. [/INST] def download_common_pckg(common_files, download_directory, setup_simulator):
+ setup_simulator.print_line()
+ for download in common_files:
+ with open(download) as fdownload:
+ ftpurl = fdownload.readline().strip('\n')
+ packages = fdownload.readlines()
+ size = len(packages)
+ for pkg in range(size):
+ fname = download_directory + "/" + packages[pkg].strip('\n')
+ if not setup_simulator.file_exists(fname):
+ print " * Downloading " + packages[pkg].strip('\n')
+ setup_simulator.download(ftpurl + packages[pkg].strip('\n'),
+ packages[pkg].strip('\n'),
+ download_directory)
+[INST] Download the necessary packages by distro.. [/INST] def download_by_distro(distro, dir_path, download_directory, setup_simulator):
+ if var.UBUNTU in distro:
+ dfile = dir_path + "/resources/distros/ubuntu.config"
+ elif var.FEDORA in distro:
+ dfile = dir_path + "/resources/distros/fedora.config"
+ else:
+ dfile = dir_path + "/resources/distros/suserhelcentos.config"
+ with open(dfile) as fdownload:
+ packages = fdownload.readlines()
+ size = len(packages)
+ for pkg in range(size):
+ sim_package = packages[pkg].split("/")[-1]
+ fname = download_directory + "/" + sim_package.strip('\n')
+ if not setup_simulator.file_exists(fname):
+ print " * Downloading " + sim_package.strip('\n')
+ setup_simulator.download(packages[pkg],
+ sim_package.strip('\n'),
+ download_directory)
+[INST] Install the simulator packages for p8 and p9 according
+the host distro. [/INST] def install_packages(simulator_versions, download_directory, setup_simulator):
+ for simulator_version in simulator_versions:
+ if not setup_simulator.directory_exists(simulator_version):
+ setup_simulator.print_line()
+ print "Installing the simulator packages..."
+ if var.UBUNTU in setup_simulator.get_distro():
+ for deb_pkg in os.listdir(download_directory):
+ if deb_pkg.endswith(".deb"):
+ pkg_name = download_directory + "/" + deb_pkg
+ setup_simulator.install_deb(pkg_name)
+ else:
+ for rpm_pkg in os.listdir(download_directory):
+ if rpm_pkg.endswith(".rpm"):
+ pkg_name = download_directory + "/" + rpm_pkg
+ setup_simulator.install_rpm(pkg_name)
+[INST] Extract the bzip2 file which contains the Debian sysroot. [/INST] def extract_img(disk_img, download_directory, setup_simulator):
+ full_img_path = download_directory + disk_img
+ try:
+ if not setup_simulator.file_exists(full_img_path):
+ if setup_simulator.file_exists(full_img_path + ".bz2"):
+ setup_simulator.print_line()
+ print "Extracting the image (it will take a while)..."
+ cmd = "bzip2 -dkv " + full_img_path + ".bz2"
+ setup_simulator.execute_cmd(cmd)
+ print
+ except (KeyboardInterrupt, SystemExit, RuntimeError):
+ raise
+[INST] Customize the disk img by copying the script which installs the SDK
+and its dependencies inside it. [/INST] def customize_img(disk_img, lock, mount, download_directory, setup_simulator):
+ full_img_path = download_directory + disk_img
+ if not setup_simulator.file_exists(lock):
+ setup_simulator.print_line()
+ print "Customizing the image..."
+ setup_simulator.configure_image(full_img_path, lock, mount)
+ else:
+ print "Image alredy customized"
+[INST] Configure a symlink to be used by the tcl script. [/INST] def create_symlink(sym_link, disk, download_directory, setup_simulator):
+ if not setup_simulator.file_exists(download_directory + sym_link):
+ cmd = download_directory + disk + " " + download_directory + sym_link
+ setup_simulator.execute_cmd("ln -s " + cmd)
+[INST] starts the simulator according the version selected by the user. [/INST] def start_simulator(version, setup_simulator):
+ if setup_simulator.show_connection_info(version):
+ os.chdir(var.DOWNLOAD_DIR)
+ set_network(setup_simulator)
+ if 'power8' in version:
+ p8_prefix = '/opt/ibm/systemsim-p8/run/pegasus/'
+ p8_sim = p8_prefix + 'power8 -W -f'
+ p8_tcl = p8_prefix + 'linux/boot-linux-le.tcl'
+ setup_simulator.execute_cmd(p8_sim + ' ' + p8_tcl)
+ elif 'power9' in version:
+ p9_prefix = '/opt/ibm/systemsim-p9/run/p9/'
+ p9_sim = p9_prefix + 'power9 -W -f'
+ p9_tcl = p9_prefix + 'linux/boot-linux-le-skiboot.tcl'
+ setup_simulator.execute_cmd(p9_sim + ' ' + p9_tcl)
+ unset_network(setup_simulator)
+[INST] install DEB file via apt-get. [/INST] def install_deb_apt(self, package):
+ try:
+ self.execute_cmd('sudo apt-get -y install ' + package)
+ except (KeyboardInterrupt, SystemExit, RuntimeError):
+ raise
+[INST] configure the image, copying the configurerepos.sh into it.. [/INST] def configure_image(self, disk_img, lock, mount_point):
+ try:
+ self.execute_cmd('sudo mkdir ' + mount_point)
+ self.execute_cmd('sudo mount -o loop ' + disk_img + ' ' + mount_point)
+ mtp = mount_point + "/home"
+ cmd = var.DOWNLOAD_DIR + 'configurerepos.sh' + ' ' + mtp
+ self.execute_cmd('sudo cp -rp ' + cmd)
+ self.execute_cmd('sudo umount ' + mount_point)
+ self.execute_cmd('sudo rm -rf ' + mount_point)
+ self.execute_cmd('touch ' + lock)
+ print "done"
+ except (KeyboardInterrupt, SystemExit, RuntimeError, OSError, IOError):
+ print "\n ERROR: could not configure the " + disk_img
+ print " exiting now!"
+ sys.exit(1)
+[INST] verify if the required dependencies are installed. [/INST] def verify_dependencies(self):
+ self.print_line()
+ print " * Checking dependencies..."
+ try:
+ for dep in var.DEPENDENCIES:
+ if not self.cmd_exists(dep):
+ self.install_dependencies(dep)
+ except (KeyboardInterrupt, SystemExit, RuntimeError):
+ sys.exit(1)
+[INST] return the size of file formated. [/INST] def size_of(value):
+ for unit in ['', 'Ki', 'Mi']:
+ if abs(value) < 1024.0:
+ return "%3.1f %s%s" % (value, unit, 'B')
+ value = value / 1024.0
+ return "%.1f%s%s" % (value, 'Yi', 'B')
+[INST] extract and convert the license from dos to unix. [/INST] def configure_license():
+ try:
+ licensezip = zipfile.ZipFile(var.LICENSE_FILE_ZIP, 'r')
+ licensezip.extractall(var.DOWNLOAD_DIR)
+ licensezip.close()
+ licensetext = open(var.LICENSE, 'rb').read().replace('\r\n', '\n')
+ open(var.LICENSE, 'wb').write(licensetext)
+ except (OSError, IOError):
+ print "\n ERROR: could not configure the license."
+ sys.exit(1)
+[INST] Show to the user how to connect to the simulator. [/INST] def show_connection_info(self, version):
+ try:
+ self.print_line()
+ sversion = 'IBM POWER' + version[-1:] + ' Functional Simulator'
+ print '\nYou are starting the ' + sversion
+ print 'When the boot process is complete, use the following'
+ print 'credentials to access it via ssh:\n'
+ print ' ssh root@172.19.98.109'
+ print ' password: mambo'
+ opt = raw_input("\n Would you like to start it now? [Y/n] ")
+ if "N" in opt or "n" in opt:
+ sys.exit(0)
+ else:
+ return True
+ except (KeyboardInterrupt, SystemExit, RuntimeError):
+ sys.exit(1)
+[INST] Loads model weights from provided path.. [/INST] def load_model(model_file_path: str) -> nn.Module:
+ embed_dim = 64
+ scene_dim = 1024
+ encoder_type = "kwmlp"
+ model = AudioMLP_Wrapper(
+ sample_rate=16000,
+ timestamp_embedding_size=embed_dim,
+ scene_embedding_size=scene_dim,
+ encoder_type=encoder_type,
+ encoder_ckpt=model_file_path
+ )
+ assert isinstance(model, nn.Module)
+ assert hasattr(model, "sample_rate")
+ assert hasattr(model, "timestamp_embedding_size")
+ assert hasattr(model, "scene_embedding_size")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device)
+ model.eval()
+ return model
+[INST] Do some initial padding in order to get embeddings at the start/end of audio.. [/INST] def initial_padding(audio: Tensor, sr=16000, hop_ms=10, window_ms=30) -> Tensor:
+ init_pad = int((window_ms // 2 - hop_ms) / 1000 * sr) if window_ms // 2 > hop_ms else 0
+ end_pad = int((window_ms // 2 ) / 1000 * sr)
+ return F.pad(audio, (init_pad, end_pad), "constant", 0)
+[INST] r"""
+Returns the name of a mission or list of missions given a specified time.
+
+Parameters
+
+time : datetime.datetime or list
+Datetime object or list of datetime objects representing the time of the requested mission.
+
+Returns
+
+list
+The names of any/all missions that had in-storm observations during the specified time.. [/INST] def findMission(self,time):
+ if isinstance(time,list):
+ t1=min(time)
+ t2=max(time)
+ else:
+ t1 = t2 = time
+ selected=[]
+ for name in self.missiondata:
+ t_start = min(self.missiondata[name]['time'])
+ t_end = max(self.missiondata[name]['time'])
+ if (t_start
+[INST] r"""
+Creates a hovmoller plot of azimuthally-averaged recon data.
+
+Parameters
+
+recon_select : Requested recon data
+pandas.DataFrame or dict,
+or datetime or list of start/end datetimes.
+varname : Variable to average and plot .
+String
+ax : axes
+Instance of axes to plot on. If none, one will be generated. Default is none.
+return_ax : bool
+If True, returns the axes instance on which the plot was generated for the user to further modify. Default is False.
+cartopy_proj : ccrs
+Instance of a cartopy projection to use. If none, one will be generated. Default is none.
+
+Other Parameters
+
+prop : dict
+Customization properties for recon plot. Please refer to :ref:`options-prop-recon-hovmoller` for available options.. [/INST] def plot_hovmoller(self,recon_select=None,varname='wspd',radlim=None,track_dict=None,plane_p_range=None,\
+ window=6,align='center',ax=None,return_ax=False,**kwargs):
+ prop = kwargs.pop('prop',{})
+ default_prop = {'cmap':'category','levels':None,'smooth_contourf':False}
+ for key in default_prop.keys():
+ if key not in prop.keys():
+ prop[key]=default_prop[key]
+ if recon_select is None:
+ dfRecon = self.recentered
+ elif isinstance(recon_select,pd.core.frame.DataFrame):
+ dfRecon = recon_select
+ elif isinstance(recon_select,dict):
+ dfRecon = pd.DataFrame.from_dict(recon_select)
+ else:
+ dfRecon = self.__getSubTime(recon_select)
+ if plane_p_range is not None:
+ dfRecon = dfRecon.loc[(dfRecon['plane_p']>min(plane_p_range)) & (dfRecon['plane_p'] 2 or prop['smooth_contourf'] == False:
+ cf = ax.contourf(radius,time,gfilt1d(vardata,sigma=3,axis=1),
+ levels=clevs,cmap=cmap)
+ else:
+ cf = ax.contourf(radius,time,gfilt1d(vardata,sigma=3,axis=1),
+ cmap=cmap,levels=np.linspace(min(prop['levels']),max(prop['levels']),256))
+ ax.axis([0,max(radius),min(time),max(time)])
+ cbar = plt.colorbar(cf,orientation='horizontal',pad=0.1)
+ ax.yaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H'))
+ for tick in ax.xaxis.get_major_ticks():
+ tick.label.set_fontsize(14)
+ for tick in ax.yaxis.get_major_ticks():
+ tick.label.set_fontsize(14)
+ ax.set_ylabel('UTC Time (MM-DD HH)',fontsize=15)
+ ax.set_xlabel('Radius (km)',fontsize=15)
+ --------
+ title_left, title_right = hovmoller_plot_title(self.storm_obj,Hov_dict,varname)
+ ax.set_title(title_left,loc='left',fontsize=16,fontweight='bold')
+ ax.set_title(title_right,loc='right',fontsize=12)
+ if return_ax:
+ return ax
+[INST] r"""
+Interpolate a storm temporally to a specified time resolution.
+
+Parameters
+
+timeres : int
+Temporal resolution in hours to interpolate storm data to. Default is 1 hour.
+dt_window : int
+Time window in hours over which to calculate temporal change data. Default is 24 hours.
+
+Returns
+
+tropycal.tracks.Storm
+New Storm object containing the updated dictionary.. [/INST] def interp(self,timeres=1,dt_window=24,dt_align='middle'):
+ NEW_STORM = copy.copy(self)
+ newdict = interp_storm(self.dict,timeres,dt_window,dt_align)
+ for key in newdict.keys():
+ NEW_STORM.dict[key] = newdict[key]
+ for key in NEW_STORM.dict.keys():
+ if key == 'realtime': continue
+ if isinstance(NEW_STORM.dict[key], (np.ndarray,list)) == False and isinstance(NEW_STORM.dict[key], dict) == False:
+ NEW_STORM[key] = NEW_STORM.dict[key]
+ NEW_STORM.coords[key] = NEW_STORM.dict[key]
+ if isinstance(NEW_STORM.dict[key], (np.ndarray,list)) == True and isinstance(NEW_STORM.dict[key], dict) == False:
+ NEW_STORM.dict[key] = list(NEW_STORM.dict[key])
+ NEW_STORM.vars[key] = np.array(NEW_STORM.dict[key])
+ NEW_STORM[key] = np.array(NEW_STORM.dict[key])
+ return NEW_STORM
+[INST] r"""
+Converts the storm dict into an xarray Dataset object.
+
+Returns
+
+xarray.Dataset
+An xarray Dataset object containing information about the storm.. [/INST] def to_xarray(self):
+ try:
+ import xarray as xr
+ except ImportError as e:
+ raise RuntimeError("Error: xarray is not available. Install xarray in order to use this function.") from e
+ time = self.dict['date']
+ ds = {}
+ attrs = {}
+ keys = [k for k in self.dict.keys() if k != 'date']
+ for key in keys:
+ if isinstance(self.dict[key], list):
+ ds[key] = xr.DataArray(self.dict[key],coords=[time],dims=['time'])
+ else:
+ attrs[key] = self.dict[key]
+ ds = xr.Dataset(ds,attrs=attrs)
+ return ds
+[INST] r"""
+Converts the storm dict into a pandas DataFrame object.
+
+Parameters
+
+attrs_as_columns : bool
+If True, adds Storm object attributes as columns in the DataFrame returned. Default is False.
+
+Returns
+
+pandas.DataFrame
+A pandas DataFrame object containing information about the storm.. [/INST] def to_dataframe(self, attrs_as_columns=False):
+ try:
+ import pandas as pd
+ except ImportError as e:
+ raise RuntimeError("Error: pandas is not available. Install pandas in order to use this function.") from e
+ time = self.dict['date']
+ ds = {}
+ keys = [k for k in self.dict.keys()]
+ for key in keys:
+ if isinstance(self.dict[key], list):
+ ds[key] = self.dict[key]
+ else:
+ if attrs_as_columns:
+ ds[key] = self.dict[key]
+ ds = pd.DataFrame(ds)
+ return ds
+[INST] r"""
+Searches for the given word or phrase through all NHC forecast discussions for this storm.
+
+Parameters
+
+query : str or list
+String or list representing a word(s) or phrase(s) to search for within the NHC forecast discussions . Query is case insensitive.
+
+Returns
+
+list
+List of dictionaries containing all relevant forecast discussions.. [/INST] def query_nhc_discussions(self,query):
+ if self.source != "hurdat":
+ msg = "Error: NHC data can only be accessed when HURDAT is used as the data source."
+ raise RuntimeError(msg)
+ if self.invest:
+ raise RuntimeError("Error: NHC does not issue advisories for invests that have not been designated as Potential Tropical Cyclones.")
+ storm_id = self.dict['operational_id']
+ storm_year = self.dict['year']
+ if storm_id == '':
+ msg = "No NHC operational data is available for this storm."
+ raise RuntimeError(msg)
+ if isinstance(query,str) == False and isinstance(query,list) == False:
+ msg = "'query' must be of type str or list."
+ raise TypeError(msg)
+ if isinstance(query,list):
+ for i in query:
+ if isinstance(i,str) == False:
+ msg = "Entries of list 'query' must be of type str."
+ raise TypeError(msg)
+ disco_dict = self.list_nhc_discussions()
+ output = []
+ for idx,forecast_date in enumerate(disco_dict['utc_date']):
+ forecast = self.get_nhc_discussion(forecast=forecast_date)
+ text = forecast['text'].lower()
+ if isinstance(query,str):
+ if text.find(query.lower()) >= 0: output.append(forecast)
+ else:
+ found = False
+ for i_query in query:
+ if text.find(i_query.lower()) >= 0: found = True
+ if found: output.append(forecast)
+ return output
+[INST] r"""
+Downloads the NHC offical Tropical Cyclone Report (TCR) for the requested storm to the requested directory. Available only for storms with advisories issued by the National Hurricane Center.
+
+Parameters
+
+save_path : str
+Path of directory to download the TCR into. Default is current working directory.. [/INST] def download_tcr(self,save_path=""):
+ if self.invest:
+ raise RuntimeError("Error: NHC does not issue advisories for invests that have not been designated as Potential Tropical Cyclones.")
+ if self.source != "hurdat":
+ msg = "NHC data can only be accessed when HURDAT is used as the data source."
+ raise RuntimeError(msg)
+ if self.year < 1995:
+ msg = "Tropical Cyclone Reports are unavailable prior to 1995."
+ raise RuntimeError(msg)
+ if isinstance(save_path,str) == False:
+ msg = "'save_path' must be of type str."
+ raise TypeError(msg)
+ storm_id = self.dict['id'].upper()
+ storm_name = self.dict['name'].title()
+ url = f"https://www.nhc.noaa.gov/data/tcr/{storm_id}_{storm_name}.pdf"
+ request = requests.get(url)
+ if request.status_code != 200:
+ msg = "This tropical cyclone does not have a Tropical Cyclone Report (TCR) available."
+ raise RuntimeError(msg)
+ response = requests.get(url)
+ full_path = os.path.join(save_path,f"TCR_{storm_id}_{storm_name}.pdf")
+ with open(full_path, 'wb') as f:
+ f.write(response.content)
+[INST] r"""
+Creates a plot of the storm and associated tornado tracks.
+
+Parameters
+
+dist_thresh : int
+Distance threshold (in kilometers) from the tropical cyclone track over which to attribute tornadoes to the TC. Default is 1000 km.
+Tors : pandas.DataFrame
+DataFrame containing tornado data associated with the storm. If None, data is automatically retrieved from TornadoDatabase. A dataframe of tornadoes associated with the TC will then be saved to this instance of storm
+for future use.
+domain : str
+Domain for the plot. Default is "dynamic". Please refer to :ref:`options-domain` for available domain options.
+plotPPH : bool or str
+Whether to plot practically perfect forecast (PPH). True defaults to "daily". Default is False.
+
+False** - no PPH plot.
+True** - defaults to "daily".
+"total"** - probability of a tornado within 25mi of a point during the period of time selected.
+"daily"** - average probability of a tornado within 25mi of a point during a day starting at 12 UTC.
+plot_all : bool
+Whether to plot dots for all observations along the track. If false, dots will be plotted every 6 hours. Default is false.
+ax : axes
+Instance of axes to plot on. If none, one will be generated. Default is none.
+cartopy_proj : ccrs
+Instance of a cartopy projection to use. If none, one will be generated. Default is none.
+save_path : str
+Relative or full path of directory to save the image in. If none, image will not be saved.
+
+Other Parameters
+
+prop : dict
+Customization properties of plot.
+map_prop : dict
+Customization properties of Cartopy map. Please refer to :ref:`options-map-prop` for available options.
+
+Returns
+
+ax
+Instance of axes containing the plot is returned.. [/INST] def plot_tors(self,dist_thresh=1000,Tors=None,domain="dynamic",plotPPH=False,plot_all=False,\
+ ax=None,cartopy_proj=None,save_path=None,prop={},map_prop={}):
+ try:
+ prop['PPHcolors']
+ except:
+ prop['PPHcolors']='Wistia'
+ if Tors is None:
+ try:
+ self.stormTors
+ except:
+ warn_message = "Reading in tornado data for this storm. If you seek to analyze tornado data for multiple storms, run \"TrackDataset.assign_storm_tornadoes()\" to avoid this warning in the future."
+ warnings.warn(warn_message)
+ Tors = TornadoDataset()
+ self.stormTors = Tors.get_storm_tornadoes(self,dist_thresh)
+ if len(self.stormTors) == 0:
+ raise RuntimeError("No tornadoes were found with this storm.")
+ if len(self.stormTors) < 5:
+ warn_message = f"{len(self.stormTors)} tornadoes were found with this storm. Default domain to east_conus."
+ warnings.warn(warn_message)
+ domain = 'east_conus'
+ self.plot_obj_tc = TrackPlot()
+ try:
+ self.plot_obj_tor = TornadoPlot()
+ except:
+ from ..tornado.plot import TornadoPlot
+ self.plot_obj_tor = TornadoPlot()
+ if cartopy_proj is None:
+ if max(self.dict['lon']) > 150 or min(self.dict['lon']) < -150:
+ self.plot_obj_tor.create_cartopy(proj='PlateCarree',central_longitude=180.0)
+ self.plot_obj_tc.create_cartopy(proj='PlateCarree',central_longitude=180.0)
+ else:
+ self.plot_obj_tor.create_cartopy(proj='PlateCarree',central_longitude=0.0)
+ self.plot_obj_tc.create_cartopy(proj='PlateCarree',central_longitude=0.0)
+ plot_ax,leg_tor,domain = self.plot_obj_tor.plot_tornadoes(self.stormTors,domain,ax=ax,return_ax=True,return_domain=True,\
+ plotPPH=plotPPH,prop=prop,map_prop=map_prop)
+ tor_title = plot_ax.get_title('left')
+ plot_ax = self.plot_obj_tc.plot_storm(self.dict,domain=domain,ax=plot_ax,prop=prop,map_prop=map_prop)
+ plot_ax.add_artist(leg_tor)
+ storm_title = plot_ax.get_title('left')
+ plot_ax.set_title(f'{storm_title}\n{tor_title}',loc='left',fontsize=17,fontweight='bold')
+ if save_path is not None and isinstance(save_path,str):
+ plt.savefig(save_path,bbox_inches='tight')
+ return plot_ax
+[INST] r"""
+Plot tracks of tornadoes relative to the storm motion vector of the tropical cyclone.
+
+Parameters
+
+dist_thresh : int
+Distance threshold (in kilometers) from the tropical cyclone track over which to attribute tornadoes to the TC. Default is 1000 km. Ignored if tornado data was passed into Storm from TrackDataset.
+save_path : str
+Relative or full path of directory to save the image in. If none, image will not be saved.
+
+Returns
+
+ax
+Instance of axes containing the plot is returned.
+
+Notes
+
+The motion vector is oriented upwards (in the +y direction).. [/INST] def plot_TCtors_rotated(self,dist_thresh=1000,save_path=None):
+ try:
+ self.stormTors
+ dist_thresh = self.tornado_dist_thresh
+ except:
+ warn_message = "Reading in tornado data for this storm. If you seek to analyze tornado data for multiple storms, run \"TrackDataset.assign_storm_tornadoes()\" to avoid this warning in the future."
+ warnings.warn(warn_message)
+ Tors = TornadoDataset()
+ stormTors = Tors.get_storm_tornadoes(self,dist_thresh)
+ self.stormTors = Tors.rotateToHeading(self,stormTors)
+ plt.figure(figsize=(9,9),dpi=150)
+ ax = plt.subplot()
+ EFcolors = get_colors_ef('default')
+ for _,row in self.stormTors.iterrows():
+ plt.plot([row['rot_xdist_s'],row['rot_xdist_e']+.01],[row['rot_ydist_s'],row['rot_ydist_e']+.01],\
+ lw=2,c=EFcolors[row['mag']])
+ ax.set_facecolor('#F6F6F6')
+ circle = plt.Circle((0,0), dist_thresh, color='w')
+ ax.add_artist(circle)
+ an = np.linspace(0, 2 * np.pi, 100)
+ ax.plot(dist_thresh * np.cos(an), dist_thresh * np.sin(an),'k')
+ ax.plot([-dist_thresh,dist_thresh],[0,0],'k--',lw=.5)
+ ax.plot([0,0],[-dist_thresh,dist_thresh],'k--',lw=.5)
+ plt.arrow(0, -dist_thresh*.1, 0, dist_thresh*.2, length_includes_head=True,
+ head_width=45, head_length=45,fc='k',lw=2,zorder=100)
+ ax.set_aspect('equal', 'box')
+ ax.set_xlabel('Left/Right of Storm Heading (km)',fontsize=13)
+ ax.set_ylabel('Behind/Ahead of Storm Heading (km)',fontsize=13)
+ ax.set_title(f'{self.name} {self.year} tornadoes relative to heading',fontsize=17)
+ ax.tick_params(axis='both', which='major', labelsize=11.5)
+ handles=[]
+ for ef,color in enumerate(EFcolors):
+ count = len(self.stormTors[self.stormTors['mag']==ef])
+ handles.append(mlines.Line2D([], [], linestyle='-',color=color,label=f'EF-{ef} ({count})'))
+ ax.legend(handles=handles,loc='lower left',fontsize=11.5)
+ ax.text(0.99,0.01,plot_credit(),fontsize=8,color='k',alpha=0.7,
+ transform=ax.transAxes,ha='right',va='bottom',zorder=10)
+ if save_path is not None and isinstance(save_path,str):
+ plt.savefig(save_path,bbox_inches='tight')
+ return ax
+[INST] r"""
+Converts the season dict into a pandas DataFrame object.
+
+Returns
+
+`pandas.DataFrame`
+A pandas DataFrame object containing information about the season.. [/INST] def to_dataframe(self):
+ try:
+ import pandas as pd
+ except ImportError as e:
+ raise RuntimeError("Error: pandas is not available. Install pandas in order to use this function.") from e
+ season_info = self.summary()
+ season_info_keys = season_info['id']
+ ds = {'id':[],'name':[],'vmax':[],'mslp':[],'category':[],'ace':[],'start_time':[],'end_time':[],'start_lat':[],'start_lon':[]}
+ keys = [k for k in self.dict.keys()]
+ for key in keys:
+ temp_type = np.array(self.dict[key]['type'])
+ tropical_idx = np.where((temp_type == 'SS') | (temp_type == 'SD') | (temp_type == 'TD') | (temp_type == 'TS') | (temp_type == 'HU'))
+ if key in season_info_keys:
+ sidx = season_info_keys.index(key)
+ ds['id'].append(key)
+ ds['name'].append(self.dict[key]['name'])
+ ds['vmax'].append(season_info['max_wspd'][sidx])
+ ds['mslp'].append(season_info['min_mslp'][sidx])
+ ds['category'].append(season_info['category'][sidx])
+ ds['start_time'].append(np.array(self.dict[key]['date'])[tropical_idx][0])
+ ds['end_time'].append(np.array(self.dict[key]['date'])[tropical_idx][-1])
+ ds['start_lat'].append(np.array(self.dict[key]['lat'])[tropical_idx][0])
+ ds['start_lon'].append(np.array(self.dict[key]['lon'])[tropical_idx][0])
+ ds['ace'].append(np.round(season_info['ace'][sidx],1))
+ ds = pd.DataFrame(ds)
+ return ds
+[INST] r"""
+Adds Tropycal plotting capability to a matplotlib.pyplot axes instance with a Cartopy projection.
+
+This axes instance must have already had a Cartopy projection added ``).
+
+Parameters
+
+ax : cartopy.mpl.geoaxes.GeoAxes
+Instance of a matplotlib axes with a Cartopy projection added.
+
+Returns
+
+ax
+The same axes instance is returned, with tropycal plotting functions from `tropycal.utils.cartopy_utils` added to it as methods.. [/INST] def add_tropycal(ax):
+ ax.plot_storm = types.MethodType( plot_storm, ax )
+ return ax
+[INST] r"""
+Interpolates storm rainfall data to a horizontal grid.
+
+Interpolation is performed using Scipy's `scipy.interpolate.griddata()` interpolation function.
+
+Parameters
+
+storm : tropycal.tracks.Storm
+Instance of a Storm object.
+grid_res : int or float
+Horizontal resolution of the desired grid in degrees.
+
+Returns
+
+dict or xarray.DataArray
+If return_xarray is True, an xarray DataArray is returned. Otherwise, a dict including the grid lat, lon and grid values is returned.. [/INST] def interpolate_to_grid(self,storm,grid_res=0.1,method='linear',return_xarray=False):
+ try:
+ import xarray as xr
+ except ImportError as e:
+ return_xarray = False
+ msg = "xarray is not installed in your python environment. Defaulting to return_xarray=False."
+ warnings.warn(msg)
+ try:
+ storm.rain
+ except:
+ storm.rain = self.get_storm_rainfall(storm)
+ grid_lon = np.arange(-140,-60+grid_res,grid_res)
+ grid_lat = np.arange(20,50+grid_res,grid_res)
+ rainfall = storm.rain['Total'].values
+ lat = storm.rain['Lat'].values
+ lon = storm.rain['Lon'].values
+ grid = griddata((lat, lon), rainfall, (grid_lat[None,:], grid_lon[:,None]), method=method)
+ grid = np.transpose(grid)
+ if return_xarray:
+ return xr.DataArray(grid,coords=[grid_lat,grid_lon],dims=['lat','lon'])
+ else:
+ return {'grid':grid,'lat':grid_lat,'lon':grid_lon}
+[INST] Mark all stories in this feed as read.. [/INST] def mark_as_read(self):
+ response = self._api_client.session.post(
+ self._api_client._construct_url('/reader/mark_feed_as_read'),
+ data={'feed_id': self.id})
+ return response.json()['result'] == 'ok'
+[INST] Login to NewsBlur, using session (cookie) authentication.. [/INST] def login(self):
+ response = self.session.post(self._construct_url('/api/login'),
+ data={'username': self.username,
+ 'password': self.password})
+ self._logger.debug(response.content)
+ assert response.json()['result'] == 'ok'
+ self.logged_in = True
+ return True
+[INST] Mark all stories from all feeds as read.. [/INST] def mark_as_read(self, days=0):
+ response = self.session.post(
+ self._construct_url('/reader/mark_all_as_read'),
+ data={'days': days})
+ return response.json()['result'] == 'ok'
+[INST] Mark provided stories as read.. [/INST] def mark_stories_as_read(self, stories):
+ response = self.session.post(
+ self._construct_url('/reader/mark_story_hashes_as_read'),
+ data={'story_hash': [story.hash for story in stories]})
+ return response.json()['result'] == 'ok'
+[INST] converts the XML-file to a dictionary. [/INST] def xmltodict(element):
+ def xmltodict_handler(parent_element,result={}):
+ for element in parent_element:
+ if len(element):
+ obj = xmltodict_handler(element)
+ else:
+ obj = element.attrib
+ for i, j in obj.items():
+ try:
+ obj[i]=float(j)
+ except:
+ pass
+ if result.get(element.tag):
+ if hasattr(result[element.tag], "append"):
+ result[element.tag].append(obj)
+ else:
+ result[element.tag] = [result[element.tag], obj]
+ else:
+ result[element.tag] = obj
+ return result
+ result=element.attrib
+ return xmltodict_handler(element,result)
+[INST] Get a name of a caller in the format module.class.method
+`skip` specifies how many levels of stack to skip while getting caller
+name. skip=1 means "who calls me", skip=2 "who calls my caller" etc.
+
+An empty string is returned if skipped levels exceed stack height. [/INST] def caller_name(skip=2):
+ stack = inspect.stack()
+ start = 0 + skip
+ if len(stack) < start+1:
+ return ''
+ parentframe = stack[start][0]
+ name = []
+ module = inspect.getmodule(parentframe)
+ if module:
+ name.append(module.__name__)
+ if 'self' in parentframe.f_locals:
+ name.append(parentframe.f_locals['self'].__class__.__name__)
+ codename = parentframe.f_code.co_name
+ if codename != '':
+ name.append( codename )
+ del parentframe
+ return ".".join(name)
+[INST] Clears the variables associated with the monitor and starts again.. [/INST] def clearMonitor(self):
+ self.t =[]
+ self.data = []
+ self.varData = []
+ self.varT = []
+ self.firstMon = []
+ self.firstVar = []
+ for i in range(len(self.devices)):
+ self.t.append(np.zeros([1,]))
+ self.data.append(np.zeros([1,]))
+ self.varData.append(np.zeros([1,]))
+ self.varT.append(np.zeros([1,]))
+ self.firstMon.append(True)
+ self.firstVar.append(True)
+[INST] Function that gets the data from the ADQ and prepares it for updating the GUI.. [/INST] def updateMon(self):
+ final_data = self.trap.readMonitor()
+ final_data = np.reshape(final_data, (self.trap.devsMonitor, int(len(final_data)/self.trap.devsMonitor)))
+ mean_data = np.mean(final_data,1)
+ varData = np.var(final_data, 1)
+ self.emit(QtCore.SIGNAL('TimeTraces'), final_data)
+ self.emit(QtCore.SIGNAL('varData'), varData)
+ self.emit(QtCore.SIGNAL('MeanData'), mean_data)
+[INST] Updates the plots of the timetraces.. [/INST] def updateTimes(self,data):
+ var = copy.copy(data)
+ for i in range(len(var)):
+ xdata = np.arange(len(var[i]))*self._session.monitorTimeresol/1000
+ old_data = self.data[i]
+ old_t = self.t[i]
+ self.t[i] = np.append(self.t[i], xdata+max(self.t[i]) + self._session.monitorTimeresol/1000)
+ self.data[i] = np.append(self.data[i],var[i])
+ limit = int(self._session.monitorTime/self._session.monitorTimeresol*1000)
+ if self.firstMon[i]:
+ limit = len(self.t[i])-1
+ self.firstMon[i] = False
+ self.t[i] = self.t[i][-limit:]
+ self.data[i] = self.data[i][-limit:]
+ self.qpdx.setData(self.t[0],self.data[0])
+ self.qpdy.setData(self.t[1],self.data[1])
+ self.qpdz.setData(self.t[2],self.data[2])
+[INST] Starts the timer with a predefined update interval.
+It also starts the monitor routine.. [/INST] def start_timer(self):
+ if not self.running:
+ if self.powerSpectra.is_running:
+ print('Cant update while power spectra is running.')
+ else:
+ conditions = {}
+ conditions['devs'] = self.devices
+ conditions['accuracy'] = self._session.monitorTimeresol/1000
+ self.trap.startMonitor(conditions)
+ self.ctimer.start(self._session.monitorRefresh)
+ self.running = True
+ else:
+ self.stop_timer()
+[INST] Stops refreshing and the monitor.. [/INST] def stop_timer(self):
+ self.ctimer.stop()
+ if self.running:
+ self.trap.stopMonitor()
+ self.running = False
+[INST] Saves the files to a specified folder.. [/INST] def fileSave(self):
+ name = 'Timetrace_Data'
+ savedir = os.path.join(self._session.saveDirectory, str(datetime.now().date()))
+ if not os.path.exists(savedir):
+ os.makedirs(savedir)
+ i=1
+ filename = name
+ while os.path.exists(os.path.join(savedir,filename+".dat")):
+ filename = '%s_%s' %(name,i)
+ i += 1
+ filename = filename+".dat"
+ np.savetxt(os.path.join(savedir, filename), [self.t, self.data], fmt='%s', delimiter=",")
+ try:
+ np.save(os.path.join(savedir, filename[:-4]), np.array(self.data))
+ except:
+ print('Error with Save')
+ print(sys.exc_info()[0])
+ print('Data saved in %s'%os.path.join(savedir, filename[:-4]))
+ return
+[INST] Updates the relevant parameters for the monitor timetrace.. [/INST] def updateParameters(self,_session):
+ self._session = _session
+ self.powerSpectra._session = _session
+ self.configWindow._session = _session
+ self.stop_timer()
+ self.start_timer()
+[INST] Updates the values displayed. Asumes a particular order of the data being passed. [/INST] def UpdateValues(self,data):
+ self.qpdx.display(data[0])
+ self.qpdy.display(data[1])
+ self.qpdz.display(data[2])
+[INST] Stops the continuous runs.
+Emmits a signal for continuing with the timetraces.. [/INST] def stop_acq(self):
+ _session.runs = False
+ self.emit(QtCore.SIGNAL('Start_Tr'))
+[INST] Connects the signals of the working Thread with the appropriate
+functions in the main Class.. [/INST] def update(self):
+ self.statusbar.showMessage('Running...')
+ self.setStatusTip('Running...')
+ if self.is_running == False:
+ self.is_running = True
+ self.workThread.start()
+ else:
+ print('Try to re-run')
+[INST] Updates the curves in the screen and the mean values.. [/INST] def updateGUI(self,frequencies,data,values):
+ self.setStatusTip('Stopped...')
+ self.is_running = False
+ self.data = data
+ self.freqs = frequencies
+ self.curvex.setData(self.freqs[1:],values[1,1:])
+ self.curvey.setData(self.freqs[1:],values[0,1:])
+ self.curvez.setData(self.freqs[1:],values[2,1:])
+ if self._session.runs:
+ self.update()
+[INST] Saves the files to a specified folder.. [/INST] def fileSave(self):
+ name = 'PowerSpectra_Data'
+ savedir = os.path.join(self._session.saveDirectory, str(datetime.now().date()))
+ if not os.path.exists(savedir):
+ os.makedirs(savedir)
+ i=1
+ filename = name
+ while os.path.exists(os.path.join(savedir,filename+".dat")):
+ filename = '%s_%s' %(name,i)
+ i += 1
+ filename_params = filename + '_config.dat'
+ filename = filename+".dat"
+ np.savetxt(os.path.join(savedir,filename), self.data,fmt='%s', delimiter=",")
+ header = "Length (s), Integration Time (ms)"
+ np.savetxt(os.path.join(savedir,filename_params), [self._session.highSpeedTime, self._session.highSpeedAccuracy], header=header,fmt='%s',delimiter=',')
+ try:
+ np.save(os.path.join(savedir,filename_params[:-4]), np.array(self.data))
+ except:
+ print('Error with Save')
+ print(sys.exc_info()[0])
+ print('Data saved in %s and configuration data in %s' % os.path.join(savedir,filename_params[:-4]) )
+ return
+[INST] Exits the application stopping the working Thread.. [/INST] def exit_safe(self):
+ if self._session.highSpeedTime>5:
+ print('Waiting for the acquisition to finish.')
+ print('It may take up to %s more seconds.'%_session.highSpeedTime)
+ self.workThread.terminate()
+ self.close()
+[INST] Triggers an analog measurement. It does not read the value.
+conditions -- a dictionary with the needed parameters for an analog acquisition.. [/INST] def triggerAnalog(self,conditions):
+ if self._session.adq['type'] == 'ni':
+ self.adq.analogSetup(conditions['channel'],conditions['points'],conditions['accuracy'],conditions['limits'])
+ self.adq.analogTrigger()
+ self.running = True
+ else:
+ raise Exception('Other types of cards not implemented for acquireAnalog')
+[INST] Starts continuous acquisition of the specified channels with the specified timing interval.
+conditions['devs'] -- list of devices to monitor
+conditions['accuracy'] -- accuracy for the monitor. If not defined defaults to 0.1s. [/INST] def startMonitor(self,conditions):
+ self.monitorTask = self.tasks['Monitor']
+ channels = []
+ if conditions['accuracy']>0:
+ accuracy = conditions['accuracy']
+ else:
+ accuracy = 0.1
+ if self._session.adq['type'] == 'ni':
+ if type(conditions['devs']) == type(""):
+ conditions['devs'] = [conditions['devs']]
+ for dev in conditions['devs']:
+ self.devsMonitor = len(conditions['devs'])
+ if dev.properties['Type'] == 'Analog':
+ channels.append(dev.properties['Input']['Hardware']['PortID'])
+ limitmax = dev.properties['Input']['Limits']['Max']
+ limitmin = dev.properties['Input']['Limits']['Min']
+ self.monitorNum = self.adq.analogSetup(self.monitorTask,channels,0,accuracy,(limitmin,limitmax))
+ self.adq.analogTrigger(self.monitorNum)
+[INST] Reads the monitor values of all the channels specified.. [/INST] def readMonitor(self):
+ val,data = self.adq.analogRead(self.monitorNum,-1)
+ return data[:val*self.devsMonitor]
+[INST] Stops all the tasks related to the monitor.. [/INST] def stopMonitor(self):
+ if self._session.adq['type'] == 'ni':
+ self.adq.clear(self.monitorNum)
+[INST] Backtracking method of solving sudoku.
+For consecutive empty cells method checks subsequent numbers (from 1 to 9).
+If number is permissible, program inserts number and moves to next empty cell.
+If number in impermissible, program checks next number.
+If all 9 numbers were already checked and none of them was permissible,
+program returns to the previous empty cell and tries next number.. [/INST] def solve(self) -> None:
+ cell_index = 0
+ while cell_index < len(self.empty_cells):
+ coordinates = self.empty_cells[cell_index]
+ row, column = coordinates
+ number = self.test_numbers[coordinates]
+ if number == 10:
+ self.insert_number(row, column)
+ self.test_numbers[coordinates] = 1
+ cell_index -= 1
+ continue
+ if self.validate_number(row, column, number):
+ self.insert_number(row, column, number)
+ cell_index += 1
+ self.test_numbers[coordinates] += 1
+ return None
+[INST] Checks if number is available in place indicated by coordinates.
+If such number is already present in row, column or box, method returns False.. [/INST] def validate_number(self, row: int, column: int, number: int) -> bool:
+ if number in self.sudoku[row, :]:
+ return False
+ if number in self.sudoku[:, column]:
+ return False
+ box_row, box_column = row // 3 * 3, column // 3 * 3
+ if number in self.sudoku[box_row:box_row + 3, box_column:box_column + 3]:
+ return False
+ return True
+[INST] Insert number into cell of provided coordinates.. [/INST] def insert_number(self, row: int, column: int, number: int = 0) -> None:
+ self.sudoku[row, column] = number
+ return None
+[INST] Checks if sudoku array is valid, i.e. if rows, columns or boxes don't contain duplicated numbers.. [/INST] def validate_sudoku(self) -> int:
+ if np.count_nonzero(self.sudoku) == 0:
+ return 0
+ if not self.check_occurences():
+ return -1
+ return 1
+[INST] Checks if every row, column and box contains only one occurrence of number.. [/INST] def check_occurences(self) -> bool:
+ for e in range(0, 9):
+ for number in range(1, 10):
+ if np.count_nonzero(self.sudoku[e, :] == number) > 1:
+ return False
+ if np.count_nonzero(self.sudoku[:, e] == number) > 1:
+ return False
+ box_row = e // 3 * 3
+ box_column = 3 * (e - box_row)
+ if np.count_nonzero(self.sudoku[box_row:box_row + 3, box_column:box_column + 3] == number) > 1:
+ return False
+ return True
+[INST] Checks if there are no blank cells and all elements are correctly filled.. [/INST] def is_sudoku_completed(self) -> bool:
+ if np.count_nonzero(self.sudoku) == 81:
+ if self.check_occurences():
+ return True
+ return False
+[INST] General flow of procedures. Starts with basic checks for blank cell that can be already completed.
+Then the main solving methods are used.
+If sudoku is still incomplete after previous procedures, program tries random insertions.. [/INST] def main_sequence(self) -> None:
+ if not self.is_sudoku_completed():
+ self.pre_solving_check()
+ self.solve()
+ self.backtracking()
+ return None
+[INST] Checks if sudoku array is valid, i.e. if rows, columns or squares don't contain duplicated digits.. [/INST] def validate_sudoku(self) -> int:
+ if self.s.elements_count[0] == 81:
+ return 0
+ if not self.check_occurences():
+ return -1
+ return 1
+[INST] Checks if every row, column and square contains only one occurrence of number.. [/INST] def check_occurences(self) -> bool:
+ if self.check_if_one_occurence_of_digit(self.s.rows):
+ if self.check_if_one_occurence_of_digit(self.s.columns):
+ if self.check_if_one_occurence_of_digit(self.s.squares):
+ return True
+ return False
+[INST] Checks if row, column or square contains only one of non-zero digits.. [/INST] def check_if_one_occurence_of_digit(self, elements_list: List[SudokuElement]) -> bool:
+ for element in elements_list:
+ array = self.s.array[element.row_start:element.row_stop, element.column_start:element.column_stop]
+ digits = [i for i in np.unique(array) if i > 0]
+ for digit in digits:
+ if np.count_nonzero(array == digit) > 1:
+ return False
+ return True
+[INST] Checks if there are rows, columns or squares ready to fill.. [/INST] def pre_solving_check(self) -> None:
+ self.check_elements(self.s.rows)
+ self.check_elements(self.s.columns)
+ self.check_elements(self.s.squares)
+ return None
+[INST] Checks if element can be completed immediately.. [/INST] def check_elements(self, elements_list: List[SudokuElement]) -> None:
+ for element in elements_list:
+ self.fill_element(element)
+ self.check_stack()
+ return None
+[INST] For element with 8 digits already in place, find last number and insert into blank cell.
+Also upadte statistics of row, column and square containing filled cell.. [/INST] def fill_element(self, element: SudokuElement) -> None:
+ if element.completed:
+ return None
+ if element.digits == 8:
+ array = self.s.array[element.row_start:element.row_stop, element.column_start:element.column_stop]
+ row_index, column_index = self.get_zero_coordinate(array, element.row_start, element.column_start)
+ digit = self.find_last_digit(array)
+ self.insert_digit(digit, row_index, column_index)
+ return None
+[INST] Get number an array is lacking. Searches only in available digits' list.. [/INST] def find_last_digit(self, array: np.ndarray) -> int:
+ for digit in self.s.available_digits:
+ if digit not in array:
+ return digit
+[INST] Insert number into cell of provided coordinates, delete cell from list of zeros' coordinates.
+Upadtes available digits and statistics of row, column and square.. [/INST] def insert_digit(self, digit, row_index, column_index) -> None:
+ if digit is None:
+ return None
+ self.s.array[row_index, column_index] = digit
+ self.s.zero_positions.remove((row_index, column_index))
+ self.update_available_elements(digit)
+ self.update_elements_statistics(row_index, column_index)
+ return None
+[INST] Adds number to the sudoku's number counter and upadates available digits.
+I inserted number is a 9th occurrence, number is removed from available digits' list.. [/INST] def update_available_elements(self, digit: int) -> None:
+ self.s.elements_count[digit] += 1
+ self.s.set_available_digits()
+ return None
+[INST] Increment number of digits in row, column and square containing recently filled cell.
+Also adds those elements to stack in order to check if (after insertion) elements are easy solvable.. [/INST] def update_elements_statistics(self, row_index: int, column_index: int) -> None:
+ self.increment_digit_and_add_to_stack(self.s.rows[row_index])
+ self.increment_digit_and_add_to_stack(self.s.columns[column_index])
+ square_index = self.find_square(row_index, column_index)
+ self.increment_digit_and_add_to_stack(self.s.squares[square_index])
+ return None
+[INST] Increments number in element's counter and adds element into stack (queue of recently updates elements,
+which enable to check if elements are not easy solvable).. [/INST] def increment_digit_and_add_to_stack(self, element: SudokuElement) -> None:
+ element.increment_digits()
+ self.s.unchecked_stack.append(element)
+ return None
+[INST] Check which sudoku square contains a cell of provided coordinates. [/INST] def find_square(self, row_index: int, column_index: int) -> int:
+ sq_number = 0
+ while not (
+ self.s.square_coordinates[sq_number][0] <= row_index < self.s.square_coordinates[sq_number][1]
+ and
+ self.s.square_coordinates[sq_number][2] <= column_index < self.s.square_coordinates[sq_number][3]
+ ):
+ sq_number += 1
+ return sq_number
+[INST] For all elements in stack function checks if element can be easly solve and solves that element.
+Stack can be incremented during a process, so function runs as long as there are no more elements to check.. [/INST] def check_stack(self) -> None:
+ self.s.unchecked_stack = list(set(self.s.unchecked_stack))
+ while len(self.s.unchecked_stack):
+ self.fill_element(self.s.unchecked_stack[0])
+ self.s.unchecked_stack.pop(0)
+ self.s.unchecked_stack = list(set(self.s.unchecked_stack))
+ return None
+[INST] Main solving sequence.
+Checks if there are blank cells where only one number can be inserted.. [/INST] def solve(self) -> None:
+ self.counter = 0
+ self.s.possible_digits_in_cells = {coordinate: [] for coordinate in self.s.zero_positions}
+ while self.counter < len(self.s.available_digits):
+ self.check_available_digits()
+ self.check_blank_cells()
+ return None
+[INST] For given number function generates all permissible blank cells.
+Checks in squares, rows and columns if there is only one possible cell to insert the number.
+If all checks don't cause number insertion, function updates dictionary with all blank cells
+and their corresponding permissible digits.. [/INST] def check_available_digits(self) -> None:
+ digit = self.s.available_digits[self.counter]
+ self.s.changed = False
+ positions_in_squares = self.get_positions_in_squares(digit)
+ self.insert_digit_if_only_one_possible_position_in_square(digit, positions_in_squares)
+ if self.s.changed:
+ self.after_change_procedure()
+ else:
+ potential_positions, solutions = self.positions_in_one_row_or_column(positions_in_squares)
+ self.apply_solutions(digit, solutions)
+ if self.s.changed:
+ self.after_change_procedure()
+ else:
+ self.add_digit_to_blank_cells(digit, potential_positions)
+ return None
+[INST] Searches for blank cells where there are no interferences (same number in a row, column or square). [/INST] def search_for_available_positions(self, digit: int) -> List[Tuple[int, int]]:
+ positions = np.where(self.s.array == digit)
+ available_rows = set(i[0] for i in self.s.zero_positions).difference(set(positions[0]))
+ available_cols = set(i[1] for i in self.s.zero_positions).difference(set(positions[1]))
+ return [i for i in product(available_rows, available_cols) if self.s.array[i[0], i[1]] == 0]
+[INST] Creates a dictionary with square numbers as keys and list of black cells' coordinates,
+where number can be inserted.. [/INST] def divide_positions_by_squares(self, available_positions: List[Tuple[int, int]], digit: int) -> Dict[
+ int, List[int]]:
+ positions_in_squares = {}
+ for element in available_positions:
+ square_number = self.find_square(element[0], element[1])
+ if self.digit_in_square(digit, square_number):
+ continue
+ if square_number not in positions_in_squares:
+ positions_in_squares[square_number] = []
+ positions_in_squares[square_number].append(element)
+ return positions_in_squares
+[INST] Checks if there are rows or columns inside squares where number must be inserted.
+Deletes the same rows/columns' coordinates in another squares (inside list od available coordinates).. [/INST] def update_positions_if_one_dimensional_positions_in_squares(self,
+ positions_in_squares: Dict[int, List[Tuple[int, int]]],
+ flag: int) -> Dict[int, List[Tuple[int, int]]]:
+ for square in positions_in_squares:
+ element = self.positions_in_one_dimention(positions_in_squares[square], flag)
+ self.delete_unavailable_coordinates(element, flag, positions_in_squares, square)
+ return positions_in_squares
+[INST] Get all potential (for insertion) blank cells' coordinates and deletes non-valid coordinates.
+Non-valid coordinates are ones that number MUST be inserted in the same row/column but in another square.. [/INST] def delete_unavailable_coordinates(element: int, flag: int, positions_in_squares: Dict[int, List[Tuple[int, int]]],
+ square: int) -> None:
+ if element:
+ for square_number in positions_in_squares:
+ if square_number == square:
+ continue
+ positions_in_squares[square_number] = [
+ coordinate for coordinate in positions_in_squares[square_number] if coordinate[flag] != element]
+ return None
+[INST] If there is only one available (for aur number) blank cell inside a square, it can be filled.
+Function inserts a number and sets sudoku array status on "changed".. [/INST] def insert_digit_if_only_one_possible_position_in_square(self, digit: int, positions_in_squares: Dict[
+ int, List[Tuple[int, int]]]) -> None:
+ for square in positions_in_squares:
+ if len(positions_in_squares[square]) == 1:
+ row_index, column_index = positions_in_squares[square][0][0], positions_in_squares[square][0][1]
+ self.insert_and_set_changed(digit, row_index, column_index)
+ return None
+[INST] Inserts number and sets sudoku array status on "changed".. [/INST] def insert_and_set_changed(self, digit: int, row_index: int, column_index: int) -> None:
+ self.insert_digit(digit, row_index, column_index)
+ self.s.changed = True
+ return None
+[INST] Checks how many available blank cells is in each row and collumn.
+If row/column has only one available cell, it can be filled with number.. [/INST] def positions_in_one_row_or_column(positions_in_squares) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
+ potential_positions = list(chain.from_iterable(positions_in_squares.values()))
+ rows_count = Counter([p[0] for p in potential_positions])
+ columns_count = Counter([p[1] for p in potential_positions])
+ solutions = [p for p in potential_positions if rows_count[p[0]] == 1 or columns_count[p[1]] == 1]
+ return potential_positions, solutions
+[INST] After inserted a number, stack has elements to check for easy solving.
+Also lists of possible digits for all blank cells need to be clared.
+Program's counter resets.. [/INST] def after_change_procedure(self) -> None:
+ self.check_stack()
+ self.clear_possible_coordinates()
+ self.counter = 0
+ return None
+[INST] Insert number into all available cells.. [/INST] def apply_solutions(self, digit, solutions) -> None:
+ if len(solutions):
+ for coordinate in solutions:
+ self.insert_and_set_changed(digit, coordinate[0], coordinate[1])
+ return None
+[INST] Checks if number is already in square.. [/INST] def digit_in_square(self, digit: int, square_number: int) -> bool:
+ return digit in self.s.array[
+ self.s.squares[square_number].row_start: self.s.squares[square_number].row_stop,
+ self.s.squares[square_number].column_start: self.s.squares[square_number].column_stop,
+ ]
+[INST] Fills dictionary of blank cells.
+For every blank cell, if number is permissible in that cell, number is added to the list of digits.. [/INST] def add_digit_to_blank_cells(self, digit: int, potential_positions: List[Tuple[int, int]]) -> None:
+ for position in potential_positions:
+ self.s.possible_digits_in_cells[position].append(digit)
+ self.counter += 1
+ return None
+[INST] Checks if there are coordinates, which have only one permissible number.
+If so, number can be inserted.. [/INST] def check_blank_cells(self) -> None:
+ self.s.changed = False
+ if len(self.s.zero_positions):
+ for coordinate in self.s.possible_digits_in_cells:
+ if len(self.s.possible_digits_in_cells[coordinate]) == 1:
+ self.insert_and_set_changed(self.s.possible_digits_in_cells[coordinate][0], coordinate[0],
+ coordinate[1])
+ if self.s.changed:
+ self.check_stack_and_solve_again()
+ return None
+[INST] If Sudoku is still incomplete after previous solving procedure, program tries to use random insertions.
+It will insert digits chosen randomly from list of available digits.
+Insertion will occur on blank cell with minimum number of possible digits,
+so the probability for accurate choice is utmost.. [/INST] def try_random_insert(self, max_tries: int = 10) -> None:
+ number_of_tries = 0
+ while len(self.s.zero_positions) and number_of_tries <= max_tries:
+ self.random_check()
+ number_of_tries += 1
+ return None
+[INST] At first function makes a deepcopy of sudoku array as an backup.
+Then it gets a blank cell with minimum number or permissible digits.
+Than gets one of the coordinates with that minimum.
+For chosen coordinates (blank cell) inserts randomly one of the available digits.
+After insertion function starts solving procedure again.
+If sudoku is not completed after solving procedure, original sudoku array is restored from backup.. [/INST] def random_check(self) -> None:
+ backup_sudoku = deepcopy(self.s)
+ minimum_number_of_digits = min(
+ len(digits_list) for digits_list in self.s.possible_digits_in_cells.values() if len(digits_list) > 0)
+ considered_position = choice([coordinate for coordinate in self.s.possible_digits_in_cells.keys()
+ if len(self.s.possible_digits_in_cells[coordinate]) == minimum_number_of_digits])
+ considered_digit = choice(self.s.possible_digits_in_cells[considered_position])
+ self.insert_digit(considered_digit, considered_position[0], considered_position[1])
+ self.check_stack_and_solve_again()
+ if not self.is_sudoku_completed():
+ self.s = backup_sudoku
+ return None
+[INST] Stack is not empty, so it needs to be checked.
+Digit has been inserted, so we can try solve sudoku again.. [/INST] def check_stack_and_solve_again(self) -> None:
+ self.check_stack()
+ self.solve()
+ return None
+[INST] Checks if there are no blank cells and all elements are correctly filled.. [/INST] def is_sudoku_completed(self) -> bool:
+ if np.count_nonzero(self.s.array) == 81:
+ if self.assert_sum_digits_in_element(self.s.rows):
+ if self.assert_sum_digits_in_element(self.s.columns):
+ if self.assert_sum_digits_in_element(self.s.squares):
+ return True
+ return False
+[INST] Check if sum of digits in sudoku element equals given result.. [/INST] def assert_sum_digits_in_element(self, elements_list: List[SudokuElement], result: int = 45) -> bool:
+ for element in elements_list:
+ if np.sum(self.s.array[element.row_start:element.row_stop,
+ element.column_start:element.column_stop]) != result:
+ return False
+ return True
+[INST] Clears lists of possible digits for all blank cells' coordinates.. [/INST] def clear_possible_coordinates(self) -> None:
+ for coordinate in self.s.possible_digits_in_cells:
+ self.s.possible_digits_in_cells[coordinate] = []
+ return None
+[INST] If all possible coordinates (for particular number) in square are located in one row or one column,
+that row/column can be deleted from the rest possible coordinates.. [/INST] def positions_in_one_dimention(square_positions: List[Tuple[int, int]], flag: int) -> Optional[int]:
+ if len(square_positions) == 1:
+ return None
+ elements = set([coordinate[flag] for coordinate in square_positions])
+ if len(elements) == 1:
+ return list(elements)[0]
+ else:
+ return None
+[INST] Each time number is inserted into sudoku element, increment counter.
+If all 9 digits have been inserted, mark element as completed.. [/INST] def increment_digits(self) -> None:
+ if not self.completed:
+ self.digits += 1
+ if self.digits == 9:
+ self.completed = True
+ return None
+[INST] Checks if user's input is valid and changes invalid characters for blank string or previous valid character.. [/INST] def character_limit(self) -> None:
+ if len(self.text) > 0:
+ if len(self.text) == 1 and not self.input_validation(self.text[0]):
+ self.text = ''
+ else:
+ if self.input_validation(self.text[-1]):
+ self.text = self.text[-1]
+ else:
+ self.text = self.text[-2]
+ return None
+[INST] Creates 3 buttons on the bottom af window.. [/INST] def create_buttons(self, frame: tk.LabelFrame) -> None:
+ solve_button = tk.Button(frame, text='Solve Sudoku', font=10, command=self.solve)
+ solve_button.grid(row=9, column=0, pady=(10, 0))
+ examples_button = tk.Button(frame, text='Get example', font=1, command=self.example)
+ examples_button.grid(row=9, column=1, pady=(10, 0))
+ clear_button = tk.Button(frame, text='Clear Sudoku', font=10, command=self.clear)
+ clear_button.grid(row=9, column=2, pady=(10, 0))
+ return None
+[INST] Creates grid of sudoku squares.. [/INST] def create_grid(self, main_frame: tk.LabelFrame) -> None:
+ for square_row in range(3):
+ for square_column in range(3):
+ square = tk.Frame(main_frame, highlightbackground='black', highlightcolor='red',
+ highlightthickness=1, width=120, heigh=120, padx=0)
+ square.grid(row=square_row, column=square_column)
+ self.create_cells_and_entries(square, square_row)
+ return None
+[INST] Creates 9 cells inside a square.. [/INST] def create_cells_and_entries(self, frame: tk.Frame, square_row: int) -> None:
+ for row in range(3):
+ for column in range(3):
+ cell = tk.Frame(frame, bg='white', highlightbackground='black',
+ highlightcolor='black', highlightthickness=0.5,
+ width=40, heigh=40, padx=3, pady=3)
+ cell.grid(row=row, column=column)
+ row_index = square_row * 3 + row
+ entry_cell = EntryCell(cell)
+ self.entries[row_index].append(entry_cell)
+ return None
+[INST] Checks if sudoku array is valid and runs solving procedures.. [/INST] def solve(self) -> None:
+ sudoku = Sudoku(self.get_data())
+ solver = SudokuSolver(sudoku)
+ validation = solver.validate_sudoku()
+ if validation == 1:
+ solver.main_sequence()
+ self.get_result(solver)
+ elif validation == -1:
+ self.status_bar.config(text='This sudoku array contains invalid digits.', fg='red')
+ return None
+[INST] Inserts solving results into sudoku grid in window.. [/INST] def insert_digits(self, solver) -> None:
+ for row in range(9):
+ for column in range(9):
+ if self.entries[row][column].text == '':
+ self.entries[row][column].text = solver.s.array[row, column]
+ self.entries[row][column].entry.config(fg='blue')
+ return None
+[INST] Randomly chooses sudoku from saved examples and prints it into grid.. [/INST] def example(self) -> None:
+ random_sudoku = choice(sudoku_examples)
+ for row in range(9):
+ for column in range(9):
+ self.entries[row][column].text = str(random_sudoku.array[row, column])
+ self.entries[row][column].entry.config(fg='black')
+ self.status_bar.config(text='Ready', fg='black')
+ return None
+[INST] Create BuildConfig and ImageStream to build Bookbag image.. [/INST] def manage_bookbag_image_build(namespace, name, bookbag_config, owner_ref, logger=None):
+ build_spec = deepcopy(bookbag_config['imageBuild'])
+ build_spec['output'] = {
+ "to": {
+ "kind": "ImageStreamTag",
+ "name": f"{name}:latest",
+ }
+ }
+ if 'strategy' not in build_spec:
+ build_spec['strategy'] = {
+ "type": "Source",
+ "sourceStrategy": {
+ "from": {
+ "kind": "DockerImage",
+ "name": "quay.io/openshifthomeroom/workshop-dashboard:5.0.0",
+ }
+ }
+ }
+ build_spec['triggers'] = [{"type": "ConfigChange"}]
+ image_stream = get_image_stream(namespace, name)
+ if not image_stream:
+ custom_objects_api.create_namespaced_custom_object(
+ 'image.openshift.io', 'v1', namespace, 'imagestreams',
+ {
+ "apiVersion": "image.openshift.io/v1",
+ "kind": "ImageStream",
+ "metadata": {
+ "name": name,
+ "namespace": namespace,
+ "ownerReferences": [owner_ref],
+ },
+ }
+ )
+ build_config = get_build_config(namespace, name)
+ if build_config:
+ merged_spec = deepmerge(deepcopy(build_config['spec']), build_spec)
+ if merged_spec != build_config['spec']:
+ try:
+ custom_objects_api.replace_namespaced_custom_object(
+ 'build.openshift.io', 'v1', namespace, 'buildconfigs', name, {
+ "apiVersion": "build.openshift.io/v1",
+ "kind": "BuildConfig",
+ "metadata": build_config['metadata'],
+ "spec": merged_spec,
+ }
+ )
+ except kubernetes.client.rest.ApiException as e:
+ if e.status == 409:
+ pass
+ else:
+ raise
+ else:
+ custom_objects_api.create_namespaced_custom_object(
+ 'build.openshift.io', 'v1', namespace, 'buildconfigs',
+ {
+ "apiVersion": "build.openshift.io/v1",
+ "kind": "BuildConfig",
+ "metadata": {
+ "name": name,
+ "namespace": namespace,
+ "ownerReferences": [owner_ref],
+ },
+ "spec": build_spec,
+ }
+ )
+[INST] Create Deployment, RoleBinding, Route, Service, ServiceAccount for bookbag.. [/INST] def manage_bookbag_deployment(namespace, name, bookbag_config, bookbag_vars, owner_ref, image=None, image_stream=None, logger=None):
+ auth = bookbag_config.get('auth', {})
+ auth_user = auth.get('user', '*')
+ auth_password = auth.get('password', '')
+ service_account = get_service_account(namespace, name)
+ if not service_account:
+ core_v1_api.create_namespaced_service_account(
+ namespace,
+ {
+ "apiVersion": "v1",
+ "kind": "ServiceAccount",
+ "metadata": {
+ "name": name,
+ "ownerReferences": [owner_ref],
+ }
+ }
+ )
+ role_binding = get_role_binding(namespace, name)
+ if not role_binding:
+ custom_objects_api.create_namespaced_custom_object(
+ 'rbac.authorization.k8s.io', 'v1', namespace, 'rolebindings',
+ {
+ "apiVersion": "rbac.authorization.k8s.io/v1",
+ "kind": "RoleBinding",
+ "metadata": {
+ "name": name,
+ "ownerReferences": [owner_ref],
+ },
+ "roleRef": {
+ "apiGroup": "rbac.authorization.k8s.io",
+ "kind": "ClusterRole",
+ "name": "basic-user",
+ },
+ "subjects": [{
+ "kind": "ServiceAccount",
+ "name": name,
+ "namespace": namespace,
+ }]
+ }
+ )
+ deployment = get_deployment(namespace, name)
+ deployment_spec = {
+ "replicas": 1,
+ "selector": {
+ "matchLabels": {
+ "name": name,
+ }
+ },
+ "strategy": {
+ "type": "Recreate",
+ },
+ "template": {
+ "metadata": {
+ "labels": {
+ "name": name,
+ }
+ },
+ "spec": {
+ "containers": [{
+ "name": "bookbag",
+ "env": [{
+ "name": "APPLICATION_NAME",
+ "value": name,
+ }, {
+ "name": "AUTH_USERNAME",
+ "value": auth_user,
+ }, {
+ "name": "AUTH_PASSWORD",
+ "value": auth_password,
+ }, {
+ "name": "CLUSTER_SUBDOMAIN",
+ }, {
+ "name": "OAUTH_SERVICE_ACCOUNT",
+ "value": name,
+ }, {
+ "name": "DOWNLOAD_URL",
+ }, {
+ "name": "WORKSHOP_FILE",
+ }, {
+ "name": "OC_VERSION",
+ }, {
+ "name": "ODO_VERSION",
+ }, {
+ "name": "KUBECTL_VERSION",
+ }, {
+ "name": "WORKSHOP_VARS",
+ "value": json.dumps(bookbag_vars),
+ }],
+ "imagePullPolicy": "Always",
+ "ports": [{
+ "containerPort": 10080,
+ }],
+ "resources": {},
+ }],
+ "serviceAccountName": name,
+ }
+ }
+ }
+ if image:
+ deployment_spec['template']['spec']['containers'][0]['image'] = image
+ if deployment:
+ merged_spec = deepmerge(deepcopy(deployment['spec']), deployment_spec)
+ if merged_spec != deployment['spec']:
+ try:
+ custom_objects_api.replace_namespaced_custom_object(
+ 'apps', 'v1', namespace, 'deployments', name, {
+ "apiVersion": "apps/v1",
+ "kind": "Deployment",
+ "metadata": deployment['metadata'],
+ "spec": merged_spec,
+ }
+ )
+ except kubernetes.client.rest.ApiException as e:
+ if e.status == 409:
+ pass
+ else:
+ raise
+ else:
+ deployment_meta = {
+ "name": name,
+ "namespace": namespace,
+ "ownerReferences": [owner_ref],
+ }
+ if not image:
+ deployment_meta['annotations'] = {
+ "image.openshift.io/triggers": json.dumps([{
+ "from": {
+ "kind": "ImageStreamTag",
+ "name": f"{image_stream}:latest",
+ },
+ "fieldPath": 'spec.template.spec.containers[?(@.name=="bookbag")].image',
+ }])
+ }
+ deployment_spec['template']['spec']['containers'][0]['image'] = \
+ f"image-registry.openshift-image-registry.svc:5000/{namespace}/{image_stream}:latest"
+ custom_objects_api.create_namespaced_custom_object(
+ 'apps', 'v1', namespace, 'deployments',
+ {
+ "apiVersion": "apps/v1",
+ "kind": "Deployment",
+ "metadata": deployment_meta,
+ "spec": deployment_spec,
+ }
+ )
+ service = get_service(namespace, name)
+ if not service:
+ core_v1_api.create_namespaced_service(
+ namespace,
+ {
+ "apiVersion": "v1",
+ "kind": "Service",
+ "metadata": {
+ "name": name,
+ "namespace": namespace,
+ "ownerReferences": [owner_ref],
+ },
+ "spec": {
+ "ports": [{
+ "name": "10080-tcp",
+ "port": 10080,
+ "protocol": "TCP",
+ "targetPort": 10080,
+ }],
+ "selector": {
+ "name": name,
+ },
+ "type": "ClusterIP",
+ }
+ }
+ )
+ route = get_route(namespace, name)
+ if not route:
+ route = custom_objects_api.create_namespaced_custom_object(
+ 'route.openshift.io', 'v1', namespace, 'routes',
+ {
+ "apiVersion": "route.openshift.io/v1",
+ "kind": "Route",
+ "metadata": {
+ "name": name,
+ "namespace": namespace,
+ "ownerReferences": [owner_ref],
+ },
+ "spec": {
+ "port": {
+ "targetPort": "10080-tcp",
+ },
+ "tls": {
+ "insecureEdgeTerminationPolicy": "Redirect",
+ "termination": "edge",
+ },
+ "to": {
+ "kind": "Service",
+ "name": name,
+ },
+ }
+ }
+ )
+ return route['spec']['host']
+[INST] Check and return true if api_client is configured with babylon admin access.
+Access is determined by whether the user can directly manage AnarchySubjects.. [/INST] def check_user_support_access(api_client):
+ (data, status, headers) = api_client.call_api(
+ '/apis/authorization.k8s.io/v1/selfsubjectaccessreviews',
+ 'POST',
+ auth_settings = ['BearerToken'],
+ body = {
+ "apiVersion": "authorization.k8s.io/v1",
+ "kind": "SelfSubjectAccessReview",
+ "spec": {
+ "resourceAttributes": {
+ "group": "poolboy.gpte.redhat.com",
+ "resource": "resourceclaims",
+ "verb": "patch",
+ }
+ },
+ "status": {
+ "allowed": False
+ }
+ },
+ response_type = 'object',
+ )
+ return data.get('status', {}).get('allowed', False)
+[INST] Checks reactions and deletes tracked messages when necessary.. [/INST] async def on_raw_reaction_add(self, payload):
+ if payload.user_id == self.bot.user.id:
+ return
+ if payload.emoji.name == CONSTANTS.REACTION_DELETE:
+ is_tracked = False
+ sender_uid = None
+ with SQLCursor(self.db) as cur:
+ cur.execute("SELECT messid, sender_uid FROM tracked_messages WHERE messid=?", (payload.message_id,))
+ row = cur.fetchone()
+ if row:
+ is_tracked = True
+ sender_uid = row[1]
+ if is_tracked:
+ reacting_member = self.bot.get_guild(payload.guild_id).get_member(payload.user_id)
+ can_delete = self.bot.get_channel(payload.channel_id).permissions_for(reacting_member).manage_messages
+ if payload.user_id == sender_uid or can_delete:
+ relevant_message = await self.bot.get_channel(payload.channel_id).get_message(payload.message_id)
+ await relevant_message.delete()
+[INST] Marks a message in the database so that it will be automatically
+deleted if the sender or an admin reacts with the 'trash' emoji. [/INST] async def track(message, author=None):
+ await message.add_reaction(CONSTANTS.REACTION_DELETE)
+ sql_db = SQLConnection()
+ aid = 0
+ if author:
+ aid = author.id
+ with SQLCursor(sql_db) as cur:
+ cur.execute("INSERT INTO tracked_messages (messid, sender_uid, track_time) VALUES (?, ?, ?);", (message.id, aid, message.created_at))
+[INST] Print the bots information on connect. [/INST] async def on_ready():
+ global bot_url
+ print('Logged in as "{name}" with id {id}'.format(name=bot.user.name,id=bot.user.id))
+ print('Invite URL: {iurl}'.format(iurl=bot_url.format(bot.user.id)))
+ print('-----')
+ await bot.change_presence(activity=discord.Game(name="Destroying propritary software"))
+[INST] Add a role to a user. [/INST] async def join(self, ctx, *, roleName: str):
+ if roleName.lower() in ["muddy", "muddy swamp", "muddyswamp", "muddy-swamp", "MUD"]:
+ await self.muddy(ctx)
+ elif roleName.lower() in ["website", "club site", "club website", "clubwebsite", "clubsite"]:
+ await self.website(ctx)
+ elif roleName.lower() in ["mvw", "marstonvswest", "marston vs west", "marston v west"]:
+ await self.mvw(ctx)
+[INST] r"""parse the hashtag to ascii encoded string. [/INST] def parse_tag(self, tag: str) -> str:
+ return parse.quote(tag)
+[INST] If user want to brute force attack to Login system
+This method prevent from this attack
+And User ip address block automatic. [/INST] def InvalidLoginAttempts(request):
+ ip = request.META.get('HTTP_X_FORWARDED_FOR', 'ipsiyoxdu')
+ address = request.META.get('HTTP_USER_AGENT', 'unvaniyoxdu')
+ try:
+ log = AllowedIpAddress.objects.get(ip=ip, address=address)
+ log.failed_login +=1
+ log.save()
+ if log.failed_login == 5:
+ log.failed_login -= 5
+ log.status = False
+ log.save()
+ print('%s %s' % ('STATUS','[PERMISSION DENIED]'))
+ except:
+ print('%s %s' % ('STATUS','[OK]'))
+[INST] Finds an alive tag in the repository matching one of the given tag names and returns it or
+None if none.. [/INST] def find_matching_tag(self, repository_ref, tag_names):
+[INST] Looks up and returns a reference to the repository with the given namespace and name, or
+None if none.. [/INST] def lookup_repository(self, namespace_name, repo_name, kind_filter=None):
+[INST] Looks up the manifest with the given digest under the given repository and returns it or
+None if none.
+
+If allow_dead is True, manifests pointed to by dead tags will also be returned. If
+require_available is True, a temporary tag will be added onto the returned manifest (before
+it is returned) to ensure it is available until another tagging or manifest operation is
+taken.. [/INST] def lookup_manifest_by_digest(
+ self,
+ repository_ref,
+ manifest_digest,
+ allow_dead=False,
+ include_legacy_image=False,
+ require_available=False,
+ ):
+[INST] Creates a manifest in a repository, adding all of the necessary data in the model.
+The `manifest_interface_instance` parameter must be an instance of the manifest
+interface as returned by the image/docker package.
+
+Note that all blobs referenced by the manifest must exist under the repository or this
+method will fail and return None.
+
+Returns a reference to the (created manifest, tag) or (None, None) on error.. [/INST] def create_manifest_and_retarget_tag(
+ self, repository_ref, manifest_interface_instance, tag_name, storage, raise_on_error=False
+ ):
+[INST] Creates a label on the manifest with the given key and value.
+Can raise InvalidLabelKeyException or InvalidMediaTypeException depending on the validation
+errors.. [/INST] def create_manifest_label(self, manifest, key, value, source_type_name, media_type_name=None):
+[INST] Returns a context manager for batch creation of labels on a manifest.
+Can raise InvalidLabelKeyException or InvalidMediaTypeException depending on the validation
+errors.. [/INST] def batch_create_manifest_labels(self, manifest):
+[INST] Returns all labels found on the manifest.
+If specified, the key_prefix will filter the labels returned to those keys that start with
+the given prefix.. [/INST] def list_manifest_labels(self, manifest, key_prefix=None):
+[INST] Delete the label with the specified UUID on the manifest.
+Returns the label deleted or None if none.. [/INST] def delete_manifest_label(self, manifest, label_uuid):
+[INST] Returns a page of active tags in a repository.
+Note that the tags returned by this method are ShallowTag objects, which only contain the
+tag name. This method will automatically cache the result and check the cache before making
+a call.. [/INST] def lookup_cached_active_repository_tags(
+ self, model_cache, repository_ref, start_pagination_id, limit
+ ):
+[INST] Returns a page of active tags in a repository.
+Note that the tags returned by this method are ShallowTag objects, which only contain the
+tag name.. [/INST] def lookup_active_repository_tags(self, repository_ref, start_pagination_id, limit):
+[INST] Returns a list of all the active tags in the repository.
+Note that this is a *HEAVY* operation on repositories with a lot of tags, and should only be
+used for testing or where other more specific operations are not possible.. [/INST] def list_all_active_repository_tags(self, repository_ref, include_legacy_images=False):
+[INST] Returns the history of all tags in the repository (unless filtered).
+This includes tags that have been made in-active due to newer versions of those tags coming
+into service.. [/INST] def list_repository_tag_history(
+ self,
+ repository_ref,
+ page=1,
+ size=100,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+ ):
+[INST] Returns true if and only if the repository contains a tag with the given name that is
+expired.. [/INST] def has_expired_tag(self, repository_ref, tag_name):
+[INST] Creates, updates or moves a tag to a new entry in history, pointing to the manifest or
+legacy image specified.
+
+If is_reversion is set to True, this operation is considered a reversion over a previous tag
+move operation. Returns the updated Tag or None on error.. [/INST] def retarget_tag(
+ self,
+ repository_ref,
+ tag_name,
+ manifest_or_legacy_image,
+ storage,
+ legacy_manifest_key,
+ is_reversion=False,
+ ):
+[INST] Deletes all tags pointing to the given manifest, making the manifest inaccessible for
+pulling.
+
+Returns the tags deleted, if any. Returns None on error.. [/INST] def delete_tags_for_manifest(self, manifest):
+[INST] Sets the expiration date of the tag under the matching repository to that given.
+If the expiration date is None, then the tag will not expire. Returns a tuple of the
+previous expiration timestamp in seconds (if any), and whether the operation succeeded.. [/INST] def change_repository_tag_expiration(self, tag, expiration_date):
+[INST] Resets the security status for the given manifest or legacy image, ensuring that it will get
+re-indexed.. [/INST] def reset_security_status(self, manifest_or_legacy_image):
+[INST] Backfills a manifest for the V1 tag specified. If a manifest already exists for the tag,
+returns that manifest.
+
+This method will only be necessary until we've completed the backfill, at which point
+it should be removed.. [/INST] def backfill_manifest_for_tag(self, tag):
+[INST] Returns an *ordered list* of the layers found in the manifest, starting at the base and
+working towards the leaf, including the associated Blob and its placements (if specified).
+
+The layer information in `layer_info` will be of type
+`image.docker.types.ManifestImageLayer`. Should not be called for a manifest list.. [/INST] def list_manifest_layers(self, manifest, storage, include_placements=False):
+[INST] Returns an *ordered list* of the layers found in the parsed manifest, starting at the base
+and working towards the leaf, including the associated Blob and its placements (if
+specified).
+
+The layer information in `layer_info` will be of type
+`image.docker.types.ManifestImageLayer`. Should not be called for a manifest list.. [/INST] def list_parsed_manifest_layers(
+ self, repository_ref, parsed_manifest, storage, include_placements=False
+ ):
+[INST] Looks up the derived image for the given manifest, verb and optional varying metadata and
+returns it or None if none.. [/INST] def lookup_derived_image(
+ self, manifest, verb, storage, varying_metadata=None, include_placements=False
+ ):
+[INST] Looks up the derived image for the given maniest, verb and optional varying metadata and
+returns it.
+
+If none exists, a new derived image is created.. [/INST] def lookup_or_create_derived_image(
+ self,
+ manifest,
+ verb,
+ storage_location,
+ storage,
+ varying_metadata=None,
+ include_placements=False,
+ ):
+[INST] Creates a new blob upload and returns a reference.
+If the blob upload could not be created, returns None.. [/INST] def create_blob_upload(self, repository_ref, upload_id, location_name, storage_metadata):
+[INST] Looks up the blob upload with the given ID under the specified repository and returns it or
+None if none.. [/INST] def lookup_blob_upload(self, repository_ref, blob_upload_id):
+[INST] Updates the fields of the blob upload to match those given.
+Returns the updated blob upload or None if the record does not exists.. [/INST] def update_blob_upload(
+ self,
+ blob_upload,
+ uncompressed_byte_count,
+ piece_hashes,
+ piece_sha_state,
+ storage_metadata,
+ byte_count,
+ chunk_count,
+ sha_state,
+ ):
+[INST] Mounts the blob from another repository into the specified target repository, and adds an
+expiration before that blob is automatically GCed.
+
+This function is useful during push operations if an existing blob from another repository
+is being pushed. Returns False if the mounting fails. Note that this function does *not
+check security for mounting the blob and the caller is responsible for doing this check (an
+example can be found in endpoints/v2/blob.py).. [/INST] def mount_blob_into_repository(self, blob, target_repository_ref, expiration_sec):
+[INST] Creates a manifest under the repository and sets a temporary tag to point to it.
+Returns the manifest object created or None on error.. [/INST] def create_manifest_with_temp_tag(
+ self, repository_ref, manifest_interface_instance, expiration_sec, storage
+ ):
+[INST] Attempts to convert the specified into a parsed manifest with a media type in the
+allowed_mediatypes set.
+
+If not possible, or an error occurs, returns None.. [/INST] def convert_manifest(
+ self, manifest, namespace_name, repo_name, tag_name, allowed_mediatypes, storage
+ ):
+[INST] Yields tags that contain one (or more) of the given layer ID pairs, in repositories which
+have been registered for vulnerability_found notifications.
+
+Returns an iterator of LikelyVulnerableTag instances.. [/INST] def yield_tags_for_vulnerability_notification(self, layer_id_pairs):
+[INST] Exposes the legacy security scan API for legacy workers that need it.. [/INST] def legacy_api_handler(self):
+ return self._legacy_secscan_api
+[INST] Performs indexing of the next set of unindexed manifests/images.
+If start_token is given, the indexing should resume from that point. Returns a new start
+index for the next iteration of indexing. The tokens returned and given are assumed to be
+opaque outside of this implementation and should not be relied upon by the caller to conform
+to any particular format.. [/INST] def perform_indexing(self, start_token=None):
+ from util.secscan.analyzer import PreemptedException
+ iterator, next_token = self._candidates_to_scan(start_token)
+ if iterator is None:
+ logger.debug("Found no additional images to scan")
+ return None
+ with UseThenDisconnect(self.app.config):
+ for candidate, abt, num_remaining in iterator:
+ try:
+ self._analyzer.analyze_recursively(candidate)
+ except PreemptedException:
+ logger.info("Another worker pre-empted us for layer: %s", candidate.id)
+ abt.set()
+ except APIRequestFailure:
+ logger.exception("Security scanner service unavailable")
+ return
+ unscanned_images.set(num_remaining)
+ return next_token
+[INST] Loads the security information for the given manifest or legacy image, returning a
+SecurityInformationLookupResult structure.
+
+The manifest_or_legacy_image must be a Manifest or LegacyImage datatype from the
+registry_model.. [/INST] def load_security_information(self, manifest_or_legacy_image, include_vulnerabilities=False):
+[INST] Performs indexing of the next set of unindexed manifests/images.
+If start_token is given, the indexing should resume from that point. Returns a new start
+index for the next iteration of indexing. The tokens returned and given are assumed to be
+opaque outside of this implementation and should not be relied upon by the caller to conform
+to any particular format.. [/INST] def perform_indexing(self, start_token=None):
+[INST] Registers any cleanup callbacks with the data model.
+Typically, a callback is registered to remove the manifest/image from the security indexer
+if it has been GCed in the data model.. [/INST] def register_model_cleanup_callbacks(self, data_model_config):
+[INST] Counts actions and aggregates search scores for a random repository for the previous day.. [/INST] def _count_repository_actions(self):
+ to_count = model.repositoryactioncount.find_uncounted_repository()
+ if to_count is None:
+ logger.debug("No further repositories to count")
+ return False
+ logger.debug("Found repository #%s to count", to_count.id)
+ yesterday = date.today() - timedelta(days=1)
+ daily_count = logs_model.count_repository_actions(to_count, yesterday)
+ if daily_count is None:
+ logger.debug("Could not load count for repository #%s", to_count.id)
+ return False
+ was_counted = model.repositoryactioncount.store_repository_action_count(
+ to_count, yesterday, daily_count
+ )
+ if not was_counted:
+ logger.debug("Repository #%s was counted by another worker", to_count.id)
+ return False
+ logger.debug("Updating search score for repository #%s", to_count.id)
+ was_updated = model.repositoryactioncount.update_repository_score(to_count)
+ if not was_updated:
+ logger.debug(
+ "Repository #%s had its search score updated by another worker", to_count.id
+ )
+ return False
+ logger.debug("Repository #%s search score updated", to_count.id)
+ while True:
+ found = model.repositoryactioncount.delete_expired_entries(to_count, 30)
+ if found <= 0:
+ break
+ logger.debug("Repository #%s old entries removed", to_count.id)
+ return True
+[INST] Returns a tuple of the full set of tags found in the specified repository, including those that
+are no longer alive (unless active_tags_only is True), and whether additional tags exist. If
+specific_tag_name is given, the tags are further filtered by name. If since is given, tags are
+further filtered to newer than that date.
+
+Note that the returned Manifest will not contain the manifest contents.. [/INST] def list_repository_tag_history(
+ repository_id,
+ page,
+ page_size,
+ specific_tag_name=None,
+ active_tags_only=False,
+ since_time_ms=None,
+):
+ query = (
+ Tag.select(Tag, Manifest.id, Manifest.digest, Manifest.media_type)
+ .join(Manifest)
+ .where(Tag.repository == repository_id)
+ .order_by(Tag.lifetime_start_ms.desc(), Tag.name)
+ .limit(page_size + 1)
+ .offset(page_size * (page - 1))
+ )
+ if specific_tag_name is not None:
+ query = query.where(Tag.name == specific_tag_name)
+ if since_time_ms is not None:
+ query = query.where(
+ (Tag.lifetime_start_ms > since_time_ms) | (Tag.lifetime_end_ms > since_time_ms)
+ )
+ if active_tags_only:
+ query = filter_to_alive_tags(query)
+ query = filter_to_visible_tags(query)
+ results = list(query)
+ return results[0:page_size], len(results) > page_size
+[INST] Creates or updates a tag with the specified name to point to the given manifest under its
+repository.
+
+If this action is a reversion to a previous manifest, is_reversion should be set to True.
+Returns the newly created tag row or None on error.. [/INST] def retarget_tag(
+ tag_name, manifest_id, is_reversion=False, now_ms=None, raise_on_error=False,
+):
+ try:
+ manifest = (
+ Manifest.select(Manifest, MediaType)
+ .join(MediaType)
+ .where(Manifest.id == manifest_id)
+ .get()
+ )
+ except Manifest.DoesNotExist:
+ if raise_on_error:
+ raise RetargetTagException("Manifest requested no longer exists")
+ return None
+ if manifest.media_type.name in DOCKER_SCHEMA1_CONTENT_TYPES:
+ try:
+ parsed = DockerSchema1Manifest(
+ Bytes.for_string_or_unicode(manifest.manifest_bytes), validate=False
+ )
+ if parsed.tag != tag_name:
+ logger.error(
+ "Tried to re-target schema1 manifest with tag `%s` to tag `%s",
+ parsed.tag,
+ tag_name,
+ )
+ return None
+ except MalformedSchema1Manifest as msme:
+ logger.exception("Could not parse schema1 manifest")
+ if raise_on_error:
+ raise RetargetTagException(msme)
+ return None
+ legacy_image = get_legacy_image_for_manifest(manifest)
+ now_ms = now_ms or get_epoch_timestamp_ms()
+ now_ts = int(now_ms / 1000)
+ with db_transaction():
+ existing_tag = get_tag(manifest.repository_id, tag_name)
+ if existing_tag is not None:
+ _, okay = set_tag_end_ms(existing_tag, now_ms)
+ if not okay:
+ return None
+ created = Tag.create(
+ name=tag_name,
+ repository=manifest.repository_id,
+ lifetime_start_ms=now_ms,
+ reversion=is_reversion,
+ manifest=manifest,
+ tag_kind=Tag.tag_kind.get_id("tag"),
+ )
+ return created
+[INST] Deletes the given tag by marking it as expired.. [/INST] def _delete_tag(tag, now_ms):
+ now_ts = int(now_ms / 1000)
+ with db_transaction():
+ updated = (
+ Tag.update(lifetime_end_ms=now_ms)
+ .where(Tag.id == tag.id, Tag.lifetime_end_ms == tag.lifetime_end_ms)
+ .execute()
+ )
+ if updated != 1:
+ return None
+ try:
+ old_style_tag = (
+ TagToRepositoryTag.select(TagToRepositoryTag, RepositoryTag)
+ .join(RepositoryTag)
+ .where(TagToRepositoryTag.tag == tag)
+ .get()
+ ).repository_tag
+ old_style_tag.lifetime_end_ts = now_ts
+ old_style_tag.save()
+ except TagToRepositoryTag.DoesNotExist:
+ pass
+ return tag
+[INST] Deletes all tags pointing to the given manifest.
+Returns the list of tags deleted.. [/INST] def delete_tags_for_manifest(manifest):
+ query = Tag.select().where(Tag.manifest == manifest)
+ query = filter_to_alive_tags(query)
+ query = filter_to_visible_tags(query)
+ tags = list(query)
+ now_ms = get_epoch_timestamp_ms()
+ with db_transaction():
+ for tag in tags:
+ _delete_tag(tag, now_ms)
+ return tags
+[INST] Adjusts the specified Tag query to only return those tags alive.
+If now_ms is specified, the given timestamp (in MS) is used in place of the current timestamp
+for determining wherther a tag is alive.. [/INST] def filter_to_alive_tags(query, now_ms=None, model=Tag):
+ if now_ms is None:
+ now_ms = get_epoch_timestamp_ms()
+ return query.where((model.lifetime_end_ms >> None) | (model.lifetime_end_ms > now_ms)).where(
+ model.hidden == False
+ )
+[INST] Changes the expiration of the specified tag to the given expiration datetime.
+If the expiration datetime is None, then the tag is marked as not expiring. Returns a tuple of
+the previous expiration timestamp in seconds (if any), and whether the operation succeeded.. [/INST] def change_tag_expiration(tag_id, expiration_datetime):
+ try:
+ tag = Tag.get(id=tag_id)
+ except Tag.DoesNotExist:
+ return (None, False)
+ new_end_ms = None
+ min_expire_sec = convert_to_timedelta(config.app_config.get("LABELED_EXPIRATION_MINIMUM", "1h"))
+ max_expire_sec = convert_to_timedelta(
+ config.app_config.get("LABELED_EXPIRATION_MAXIMUM", "104w")
+ )
+ if expiration_datetime is not None:
+ lifetime_start_ts = int(tag.lifetime_start_ms / 1000)
+ offset = timegm(expiration_datetime.utctimetuple()) - lifetime_start_ts
+ offset = min(max(offset, min_expire_sec.total_seconds()), max_expire_sec.total_seconds())
+ new_end_ms = tag.lifetime_start_ms + (offset * 1000)
+ if new_end_ms == tag.lifetime_end_ms:
+ return (None, True)
+ return set_tag_end_ms(tag, new_end_ms)
+[INST] Yields any alive Tags found in repositories with an event with the given name registered and
+whose legacy Image has the given docker image ID and storage UUID.. [/INST] def lookup_notifiable_tags_for_legacy_image(docker_image_id, storage_uuid, event_name):
+ event = ExternalNotificationEvent.get(name=event_name)
+ images = (
+ Image.select()
+ .join(ImageStorage)
+ .where(Image.docker_image_id == docker_image_id, ImageStorage.uuid == storage_uuid)
+ )
+ for image in list(images):
+ try:
+ RepositoryNotification.get(repository=image.repository_id, event=event)
+ except RepositoryNotification.DoesNotExist:
+ continue
+ for tag in tags_containing_legacy_image(image):
+ yield tag
+[INST] Filters the query by ensuring the repositories returned have the given event.
+NOTE: This is for legacy support in the old security notification worker and should
+be removed once that code is no longer necessary.. [/INST] def _filter_has_repository_event(query, event):
+ return (
+ query.join(Repository)
+ .join(RepositoryNotification)
+ .where(RepositoryNotification.event == event)
+ )
+[INST] Filters the query by ensuring the tags live in a repository that has the given
+event. Also orders the results by lifetime_start_ms.
+
+This is for legacy support in the old security notification worker and should
+be removed once that code is no longer necessary.. [/INST] def filter_tags_have_repository_event(query, event):
+ query = _filter_has_repository_event(query, event)
+ query = query.switch(Tag).order_by(Tag.lifetime_start_ms.desc())
+ return query
+[INST] Returns a repository that has garbage (defined as an expired Tag that is past
+the repo's namespace's expiration window) or None if none.. [/INST] def find_repository_with_garbage(limit_to_gc_policy_s):
+ expiration_timestamp = get_epoch_timestamp_ms() - (limit_to_gc_policy_s * 1000)
+ try:
+ candidates = (
+ Tag.select(Tag.repository)
+ .join(Repository)
+ .join(Namespace, on=(Repository.namespace_user == Namespace.id))
+ .where(
+ ~(Tag.lifetime_end_ms >> None),
+ (Tag.lifetime_end_ms <= expiration_timestamp),
+ (Namespace.removed_tag_expiration_s == limit_to_gc_policy_s),
+ )
+ .limit(GC_CANDIDATE_COUNT)
+ .distinct()
+ .alias("candidates")
+ )
+ found = (
+ Tag.select(candidates.c.repository_id)
+ .from_(candidates)
+ .order_by(db_random_func())
+ .get()
+ )
+ if found is None:
+ return
+ return Repository.get(Repository.id == found.repository_id)
+ except Tag.DoesNotExist:
+ return None
+ except Repository.DoesNotExist:
+ return None
+[INST] Creates a new manifest builder for populating manifests under the specified repository and
+returns it.
+
+Returns None if the builder could not be constructed.. [/INST] def create_manifest_builder(repository_ref, storage, legacy_signing_key):
+ builder_id = str(uuid.uuid4())
+ builder = _ManifestBuilder(
+ repository_ref, _BuilderState(builder_id, {}, {}, {}, []), storage, legacy_signing_key
+ )
+ builder._save_to_session()
+ return builder
+[INST] Looks up the manifest builder with the given ID under the specified repository and returns it or
+None if none.. [/INST] def lookup_manifest_builder(repository_ref, builder_id, storage, legacy_signing_key):
+ builder_state_tuple = session.get(_SESSION_KEY)
+ if builder_state_tuple is None:
+ return None
+ builder_state = _BuilderState(*builder_state_tuple)
+ if builder_state.builder_id != builder_id:
+ return None
+ return _ManifestBuilder(repository_ref, builder_state, storage, legacy_signing_key)
+[INST] Returns the unique ID for this builder.. [/INST] def builder_id(self):
+ return self._builder_state.builder_id
+[INST] Returns the tags committed by this builder, if any.. [/INST] def committed_tags(self):
+ return [
+ registry_model.get_repo_tag(self._repository_ref, tag_name, include_legacy_image=True)
+ for tag_name in self._builder_state.tags.keys()
+ ]
+[INST] Starts a new layer with the given ID to be placed into a manifest.
+Returns the layer started or None if an error occurred.. [/INST] def start_layer(
+ self, layer_id, v1_metadata_string, location_name, calling_user, temp_tag_expiration
+ ):
+ repository = model.repository.lookup_repository(self._repository_ref._db_id)
+ if repository is None:
+ return None
+ namespace_name = repository.namespace_user.username
+ repo_name = repository.name
+ try:
+ v1_metadata = json.loads(v1_metadata_string)
+ except ValueError:
+ logger.exception(
+ "Exception when trying to parse V1 metadata JSON for layer %s", layer_id
+ )
+ return None
+ except TypeError:
+ logger.exception(
+ "Exception when trying to parse V1 metadata JSON for layer %s", layer_id
+ )
+ return None
+ if layer_id != v1_metadata["id"]:
+ return None
+ parent_id = v1_metadata.get("parent", None)
+ parent_image = None
+ if parent_id is not None:
+ parent_image = model.image.get_repo_image(namespace_name, repo_name, parent_id)
+ if parent_image is None:
+ return None
+ existing_image = registry_model.get_legacy_image(self._repository_ref, layer_id)
+ if existing_image is not None:
+ self._builder_state.images[layer_id] = existing_image.id
+ self._save_to_session()
+ return ManifestLayer(layer_id, v1_metadata_string, existing_image.id)
+ with db_transaction():
+ created = model.image.find_create_or_link_image(
+ layer_id, repository, calling_user, {}, location_name
+ )
+ temp_tag_name = model.tag.create_temporary_hidden_tag(
+ repository, created, temp_tag_expiration
+ )
+ if temp_tag_name is None:
+ return None
+ command_list = v1_metadata.get("container_config", {}).get("Cmd", None)
+ command = json.dumps(command_list) if command_list else None
+ model.image.set_image_metadata(
+ layer_id,
+ namespace_name,
+ repo_name,
+ v1_metadata.get("created"),
+ v1_metadata.get("comment"),
+ command,
+ v1_metadata_string,
+ parent=parent_image,
+ )
+ self._builder_state.images[layer_id] = created.id
+ self._save_to_session()
+ return ManifestLayer(layer_id, v1_metadata_string, created.id)
+[INST] Returns a layer with the given ID under this builder.
+If none exists, returns None.. [/INST] def lookup_layer(self, layer_id):
+ if layer_id not in self._builder_state.images:
+ return None
+ image = model.image.get_image_by_db_id(self._builder_state.images[layer_id])
+ if image is None:
+ return None
+ return ManifestLayer(layer_id, image.v1_json_metadata, image.id)
+[INST] Assigns a blob to a layer.. [/INST] def assign_layer_blob(self, layer, blob, computed_checksums):
+ assert blob
+ assert not blob.uploading
+ repo_image = model.image.get_image_by_db_id(layer.db_id)
+ if repo_image is None:
+ return None
+ with db_transaction():
+ existing_storage = repo_image.storage
+ repo_image.storage = blob._db_id
+ repo_image.save()
+ if existing_storage.uploading:
+ self._builder_state.temp_storages.append(existing_storage.id)
+ self._builder_state.checksums[layer.layer_id] = computed_checksums
+ self._save_to_session()
+ return True
+[INST] Returns whether the checksum for a layer matches that specified.. [/INST] def validate_layer_checksum(self, layer, checksum):
+ return checksum in self.get_layer_checksums(layer)
+[INST] Saves a precomputed checksum for a layer.. [/INST] def save_precomputed_checksum(self, layer, checksum):
+ checksums = self._builder_state.checksums.get(layer.layer_id) or []
+ checksums.append(checksum)
+ self._builder_state.checksums[layer.layer_id] = checksums
+ self._save_to_session()
+[INST] Commits a new tag + manifest for that tag to the repository with the given name, pointing to
+the given layer.. [/INST] def commit_tag_and_manifest(self, tag_name, layer):
+ legacy_image = registry_model.get_legacy_image(self._repository_ref, layer.layer_id)
+ if legacy_image is None:
+ return None
+ tag = registry_model.retarget_tag(
+ self._repository_ref, tag_name, legacy_image, self._storage, self._legacy_signing_key
+ )
+ if tag is None:
+ return None
+ self._builder_state.tags[tag_name] = tag._db_id
+ self._save_to_session()
+ return tag
+[INST] Marks the manifest builder as complete and disposes of any state.
+This call is optional and it is expected manifest builders will eventually time out if
+unused for an extended period of time.. [/INST] def done(self):
+ temp_storages = self._builder_state.temp_storages
+ for storage_id in temp_storages:
+ try:
+ storage = ImageStorage.get(id=storage_id)
+ if storage.uploading and storage.content_checksum != EMPTY_LAYER_BLOB_DIGEST:
+ ImageStoragePlacement.delete().where(
+ ImageStoragePlacement.storage == storage
+ ).execute()
+ storage.delete_instance()
+ except ImageStorage.DoesNotExist:
+ pass
+ session.pop(_SESSION_KEY, None)
+[INST] Returns the HTTP status code for the expected failure under the specified protocol version (1 or
+2).
+
+If none, returns the success status code.. [/INST] def _get_expected_code(expected_failure, version, success_status_code):
+ if not expected_failure:
+ return success_status_code
+ return expected_failure[version]
+[INST] Unless a specific tag is known, 'skopeo inspect' won't work.
+Here first 'latest' is checked and then the tag expression, split at commas, is each checked
+until one works.. [/INST] def tags(
+ self,
+ repository,
+ rule_value,
+ username=None,
+ password=None,
+ tls_verify=True,
+ proxy=None,
+ verbose_logs=False,
+ ):
+ args = ["/usr/bin/skopeo"]
+ if verbose_logs:
+ args = args + ["--debug"]
+ args = args + ["inspect", "--tls-verify=%s" % tls_verify]
+ args = args + self.external_registry_credentials("--creds", username, password)
+ if not rule_value:
+ rule_value = []
+ all_tags = []
+ for tag in rule_value + ["latest"]:
+ result = self.run_skopeo(args + [quote("%s:%s" % (repository, tag))], proxy)
+ if result.success:
+ all_tags = json.loads(result.stdout)["RepoTags"]
+ if all_tags is not []:
+ break
+ return SkopeoResults(result.success, all_tags, result.stdout, result.stderr)
+[INST] qid : list, default - qubit id's list for all of the qubits
+qubit id's list to reset.
+
+Notes
+
+If 'qid' is set, specified qubits are reset after
+measurement. So if the specified qubits are entangled with the
+remaining qubits, output quantum state is probabilistic. If no
+qubits are set, all qubits are zero reset.. [/INST] def reset(self, qid=[]):
+ qstate_reset(self, qid=qid)
+[INST] free memory of the all quantum states.
+Parameters
+
+qstates : instance of QState,instance of QState,
+set of QState instances
+
+Returns
+
+None. [/INST] def free_all(cls, *qstates):
+ warnings.warn("No need to call 'free_all' method because free automatically, or you can use class method 'del_all' to free memory explicitly.")
+[INST] free memory of the all quantum states.
+Parameters
+
+qstates : instance of QState,instance of QState,
+set of QState instances
+
+Returns
+
+None. [/INST] def del_all(cls, *qstates):
+ for qs in qstates:
+ if type(qs) is list or type(qs) is tuple:
+ cls.del_all(*qs)
+ elif type(qs) is QState:
+ del qs
+ else:
+ raise QState_Error_FreeAll()
+[INST] get the partial quantum state.
+Parameters
+
+qid : list of int, default - list of all of the qubit id
+qubit id's list to get as a partial quantum
+system.
+
+Returns
+
+qs : instance of QState
+partial quantum state.
+
+Notes
+
+If 'qid' is set, specified partial quantum system are got
+after remaining system are measured. So if the specified
+quantum system are entangled with the remaining system, output
+quantum state is probabilistic. If no qubits are specified,
+you get the copy of original quantum state.. [/INST] def partial(self, qid=None):
+ vec = self.get_amp(qid)
+ qs = QState(vector=vec)
+ return qs
+[INST] show the quantum state
+(elements of the state vector and probabilities).
+
+Parameters
+
+qid : list of int, default - list of all of the qubit id
+qubit id's list to show.
+nonzero : bool, default False
+if True, only non-zero amplitudes are printed.
+
+Returns
+
+None
+
+Notes
+
+If 'qid' is set, it shows the elements of quantum state vector
+for partial quantum system specified by 'qid'. If the
+specified quantum system are entangled with the remaining
+system, output is probabilistic. If no qubits are specified,
+it shows the whole quantum state. This method does not change
+the original quantum state.
+
+Examples. [/INST] def show(self, qid=None, nonzero=False):
+ qstate_print(self, qid, nonzero)
+[INST] get the inner product with quantum state.
+Parameters
+
+qstate : instance of QState
+one of the two quantum state.
+qid : list of int, default - list of all of the qubit id
+qubit id's list.
+
+Returns
+
+inp : complex
+inner produt ().
+
+Notes
+
+If 'qid' is set, you can get the inner product for partial
+quantum state. If the specified quantum system are entangled
+with the remaining system, output value is probabilistic,
+while original quantum states do not change.. [/INST] def inpro(self, qstate, qid=[]):
+ if qid == []:
+ inp = qstate_inner_product(self, qstate)
+ else:
+ qs_0 = self.partial(qid=qid)
+ qs_1 = qstate.partial(qid=qid)
+ inp = qstate_inner_product(qs_0, qs_1)
+ return inp
+[INST] get the fidelity with quantum state.
+Parameters
+
+qstate : instance of QState
+one of the two quantum state.
+qid : list of int
+qubit id's list.
+
+Returns
+
+fid : float
+fidelity of two quantum states. absolute value of the
+inner product of two quantum states.
+
+Notes
+
+If 'qid' is set, you can get the fidelity for partial quantum
+state. If the specified quantum system are entangled with the
+remaining system, output value is probabilistic, while
+original quantum states do not change.. [/INST] def fidelity(self, qstate, qid=[]):
+ return abs(self.inpro(qstate, qid=qid))
+[INST] operate unitary operator to quantum state.
+Parameters
+
+pp : instance of PauliProduct
+pauli product to operate
+ctrl : int
+contoroll qubit id for controlled pauli product
+
+Returns
+
+self : instance of QState
+quantum state after operation. [/INST] def operate(self, pp=None, ctrl=None):
+ pauli_list = pp.pauli_list
+ qid = pp.qid
+ if ctrl is None:
+ for q, pauli in zip(qid, pauli_list):
+ if pauli == 'X':
+ self.x(q)
+ elif pauli == 'Y':
+ self.y(q)
+ elif pauli == 'Z':
+ self.z(q)
+ else:
+ continue
+ else:
+ if ctrl in qid:
+ raise ValueError("controll and target qubit id conflict")
+ for q, pauli in zip(qid, pauli_list):
+ if pauli == 'X':
+ self.cx(ctrl, q)
+ elif pauli == 'Y':
+ self.cy(ctrl, q)
+ elif pauli == 'Z':
+ self.cz(ctrl, q)
+ else:
+ continue
+ return self
+[INST] free memory of the all stabilizer.
+Parameters
+
+stabs : instance of Stabilizer,instance of Stabilizer,
+set of Stabilizer instances
+
+Returns
+
+None. [/INST] def free_all(cls, *stabs):
+ warnings.warn("No need to call 'free_all' method because free automatically, or you can use class method 'del_all' to free memory explicitly.")
+[INST] free memory of the all stabilizer.
+Parameters
+
+stabs : instance of Stabilizer,instance of Stabilizer,
+set of Stabilizer instances
+
+Returns
+
+None. [/INST] def del_all(cls, *stabs):
+ for sb in stabs:
+ if type(sb) is list or type(sb) is tuple:
+ cls.del_all(*sb)
+ elif type(sb) is Stabilizer:
+ del sb
+ else:
+ raise Stabilizer_Error_FreeAll()
+[INST] linear sum of the density operators.
+Parameters
+
+densop : list of instances of DensOp
+densitiy operators.
+prob : list of float
+probabilities (coefficients of the linear sum).. [/INST] def mix(cls, densop=[], prob=[]):
+ N = len(densop)
+ if sum(prob) != 1.0:
+ s = sum(prob)
+ prob = [p/s for p in prob]
+ de_out = densop[0].clone()
+ de_out.mul(factor=prob[0])
+ for i in range(1,len(densop)):
+ de_tmp = densop[i].clone()
+ de_tmp.mul(factor=prob[i])
+ de_out.add(densop=de_tmp)
+ return de_out
+[INST] free memory of the all density operators.
+Parameters
+
+densops : instance of DensOp,instance of DensOp,
+set of DensOp instances
+
+Returns
+
+None. [/INST] def free_all(cls, *densops):
+ warnings.warn("No need to call 'free_all' method because free automatically, or you can use class method 'del_all' to free memory explicitly.")
+[INST] free memory of the all density operators.
+Parameters
+
+densops : instance of DensOp,instance of DensOp,
+set of DensOp instances
+
+Returns
+
+None. [/INST] def del_all(cls, *densops):
+ for de in densops:
+ if type(de) is list or type(de) is tuple:
+ cls.del_all(*de)
+ elif type(de) is DensOp:
+ del de
+ else:
+ raise DensOp_Error_FreeAll()
+[INST] get the copy of density operator.
+Parameters
+
+None
+
+Returns
+
+densop : instance of DensOp
+copy of the original density operator.. [/INST] def clone(self):
+ obj = densop_copy(self)
+ de = ctypes.cast(obj.value, ctypes.POINTER(self.__class__)).contents
+ return de
+[INST] get the partial trace of density operator.
+Parameters
+
+qid : list of int
+qubit id's list to show.
+
+Returns
+
+densop : instance of DensOp
+density operator after partial trace.. [/INST] def patrace(self, qid=[]):
+ obj = densop_patrace(self, qid=qid)
+ de = ctypes.cast(obj.value, ctypes.POINTER(self.__class__)).contents
+ return de
+[INST] get the tensor product with density operator.
+Parameters
+
+densop : instance of DensOp
+density operator to get the tensor product
+
+Returns
+
+densop_out : instance of DensOp
+tensor produt of 'self' and 'densop'.. [/INST] def tenspro(self, densop):
+ obj = densop_tensor_product(self, densop)
+ de = ctypes.cast(obj.value, ctypes.POINTER(self.__class__)).contents
+ return de
+[INST] get the composite density operator of same density operators.
+Parameters
+
+num : int
+number of density operators
+
+Returns
+
+de : instance of DensOp
+composite density operator.. [/INST] def composite(self, num=0):
+ if num <= 1:
+ return self
+ else:
+ de = self.clone()
+ for i in range(num-1):
+ de_tmp = de.tenspro(self)
+ de = de_tmp.clone()
+ return de
+[INST] get the expectation value of matrix under this density operator.
+Parameters
+
+matrix : list of list of complex
+matrix expression of hermitian operator.
+
+Returns
+
+value : float
+expectation value.
+
+Notes
+
+'matrix' must be hermitian, and its dimension is equal to
+the dimension of density operator.. [/INST] def expect(self, matrix=None):
+ densop = self.clone()
+ densop_apply_matrix(densop, matrix=matrix, dire='left')
+ value = densop.trace()
+ return value
+[INST] get the probabilities for measuring operators.
+(Kraus or POVM operators).
+
+Parameters
+
+kraus : list of list of comprex
+Kraus operators.
+povm : list of list of comprex
+POVM operators.
+qid : list
+qubit id's list to measure.
+
+Returns
+
+prob : list of float
+probabilities for measuring operators.
+
+Notes
+
+Either 'kraus' or 'povm' must be set. If 'qid' is not set, all
+of the qubits are measured, the dimention of Kraus or POVM
+operator must be equal to the dimension of density
+operator. If 'qid' is set, the part of qubits are measured,
+the dimension of Kraus or POVM operator must be equal to the 2
+power of 'qid' length. This method does not change the
+original density operator.. [/INST] def probability(self, kraus=[], povm=[], qid=[]):
+ if kraus != []:
+ N = len(kraus)
+ prob = [0.0]*N
+ for i in range(N):
+ prob[i] = densop_probability(self, matrix=kraus[i], qid=qid,
+ matrix_type='kraus')
+ if abs(prob[i]) < EPS:
+ prob[i] = 0.0
+ elif povm != []:
+ N = len(povm)
+ prob = [0.0]*N
+ for i in range(N):
+ prob[i] = densop_probability(self, matrix=povm[i], qid=qid,
+ matrix_type='povm')
+ if abs(prob[i]) < EPS:
+ prob[i] = 0.0
+ else:
+ raise DensOp_Error_Probability()
+ return prob
+[INST] get the spectrum.
+Parameters
+
+None
+
+Returns
+
+qstate : list of QState
+list of the quantum state basis.
+prob : list of float
+list of coefficients for each quantum states basis.. [/INST] def spectrum(self):
+ mat = self.get_elm()
+ eigvals,eigvecs = self.__mat_spectrum(mat)
+ prob = [eigvals[i] for i in range(len(eigvals)) if abs(eigvals[i]) > EPS]
+ vecs = [eigvecs[i] for i in range(len(eigvals)) if abs(eigvals[i]) > EPS]
+ qstate = [QState(vector=vecs[i]) for i in range(len(prob))]
+ return qstate,prob
+[INST] free memory of the all quantum computers.
+Parameters
+
+qcomps : instance of QComp,instance of QComp,
+set of QComp instances
+
+Returns
+
+None. [/INST] def free_all(cls, *qcomps):
+ warnings.warn("No need to call 'free_all' method because free automatically, or you can use class method 'del_all' to free memory explicitly.")
+[INST] free memory of the all quantum computers.
+Parameters
+
+qcomps : instance of QComp,instance of QComp,
+set of QComp instances
+
+Returns
+
+None. [/INST] def del_all(cls, *qcomps):
+ for qc in qcomps:
+ if type(qc) is list or type(qc) is tuple:
+ cls.del_all(*qc)
+ elif type(qc) is QComp:
+ del qc
+ else:
+ raise QComp_Error_FreeAll()
+[INST] last measured value (binary string). [/INST] def last(self):
+ mval = self.measured_value(angle=self.angle, phase=self.phase)
+ digits = len(self.qid)
+ return '{:0{digits}b}'.format(mval, digits=digits)
+[INST] Flush modified and created object to the database before clearing the
+cache.. [/INST] def save_changes(self, session=None, cursor=None):
+ self.log_memory_usage()
+ if session is None:
+ session = self.session
+ if cursor is None and self._use_copy:
+ cursor = self.session.using_bind(
+ self._engine_name).connection().connection.cursor()
+ self._log('debug', 'Flushing model cache.')
+ self._assign_sequences()
+ self._sync_relationship_attrs()
+ for model_name in self._save_order:
+ if model_name not in self._cached_instances:
+ continue
+ objects = self._cached_instances[model_name]
+ objects = self._filter_sa_result_objects(objects)
+ if len(objects) == 0:
+ continue
+ self._deregister_change_handler(
+ type(objects[0]), self._instance_change_handler)
+ new_objects, updated_objects = [], []
+ for object in objects:
+ if attributes.instance_state(object).key is None:
+ new_objects.append(object)
+ else:
+ updated_objects.append(object)
+ if self._use_copy:
+ self._save_by_copy(cursor, new_objects)
+ else:
+ session.bulk_save_objects(new_objects)
+ session.bulk_save_objects(updated_objects)
+ session.flush()
+ if self._use_copy:
+ cursor.connection.commit()
+ self.clear(session)
+ self._log('info', 'Flushed model cache.')
+[INST] Insert objects by using PostgreSQL's COPY command. This is
+database-specific but one of the most efficient ways to populate a
+table. Expects instances of one single model per call.. [/INST] def _save_by_copy(self, cursor, objects):
+ if len(objects) == 0:
+ return
+ model = inspect(objects[0]).mapper
+ for table in model.tables:
+ file = io.StringIO()
+ writer = csv.DictWriter(file, table.columns.keys())
+ columns = [
+ c for c in model.c.items()
+ if c[1].table == table or c[1].primary_key
+ ]
+ for object in objects:
+ row = {}
+ for attr, column in columns:
+ value = getattr(object, attr)
+ if value is None:
+ if column.default is not None:
+ value = column.default.arg
+ else:
+ value = r'\N'
+ elif type(value) != column.type.python_type:
+ value = column.type.python_type(value)
+ row[column.key] = value
+ writer.writerow(row)
+ file.seek(0)
+ columns_string = ','.join(
+ [f'"{column}"' for column in table.columns.keys()]
+ )
+ cursor.copy_expert(
+ f'COPY {table} ({columns_string}) '
+ "FROM STDIN WITH CSV DELIMITER ',' NULL '\\N'",
+ file,
+ )
+ cursor.connection.commit()
+[INST] Clear the cache. Will result in data loss of unflushed objects.. [/INST] def clear(self, session=None):
+ self._cached_instances.clear()
+ self._indices.clear()
+ gc.collect()
+ self.log_memory_usage()
+[INST] Return a list of every existing and newly created instance of `model`.. [/INST] def _get_model_cache(self, model):
+ model_key = self._model_key(model)
+ if model_key not in self._cached_instances:
+ if model_key in self._preload_models_data:
+ query = model.query(*self._preload_models_data[model_key])
+ else:
+ query = model.query()
+ if model_key in self._preload_models_filter:
+ query = query.filter(self._preload_models_filter[model_key])
+ if self._prefetch is not None and model_key in self._prefetch:
+ query = query.options(
+ sqlalchemy.orm.joinedload(
+ *[attr for attr in self._prefetch[model_key]]
+ )
+ )
+ if self._preload_models:
+ self._cached_instances[model_key] = query.all()
+ else:
+ self._cached_instances[model_key] = query.limit(0).all()
+ self._register_change_handler(model, self._instance_change_handler)
+ return self._cached_instances[model_key]
+[INST] Return a dictionary containing tuples of indexed attributes as keys.. [/INST] def _get_model_indices(self, model):
+ model_key = self._model_key(model)
+ if model_key not in self._indices:
+ self._indices[model_key] = {}
+ return self._indices[model_key]
+[INST] Return a dictionary containing model attribute values as keys and a
+list of matching instances as values.. [/INST] def _get_attribute_index(self, model, attributes):
+ model_indices = self._get_model_indices(model)
+ attribute_key = tuple(sorted(attributes))
+ if attribute_key not in model_indices:
+ model_cache = self._get_model_cache(model)
+ indexed_instances = {}
+ for instance in model_cache:
+ instance_key = self._object_instance_key(
+ instance, attribute_key
+ )
+ cache = indexed_instances.setdefault(instance_key, [])
+ cache.append(instance)
+ model_indices[attribute_key] = indexed_instances
+ return model_indices[attribute_key]
+[INST] Return the key used to reference a given `model` in the instance cache
+and indices.. [/INST] def _model_key(self, model):
+ return model.__name__
+[INST] Assign sequence values to empty model attributes specified in the
+constructor.
+
+This iterates over every cached model and its corresponding sequences
+and assigns sequence values where applicable
+Empty sequence attributes are counted and matching sequence values are
+fetched from the database in a single request.. [/INST] def _assign_sequences(self):
+ for model_name, objects in self._cached_instances.items():
+ if model_name not in self._sequences:
+ continue
+ model_sequences = self._sequences[model_name]
+ for sequence in model_sequences:
+ new_objects = list(
+ filter(lambda o: getattr(o, sequence[0]) is None, objects)
+ )
+ if len(new_objects) > 0:
+ sequence_values = self.session.using_bind(
+ self._engine_name
+ ).execute(
+ "select nextval('%s') from generate_series(1,%s)"
+ % (sequence[1], len(new_objects))
+ )
+ for object, value in zip(new_objects, sequence_values):
+ setattr(object, sequence[0], value[0])
+[INST] Synchronize relationship attributes with their corresponding ID
+attributes.
+
+When using SQLAlchemy bulk functionality, the ORM behavior is changed
+in multiple ways. One important change is that relationship attributes
+aren't evaluated anymore.
+To still handle related object as expected, we iterate over
+relationship attributes and set their corresponding ID attributes on
+the same object.. [/INST] def _sync_relationship_attrs(self):
+ for objects in self._cached_instances.values():
+ objects = self._filter_sa_result_objects(objects)
+ if len(objects) == 0:
+ continue
+ model = type(list(objects)[0])
+ rel_map = [
+ (rel.key, pair[1].key, pair[0].key)
+ for rel in inspect(model).relationships
+ for pair in rel.local_remote_pairs
+ if rel.direction is not sqlalchemy.orm.interfaces.ONETOMANY
+ ]
+ for object in objects:
+ for from_obj, from_attr, to_attr in rel_map:
+ if getattr(object, to_attr, None) is None:
+ setattr(
+ object,
+ to_attr,
+ getattr(
+ getattr(object, from_obj), from_attr, None
+ ),
+ )
+[INST] Re-index model instances that were changed
+(to keep the index up-to-date).
+Called by SQLAlchemy's model `set` event which fires when a model
+attribute was changed.. [/INST] def _instance_change_handler(self, instance, value, oldvalue, initiator):
+ if oldvalue is sqlalchemy.util.symbol('NEVER_SET'):
+ return
+ model = type(instance)
+ changed_attr = initiator.key
+ model_indices = self._get_model_indices(model)
+ for attribute_key in model_indices.keys():
+ if changed_attr not in attribute_key:
+ continue
+ old_instance_key = self._object_instance_key(
+ instance, attribute_key, replace={changed_attr: oldvalue}
+ )
+ new_instance_key = self._object_instance_key(
+ instance, attribute_key, replace={changed_attr: value}
+ )
+ old_cache = model_indices[attribute_key].get(old_instance_key, [])
+ if len(old_cache) > 1:
+ old_cache.remove(instance)
+ else:
+ model_indices[attribute_key].pop(old_instance_key, [])
+ cache = model_indices[attribute_key].setdefault(
+ new_instance_key, []
+ )
+ if instance not in cache:
+ cache.append(instance)
+[INST] Create a persisted example object in the database.. [/INST] def example_model(test_model_factory):
+ model = test_model_factory('db1')
+ db = risclog.sqlalchemy.db.get_database(testing=True)
+ db.create_all('db1')
+ model.persist()
+ transaction.commit()
+ return model
+[INST] Somewhat misleading name. This method is used to determine if we should
+(a) be running training AFTER environment data collection, and (b)
+evaluating with test-time rollouts periodically. For (a) the reason was
+the --disp option, which is needed to see cloth, will not let us run
+multiple Environment calls. This also applies to (b), but for (b) there's
+an extra aspect: most of these new environments have reward functions
+that are not easily interpretable, or they have extra stuff in `info`
+that would be better for us to use. In that case we should be using
+`load.py` to roll out these policies.
+
+Actually, even for stuff like cable-shape, where the reward really tells
+us all we need to know, it would be nice to understand failure cases
+based on the type of the target. So for now let's set it to record
+`cable-` which will ignore `cable` but catch all my custom environments.
+The custom cable environments don't use --disp but let's just have them
+here for consistency in the main-then-load paradigm.. [/INST] def has_deformables(task):
+ return ('cable-' in task) or ('cloth' in task) or ('bag' in task)
+[INST] Be careful with checking this condition.
+
+all right, let's modify this to incorpoate gt_state w/out too much
+extra work.. [/INST] def is_goal_conditioned(args):
+ goal_tasks = ['insertion-goal', 'cable-shape-notarget', 'cable-line-notarget',
+ 'cloth-flat-notarget', 'bag-color-goal']
+ goal_task = (args.task in goal_tasks)
+ if goal_task:
+ assert 'goal' in args.agent or 'gt_state' in args.agent, \
+ 'Agent should be a goal-based agent, or gt_state agent.'
+ return goal_task
+[INST] In some cases, we should filter out demonstrations.
+Filter for if t == 0, which means the initial state was a success.
+Also, for the bag envs, if we end up in a catastrophic state, I exit
+gracefully and we should avoid those demos (they won't have images we
+need for the dataset anyway).. [/INST] def ignore_this_demo(args, demo_reward, t, last_extras):
+ ignore = (t == 0)
+ if 'exit_gracefully' in last_extras:
+ assert last_extras['exit_gracefully']
+ return True
+ if (args.task in ['bag-color-goal']) and demo_reward <= 0.5:
+ return True
+ We can get 0.5 reward by touching the cube only (bad).
+ if args.task == 'bag-items-easy' and demo_reward <= 0.5:
+ return True
+ if args.task == 'bag-items-hard':
+ return (last_extras['zone_items_rew'] < 0.5 or
+ last_extras['zone_beads_rew'] == 0)
+ return ignore
+[INST] Prediction as the mean of the most-weighted gaussian.. [/INST] def pick_max_mean(pi, mu, var):
+ mu = tf.reshape(mu, (tf.shape(mu)[0], tf.shape(pi)[1], -1))
+ d_out = tf.shape(mu)[-1]
+ batch_size, k = pi.shape
+ prediction = np.zeros((batch_size, d_out))
+ argmax_pi = tf.argmax(pi, axis=1)
+ for i in range(batch_size):
+ ith_argmax_pi = argmax_pi[i].numpy()
+ prediction[i] = mu[i, ith_argmax_pi]
+ return prediction
+[INST] Prediction as a sample from the gaussian mixture.. [/INST] def sample_from_pdf(pi, mu, var, num_samples=1):
+ pi, mu, var = pi.numpy(), mu.numpy(), var.numpy()
+ var = var**4
+ pi = pi * (1/pi.sum(1)[..., None])
+ batch_size, k = pi.shape
+ mu = tf.reshape(mu, (tf.shape(mu)[0], tf.shape(pi)[1], -1))
+ d_out = tf.shape(mu)[-1]
+ samples = np.zeros((batch_size, num_samples, d_out))
+ for i in range(batch_size):
+ for j in range(num_samples):
+ idx = np.random.choice(range(k), p=pi[i])
+ draw = np.random.normal(mu[i, idx], np.sqrt(var[i, idx]))
+ samples[i,j] = draw
+ return samples
+[INST] Train on dataset for a specific number of iterations.. [/INST] def train(self, dataset, num_iter, writer):
+ for i in range(num_iter):
+ obs, act, info = dataset.random_sample()
+ configs = act['camera_config']
+ colormap, heightmap = self.get_heightmap(obs, configs)
+ loss = 0.
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss:.4f}')
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action given visual observations.. [/INST] def act(self, obs, info):
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if not obs:
+ return act
+ colormap, heightmap = self.get_heightmap(obs, self.camera_config)
+ p0_position = (self.bounds[:, 1] - self.bounds[:, 0]) / 2
+ p0_position += self.bounds[:, 0]
+ p1_position = p0_position
+ rotation = p.getQuaternionFromEuler((0, 0, 0))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, rotation),
+ 'pose1': (p1_position, rotation)}
+ act['params'] = params
+ return act
+[INST] Train on dataset for a specific number of iterations.
+Daniel: notice how little training data we use. One 'iteration' is
+simply one image and an associated action, drawn by (a) sampling
+demo, then (b) sampling time within it. We do heavy data
+augmentation, but it's still just one real image.
+
+If using a goal image, we use a different random_sample method that
+also picks the LAST image of that episode, which is assigned as the
+goal image. This would likely not work for super long-horizon tasks,
+but maybe; (Agarwal et al., NeurIPS 2016) in the PokeBot paper
+actually got something like this 'greedy-style' planning to work.
+Otherwise we might have to do something like (Nair et al., ICRA 2017)
+in the follow-up work where we feed in a target image for each time
+step, which would be the *next* image saved.
+
+For data augmentation with this goal image, I believe we should stack
+the current and goal image together, and THEN do augmentation. The
+perturb method will make sure placing pixels are preserved -- which
+for short-horizon environments usually means the goal image will
+contain most of the relevant information. When data augmenting, for
+both normal and goal-conditioned Transporters, the p1_theta
+(rotation) is the same, but pick points are correctly 'converted' to
+those appropriate for the augmented images.. [/INST] def train(self, dataset, num_iter, writer):
+ for i in range(num_iter):
+ if self.use_goal_image:
+ obs, act, info, goal = dataset.random_sample(goal_images=True)
+ else:
+ obs, act, info = dataset.random_sample()
+ configs = act['camera_config']
+ colormap, heightmap = self.get_heightmap(obs, configs)
+ if self.use_goal_image:
+ colormap_g, heightmap_g = self.get_heightmap(goal, configs)
+ pose0, pose1 = act['params']['pose0'], act['params']['pose1']
+ p0_position, p0_rotation = pose0[0], pose0[1]
+ p0 = utils.position_to_pixel(p0_position, self.bounds, self.pixel_size)
+ p0_theta = -np.float32(p.getEulerFromQuaternion(p0_rotation)[2])
+ p1_position, p1_rotation = pose1[0], pose1[1]
+ p1 = utils.position_to_pixel(p1_position, self.bounds, self.pixel_size)
+ p1_theta = -np.float32(p.getEulerFromQuaternion(p1_rotation)[2])
+ p1_theta = p1_theta - p0_theta
+ p0_theta = 0
+ input_image = self.concatenate_c_h(colormap, heightmap)
+ if self.use_goal_image:
+ goal_image = self.concatenate_c_h(colormap_g, heightmap_g)
+ input_image = np.concatenate((input_image, goal_image), axis=2)
+ assert input_image.shape[2] == 12, input_image.shape
+ original_pixels = (p0, p1)
+ input_image, pixels = utils.perturb(input_image, [p0, p1])
+ p0, p1 = pixels
+ if False:
+ self.visualize_images(p0, p0_theta, p1, p1_theta, original_pixels,
+ colormap=colormap, heightmap=heightmap,
+ colormap_g=colormap_g, heightmap_g=heightmap_g,
+ input_image=input_image, before_aug=False)
+ if self.attn_no_targ and self.use_goal_image:
+ maxdim = int(input_image.shape[2] / 2)
+ input_only = input_image[:, :, :maxdim]
+ loss0 = self.attention_model.train(input_only, p0, p0_theta)
+ else:
+ loss0 = self.attention_model.train(input_image, p0, p0_theta)
+ with writer.as_default():
+ tf.summary.scalar('attention_loss', self.attention_model.metric.result(),
+ step=self.total_iter+i)
+ if isinstance(self.transport_model, Attention):
+ loss1 = self.transport_model.train(input_image, p1, p1_theta)
+ elif isinstance(self.transport_model, TransportGoal):
+ half = int(input_image.shape[2] / 2)
+ img_curr = input_image[:, :, :half]
+ img_goal = input_image[:, :, half:]
+ loss1 = self.transport_model.train(img_curr, img_goal, p0, p1, p1_theta)
+ else:
+ loss1 = self.transport_model.train(input_image, p0, p1, p1_theta)
+ with writer.as_default():
+ tf.summary.scalar('transport_loss', self.transport_model.metric.result(),
+ step=self.total_iter+i)
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss0:.4f} {loss1:.4f}')
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action given visual observations.
+If goal-conditioning, provide `goal`.. [/INST] def act(self, obs, info, debug_imgs=False, goal=None):
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if not obs:
+ return act
+ colormap, heightmap = self.get_heightmap(obs, self.camera_config)
+ if goal is not None:
+ colormap_g, heightmap_g = self.get_heightmap(goal, self.camera_config)
+ input_image = self.concatenate_c_h(colormap, heightmap)
+ if self.use_goal_image:
+ goal_image = self.concatenate_c_h(colormap_g, heightmap_g)
+ input_image = np.concatenate((input_image, goal_image), axis=2)
+ assert input_image.shape[2] == 12, input_image.shape
+ if self.attn_no_targ and self.use_goal_image:
+ maxdim = int(input_image.shape[2] / 2)
+ input_only = input_image[:, :, :maxdim]
+ attention = self.attention_model.forward(input_only)
+ else:
+ attention = self.attention_model.forward(input_image)
+ argmax = np.argmax(attention)
+ argmax = np.unravel_index(argmax, shape=attention.shape)
+ p0_pixel = argmax[:2]
+ p0_theta = argmax[2] * (2 * np.pi / attention.shape[2])
+ if isinstance(self.transport_model, TransportGoal):
+ half = int(input_image.shape[2] / 2)
+ img_curr = input_image[:, :, :half]
+ img_goal = input_image[:, :, half:]
+ transport = self.transport_model.forward(img_curr, img_goal, p0_pixel)
+ else:
+ transport = self.transport_model.forward(input_image, p0_pixel)
+ argmax = np.argmax(transport)
+ argmax = np.unravel_index(argmax, shape=transport.shape)
+ p1_pixel = argmax[:2]
+ p1_theta = argmax[2] * (2 * np.pi / transport.shape[2])
+ p0_position = utils.pixel_to_position(p0_pixel, heightmap, self.bounds, self.pixel_size)
+ p1_position = utils.pixel_to_position(p1_pixel, heightmap, self.bounds, self.pixel_size)
+ p0_rotation = p.getQuaternionFromEuler((0, 0, -p0_theta))
+ p1_rotation = p.getQuaternionFromEuler((0, 0, -p1_theta))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, p0_rotation),
+ 'pose1': (p1_position, p1_rotation)}
+ act['params'] = params
+ if self.task in ['bag-items-easy', 'bag-items-hard', 'bag-color-goal']:
+ self._determine_task_stage(p0_pixel, p1_pixel)
+ if debug_imgs:
+ a_heat = self.attention_model.get_attention_heatmap(attention)
+ t_heat = self.transport_model.get_transport_heatmap(transport)
+ extras = {
+ 'input_c': cv2.cvtColor(colormap, cv2.COLOR_RGB2BGR),
+ 'attn_heat_bgr': a_heat,
+ 'tran_heat_bgr': t_heat,
+ 'tran_rot_argmax': argmax[2],
+ 'tran_p1_theta': p1_theta,
+ }
+ return act, extras
+ else:
+ return act
+[INST] Concatenates color and height images to get a 6D image.. [/INST] def concatenate_c_h(self, colormap, heightmap):
+ img = np.concatenate((colormap,
+ heightmap[..., None],
+ heightmap[..., None],
+ heightmap[..., None]), axis=2)
+ assert img.shape == self.input_shape, img.shape
+ return img
+[INST] Load pre-trained models.. [/INST] def load(self, num_iter):
+ attention_fname = 'attention-ckpt-%d.h5' % num_iter
+ transport_fname = 'transport-ckpt-%d.h5' % num_iter
+ attention_fname = os.path.join(self.models_dir, attention_fname)
+ transport_fname = os.path.join(self.models_dir, transport_fname)
+ self.attention_model.load(attention_fname)
+ self.transport_model.load(transport_fname)
+ self.total_iter = num_iter
+[INST] Determines task stage for the bag-items tasks.
+Hacky solution, unfortunately, assumes we assigned task.env. Assumes that we
+have an actual `self.real_task` we can use; `self.task` is just a string.
+Currently working reasonably well for bag-items-easy. Note: see gt_state.py
+for the version that works for the gt_state baselines. Also, note that the
+self.real_task.task_stage is reset to 1 whenever we reset() so there's no
+need to deal with that logic here.
+
+(0) So far this is working reasonably well. Note that:
+if object_mask[p0_pixel] == 38, then we're picking a cube.
+for bag-items-hard, ID 39 is also a block (note: no cubes).
+and if we end up putting both items into the bag, BUT we STILL end up
+with object_mask[p0_pixel] as gripping one of those items, we DO NOT go
+to task stage 3, which makes sense and is correct.
+
+(1) However, it will not be able to catch this failure:
+putting two items in the bag.
+but picking one item out of the bag, and inserting it to the target
+then properly doing 'task stage 3' by pulling and inserting the bag
+with ONE item into the zone. However, the first item that was pulled out
+may be in the zone, leading to an 'undesirable' success.
+
+(2) Another potential failure that I observed:
+FIRST action: sometimes the robot grips an item at the start, which is
+the right thing to do if the bag is already open. But we may grip the block
+EVEN THOUGH the pixel may not correspond to the block in the the segmentation
+mask. This happens when the pick point is just outside the block's pixels.
+Then this will result in a very slow pick and place since we follow task
+stage 1 parameters with slow movements and often get movej time outs if the
+block started far from the bag. In that case we can argue that we should have
+had a better pick point so that we correctly detect that we should move onto
+task stage 2. A better solution would be to check for any contact points
+after we grip, and test for contact with block IDs.
+I ended up increasing the task stage 1 speed from 0.003 --> 0.004.
+
+(3) Of course, there is still a failure of picking a vertex instead of a bag bead
+during the bag pulling stage. AH, oh well.. [/INST] def _determine_task_stage(self, p0_pixel, p1_pixel):
+ real_task = self.real_task
+ colormap, heightmap, object_mask = real_task.get_object_masks(real_task.env)
+ if False:
+ nb = len([x for x in os.listdir('.') if '.png' in x])
+ mask = np.array(object_mask / np.max(object_mask) * 255).astype(np.uint8)
+ mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
+ p0 = (p0_pixel[1], p0_pixel[0])
+ p1 = (p1_pixel[1], p1_pixel[0])
+ cv2.circle(mask, p0, radius=3, color=(255,0,255), thickness=-1)
+ cv2.circle(mask, p1, radius=3, color=(255,255,0), thickness=-1)
+ cv2.imwrite(f'mask_{nb}.png', mask)
+ if self.task in ['bag-items-easy', 'bag-items-hard']:
+ if object_mask[p0_pixel] in [38, 39]:
+ real_task.task_stage = 2
+ elif real_task.task_stage == 2:
+ real_task.task_stage = 3
+ elif self.task in ['bag-color-goal']:
+ if object_mask[p0_pixel] == real_task.single_block_ID:
+ real_task.task_stage = 2
+ else:
+ real_task.task_stage = 1
+ else:
+ raise NotImplementedError(self.task)
+[INST] code to debug and visualuze the image (including perturbed).
+The height maps will not be grayscale because of matplotlib's color
+scheme, I think. Using cv2.imwrite(..., heightmap) shows grayscale.. [/INST] def visualize_images(self, p0, p0_theta, p1, p1_theta, original_pixels,
+ colormap, heightmap, colormap_g, heightmap_g, input_image, before_aug):
+ print(f'\nForward pass.')
+ p0_theta_d = (180 / np.pi) * p0_theta
+ p1_theta_d = (180 / np.pi) * p1_theta
+ heightmap = heightmap / np.max(heightmap) * 255
+ heightmap_g = heightmap_g / np.max(heightmap_g) * 255
+ plt.subplots(1, 6, figsize=(12,4))
+ plt.subplot(1,6,1)
+ plt.imshow(colormap)
+ plt.subplot(1,6,2)
+ plt.imshow(heightmap)
+ plt.subplot(1,6,3)
+ plt.imshow(colormap_g)
+ plt.subplot(1,6,4)
+ plt.imshow(heightmap_g)
+ plt.subplot(1,6,5)
+ plt.imshow(np.array(input_image[:,:,0:3]).astype(np.uint8))
+ plt.subplot(1,6,6)
+ plt.imshow(np.array(input_image[:,:,6:9]).astype(np.uint8))
+ op_0, op_1 = original_pixels
+ title = f'Before Data Aug: ' \
+ f'Pick: ({op_0}, {p0_theta:0.2f}) ' \
+ f'Place: ({op_1}, {p1_theta:0.2f}={p1_theta_d:0.2f})\n' \
+ f'After Data Aug: ' \
+ f'Pick: ({p0}) ' \
+ f'Place: ({p1})'
+ plt.suptitle(title, fontsize=15)
+ plt.tight_layout()
+ plt.show()
+[INST] Initialize models, including normalization parameters.. [/INST] def init_model(self, dataset):
+ self.set_max_obs_vector_length(dataset)
+ if self.goal_conditioned:
+ _, _, info, goal = dataset.random_sample(goal_images=True)
+ obs_vector = self.info_to_gt_obs(info, goal=goal)
+ else:
+ _, _, info = dataset.random_sample()
+ obs_vector = self.info_to_gt_obs(info)
+ obs_dim = obs_vector.shape[0]
+ act_dim = 3
+ self.pick_model = MlpModel(self.BATCH_SIZE, obs_dim, act_dim, 'relu', self.USE_MDN, dropout=0.1)
+ sampled_gt_obs = []
+ num_samples = 1000
+ for _ in range(num_samples):
+ t_worldaug_world, _ = self.get_augmentation_transform()
+ if self.goal_conditioned:
+ _, _, info, goal = dataset.random_sample(goal_images=True)
+ sampled_gt_obs.append(self.info_to_gt_obs(info, t_worldaug_world, goal=goal))
+ else:
+ _, _, info = dataset.random_sample()
+ sampled_gt_obs.append(self.info_to_gt_obs(info, t_worldaug_world))
+ sampled_gt_obs = np.array(sampled_gt_obs)
+ obs_train_parameters = dict()
+ obs_train_parameters['mean'] = sampled_gt_obs.mean(axis=(0)).astype(np.float32)
+ obs_train_parameters['std'] = sampled_gt_obs.std(axis=(0)).astype(np.float32)
+ self.pick_model.set_normalization_parameters(obs_train_parameters)
+ self.pick_obs_train_mean = obs_train_parameters['mean']
+ self.pick_obs_train_std = obs_train_parameters['std']
+ obs_dim = obs_vector.shape[0] + act_dim
+ act_dim = 3
+ self.place_model = MlpModel(self.BATCH_SIZE, obs_dim, act_dim, 'relu', self.USE_MDN, dropout=0.1)
+ sampled_gt_obs = []
+ num_samples = 1000
+ for _ in range(num_samples):
+ t_worldaug_world, _ = self.get_augmentation_transform()
+ if self.goal_conditioned:
+ _, act, info, goal = dataset.random_sample(goal_images=True)
+ obs = self.info_to_gt_obs(info, t_worldaug_world, goal=goal)
+ else:
+ _, act, info = dataset.random_sample()
+ obs = self.info_to_gt_obs(info, t_worldaug_world)
+ obs = np.hstack((obs, self.act_to_gt_act(act, t_worldaug_world)[:3]))
+ sampled_gt_obs.append(obs)
+ sampled_gt_obs = np.array(sampled_gt_obs)
+ obs_train_parameters = dict()
+ obs_train_parameters['mean'] = sampled_gt_obs.mean(axis=(0)).astype(np.float32)
+ obs_train_parameters['std'] = sampled_gt_obs.std(axis=(0)).astype(np.float32)
+ self.place_model.set_normalization_parameters(obs_train_parameters)
+ self.place_obs_train_mean = obs_train_parameters['mean']
+ self.place_obs_train_std = obs_train_parameters['std']
+ print('Done initializing self.model for ground truth agent.')
+[INST] Train on dataset for a specific number of iterations.
+As with the gt_state, need a special case to handle the num_iter=0 case.. [/INST] def train(self, dataset, num_iter, writer, validation_dataset=None):
+ if self.pick_model is None:
+ self.init_model(dataset)
+ if self.USE_MDN:
+ loss_criterion = mdn_utils.mdn_loss
+ else:
+ loss_criterion = tf.keras.losses.MeanSquaredError()
+ @tf.function
+ def train_step(pick_model, place_model, batch_obs, batch_act, loss_criterion):
+ with tf.GradientTape() as tape:
+ prediction = pick_model(batch_obs)
+ loss0 = loss_criterion(batch_act[:,0:3], prediction)
+ grad = tape.gradient(loss0, pick_model.trainable_variables)
+ self.pick_optim.apply_gradients(zip(grad, pick_model.trainable_variables))
+ with tf.GradientTape() as tape:
+ batch_obs = tf.concat((batch_obs, batch_act[:,0:3]), axis=1)
+ prediction = place_model(batch_obs)
+ loss1 = loss_criterion(batch_act[:,3:], prediction)
+ grad = tape.gradient(loss1, place_model.trainable_variables)
+ self.place_optim.apply_gradients(zip(grad, place_model.trainable_variables))
+ return loss0 + loss1
+ if num_iter == 0:
+ batch_obs, batch_act, obs, act, info = self.get_data_batch(dataset)
+ print('Doing a single forward pass to enable us to save a snapshot for the num_iter == 0 case.')
+ prediction = self.pick_model(batch_obs)
+ batch_obs = tf.concat((batch_obs, batch_act[:,0:3]), axis=1)
+ prediction = self.place_model(batch_obs)
+ print_rate = 50
+ for i in range(num_iter):
+ start = time.time()
+ batch_obs, batch_act, obs, act, info = self.get_data_batch(dataset)
+ self.metric.reset_states()
+ loss = train_step(self.pick_model, self.place_model, batch_obs, batch_act, loss_criterion)
+ self.metric(loss)
+ with writer.as_default():
+ tf.summary.scalar('gt_state_loss', self.metric.result(), step=self.total_iter+i)
+ if i % print_rate == 0:
+ loss = np.float32(loss)
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss:.4f} Iter time:', time.time() - start)
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action.. [/INST] def act(self, obs, info, goal=None):
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if self.goal_conditioned:
+ gt_obs = self.info_to_gt_obs(info, goal=goal)
+ else:
+ gt_obs = self.info_to_gt_obs(info)
+ pick_prediction = self.pick_model(gt_obs[None, ...])
+ if self.USE_MDN:
+ pi, mu, var = pick_prediction
+ pick_prediction = mdn_utils.sample_from_pdf(pi, mu, var)
+ pick_prediction = pick_prediction[:, 0, :]
+ pick_prediction = pick_prediction[0]
+ obs_with_pick = np.hstack((gt_obs, pick_prediction))
+ obs_with_pick[-1] = 0.0
+ place_prediction = self.place_model(obs_with_pick[None, ...])
+ if self.USE_MDN:
+ pi, mu, var = place_prediction
+ place_prediction = mdn_utils.sample_from_pdf(pi, mu, var)
+ place_prediction = place_prediction[:, 0, :]
+ place_prediction = place_prediction[0]
+ prediction = np.hstack((pick_prediction, place_prediction))
+ p0_position = np.hstack((prediction[0:2], 0.02))
+ p0_pred_rot = 0.0 if self.one_rot_inf else -prediction[2]*self.THETA_SCALE
+ p0_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, p0_pred_rot))
+ p1_position = np.hstack((prediction[3:5], 0.02))
+ p1_pred_rot = 0.0 if self.one_rot_inf else -prediction[5]*self.THETA_SCALE
+ p1_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, p1_pred_rot))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, p0_rotation),
+ 'pose1': (p1_position, p1_rotation)}
+ act['params'] = params
+ if self.task in ['bag-items-easy', 'bag-items-hard']:
+ self._determine_task_stage(p0_position, p1_position)
+ return act
+[INST] Load in a similar fashion as the 1-step GT agent.. [/INST] def load(self, num_iter):
+ pick_fname = 'gt-state-2-step-pick-ckpt-%d' % num_iter
+ place_fname = 'gt-state-2-step-place-ckpt-%d' % num_iter
+ pick_fname = os.path.join(self.models_dir, pick_fname)
+ place_fname = os.path.join(self.models_dir, place_fname)
+ self.pick_model = keras.models.load_model(pick_fname, compile=False)
+ self.place_model = keras.models.load_model(place_fname, compile=False)
+ self.total_iter = num_iter
+ data_fname = os.path.join(self.models_dir, 'misc_data.pkl')
+ with open(data_fname, 'rb') as fh:
+ data = pickle.load(fh)
+ self.max_obs_vector_length = data['max_obs_vector_length']
+ self.pick_model.obs_train_mean = data['pick_obs_train_mean']
+ self.place_model.obs_train_mean = data['place_obs_train_mean']
+ self.pick_model.obs_train_std = data['pick_obs_train_std']
+ self.place_model.obs_train_std = data['place_obs_train_std']
+[INST] Parameter-less, extract coordinates for each channel.
+~H = size related to original image H size
+~W = size related to original image W size
+C channels. [/INST] def compute_spatial_soft_argmax(self, x):
+ H, W, C = 149, 69, 16
+ B = self.batch_size
+ x = tf.reshape(tf.transpose(x, [0, 3, 1, 2]), [B * C, H * W])
+ softmax = tf.nn.softmax(x)
+ softmax = tf.transpose(tf.reshape(softmax, [B, C, H, W]), [0, 2, 3, 1])
+ posx, posy = tf.meshgrid(tf.linspace(-1., 1., num=H),
+ tf.linspace(-1., 1., num=W),
+ indexing='ij')
+ image_coords = tf.stack((posx, posy), axis=2)
+ softmax = tf.expand_dims(softmax, -1)
+ image_coords = tf.expand_dims(image_coords, 2)
+ spatial_soft_argmax = tf.reduce_sum(softmax * image_coords, axis=[1, 2])
+ return spatial_soft_argmax
+[INST] Given the pickle file which has results (`data`) determine what to plot.
+We could use total_rewards but better to use `task.done`, because total_rewards
+shows the overall sum of delta rewards, but for convex hull tasks, that is not
+interpretable, for cloth and bag-items tasks, we'd rather use coverage or our
+defined success criterion, for cables we probably want percentage of beads in a
+zone, and also total_rewards won't take into account how some initial states
+may start with some reward (especially for cable/cloth stuff), etc. For some
+bag tasks we have to do a little extra filtering. Research code.
+
+So, to be clear, we use task.done as our metric, EXCEPT for: (a) cable tasks
+(not cable-ring) where we report percentage of beads in a zone, and (b) cloth
+flat stuff, where we report coverage.
+
+(13 Oct 2020) realized there is a slight issue, len(episode_list) CAN be 0
+for gt_state agents because before today, I was not saving the initial states.
+AH. This only matters if we wanted to report the gain in reward over the
+starting state. The final state still is OK, and the length is OK since we
+can take the value in the last_info and that is correct. It's just that before
+this, the length of the `episode_list` is one minus that.. [/INST] def _get_episode_results(args, data, name_task, pkl_file):
+ total_rewards = []
+ lengths = []
+ dones = []
+ iters = []
+ metrics = []
+ for item in data:
+ assert len(item) == 3, len(item)
+ itr, episode_list, last_info = item
+ if len(episode_list) == 0:
+ print(f'Note, zero length episode list in {pkl_file}')
+ assert 'gt_state' in pkl_file, \
+ 'We expect this to be a gt_state agent on or before October 13.'
+ elif 'gt_state' in pkl_file:
+ pass
+ else:
+ first_info = episode_list[0][1]
+ assert 'extras' in first_info, \
+ f'Issue with {first_info.keys()}\nfor {pkl_file}, length: {len(episode_list)}'
+ assert 'final_info' in last_info, last_info.keys()
+ if 'total_rewards' not in last_info['final_info'].keys():
+ assert name_task == 'insertion-goal', \
+ f'Something went wrong: look at: {last_info} in {pkl_file}'
+ total_rewards.append(last_info['reward'])
+ else:
+ total_rewards.append(last_info['final_info']['total_rewards'])
+ lengths.append(last_info['length'])
+ iters.append(itr)
+ done = last_info['final_info']['task.done']
+ if name_task not in ['bag-items-easy', 'bag-items-hard']:
+ dones.append(done)
+ elif name_task == 'bag-items-easy':
+ episode_total_reward = last_info['final_info']['total_rewards']
+ truly_done = done and episode_total_reward > 0.5
+ dones.append(truly_done)
+ elif name_task == 'bag-items-hard':
+ truly_done = (done and
+ (last_info['final_info']['zone_items_rew'] == 0.5) and
+ (last_info['final_info']['zone_beads_rew'] > 0))
+ dones.append(truly_done)
+ if args.binary:
+ metrics.append(dones[-1])
+ else:
+ if name_task in ['cloth-flat', 'cloth-flat-notarget']:
+ coverage = last_info['final_info']['cloth_coverage']
+ metrics.append(coverage)
+ elif name_task in ['cable-shape', 'cable-shape-notarget', 'cable-line-notarget']:
+ nb_zone = last_info['final_info']['nb_zone']
+ nb_beads = last_info['final_info']['nb_beads']
+ frac = nb_zone / nb_beads
+ metrics.append(frac)
+ elif name_task in ['bag-color-goal']:
+ frac = last_info['final_info']['frac_in_target_bag']
+ metrics.append(frac)
+ else:
+ metrics.append(dones[-1])
+ assert np.std(iters) == 0, iters
+ xs = int(np.mean(iters))
+ ys = metrics
+ return {'x': xs, 'y': ys, 'total_rewards': total_rewards, 'lengths': lengths, 'dones': dones}
+[INST] Plot only one thing, similar to the table method.
+Let's not plot transporter-goal-snaive.. [/INST] def plot_single(args, goal_conditioned, name_task, name_plot):
+ IGNORE = ['-transporter-goal-naive']
+ title_size = 40
+ x_size = 38
+ y_size = 38
+ tick_size = 35
+ legend_size = 25
+ lw = 5
+ ms = 12
+ nrows, ncols = 1, 4
+ fig, ax = plt.subplots(nrows, ncols, squeeze=True, figsize=(8.0*ncols, 9.0*nrows))
+ title = f'Training Results, {name_plot}'
+ supertitle = fig.suptitle(title, fontsize=title_size+2)
+ for col, num_demos in enumerate([1, 10, 100, 1000]):
+ directories, seeds = get_subdirectories_demos(name_task, num_demos)
+ stats_combo = _get_results(name_task, directories, seeds, debug)
+ for agent in sorted(stats_combo.keys()):
+ if agent in IGNORE:
+ print(f' ignoring {agent} for demos {num_demos}')
+ continue
+ x_vals = []
+ y_vals = []
+ y_errs = []
+ y_stde = []
+ print(f'on agent {agent} with keys {stats_combo[agent].keys()}')
+ for itr in sorted(stats_combo[agent].keys()):
+ x_vals.append( itr )
+ values = stats_combo[agent][itr]
+ if (len(values) != 60 and name_task != 'bag-color-goal' and itr != 0):
+ print(f'\tWarning, len(values): {len(values)} for: {agent}, {itr}')
+ elif (len(values) != 20 and name_task == 'bag-color-goal'):
+ print(f'\tWarning, len(values): {len(values)} for: {agent}, {itr}')
+ y_vals.append( np.mean(values) )
+ y_errs.append( np.std(values) )
+ y_stde.append( np.std(values) / np.sqrt(len(values)) )
+ x_vals = np.array(x_vals)
+ y_vals = np.array(y_vals)
+ y_errs = np.array(y_errs)
+ y_stde = np.array(y_stde)
+ x_vals = x_vals / 1000
+ ag = agent[1:]
+ agent_label = AGENT_TO_LABEL[ag]
+ color = AGENT_TO_COLOR[ag]
+ if standard_error:
+ lower = np.clip(y_vals - y_stde, 0.0 - EPS, 1.0 + EPS)
+ upper = np.clip(y_vals + y_stde, 0.0 - EPS, 1.0 + EPS)
+ else:
+ lower = np.clip(y_vals - y_errs, 0.0 - EPS, 1.0 + EPS)
+ upper = np.clip(y_vals + y_errs, 0.0 - EPS, 1.0 + EPS)
+ ax[col].plot(x_vals, y_vals, lw=lw, color=color, ms=ms, marker='o', label=agent_label)
+ ax[col].fill_between(x_vals, upper, lower, color=color, alpha=0.3, linewidth=0.0)
+ ax[col].legend(prop={'size': legend_size})
+ subplot_title = f'Train Demos: {num_demos}'
+ ax[col].set_title(subplot_title, fontsize=title_size)
+ ax[col].set_xlabel(xlabel, fontsize=x_size)
+ ylabel = 'Success Rate'
+ if not args.binary:
+ if name_task in ['cloth-flat', 'cloth-flat-notarget']:
+ ylabel = 'Zone Coverage'
+ elif name_task in ['cable-shape', 'cable-shape-notarget', 'cable-line-notarget']:
+ ylabel = 'Percent in Zone'
+ if col == 0:
+ ax[col].set_ylabel(ylabel, fontsize=y_size)
+ ax[col].set_ylim([0.0 - EPS, 1.0 + EPS])
+ ax[col].tick_params(axis='x', labelsize=tick_size)
+ ax[col].tick_params(axis='y', labelsize=tick_size)
+ plt.tight_layout()
+ supertitle.set_y(0.98)
+ fig.subplots_adjust(top=0.85)
+ suffix = f'plot_demos_task_{name_task}_v02.png'
+ fname = join(os.getcwd(), 'figures', suffix)
+ plt.savefig(fname)
+ print(f'\nImage saved to: {fname}')
+[INST] Use this for the broad overview table in the paper, showing convergence results.
+Careful: we use some braces, so could be tricky to integrate with string formatting.
+Remember that for line breaks we need an escape character for \, so \\.
+NOTE: Put this between \toprule and \bottomrule commands in LaTeX for tables.. [/INST] def print_table(args):
+ s = ''
+ tasks_l = ['cable-ring', 'cable-ring-notarget', 'cable-shape', 'cloth-cover',
+ 'cloth-flat', 'bag-alone-open', 'bag-items-easy', 'bag-items-hard',
+ 'cable-line-notarget', 'cable-shape-notarget', 'cloth-flat-notarget', 'bag-color-goal',]
+ T = []
+ for t in tasks_l:
+ if t in REPLACE_TASK:
+ T.append(REPLACE_TASK[t])
+ else:
+ T.append(t)
+ a1 = defaultdict(list)
+ a2 = defaultdict(list)
+ a3 = defaultdict(list)
+ a4 = defaultdict(list)
+ a5 = defaultdict(list)
+ a6 = defaultdict(list)
+ a7 = defaultdict(list)
+ for name_task in tasks_l:
+ for num_demos in [1, 10, 100, 1000]:
+ directories, seeds = get_subdirectories_demos(name_task, num_demos)
+ stats_combo = _get_results(name_task, directories, seeds, debug)
+ agents = sorted(stats_combo.keys())
+ print(f'For task {name_task}, we have results from agents: {agents}')
+ def get_max(_name_task, ag_key):
+ stat_max = -1
+ for key in sorted(stats_combo[ag_key].keys()):
+ if (len(stats_combo[ag_key][key]) != 60 and name_task != 'bag-color-goal' and key != 0):
+ print(f'\tWarning, len(stats_combo[ag_key][key]): {len(stats_combo[ag_key][key])} for: {ag_key}, {key}')
+ elif len(stats_combo[ag_key][key]) != 20 and name_task == 'bag-color-goal':
+ print(f'\tWarning, len(stats_combo[ag_key][key]): {len(stats_combo[ag_key][key])} for: {ag_key}, {key}')
+ stat_max = max(stat_max, np.mean(stats_combo[ag_key][key]))
+ stat_max *= 100
+ if _name_task in REPLACE_TASK:
+ return (REPLACE_TASK[_name_task], stat_max)
+ else:
+ return (_name_task, stat_max)
+ if '-transporter' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-transporter')
+ a1[i1].append(i2)
+ if '-transporter-goal' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-transporter-goal')
+ a2[i1].append(i2)
+ if '-transporter-goal-naive' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-transporter-goal-naive')
+ a3[i1].append(i2)
+ if '-gt_state' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-gt_state')
+ a4[i1].append(i2)
+ if '-gt_state_2_step' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-gt_state_2_step')
+ a5[i1].append(i2)
+ if '-conv_mlp' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-gt_state_2_step')
+ a6[i1].append(i2)
+ if '-transporter-goal-snaive' in agents:
+ i1, i2 = get_max(_name_task=name_task, ag_key='-transporter-goal-snaive')
+ a7[i1].append(i2)
+ print('\nDebugging of which results we have:')
+ print(f'keys in a1 (transporter): {sorted(a1.keys())}')
+ print(f'keys in a2 (transporter-goal): {sorted(a2.keys())}')
+ print(f'keys in a3 (transporter-goal-naive): {sorted(a3.keys())}')
+ print(f'keys in a4 (gt_state): {sorted(a4.keys())}')
+ print(f'keys in a5 (gt_state_2_step): {sorted(a5.keys())}')
+ print(f'keys in a6 (conv_mlp): {sorted(a6.keys())}')
+ print(f'keys in a7 (transporter-goal-S-naive): {sorted(a7.keys())}')
+ NA = 'N/A'
+ s += f'\n & \multicolumn{{4}}{{c}}{{{T[0]}}} & \multicolumn{{4}}{{c}}{{{T[1]}}} & \multicolumn{{4}}{{c}}{{{T[2]}}} & \multicolumn{{4}}{{c}}{{{T[3]}}} \\\\'
+ s += '\n \cmidrule(lr){2-5} \cmidrule(lr){6-9} \cmidrule(lr){10-13} \cmidrule(lr){14-17}'
+ s += '\n Method & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 \\\\'
+ s += '\n \midrule'
+ s += (f'\n GT-State MLP & {a4[ T[0] ][0]:0.1f} & {a4[ T[0] ][1]:0.1f} & {a4[ T[0] ][2]:0.1f} & {a4[ T[0] ][3]:0.1f} '
+ f'& {a4[ T[1] ][0]:0.1f} & {a4[ T[1] ][1]:0.1f} & {a4[ T[1] ][2]:0.1f} & {a4[ T[1] ][3]:0.1f} '
+ f'& {a4[ T[2] ][0]:0.1f} & {a4[ T[2] ][1]:0.1f} & {a4[ T[2] ][2]:0.1f} & {a4[ T[2] ][3]:0.1f} '
+ f'& {a4[ T[3] ][0]:0.1f} & {a4[ T[3] ][1]:0.1f} & {a4[ T[3] ][2]:0.1f} & {a4[ T[3] ][3]:0.1f} \\\\')
+ s += (f'\n GT-State MLP 2-Step & {a5[ T[0] ][0]:0.1f} & {a5[ T[0] ][1]:0.1f} & {a5[ T[0] ][2]:0.1f} & {a5[ T[0] ][3]:0.1f} '
+ f'& {a5[ T[1] ][0]:0.1f} & {a5[ T[1] ][1]:0.1f} & {a5[ T[1] ][2]:0.1f} & {a5[ T[1] ][3]:0.1f} '
+ f'& {a5[ T[2] ][0]:0.1f} & {a5[ T[2] ][1]:0.1f} & {a5[ T[2] ][2]:0.1f} & {a5[ T[2] ][3]:0.1f} '
+ f'& {a5[ T[3] ][0]:0.1f} & {a5[ T[3] ][1]:0.1f} & {a5[ T[3] ][2]:0.1f} & {a5[ T[3] ][3]:0.1f} \\\\')
+ s += (f'\n Transporter & {a1[ T[0] ][0]:0.1f} & {a1[ T[0] ][1]:0.1f} & {a1[ T[0] ][2]:0.1f} & {a1[ T[0] ][3]:0.1f} '
+ f'& {a1[ T[1] ][0]:0.1f} & {a1[ T[1] ][1]:0.1f} & {a1[ T[1] ][2]:0.1f} & {a1[ T[1] ][3]:0.1f} '
+ f'& {a1[ T[2] ][0]:0.1f} & {a1[ T[2] ][1]:0.1f} & {a1[ T[2] ][2]:0.1f} & {a1[ T[2] ][3]:0.1f} '
+ f'& {a1[ T[3] ][0]:0.1f} & {a1[ T[3] ][1]:0.1f} & {a1[ T[3] ][2]:0.1f} & {a1[ T[3] ][3]:0.1f} \\\\')
+ s += '\n \midrule'
+ s += f'\n & \multicolumn{{4}}{{c}}{{{T[4]}}} & \multicolumn{{4}}{{c}}{{{T[5]}}} & \multicolumn{{4}}{{c}}{{{T[6]}}} & \multicolumn{{4}}{{c}}{{{T[7]}}} \\\\'
+ s += '\n \cmidrule(lr){2-5} \cmidrule(lr){6-9} \cmidrule(lr){10-13} \cmidrule(lr){14-17}'
+ s += '\n Method & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 \\\\'
+ s += '\n \midrule'
+ s += (f'\n GT-State MLP & {a4[ T[4] ][0]:0.1f} & {a4[ T[4] ][1]:0.1f} & {a4[ T[4] ][2]:0.1f} & {a4[ T[4] ][3]:0.1f} '
+ f'& {a4[ T[5] ][0]:0.1f} & {a4[ T[5] ][1]:0.1f} & {a4[ T[5] ][2]:0.1f} & {a4[ T[5] ][3]:0.1f} '
+ f'& {a4[ T[6] ][0]:0.1f} & {a4[ T[6] ][1]:0.1f} & {a4[ T[6] ][2]:0.1f} & {a4[ T[6] ][3]:0.1f} '
+ f'& {a4[ T[7] ][0]:0.1f} & {a4[ T[7] ][1]:0.1f} & {a4[ T[7] ][2]:0.1f} & {a4[ T[7] ][3]:0.1f} \\\\')
+ s += (f'\n GT-State MLP 2-Step & {a5[ T[4] ][0]:0.1f} & {a5[ T[4] ][1]:0.1f} & {a5[ T[4] ][2]:0.1f} & {a5[ T[4] ][3]:0.1f} '
+ f'& {a5[ T[5] ][0]:0.1f} & {a5[ T[5] ][1]:0.1f} & {a5[ T[5] ][2]:0.1f} & {a5[ T[5] ][3]:0.1f} '
+ f'& {a5[ T[6] ][0]:0.1f} & {a5[ T[6] ][1]:0.1f} & {a5[ T[6] ][2]:0.1f} & {a5[ T[6] ][3]:0.1f} '
+ f'& {a5[ T[7] ][0]:0.1f} & {a5[ T[7] ][1]:0.1f} & {a5[ T[7] ][2]:0.1f} & {a5[ T[7] ][3]:0.1f} \\\\')
+ s += (f'\n Transporter & {a1[ T[4] ][0]:0.1f} & {a1[ T[4] ][1]:0.1f} & {a1[ T[4] ][2]:0.1f} & {a1[ T[4] ][3]:0.1f} '
+ f'& {a1[ T[5] ][0]:0.1f} & {a1[ T[5] ][1]:0.1f} & {a1[ T[5] ][2]:0.1f} & {a1[ T[5] ][3]:0.1f} '
+ f'& {a1[ T[6] ][0]:0.1f} & {a1[ T[6] ][1]:0.1f} & {a1[ T[6] ][2]:0.1f} & {a1[ T[6] ][3]:0.1f} '
+ f'& {a1[ T[7] ][0]:0.1f} & {a1[ T[7] ][1]:0.1f} & {a1[ T[7] ][2]:0.1f} & {a1[ T[7] ][3]:0.1f} \\\\')
+ s += '\n \midrule'
+ s += f'\n & \multicolumn{{4}}{{c}}{{{T[8]}}} & \multicolumn{{4}}{{c}}{{{T[9]}}} & \multicolumn{{4}}{{c}}{{{T[10]}}} & \multicolumn{{4}}{{c}}{{{T[11]}}} \\\\'
+ s += '\n \cmidrule(lr){2-5} \cmidrule(lr){6-9} \cmidrule(lr){10-13} \cmidrule(lr){14-17}'
+ s += '\n Method & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 & 1 & 10 & 100 & 1000 \\\\'
+ s += '\n \midrule'
+ s += (f'\n GT-State MLP & {a4[ T[8] ][0]:0.1f} & {a4[ T[8] ][1]:0.1f} & {a4[ T[8] ][2]:0.1f} & {a4[ T[8] ][3]:0.1f} '
+ f'& {a4[ T[9] ][0]:0.1f} & {a4[ T[9] ][1]:0.1f} & {a4[ T[9] ][2]:0.1f} & {a4[ T[9] ][3]:0.1f} '
+ f'& {a4[ T[10]][0]:0.1f} & {a4[ T[10]][1]:0.1f} & {a4[ T[10]][2]:0.1f} & {a4[ T[10]][3]:0.1f} '
+ f'& {a4[ T[11]][0]:0.1f} & {a4[ T[11]][1]:0.1f} & {a4[ T[11]][2]:0.1f} & {a4[ T[11]][3]:0.1f} \\\\')
+ s += (f'\n GT-State MLP 2-Step & {a5[ T[8] ][0]:0.1f} & {a5[ T[8] ][1]:0.1f} & {a5[ T[8] ][2]:0.1f} & {a5[ T[8] ][3]:0.1f} '
+ f'& {a5[ T[9] ][0]:0.1f} & {a5[ T[9] ][1]:0.1f} & {a5[ T[9] ][2]:0.1f} & {a5[ T[9] ][3]:0.1f} '
+ f'& {a5[ T[10]][0]:0.1f} & {a5[ T[10]][1]:0.1f} & {a5[ T[10]][2]:0.1f} & {a5[ T[10]][3]:0.1f} '
+ f'& {a5[ T[11]][0]:0.1f} & {a5[ T[11]][1]:0.1f} & {a5[ T[11]][2]:0.1f} & {a5[ T[11]][3]:0.1f} \\\\')
+ s += (f'\n Transporter-Goal-Stack & {a7[ T[8] ][0]:0.1f} & {a7[ T[8] ][1]:0.1f} & {a7[ T[8] ][2]:0.1f} & {a7[ T[8] ][3]:0.1f} '
+ f'& {a7[ T[9] ][0]:0.1f} & {a7[ T[9] ][1]:0.1f} & {a7[ T[9] ][2]:0.1f} & {a7[ T[9] ][3]:0.1f} '
+ f'& {a7[ T[10]][0]:0.1f} & {a7[ T[10]][1]:0.1f} & {a7[ T[10]][2]:0.1f} & {a7[ T[10]][3]:0.1f} '
+ f'& {a7[ T[11]][0]:0.1f} & {a7[ T[11]][1]:0.1f} & {a7[ T[11]][2]:0.1f} & {a7[ T[11]][3]:0.1f} \\\\')
+ s += (f'\n Transporter-Goal-Split & {a2[ T[8] ][0]:0.1f} & {a2[ T[8] ][1]:0.1f} & {a2[ T[8] ][2]:0.1f} & {a2[ T[8] ][3]:0.1f} '
+ f'& {a2[ T[9] ][0]:0.1f} & {a2[ T[9] ][1]:0.1f} & {a2[ T[9] ][2]:0.1f} & {a2[ T[9] ][3]:0.1f} '
+ f'& {a2[ T[10]][0]:0.1f} & {a2[ T[10]][1]:0.1f} & {a2[ T[10]][2]:0.1f} & {a2[ T[10]][3]:0.1f} '
+ f'& {a2[ T[11]][0]:0.1f} & {a2[ T[11]][1]:0.1f} & {a2[ T[11]][2]:0.1f} & {a2[ T[11]][3]:0.1f} \\\\')
+ print(s)
+[INST] Use this for printing a SINGLE item, for single inspection.
+name_task: what we used in code. name_plot: what we want to show in the plot.
+For now, take the max over the iterations (easy to spot check w/curves).
+Actually we have two cases to watch out for (goal conditioned or not)
+
+For the table, I'm going to add a few more commands to make it a little easier to copy/paste.. [/INST] def print_single(args, goal_conditioned, name_task, name_plot):
+ def get_max(stats_combo, ag_key):
+ stat_max = -1
+ for key in sorted(stats_combo[ag_key].keys()):
+ stat_max = max(stat_max, np.mean(stats_combo[ag_key][key]))
+ stat_max *= 100
+ return stat_max
+ a1, a2, a3, a4, a5, a6, a7 = {}, {}, {}, {}, {}, {}, {}
+ for num_demos in [1, 10, 100, 1000]:
+ directories, seeds = get_subdirectories_demos(name_task, num_demos)
+ stats_combo = _get_results(name_task, directories, seeds, debug)
+ agents = sorted(stats_combo.keys())
+ print(f'For task {name_task}, we have results from agents: {agents}')
+ if '-transporter' in agents:
+ a1[num_demos] = get_max(stats_combo, ag_key='-transporter')
+ if '-transporter-goal' in agents:
+ a2[num_demos] = get_max(stats_combo, ag_key='-transporter-goal')
+ if '-transporter-goal-naive' in agents:
+ a3[num_demos] = get_max(stats_combo, ag_key='-transporter-goal-naive')
+ if '-gt_state' in agents:
+ a4[num_demos] = get_max(stats_combo, ag_key='-gt_state')
+ if '-gt_state_2_step' in agents:
+ a5[num_demos] = get_max(stats_combo, ag_key='-gt_state_2_step')
+ if '-conv_mlp' in agents:
+ a6[num_demos] = get_max(stats_combo, ag_key='-gt_state_2_step')
+ if '-transporter-goal-snaive' in agents:
+ a7[num_demos] = get_max(stats_combo, ag_key='-transporter-goal-snaive')
+ print('\nDebugging of which results we have:')
+ print(f'keys in a1 (transporter): {sorted(a1.keys())}')
+ print(f'keys in a2 (transporter-goal-split): {sorted(a2.keys())}')
+ print(f'keys in a3 (transporter-goal-naive): {sorted(a3.keys())} [note should not use this]')
+ print(f'keys in a4 (gt_state): {sorted(a4.keys())}')
+ print(f'keys in a5 (gt_state_2_step): {sorted(a5.keys())}')
+ print(f'keys in a6 (conv_mlp): {sorted(a6.keys())}')
+ print(f'keys in a7 (transporter-goal-stack): {sorted(a7.keys())}')
+ N = ''
+ s = '\n \\begin{tabular}{@{}lrrrr}'
+ s += '\n \\toprule'
+ s += f'\n & \multicolumn{{4}}{{c}}{{{name_plot}}} \\\\'
+ s += '\n \cmidrule(lr){2-5}'
+ s += '\n Method & 1 & 10 & 100 & 1000 \\\\'
+ s += '\n \midrule'
+ s += (f'\n GT-State MLP & {a4[1]:0.1f} & {a4[10]:0.1f} & {a4[100]:0.1f} & {a4[1000]:0.1f} \\\\')
+ s += (f'\n GT-State MLP 2-Step & {a5[1]:0.1f} & {a5[10]:0.1f} & {a5[100]:0.1f} & {a5[1000]:0.1f} \\\\')
+ if goal_conditioned:
+ s += (f'\n Transporter-Goal-Stack & {a7[1]:0.1f} & {a7[10]:0.1f} & {a7[100]:0.1f} & {a7[1000]:0.1f} \\\\')
+ s += (f'\n Transporter-Goal-Split & {a2[1]:0.1f} & {a2[10]:0.1f} & {a2[100]:0.1f} & {a2[1000]:0.1f} \\\\')
+ else:
+ s += (f'\n Transporter & {a1[1]:0.1f} & {a1[10]:0.1f} & {a1[100]:0.1f} & {a1[1000]:0.1f} \\\\')
+ s += '\n \\toprule'
+ s += '\n \\end{tabular}'
+ print(s)
+[INST] Simulates suction by creating rigid fixed constraint between suction
+gripper and contacted object.. [/INST] def activate(self, possible_objects, def_IDs):
+ if not self.activated:
+ points = p.getContactPoints(bodyA=self.body, linkIndexA=0)
+ if len(points) > 0:
+ for point in points:
+ object_id, contact_link = point[2], point[4]
+ if object_id in possible_objects:
+ body_pose = p.getLinkState(self.body, 0)
+ object_pose = p.getBasePositionAndOrientation(object_id)
+ world_to_body = p.invertTransform(
+ body_pose[0], body_pose[1])
+ object_to_body = p.multiplyTransforms(
+ world_to_body[0], world_to_body[1],
+ object_pose[0], object_pose[1])
+ self.contact_constraint = p.createConstraint(
+ parentBodyUniqueId=self.body,
+ parentLinkIndex=0,
+ childBodyUniqueId=object_id,
+ childLinkIndex=contact_link,
+ jointType=p.JOINT_FIXED,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=object_to_body[0],
+ parentFrameOrientation=object_to_body[1],
+ childFramePosition=(0, 0, 0),
+ childFrameOrientation=(0, 0, 0))
+ distance = np.linalg.norm(
+ np.float32(body_pose[0]) - np.float32(object_pose[0]))
+ self.init_grip_distance = distance
+ self.init_grip_item = object_id
+ elif (self.def_grip_item is not None):
+ info = self.activate_def(self.def_grip_item)
+ self.def_grip_anchors = info['anchors']
+ self.def_min_vertex = info['closest_vertex']
+ self.def_min_distance = info['min_distance']
+ self.activated = True
+[INST] Simulates suction by anchoring vertices of the deformable object.
+Get distance values in `distances`, get indices for argsort, then
+resulting indices in `distances_sort` correspond _exactly_ to vertex
+indices arranged from nearest to furthest to the gripper.. [/INST] def activate_def(self, defId):
+ _, vert_pos_l = p.getMeshData(defId, -1, flags=p.MESH_DATA_SIMULATION_MESH)
+ gripper_position = np.float32(p.getLinkState(self.body, 0)[0])
+ distances = []
+ for v_position in vert_pos_l:
+ d = gripper_position - np.float32(v_position)
+ distances.append(np.linalg.norm(d))
+ distances_sort = np.argsort(distances)
+ anchors = []
+ for i in range(self.def_nb_anchors):
+ vIndex = distances_sort[i]
+ if distances[vIndex] > self.def_threshold:
+ pass
+ if distances[vIndex] > self.def_ignore:
+ print(f'WARNING, dist={distances[vIndex]:0.4f} > thresh '
+ f'{self.def_ignore:0.4f}. No points are close to the suction')
+ break
+ anchorId = p.createSoftBodyAnchor(
+ softBodyBodyUniqueId=defId,
+ nodeIndex=vIndex,
+ bodyUniqueId=self.body,
+ linkIndex=-1,)
+ anchors.append(anchorId)
+ info = {'anchors': anchors,
+ 'closest_vertex': distances_sort[0],
+ 'min_distance': np.min(distances),}
+ return info
+[INST] If suction off, detect contact between gripper and objects.
+If suction on, detect contact between picked object and other objects.. [/INST] def release(self):
+ if self.activated:
+ self.activated = False
+ if self.contact_constraint is not None:
+ try:
+ p.removeConstraint(self.contact_constraint)
+ self.contact_constraint = None
+ except:
+ pass
+ self.init_grip_distance = None
+ self.init_grip_item = None
+ if self.def_grip_anchors is not None:
+ for anchorId in self.def_grip_anchors:
+ p.removeConstraint(anchorId)
+ self.def_grip_anchors = None
+ self.def_grip_item = None
+ self.def_min_vetex = None
+ self.def_min_distance = None
+[INST] If suction off, detect contact between gripper and objects.
+If suction on, detect contact between picked object and other objects.. [/INST] def detect_contact(self, def_IDs):
+ body, link = self.body, 0
+ if self.activated and self.contact_constraint is not None:
+ try:
+ info = p.getConstraintInfo(self.contact_constraint)
+ body, link = info[2], info[3]
+ except:
+ self.contact_constraint = None
+ pass
+ points = p.getContactPoints(bodyA=body, linkIndexA=link)
+ if self.activated:
+ points = [point for point in points if point[2] != self.body]
+ if len(points) > 0:
+ return True
+ if not self.activated:
+ gripper_position = np.float32(p.getLinkState(self.body, 0)[0])
+ for ID in def_IDs:
+ if self.detect_contact_def(gripper_position, ID):
+ self.def_grip_item = ID
+ return True
+ return False
+ assert self.activated
+ if (self.init_grip_item is not None):
+ object_pose = p.getBasePositionAndOrientation(self.init_grip_item)
+ gripper_position = np.float32(p.getLinkState(self.body, 0)[0])
+ d = gripper_position - np.float32(object_pose[0])
+ distance = np.linalg.norm(d)
+ fraction = distance / self.init_grip_distance
+ if (fraction <= self.def_frac_lower) or (fraction >= self.def_frac_upper):
+ self.init_grip_distance = None
+ self.init_grip_item = None
+ return (fraction <= self.def_frac_lower) or (fraction >= self.def_frac_upper)
+ elif (self.def_grip_item is not None):
+ TODO: actually need to fix, I think we need suction gripper and
+ the robot's joint just before that. This will not get activated.
+ return False
+ else:
+ I don't think this should ever invoke -- we should always be
+ gripping a rigid or deformable in this condition.
+ return False
+[INST] Detect contact, when dealing with deformables.
+We may want to speed this up if it is a bottleneck. Returns a binary
+signal of whether there exists _any_ vertex within the threshold.
+Note with collisionMargin=0.004, I am getting most cloth vertices
+(for ClothFlat) to settle at ~0.004m high.. [/INST] def detect_contact_def(self, gripper_position, defId):
+ _, vert_pos_l = p.getMeshData(defId, -1, flags=p.MESH_DATA_SIMULATION_MESH)
+ distances_np = gripper_position - np.array(vert_pos_l)
+ assert len(distances_np.shape) == 2, distances_np.shape
+ distances_L2 = np.linalg.norm(distances_np, axis=1)
+ return np.min(distances_L2) < self.def_threshold
+[INST] Check a grasp for picking success.
+If picking fails, then robot doesn't do the place action. For rigid
+items: index 2 in getConstraintInfo returns childBodyUniqueId. For
+deformables, check the length of the anchors.. [/INST] def check_grasp(self):
+ pick_deformable = False
+ if self.def_grip_anchors is not None:
+ pick_deformable = len(self.def_grip_anchors) > 0
+ return (not self.contact_constraint is None) or pick_deformable
+[INST] Adapted from `dataset.py` so we can sample goal images. Just including
+some logic to extract the episode automatically based on the index
+`iepisode`, so we don't need to know the length in advance.. [/INST] def load(path, iepisode, field):
+ field_path = os.path.join(path, field)
+ data_list = [os.path.join(field_path, x) for x in os.listdir(field_path)]
+ fname = [x for x in data_list if f'{iepisode:06d}' in x]
+ assert len(fname) == 1, fname
+ fname = fname[0]
+ return pickle.load(open(fname, 'rb'))
+[INST] Save images and other stuff from time `t` in episode `epidx`.. [/INST] def debug_time_step(t, epidx, obs, act, extras, goal=None):
+ pth = 'tmp'
+ tt = str(t).zfill(2)
+ def save(fname, c_img):
+ cv2.imwrite(fname, img=cv2.cvtColor(c_img, cv2.COLOR_BGR2RGB))
+ for img_idx, c_img in enumerate(obs['color']):
+ fname = join(pth, f'ep_{epidx}_t{tt}_cimg_{img_idx}.png')
+ save(fname, c_img)
+ colormap_o, _ = get_heightmap(obs=obs)
+ fname = join(pth, f'ep_{epidx}_t{tt}_cimg_fused.png')
+ save(fname, colormap_o)
+ if (goal is not None) and t == 1:
+ for img_idx, c_img in enumerate(goal['color']):
+ fname = join(pth, f'ep_{epidx}_t{tt}_cimg_{img_idx}_goal.png')
+ save(fname, c_img)
+ colormap_g, _ = get_heightmap(obs=goal)
+ fname = join(pth, f'ep_{epidx}_t{tt}_cimg_fused_goal.png')
+ save(fname, colormap_g)
+ pose0 = act['params']['pose0']
+ pose1 = act['params']['pose1']
+ print(f" pose0, pose1: {U.round_pose(pose0)}, {U.round_pose(pose1)}")
+ fname1 = join(pth, f'ep_{epidx}_t{tt}_attn_input.png')
+ fname2 = join(pth, f'ep_{epidx}_t{tt}_attn_heat_bgr.png')
+ cv2.imwrite(fname1, extras['input_c'])
+ cv2.imwrite(fname2, extras['attn_heat_bgr'])
+ for idx, tran_heat in enumerate(extras['tran_heat_bgr']):
+ idxstr = str(idx).zfill(2)
+ fname = join(pth, f'ep_{epidx}_t{tt}_tran_rot_{idxstr}.png')
+ if idx == extras['tran_rot_argmax']:
+ fname = fname.replace('.png', '_rot_chosen.png')
+ cv2.imwrite(fname, tran_heat)
+[INST] Standard gym environment rollout.
+Adding more debugging options (enable with debug=True), such as printing
+the pose and saving the images and heatmaps. We can also run `dataset.py`
+and see goal images in the `goals_out` directory.. [/INST] def rollout(agent, env, task, goal_conditioned, args, num_finished, debug=False):
+ if debug:
+ if not os.path.exists('tmp/'):
+ os.makedirs('tmp/')
+ print('')
+ start_t = 0
+ if args.agent in ['gt_state', 'gt_state_2_step']:
+ start_t = 1
+ episode = []
+ total_reward = 0
+ if goal_conditioned:
+ task.goal_cond_testing = True
+ path = os.path.join('goals', args.task)
+ goal = {}
+ goal['color'] = load(path, num_finished, 'last_color')
+ goal['depth'] = load(path, num_finished, 'last_depth')
+ goal['info'] = load(path, num_finished, 'last_info')
+ goal_imgs = goal if goal_conditioned else None
+ if goal_conditioned:
+ obs = env.reset(task, last_info=goal['info'])
+ else:
+ obs = env.reset(task)
+ info = env.info
+ for t in range(start_t, task.max_steps):
+ if debug and t > 0:
+ act, extras = agent.act(obs, info, goal=goal_imgs, debug_imgs=True)
+ else:
+ act = agent.act(obs, info, goal=goal_imgs)
+ if debug and 'params' in act:
+ debug_time_step(t, num_finished, obs, act, extras, goal=goal_imgs)
+ if (len(obs) > 0 and act['primitive']) or (args.agent in ['gt_state', 'gt_state_2_step']):
+ episode.append((act, info))
+ (obs, reward, done, info) = env.step(act)
+ if goal_conditioned and ('color' in obs and 'depth' in obs):
+ info['image_metrics'] = goal_similarity(obs, goal_imgs)
+ else:
+ info['image_metrics'] = None
+ if debug:
+ print(' {}/{}, rew: {:0.3f}, len(epis): {}, act: {}, info: {}'.format(t,
+ task.max_steps, reward, len(episode), act['primitive'], info['extras']))
+ if goal_conditioned:
+ print(' goal-conditioning image metrics: {}'.format(info['image_metrics']))
+ total_reward += reward
+ last_obs_info = (obs, info)
+ if done:
+ break
+ return total_reward, episode, t, last_obs_info
+[INST] In some cases, we should filter out demonstrations.
+Filter for if t == 0, which means the initial state was a success, and
+also if we have exit_gracefully, which means for the bag-items tasks, it
+may not have had visible item(s) at the start, for some reason.. [/INST] def ignore_this_demo(args, reward, t, last_extras):
+ ignore = (t == 0)
+ if 'exit_gracefully' in last_extras:
+ assert last_extras['exit_gracefully']
+ return True
+ return ignore
+[INST] Adding optional hertz parameter for better cloth physics.
+From our discussion with Erwin, we should just set time.sleep(0.001),
+or even consider removing it all together. It's mainly for us to
+visualize PyBullet with the GUI to make it not move too fast. [/INST] def step_simulation(self):
+ p.setTimeStep(1.0 / self.hz)
+ while True:
+ if self.running:
+ p.stepSimulation()
+ if self.ee is not None:
+ self.ee.step()
+ time.sleep(0.001)
+[INST] Checks if env is static, used for checking if action finished.
+However, this won't work in PyBullet (at least v2.8.4) since soft
+bodies cause this code to hang. Therefore, look at the task's
+`def_IDs` list, which by design will have all IDs of soft bodies.
+Furthermore, for the bag tasks, the beads generally move around, so
+for those, just use a hard cutoff limit (outside this method).. [/INST] def is_static(self):
+ if self.is_softbody_env():
+ assert len(self.task.def_IDs) > 0, 'Did we forget to add to def_IDs?'
+ v = [np.linalg.norm(p.getBaseVelocity(i)[0]) for i in self.objects
+ if i not in self.task.def_IDs]
+ else:
+ v = [np.linalg.norm(p.getBaseVelocity(i)[0]) for i in self.objects]
+ return all(np.array(v) < 1e-2)
+[INST] Sets up PyBullet, loads models, resets the specific task.
+We do a step() call with act=None at the end. This will only return
+an empty obs dict, obs={}. For some tasks where the reward could be
+nonzero at the start, we can report the reward shown here.. [/INST] def reset(self, task, last_info=None, disable_render_load=True):
+ self.pause()
+ self.task = task
+ self.objects = []
+ self.fixed_objects = []
+ if self.use_new_deformable:
+ p.resetSimulation(p.RESET_USE_DEFORMABLE_WORLD)
+ else:
+ p.resetSimulation()
+ p.setGravity(0, 0, -9.8)
+ if self.is_bag_env():
+ self.t_lim = 60
+ if isinstance(self.task, tasks.names['bag-color-goal']):
+ self.t_lim = 120
+ if disable_render_load:
+ p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0)
+ id_plane = p.loadURDF('assets/plane/plane.urdf', [0, 0, -0.001])
+ id_ws = p.loadURDF('assets/ur5/workspace.urdf', [0.5, 0, 0])
+ self.ur5 = p.loadURDF(f'assets/ur5/ur5-{self.task.ee}.urdf')
+ self.ee_tip_link = 12
+ if self.task.ee == 'suction':
+ self.ee = Suction(self.ur5, 11)
+ elif self.task.ee == 'gripper':
+ self.ee = Robotiq2F85(self.ur5, 9)
+ self.ee_tip_link = 10
+ else:
+ self.ee = Gripper()
+ num_joints = p.getNumJoints(self.ur5)
+ joints = [p.getJointInfo(self.ur5, i) for i in range(num_joints)]
+ self.joints = [j[0] for j in joints if j[2] == p.JOINT_REVOLUTE]
+ for i in range(len(self.joints)):
+ p.resetJointState(self.ur5, self.joints[i], self.homej[i])
+ ee_tip_state = p.getLinkState(self.ur5, self.ee_tip_link)
+ self.home_pose = np.array(ee_tip_state[0] + ee_tip_state[1])
+ self.ee.release()
+ task.exit_gracefully = False
+ if last_info is not None:
+ task.reset(self, last_info)
+ else:
+ task.reset(self)
+ self.IDTracker = utils.TrackIDs()
+ self.IDTracker.add(id_plane, 'Plane')
+ self.IDTracker.add(id_ws, 'Workspace')
+ self.IDTracker.add(self.ur5, 'UR5')
+ try:
+ self.IDTracker.add(self.ee.body, 'Gripper.body')
+ except:
+ pass
+ try:
+ task_IDs = task.get_ID_tracker()
+ for i in task_IDs:
+ self.IDTracker.add(i, task_IDs[i])
+ except AttributeError:
+ pass
+ assert id_ws == 1, f'Workspace ID: {id_ws}'
+ if self.is_softbody_env():
+ self.ee.set_def_threshold(threshold=self.task.def_threshold)
+ self.ee.set_def_nb_anchors(nb_anchors=self.task.def_nb_anchors)
+ assert self.hz >= 480, f'Error, hz={self.hz} is too small!'
+ self.start()
+ if disable_render_load:
+ p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1)
+ (obs, _, _, _) = self.step()
+ return obs
+[INST] Execute action with specified primitive.
+For each episode (training, loading, etc.), this is normally called the first
+time from `env.reset()` above, with NO action, and returns an EMPTY observation.
+Then, it's called a SECOND time with an action that lacks a primitive (set to
+None, even though the key exists). But, this method will return an actual image
+observation that we pass to the policy. Finally, subsequent calls will have
+proper actions.
+
+(Sept 08) Added graceful exit functionality. This will let the code terminate
+early but with task.done=False as we didn't actually 'finish' -- just failed.
+Should use task.done=True to indicate any form of 'successful' dones (hitting
+time limit doesn't count).
+
+(Oct 09) Clarify documentation for the confusing first time step. I now see
+with ground truth agents, there is no 'second action lacking a primitive',
+because ground truth agents don't need images .. [/INST] def step(self, act=None):
+ if act and act['primitive']:
+ success = self.primitives[act['primitive']](**act['params'])
+ if (not success) or self.task.exit_gracefully:
+ _, reward_extras = self.task.reward()
+ info = self.info
+ reward_extras['task.done'] = False
+ if self.task.exit_gracefully:
+ reward_extras['exit_gracefully'] = True
+ self.task.exit_gracefully = False
+ if isinstance(self.task, tasks.names['cloth-flat-notarget']):
+ info['sampled_zone_pose'] = self.task.zone_pose
+ elif isinstance(self.task, tasks.names['bag-color-goal']):
+ info['bag_base_pos'] = self.task.bag_base_pos[0]
+ info['bag_base_orn'] = self.task.bag_base_orn[0]
+ info['bag_target_color'] = self.task.bag_colors[0]
+ info['extras'] = reward_extras
+ return {}, 0, True, info
+ start_t = time.time()
+ while not self.is_static():
+ if self.is_bag_env() and (time.time() - start_t > 2.0):
+ break
+ time.sleep(0.001)
+ reward, reward_extras = self.task.reward()
+ done = self.task.done()
+ info = self.info
+ reward_extras['task.done'] = done
+ info['extras'] = reward_extras
+ if isinstance(self.task, tasks.names['cloth-flat-notarget']):
+ info['sampled_zone_pose'] = self.task.zone_pose
+ elif isinstance(self.task, tasks.names['bag-color-goal']):
+ info['bag_base_pos'] = self.task.bag_base_pos[0]
+ info['bag_base_orn'] = self.task.bag_base_orn[0]
+ info['bag_target_color'] = self.task.bag_colors[0]
+ obs = {}
+ if act and 'camera_config' in act:
+ obs['color'], obs['depth'] = [], []
+ for config in act['camera_config']:
+ color, depth, _ = self.render(config)
+ obs['color'].append(color)
+ obs['depth'].append(depth)
+ return obs, reward, done, info
+[INST] Render RGB-D image with specified configuration.. [/INST] def render(self, config):
+ lookdir = np.array([0, 0, 1]).reshape(3, 1)
+ updir = np.array([0, -1, 0]).reshape(3, 1)
+ rotation = p.getMatrixFromQuaternion(config['rotation'])
+ rotm = np.array(rotation).reshape(3, 3)
+ lookdir = (rotm @ lookdir).reshape(-1)
+ updir = (rotm @ updir).reshape(-1)
+ lookat = config['position'] + lookdir
+ focal_length = config['intrinsics'][0]
+ znear, zfar = config['zrange']
+ viewm = p.computeViewMatrix(config['position'], lookat, updir)
+ fovh = (np.arctan((config['image_size'][0] /
+ 2) / focal_length) * 2 / np.pi) * 180
+ aspect_ratio = config['image_size'][1] / config['image_size'][0]
+ projm = p.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar)
+ _, _, color, depth, segm = p.getCameraImage(
+ width=config['image_size'][1],
+ height=config['image_size'][0],
+ viewMatrix=viewm,
+ projectionMatrix=projm,
+ shadow=1,
+ flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX,
+ renderer=p.ER_BULLET_HARDWARE_OPENGL)
+ color_image_size = (config['image_size'][0],
+ config['image_size'][1], 4)
+ color = np.array(color, dtype=np.uint8).reshape(color_image_size)
+ color = color[:, :, :3]
+ color_image_size = (color_image_size[0], color_image_size[1], 3)
+ if config['noise']:
+ color = np.int32(color)
+ color += np.int32(np.random.normal(0, 3, color_image_size))
+ color = np.uint8(np.clip(color, 0, 255))
+ depth_image_size = (config['image_size'][0], config['image_size'][1])
+ zbuffer = np.array(depth).reshape(depth_image_size)
+ depth = (zfar + znear - (2. * zbuffer - 1.) * (zfar - znear))
+ depth = (2. * znear * zfar) / depth
+ if config['noise']:
+ depth += np.random.normal(0, 0.003, depth_image_size)
+ segm = np.uint8(segm).reshape(depth_image_size)
+ return color, depth, segm
+[INST] Move UR5 to target joint configuration.. [/INST] def movej(self, targj, speed=0.01, t_lim=20):
+ t0 = time.time()
+ while (time.time() - t0) < t_lim:
+ currj = [p.getJointState(self.ur5, i)[0] for i in self.joints]
+ currj = np.array(currj)
+ diffj = targj - currj
+ if all(np.abs(diffj) < 1e-2):
+ return True
+ norm = np.linalg.norm(diffj)
+ v = diffj / norm if norm > 0 else 0
+ stepj = currj + v * speed
+ gains = np.ones(len(self.joints))
+ p.setJointMotorControlArray(
+ bodyIndex=self.ur5,
+ jointIndices=self.joints,
+ controlMode=p.POSITION_CONTROL,
+ targetPositions=stepj,
+ positionGains=gains)
+ time.sleep(0.001)
+ print('Warning: movej exceeded {} sec timeout. Skipping.'.format(t_lim))
+ return False
+[INST] Move UR5 to target end effector pose.. [/INST] def movep(self, pose, speed=0.01):
+ targj = self.solve_IK(pose)
+ return self.movej(targj, speed, self.t_lim)
+[INST] Execute pick and place primitive.
+Standard ravens tasks use the `delta` vector to lower the gripper
+until it makes contact with something. With deformables, however, we
+need to consider cases when the gripper could detect a rigid OR a
+soft body (cloth or bag); it should grip the first item it touches.
+This is handled in the Gripper class.
+
+Different deformable ravens tasks use slightly different parameters
+for better physics (and in some cases, faster simulation). Therefore,
+rather than make special cases here, those tasks will define their
+own action parameters, which we use here if they exist. Otherwise, we
+stick to defaults from standard ravens. Possible action parameters a
+task might adjust.. [/INST] def pick_place(self, pose0, pose1):
+ speed = 0.01
+ delta_z = -0.001
+ prepick_z = 0.3
+ postpick_z = 0.3
+ preplace_z = 0.3
+ pause_place = 0.0
+ final_z = 0.3
+ if hasattr(self.task, 'primitive_params'):
+ ts = self.task.task_stage
+ if 'prepick_z' in self.task.primitive_params[ts]:
+ prepick_z = self.task.primitive_params[ts]['prepick_z']
+ speed = self.task.primitive_params[ts]['speed']
+ delta_z = self.task.primitive_params[ts]['delta_z']
+ postpick_z = self.task.primitive_params[ts]['postpick_z']
+ preplace_z = self.task.primitive_params[ts]['preplace_z']
+ pause_place = self.task.primitive_params[ts]['pause_place']
+ def_IDs = []
+ if hasattr(self.task, 'def_IDs'):
+ def_IDs = self.task.def_IDs
+ success = True
+ pick_position = np.array(pose0[0])
+ pick_rotation = np.array(pose0[1])
+ prepick_position = pick_position.copy()
+ prepick_position[2] = prepick_z
+ prepick_pose = np.hstack((prepick_position, pick_rotation))
+ success &= self.movep(prepick_pose)
+ target_pose = prepick_pose.copy()
+ delta = np.array([0, 0, delta_z, 0, 0, 0, 0])
+ while not self.ee.detect_contact(def_IDs) and target_pose[2] > 0:
+ target_pose += delta
+ success &= self.movep(target_pose)
+ self.ee.activate(self.objects, def_IDs)
+ if self.is_softbody_env() or self.is_new_cable_env():
+ prepick_pose[2] = postpick_z
+ success &= self.movep(prepick_pose, speed=speed)
+ time.sleep(pause_place)
+ elif isinstance(self.task, tasks.names['cable']):
+ prepick_pose[2] = 0.03
+ success &= self.movep(prepick_pose, speed=0.001)
+ else:
+ prepick_pose[2] += pick_position[2]
+ success &= self.movep(prepick_pose)
+ pick_success = self.ee.check_grasp()
+ if pick_success:
+ place_position = np.array(pose1[0])
+ place_rotation = np.array(pose1[1])
+ preplace_position = place_position.copy()
+ preplace_position[2] = 0.3 + pick_position[2]
+ preplace_pose = np.hstack((preplace_position, place_rotation))
+ if self.is_softbody_env() or self.is_new_cable_env():
+ preplace_pose[2] = preplace_z
+ success &= self.movep(preplace_pose, speed=speed)
+ time.sleep(pause_place)
+ elif isinstance(self.task, tasks.names['cable']):
+ preplace_pose[2] = 0.03
+ success &= self.movep(preplace_pose, speed=0.001)
+ else:
+ success &= self.movep(preplace_pose)
+ target_pose = preplace_pose.copy()
+ while not self.ee.detect_contact(def_IDs) and target_pose[2] > 0:
+ target_pose += delta
+ success &= self.movep(target_pose)
+ self.ee.release()
+ preplace_pose[2] = final_z
+ success &= self.movep(preplace_pose)
+ else:
+ self.ee.release()
+ prepick_pose[2] = final_z
+ success &= self.movep(prepick_pose)
+ return success
+[INST] I want a way to track new cable-related stuff alone.. [/INST] def is_new_cable_env(self):
+ return (isinstance(self.task, tasks.names['cable-shape']) or
+ isinstance(self.task, tasks.names['cable-shape-notarget']) or
+ isinstance(self.task, tasks.names['cable-line-notarget']) or
+ isinstance(self.task, tasks.names['cable-ring']) or
+ isinstance(self.task, tasks.names['cable-ring-notarget']))
+[INST] Keep this updated when I adjust environment names.. [/INST] def is_cloth_env(self):
+ return (isinstance(self.task, tasks.names['cloth-flat']) or
+ isinstance(self.task, tasks.names['cloth-flat-notarget']) or
+ isinstance(self.task, tasks.names['cloth-cover']))
+[INST] Keep this updated when I adjust environment names.. [/INST] def is_bag_env(self):
+ return (isinstance(self.task, tasks.names['bag-alone-open']) or
+ isinstance(self.task, tasks.names['bag-items-easy']) or
+ isinstance(self.task, tasks.names['bag-items-hard']) or
+ isinstance(self.task, tasks.names['bag-color-goal']))
+[INST] Helper to handle common pick-place code for the oracle policy.
+We often have this patten: vertex positions and target positions in
+2D, and then potentially slightly overshoot the target. For example,
+with cloth it's helpful to do this since otherwise the physics will
+favor the cloth 'resetting' to its original state. Get the direction
+by creating vectors and then add to the new target. Then form tuples
+for the action params. Assumes no new rotations.. [/INST] def params_no_rots(self, vertex_pos, target_pos, overshoot):
+ p0 = (vertex_pos[0], vertex_pos[1], 0.001)
+ p1 = (target_pos[0], target_pos[1], 0.001)
+ direction = np.float32(p0) - np.float32(p1)
+ length = np.linalg.norm(direction)
+ direction = direction / length
+ new_p0 = np.float32(p1) + direction * (length - 0.00)
+ new_p1 = np.float32(p0) - direction * (length + overshoot)
+ p0, p1 = tuple(new_p0), tuple(new_p1)
+ params = {'pose0': (p0, (0,0,0,1)), 'pose1': (p1, (0,0,0,1))}
+ return params
+[INST] Check if the task is done AND has not failed.
+
+For multi-step tasks such as `cloth-cover-item`, one must need to
+accomplish all stages correctly.
+
+(08 Sept 2020): I added self.exit_gracefully, which will help us
+quickly exit if the demo has failed, but this is not handled here.
+
+(14 Sept 2020) We want self.done to return True if we should consider
+this episode a 'success'. If not, then if we return done=True with
+the understanding that it's a failure, we need to filter this episode
+elsewhere. So far the only tasks where this extra consideration has
+to be applied are bag-items-easy and bag-items-hard. See `main.py`
+and the `ignore_this_demo()` method.. [/INST] def done(self):
+ zone_done, goal_done, cable_done, cov_done, bag_done = \
+ False, False, False, False, False
+ if self.metric == 'zone':
+ zone_done = self.total_rewards == 1
+ elif self.metric == 'cable-target':
+ cable_done = self.total_rewards == 1
+ elif self.metric == 'cable-ring':
+ fraction = self.total_rewards / self.circle_area
+ if fraction >= self.area_thresh:
+ print(f'Frac: {fraction:0.4f} exceeds: {self.area_thresh:0.4f}')
+ cable_done = True
+ elif self.metric == 'cloth-coverage':
+ if self.total_rewards >= self.coverage_threshold:
+ cov_done = True
+ elif self.metric == 'cloth-cover-item':
+ cov_done = self.total_rewards == 1
+ elif self.metric == 'bag-alone-open':
+ fraction = self.total_rewards / self.circle_area
+ if fraction >= self.area_thresh:
+ print(f'Frac: {fraction:0.4f} exceeds: {self.area_thresh:0.4f}')
+ bag_done = True
+ elif self.metric == 'bag-items':
+ bag_done = self.total_rewards > 0
+ elif self.metric == 'bag-color-goal':
+ bag_done = self.total_rewards >= 0.95
+ if hasattr(self, 'goal'):
+ goal_done = len(self.goal['steps']) == 0
+ return zone_done or cable_done or cov_done or bag_done or goal_done
+[INST] Read template file and replace string keys.
+The replace[field] needs to be a tuple. Use (item,) for a single
+item, but you can combine keys with NAMEX where X starts from 0.. [/INST] def fill_template(self, template, replace):
+ filepath = os.path.dirname(os.path.abspath(__file__))
+ template = os.path.join(filepath, '..', template)
+ with open(template, 'r') as file:
+ fdata = file.read()
+ for field in replace:
+ for i in range(len(replace[field])):
+ fdata = fdata.replace(f'{field}{i}', str(replace[field][i]))
+ alphabet = string.ascii_lowercase + string.digits
+ rname = ''.join(random.choices(alphabet, k=16))
+ fname = f'{template}.{rname}'
+ with open(fname, 'w') as file:
+ file.write(fdata)
+ return fname
+[INST] Get random collision-free pose in workspace bounds for object.
+For tasks like sweeping to a zone target, generates a target zone at
+random, then later generates items to be swept in it. The second step
+requires free space to avoid being within the zone.
+
+The `mask` defines a distribution over pixels in the 320x160 top-down
+image, where 0=black means don't sample those pixels. For the
+remaining pixels (1=white), sample uniformly from those. It should
+correctly detect if we have a bag [but may require a low enough z
+coord]. If you increase `object_size`, then that reduces the
+available sampling area.. [/INST] def random_pose(self, env, object_size, hard_code=False):
+ plane_id = 1
+ max_size = np.sqrt(object_size[0]**2 + object_size[1]**2)
+ erode_size = int(np.round(max_size / self.pixel_size))
+ colormap, heightmap, object_mask = self.get_object_masks(env)
+ mask = np.uint8(object_mask == plane_id)
+ mask[0, :], mask[:, 0], mask[-1, :], mask[:, -1] = 0, 0, 0, 0
+ mask = cv2.erode(mask, np.ones((erode_size, erode_size), np.uint8))
+ if np.sum(mask) == 0:
+ print('Warning! Sum of mask is zero in random_pose().')
+ return
+ pixel = utils.sample_distribution(np.float32(mask))
+ position = utils.pixel_to_position(
+ pixel, heightmap, self.bounds, self.pixel_size)
+ position = (position[0], position[1], object_size[2] / 2)
+ rtheta = np.random.rand() * 2 * np.pi
+ rotation = p.getQuaternionFromEuler((0, 0, rtheta))
+ if hard_code:
+ position = (0.0, 0.0, 0.0)
+ rotation = p.getQuaternionFromEuler((0,0,0))
+ return position, rotation
+[INST] apply a rigid body transformation on `position` by `pose`.
+The `pose` has the rotation matrix and translation vector. Apply the
+rotation matrix on `position` and then translate. Example: used in
+cables env to sequentially translate each bead in the cable.
+
+Returns a *position*, not a pose (no orientation).. [/INST] def apply(self, pose, position):
+ position = np.float32(position)
+ position_shape = position.shape
+ position = np.float32(position).reshape(3, -1)
+ rotation = np.float32(p.getMatrixFromQuaternion(pose[1])).reshape(3, 3)
+ translation = np.float32(pose[0]).reshape(3, 1)
+ position = rotation @ position + translation
+ return tuple(position.reshape(position_shape))
+[INST] Given either object OR action pose info, return stuff to put in GT observation.
+Note: only called from within this class and 2-step case via subclassing.
+
+During training, there is normally data augmentation applied, so t_worldaug_world
+is NOT None, and augmentation is applied as if the 'image' were adjusted. However,
+for actions, we preserve theta so `object_quat_xyzw` does not change, and we query
+index=2 for the z rotation. For most deformables tasks, we did not use rotations,
+hence theta=0.
+
+Get t_world_object which is a 4x4 homogeneous matrix, then stick the position there (why?).
+Then we do the matrix multiplication with '@' to get the homogeneous matrix representation
+of the object after augmentation is applied.
+
+Augmentation for cloth vertices. For now I'm following what they do for positions only
+(ignoring orientation), and returning updated (x,y,z) where the last value is the z
+(height) and not a rotation (doesn't make sense with vertrics anyway). Augmentations
+are all in SE(2), so the z coordinates should not change. Indeed that's what I see after
+extracting positions from `t_worldaug_object`.
+
+And yes, with vertices, they are flattened in the same way regardless of whether data
+augmentation is used, so every third value represents a vertex height.. [/INST] def extract_x_y_theta(self, object_info, t_worldaug_world=None, preserve_theta=False, softbody=False):
+ if (self.task in TASKS_SOFT) and (len(object_info) == 2) and softbody:
+ nb_vertices, vert_pos_l = object_info
+ assert nb_vertices == 100, f'We should be using 100 vertices but have: {nb_vertices}'
+ assert nb_vertices == len(vert_pos_l), f'{nb_vertices}, {len(vert_pos_l)}'
+ vert_pos_np = np.array(vert_pos_l)
+ if t_worldaug_world is not None:
+ augmented_vertices = []
+ for i in range(vert_pos_np.shape[0]):
+ vertex_position = vert_pos_np[i,:]
+ t_world_object = transformations.quaternion_matrix( (1,0,0,0) )
+ t_world_object[0:3, 3] = np.array(vertex_position)
+ t_worldaug_object = t_worldaug_world @ t_world_object
+ new_v_position = t_worldaug_object[0:3, 3]
+ augmented_vertices.append(new_v_position)
+ augmented_flattened = np.concatenate(augmented_vertices).flatten().astype(np.float32)
+ return augmented_flattened
+ else:
+ vert_flattened = vert_pos_np.flatten().astype(np.float32)
+ return vert_flattened
+ object_position = object_info[0]
+ object_quat_xyzw = object_info[1]
+ if t_worldaug_world is not None:
+ object_quat_wxyz = (object_quat_xyzw[3], object_quat_xyzw[0], object_quat_xyzw[1], object_quat_xyzw[2])
+ t_world_object = transformations.quaternion_matrix(object_quat_wxyz)
+ t_world_object[0:3, 3] = np.array(object_position)
+ assert t_world_object.shape == (4,4), f'shape is not 4x4: {t_world_object.shape}'
+ assert t_worldaug_world.shape == (4,4), f'shape is not 4x4: {t_worldaug_world.shape}'
+ t_worldaug_object = t_worldaug_world @ t_world_object
+ object_quat_wxyz = transformations.quaternion_from_matrix(t_worldaug_object)
+ if not preserve_theta:
+ object_quat_xyzw = (object_quat_wxyz[1], object_quat_wxyz[2], object_quat_wxyz[3], object_quat_wxyz[0])
+ object_position = t_worldaug_object[0:3, 3]
+ object_xy = object_position[0:2]
+ object_theta = -np.float32(utils.get_rot_from_pybullet_quaternion(object_quat_xyzw)[2]) / self.THETA_SCALE
+ return np.hstack((object_xy, object_theta)).astype(np.float32), object_position, object_quat_xyzw
+[INST] from info dict of IDs, create the observation for GT models.
+Assumes `info` consists of just PyBullet object IDs. Creates a numpy array
+from combining the `object_x_y_theta` from all IDs, and potentially add more
+info based on if using box dimensions or colors; see `__init__()` above.
+
+For soft body tasks, we should have data generated so that info[cloth_id] or
+info[bag_id] contains the 3D vertex position of each point. For now we only
+test with cloth since the bag ones are honestly probably better suited with
+just the beads (but will have to check). Fortunately, all the cloth tasks as
+of late October 2020 will use ID=5 as the soft body andit's also easy to check.
+
+(14 Oct 2020) adding `goal` as argument for goal conditioning. Use goal['info'].
+We're going to stack the two together into one input.
+(14 Oct 2020) adding support for cloth. We don't use `pos` and `quat`, just use
+the returned object_x_y_theta.. [/INST] def info_to_gt_obs(self, info, t_worldaug_world=None, goal=None):
+ info = self.remove_nonint_keys(info)
+ if goal is not None:
+ g_info = self.remove_nonint_keys(goal['info'])
+ else:
+ g_info = {}
+ observation_vector = []
+ object_keys = sorted(info.keys())
+ for object_key in object_keys:
+ if (self.task in TASKS_SOFT and len(info[object_key]) == 2):
+ object_x_y_theta = self.extract_x_y_theta(info[object_key], t_worldaug_world, softbody=True)
+ else:
+ object_x_y_theta, pos, quat = self.extract_x_y_theta(info[object_key], t_worldaug_world)
+ observation_vector.append(object_x_y_theta)
+ if self.use_box_dimensions:
+ observation_vector.append(self.extract_box_dimensions(info[object_key]))
+ if self.use_colors:
+ if self.task in ['bag-color-goal']:
+ if object_key in [4, 38]:
+ observation_vector.append(self.extract_color(info[object_key]))
+ else:
+ observation_vector.append(self.extract_color(info[object_key]))
+ for object_key in sorted(g_info.keys()):
+ if (self.task in TASKS_SOFT and len(g_info[object_key]) == 2):
+ object_x_y_theta = self.extract_x_y_theta(g_info[object_key], t_worldaug_world, softbody=True)
+ else:
+ object_x_y_theta, pos, quat = self.extract_x_y_theta(g_info[object_key], t_worldaug_world)
+ observation_vector.append(object_x_y_theta)
+ if self.use_box_dimensions:
+ observation_vector.append(self.extract_box_dimensions(g_info[object_key]))
+ if self.use_colors:
+ if self.task in ['bag-color-goal']:
+ if object_key in [4, 38]:
+ observation_vector.append(self.extract_color(g_info[object_key]))
+ else:
+ observation_vector.append(self.extract_color(g_info[object_key]))
+ if (self.task in TASKS_SOFT):
+ observation_vector = np.concatenate(observation_vector).reshape(-1).astype(np.float32)
+ else:
+ observation_vector = np.array(observation_vector).reshape(-1).astype(np.float32)
+ if self.max_obs_vector_length != 0:
+ observation_vector = np.pad(observation_vector,
+ [0, self.max_obs_vector_length-len(observation_vector)])
+ return observation_vector
+[INST] similarly, from action, create the appropriate ground truth action.
+This may involve a transformation if doing data augmentation.
+Comment from Andy/Pete: dont update theta due to suction invariance to theta. [/INST] def act_to_gt_act(self, act, t_worldaug_world=None, transform_params=None):
+ pick_se2, _, _ = self.extract_x_y_theta(act['params']['pose0'], t_worldaug_world, preserve_theta=True)
+ place_se2, _, _ = self.extract_x_y_theta(act['params']['pose1'], t_worldaug_world, preserve_theta=True)
+ return np.hstack((pick_se2, place_se2)).astype(np.float32)
+[INST] Initialize self.model, including normalization parameters.. [/INST] def init_model(self, dataset):
+ self.set_max_obs_vector_length(dataset)
+ if self.goal_conditioned:
+ _, _, info, goal = dataset.random_sample(goal_images=True)
+ obs_vector = self.info_to_gt_obs(info, goal=goal)
+ else:
+ _, _, info = dataset.random_sample()
+ obs_vector = self.info_to_gt_obs(info)
+ obs_dim = obs_vector.shape[0]
+ act_dim = 6
+ if self.six_dof:
+ act_dim = 9
+ self.model = MlpModel(self.BATCH_SIZE, obs_dim, act_dim, 'relu', self.USE_MDN, dropout=0.1)
+ sampled_gt_obs = []
+ num_samples = 1000
+ for _ in range(num_samples):
+ t_worldaug_world, _ = self.get_augmentation_transform()
+ if self.goal_conditioned:
+ _, _, info, goal = dataset.random_sample(goal_images=True)
+ sampled_gt_obs.append(self.info_to_gt_obs(info, t_worldaug_world, goal=goal))
+ else:
+ _, _, info = dataset.random_sample()
+ sampled_gt_obs.append(self.info_to_gt_obs(info, t_worldaug_world))
+ sampled_gt_obs = np.array(sampled_gt_obs)
+ obs_train_parameters = dict()
+ obs_train_parameters['mean'] = sampled_gt_obs.mean(axis=(0)).astype(np.float32)
+ obs_train_parameters['std'] = sampled_gt_obs.std(axis=(0)).astype(np.float32)
+ self.model.set_normalization_parameters(obs_train_parameters)
+ self.obs_train_mean = obs_train_parameters['mean']
+ self.obs_train_std = obs_train_parameters['std']
+ print('Done initializing self.model for ground truth agent.')
+[INST] Train on dataset for a specific number of iterations.
+Daniel: not testing with validation, argument copied over from ravens. Naively,
+we can train with MSE, but better to use a mixture model (MDN) since the output
+should be multi-modal; could be several pick points, and several placing points
+wrt those pick points.
+
+Also, notice how one iteration involves taking a batch of data, whereas for
+Transporters, we only used one image per 'iteration'. This means that each gt-state
+model consumes 128x more data points during training.
+
+there is a special case for num_iter=0 which I use for the "zero iteration" baseline.
+Without this, it is impossible to save the model because it hasn't been built yet. However,
+even with this I get the warning.
+
+Skipping full serialization of Keras layer , because it is not built.
+
+but the loading seems to be OK, so I'm not sure what could be the issue, and given that
+it's for the 0 iteration case, it's unlikely to be worth fully investigating for now.. [/INST] def train(self, dataset, num_iter, writer, validation_dataset=None):
+ if self.model is None:
+ self.init_model(dataset)
+ if self.USE_MDN:
+ loss_criterion = mdn_utils.mdn_loss
+ else:
+ loss_criterion = tf.keras.losses.MeanSquaredError()
+ @tf.function
+ def train_step(model, batch_obs, batch_act, loss_criterion):
+ with tf.GradientTape() as tape:
+ prediction = model(batch_obs)
+ loss = loss_criterion(batch_act, prediction)
+ grad = tape.gradient(loss, model.trainable_variables)
+ self.optim.apply_gradients(zip(grad, model.trainable_variables))
+ return loss
+ print_rate = 50
+ VALIDATION_RATE = 1000
+ if num_iter == 0:
+ batch_obs, batch_act, obs, act, info = self.get_data_batch(dataset)
+ print('Doing a single forward pass to enable us to save a snapshot for the num_iter == 0 case.')
+ prediction = self.model(batch_obs)
+ for i in range(num_iter):
+ start = time.time()
+ batch_obs, batch_act, obs, act, info = self.get_data_batch(dataset)
+ self.metric.reset_states()
+ loss = train_step(self.model, batch_obs, batch_act, loss_criterion)
+ self.metric(loss)
+ with writer.as_default():
+ tf.summary.scalar('gt_state_loss', self.metric.result(), step=self.total_iter+i)
+ if i % print_rate == 0:
+ loss = np.float32(loss)
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss:.4f} Iter time:', time.time() - start)
+ if ((self.total_iter + i) % VALIDATION_RATE == 0) and (validation_dataset is not None):
+ print("Validating!")
+ tf.keras.backend.set_learning_phase(0)
+ self.val_metric.reset_states()
+ batch_obs, batch_act, _, _, _ = self.get_data_batch(validation_dataset)
+ prediction = self.model(batch_obs)
+ loss = loss_criterion(batch_act, prediction)
+ self.val_metric(loss)
+ with writer.as_default():
+ tf.summary.scalar('validation_gt_state_loss', self.val_metric.result(),
+ step=self.total_iter+i)
+ tf.keras.backend.set_learning_phase(1)
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action.. [/INST] def act(self, obs, info, goal=None):
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if self.goal_conditioned:
+ gt_obs = self.info_to_gt_obs(info, goal=goal)
+ gt_act_center = self.info_to_gt_obs(info, goal=goal)
+ else:
+ gt_obs = self.info_to_gt_obs(info)
+ gt_act_center = self.info_to_gt_obs(info)
+ prediction = self.model(gt_obs[None, ...])
+ if self.USE_MDN:
+ mdn_prediction = prediction
+ pi, mu, var = mdn_prediction
+ prediction = mdn_utils.sample_from_pdf(pi, mu, var)
+ prediction = prediction[:, 0, :]
+ prediction = prediction[0]
+ p0_position = np.hstack((prediction[0:2], 0.02))
+ p0_pred_rot = 0.0 if self.one_rot_inf else -prediction[2]*self.THETA_SCALE
+ p0_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, p0_pred_rot))
+ p1_position = np.hstack((prediction[3:5], 0.02))
+ p1_pred_rot = 0.0 if self.one_rot_inf else -prediction[5]*self.THETA_SCALE
+ p1_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, p1_pred_rot))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, p0_rotation),
+ 'pose1': (p1_position, p1_rotation)}
+ act['params'] = params
+ if self.task in ['bag-items-easy', 'bag-items-hard', 'bag-color-goal']:
+ self._determine_task_stage(p0_position, p1_position)
+ return act
+[INST] Add a cable like Andy does it in his cable environment.
+
+When iterating through the number of parts, ensure that the given
+cable is _separate_ from prior cables, in case there are more than
+one. ALL beads are put in the `env.objects` list.
+
+The zone_range is used because we need the cables to start outside of
+the zone. However, we should check if sampling multiple cables will
+work; there might not be space to sample a zone and two cables.
+
+Parameters. [/INST] def add_cable(self, env, size_range, info={}, cable_idx=0,
+ direction='z', max_force=100):
+ num_parts = self.num_parts
+ radius = self.radius
+ length = self.length
+ color = self.colors[cable_idx] + [1]
+ color_end = U.COLORS['yellow'] + [1]
+ distance = length / num_parts
+ position, _ = self.random_pose(env, size_range)
+ position = np.float32(position)
+ part_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=[radius]*3)
+ part_visual = p.createVisualShape(p.GEOM_SPHERE, radius=radius*1.5)
+ for i in range(num_parts):
+ if direction == 'x':
+ position[0] += distance
+ parent_frame = (distance, 0, 0)
+ elif direction == 'y':
+ position[1] += distance
+ parent_frame = (0, distance, 0)
+ else:
+ position[2] += distance
+ parent_frame = (0, 0, distance)
+ part_id = p.createMultiBody(0.1, part_shape, part_visual,
+ basePosition=position)
+ if i > 0:
+ constraint_id = p.createConstraint(
+ parentBodyUniqueId=env.objects[-1],
+ parentLinkIndex=-1,
+ childBodyUniqueId=part_id,
+ childLinkIndex=-1,
+ jointType=p.JOINT_POINT2POINT,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=parent_frame,
+ childFramePosition=(0, 0, 0))
+ p.changeConstraint(constraint_id, maxForce=max_force)
+ if (i > 0) and (i < num_parts - 1):
+ p.changeVisualShape(part_id, -1, rgbaColor=color)
+ elif i == num_parts - 1:
+ p.changeVisualShape(part_id, -1, rgbaColor=color_end)
+ self.cable_bead_IDs.append(part_id)
+ self._IDs[part_id] = f'cable_part_{str(part_id).zfill(2)}'
+ env.objects.append(part_id)
+ self.object_points[part_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ if self._name == 'cable-shape' or self._name == 'cable-shape-notarget' or \
+ self._name == 'cable-line-notarget':
+ lx = info['lengthx']
+ ly = info['lengthy']
+ r = radius
+ if info['nb_sides'] == 1:
+ x_coord = lx / 2 - (distance * i)
+ y_coord = 0
+ true_position = (x_coord - r, y_coord, 0)
+ elif info['nb_sides'] == 2:
+ if i < info['cutoff']:
+ x_coord = lx / 2 - (distance * i)
+ y_coord = -ly / 2
+ true_position = (x_coord - r, y_coord, 0)
+ else:
+ x_coord = -lx / 2
+ y_coord = -ly / 2 + (distance * (i - info['cutoff']))
+ true_position = (x_coord, y_coord + r, 0)
+ elif info['nb_sides'] == 3:
+ v1 = (self.num_parts - info['cutoff']) / 2
+ v2 = self.num_parts - v1
+ if i < v1:
+ x_coord = lx / 2
+ y_coord = ly / 2 - (distance * i)
+ true_position = (x_coord, y_coord - r, 0)
+ elif i < v2:
+ x_coord = lx / 2 - (distance * (i - v1))
+ y_coord = -ly / 2
+ true_position = (x_coord - r, y_coord, 0)
+ else:
+ x_coord = -lx / 2
+ y_coord = -ly / 2 + (distance * (i - v2))
+ true_position = (x_coord, y_coord + r, 0)
+ elif info['nb_sides'] == 4:
+ v1 = info['cutoff'] / 2
+ v2 = num_parts / 2
+ v3 = (num_parts + info['cutoff']) / 2
+ if i < v1:
+ x_coord = lx / 2 - (distance * i)
+ y_coord = -ly / 2
+ true_position = (x_coord, y_coord, 0)
+ elif i < v2:
+ x_coord = -lx / 2
+ y_coord = -ly / 2 + (distance * (i - v1))
+ true_position = (x_coord, y_coord, 0)
+ elif i < v3:
+ x_coord = -lx / 2 + (distance * (i - v2))
+ y_coord = ly / 2
+ true_position = (x_coord, y_coord, 0)
+ else:
+ x_coord = lx / 2
+ y_coord = ly / 2 - (distance * (i - v3))
+ true_position = (x_coord, y_coord, 0)
+ true_position = self.apply(self.zone_pose, true_position)
+ self.goal['places'][part_id] = (true_position, (0, 0, 0, 1.))
+ symmetry = 0
+ self.goal['steps'][0][part_id] = (symmetry, [part_id])
+ if self.target_debug_markers:
+ sq_pose = ((true_position[0], true_position[1], 0.002), (0,0,0,1))
+ sq_template = 'assets/square/square-template-allsides-blue.urdf'
+ replace = {'DIM': (0.003,), 'HALF': (0.003 / 2,)}
+ urdf = self.fill_template(sq_template, replace)
+ env.add_object(urdf, sq_pose, fixed=True)
+ os.remove(urdf)
+ else:
+ print(f'Warning, env {self._name} will not have goals.')
+[INST] Add a cable, but make it connected at both ends to form a ring.
+For consistency, add each `part_id` to various information tracking
+lists and dictionaries .. [/INST] def add_cable_ring(self, env, info={}, cable_idx=0):
+ def rad_to_deg(rad):
+ return (rad * 180.0) / np.pi
+ def get_discretized_rotations(num_rotations):
+ theta = i * (2 * np.pi) / num_rotations
+ return (theta, rad_to_deg(theta))
+ num_parts = self.num_parts
+ radius = self.radius
+ color = self.colors[cable_idx] + [1]
+ ring_radius = info['ring_radius']
+ beads = []
+ bead_positions_l = []
+ position = np.float32(info['center_position'])
+ part_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=[radius]*3)
+ part_visual = p.createVisualShape(p.GEOM_SPHERE, radius=radius*1.5)
+ for i in range(num_parts):
+ angle_rad, _ = get_discretized_rotations(num_parts)
+ px = ring_radius * np.cos(angle_rad)
+ py = ring_radius * np.sin(angle_rad)
+ bead_position = np.float32([position[0] + px, position[1] + py, 0.01])
+ part_id = p.createMultiBody(0.1, part_shape, part_visual,
+ basePosition=bead_position)
+ p.changeVisualShape(part_id, -1, rgbaColor=color)
+ if i > 0:
+ parent_frame = bead_position - bead_positions_l[-1]
+ constraint_id = p.createConstraint(
+ parentBodyUniqueId=beads[-1],
+ parentLinkIndex=-1,
+ childBodyUniqueId=part_id,
+ childLinkIndex=-1,
+ jointType=p.JOINT_POINT2POINT,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=parent_frame,
+ childFramePosition=(0, 0, 0))
+ p.changeConstraint(constraint_id, maxForce=100)
+ if i == num_parts - 1:
+ parent_frame = bead_positions_l[0] - bead_position
+ constraint_id = p.createConstraint(
+ parentBodyUniqueId=part_id,
+ parentLinkIndex=-1,
+ childBodyUniqueId=beads[0],
+ childLinkIndex=-1,
+ jointType=p.JOINT_POINT2POINT,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=parent_frame,
+ childFramePosition=(0, 0, 0))
+ p.changeConstraint(constraint_id, maxForce=100)
+ beads.append(part_id)
+ bead_positions_l.append(bead_position)
+ self.cable_bead_IDs.append(part_id)
+ self._IDs[part_id] = f'cable_part_{str(part_id).zfill(2)}'
+ env.objects.append(part_id)
+ self.object_points[part_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ if 'cable-ring' in self._name:
+ true_position = (bead_position[0], bead_position[1], 0)
+ self.goal['places'][part_id] = (true_position, (0, 0, 0, 1.))
+ symmetry = 0
+ self.goal['steps'][0][part_id] = (symmetry, [part_id])
+ if info['targets_visible']:
+ sq_pose = ((true_position[0], true_position[1], 0.002), (0,0,0,1))
+ sq_template = 'assets/square/square-template-allsides-green.urdf'
+ replace = {'DIM': (0.003,), 'HALF': (0.003 / 2,)}
+ urdf = self.fill_template(sq_template, replace)
+ env.add_object(urdf, sq_pose, fixed=True)
+ os.remove(urdf)
+ else:
+ print(f'Warning, env {self._name} will not have goals.')
+[INST] Generate randomly shaped box, from aligning env.. [/INST] def add_random_box(self, env):
+ box_size = self.random_size(0.05, 0.15, 0.05, 0.15, 0.01, 0.06)
+ box_pose = self.random_pose(env, box_size)
+ box_template = 'assets/box/box-template.urdf'
+ box_urdf = self.fill_template(box_template, {'DIM': box_size})
+ box_id = env.add_object(box_urdf, box_pose)
+ os.remove(box_urdf)
+ self.color_random_brown(box_id)
+ self.object_points[box_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ self._IDs[box_id] = 'random_box'
+ return box_id
+[INST] Only if we are using the cable-ring environment
+So far I think using 0.8 or higher might be too hard because moving
+the ring to the target can cause other areas to decrease.. [/INST] def area_thresh(self):
+ return 0.75
+[INST] Reset to start an episode.
+If generating training data for goal-conditioned Transporters with
+`main.py` or goal images using `generate_goals.py`, then call the
+superclass. The code already puts the bead poses inside `info`. For
+this env it's IDs 4 through 27 (for 24 beads) but I scale it based on
+num_parts in case we change this value.
+
+If loading using `load.py` (detect with self.goal_cond_testing) then
+must make targets based on loaded info. However, we still have to
+randomly create the cable, so the easiest way might be to make the
+cable as usual, and then just override the 'places' key later.. [/INST] def reset(self, env, last_info=None):
+ super().reset(env)
+ if self.goal_cond_testing:
+ assert last_info is not None
+ self.goal['places'] = self._get_goal_info(last_info)
+[INST] Used to determine the goal given the last `info` dict.. [/INST] def _get_goal_info(self, last_info):
+ start_ID = 4
+ end_ID = start_ID + self.num_parts
+ places = {}
+ for ID in range(start_ID, end_ID):
+ assert ID in last_info, f'something went wrong with ID={ID}'
+ position, _, _ = last_info[ID]
+ places[ID] = (position, (0, 0, 0, 1.))
+ return places
+[INST] Apply rigid transformation to 3D pointcloud.. [/INST] def transform_pointcloud(points, transform):
+ padding = ((0, 0), (0, 0), (0, 1))
+ homogen_points = np.pad(points.copy(), padding,
+ 'constant', constant_values=1)
+ for i in range(3):
+ points[..., i] = np.sum(transform[i, :] * homogen_points, axis=-1)
+ return points
+[INST] Reconstruct top-down heightmap views from multiple 3D pointclouds.
+The color and depth are np.arrays or lists, where the leading dimension
+or list wraps around differnet viewpoints. So, if the np.array shape is
+(3,480,640,3), then the leading '3' denotes the number of camera views.
+
+documentation.. [/INST] def reconstruct_heightmaps(color, depth, configs, bounds, pixel_size):
+ heightmaps, colormaps = [], []
+ for color, depth, config in zip(color, depth, configs):
+ intrinsics = np.array(config['intrinsics']).reshape(3, 3)
+ xyz = get_pointcloud(depth, intrinsics)
+ position = np.array(config['position']).reshape(3, 1)
+ rotation = p.getMatrixFromQuaternion(config['rotation'])
+ rotation = np.array(rotation).reshape(3, 3)
+ transform = np.eye(4)
+ transform[:3, :] = np.hstack((rotation, position))
+ xyz = transform_pointcloud(xyz, transform)
+ heightmap, colormap = get_heightmap(xyz, color, bounds, pixel_size)
+ heightmaps.append(heightmap)
+ colormaps.append(colormap)
+ return heightmaps, colormaps
+[INST] Convert from pixel location on heightmap to 3D position.. [/INST] def pixel_to_position(pixel, height, bounds, pixel_size, skip_height=False):
+ u, v = pixel
+ x = bounds[0, 0] + v * pixel_size
+ y = bounds[1, 0] + u * pixel_size
+ if not skip_height:
+ z = bounds[2, 0] + height[u, v]
+ else:
+ z = 0.0
+ return (x, y, z)
+[INST] Convert from 3D position to pixel location on heightmap.. [/INST] def position_to_pixel(position, bounds, pixel_size):
+ u = int(np.round((position[1] - bounds[1, 0]) / pixel_size))
+ v = int(np.round((position[0] - bounds[0, 0]) / pixel_size))
+ return (u, v)
+[INST] Vectorized version of unproject(), for N points.. [/INST] def unproject_vectorized(uv_coordinates: np.ndarray, depth_values: np.ndarray,
+ intrinsic: np.ndarray,
+ distortion: np.ndarray) -> np.ndarray:
+ cam_mtx = intrinsic
+ cam_dist = np.array(distortion)
+ points_undistorted = cv2.undistortPoints(
+ uv_coordinates.reshape((-1, 1, 2)), cam_mtx, cam_dist).squeeze()
+ x = points_undistorted[:, 0] * depth_values
+ y = points_undistorted[:, 1] * depth_values
+ xyz = np.vstack((x, y, depth_values)).T
+ return xyz
+[INST] Unproject depth image into 3D point cloud, using calibration.. [/INST] def unproject_depth_vectorized(im_depth: np.ndarray, depth_dist: np.ndarray,
+ camera_mtx: np.ndarray,
+ camera_dist: np.ndarray) -> np.ndarray:
+ h, w = im_depth.shape
+ u_map, v_map = np.meshgrid(np.linspace(0, w - 1, w), np.linspace(0, h - 1, h))
+ adjusted_depth = depth_dist[0] + im_depth * depth_dist[1]
+ uv_coordinates = np.stack((u_map.reshape(-1), v_map.reshape(-1)), axis=-1)
+ return unproject_vectorized(uv_coordinates, adjusted_depth.reshape(-1),
+ camera_mtx, camera_dist)
+[INST] Sample data point from a custom distribution.. [/INST] def sample_distribution(prob, n_samples=1):
+ flat_prob = np.ndarray.flatten(prob) / np.sum(prob)
+ rand_ind = np.random.choice(
+ np.arange(len(flat_prob)), n_samples, p=flat_prob, replace=False)
+ rand_ind_coords = np.array(np.unravel_index(rand_ind, prob.shape)).T
+ return np.int32(rand_ind_coords.squeeze())
+[INST] Valid transform only if pixel locations are still in FoV after transform.. [/INST] def check_transform(image, pixel, transform):
+ new_pixel = np.flip(np.int32(np.round(np.dot(transform, np.float32(
+ [pixel[1], pixel[0], 1.]).reshape(3, 1))))[:2].squeeze())
+ valid = np.all(new_pixel >= 0) and new_pixel[0] < image.shape[
+ 0] and new_pixel[1] < image.shape[1]
+ return valid, new_pixel
+[INST] Train on dataset for a specific number of iterations.. [/INST] def train(self, dataset, num_iter, writer):
+ @tf.function
+ def pick_train_step(model, optim, in_tensor, yxtheta, loss_criterion):
+ with tf.GradientTape() as tape:
+ output = model(in_tensor)
+ loss = loss_criterion(yxtheta, output)
+ grad = tape.gradient(loss, model.trainable_variables)
+ optim.apply_gradients(
+ zip(grad, model.trainable_variables))
+ return loss
+ @tf.function
+ def place_train_step(model, optim, in_tensor, yxtheta, loss_criterion):
+ with tf.GradientTape() as tape:
+ output = model(in_tensor)
+ loss = loss_criterion(yxtheta, output)
+ grad = tape.gradient(loss, model.trainable_variables)
+ optim.apply_gradients(
+ zip(grad, model.trainable_variables))
+ return loss
+ for i in range(num_iter):
+ start = time.time()
+ input_images, p0s, p0_thetas = [], [], []
+ p1s, p1_thetas = [], []
+ for _ in range(self.batch_size):
+ obs, act, info = dataset.random_sample()
+ configs = act['camera_config']
+ colormap, heightmap = self.get_heightmap(obs, configs)
+ l_object = info[4]
+ pose0 = l_object[0], l_object[1]
+ l_target = info[5]
+ pose1 = l_target[0], l_target[1]
+ p0_position, p0_rotation = pose0[0], pose0[1]
+ p0 = utils.position_to_pixel(p0_position, self.bounds, self.pixel_size)
+ p0_theta = -np.float32(p.getEulerFromQuaternion(p0_rotation)[2])
+ p1_position, p1_rotation = pose1[0], pose1[1]
+ p1 = utils.position_to_pixel(p1_position, self.bounds, self.pixel_size)
+ p1_theta = -np.float32(p.getEulerFromQuaternion(p1_rotation)[2])
+ p1_xytheta = np.array([p1_position[0], p1_position[1], p1_theta])
+ input_image = colormap
+ input_images.append(input_image)
+ p0s.append(p0)
+ p0_thetas.append(p0_theta)
+ p1s.append(p1)
+ p1_thetas.append(p1_theta)
+ input_image = np.array(input_images)
+ p0 = np.array(p0s)
+ p0_theta = np.array(p0_thetas)
+ p1 = np.array(p1s)
+ p1_theta = np.array(p1_thetas)
+ loss0 = self.pick_regression_model.train_pick(input_image, p0, p0_theta, pick_train_step)
+ with writer.as_default():
+ tf.summary.scalar('pick_loss', self.pick_regression_model.metric.result(),
+ step=self.total_iter+i)
+ loss1 = self.place_regression_model.train_pick(input_image, p1, p1_theta, place_train_step)
+ with writer.as_default():
+ tf.summary.scalar('place_loss', self.place_regression_model.metric.result(),
+ step=self.total_iter+i)
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss0:.4f} {loss1:.4f} Iter time:', time.time() - start)
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action given visual observations.. [/INST] def act(self, obs, info):
+ self.pick_regression_model.set_batch_size(1)
+ self.place_regression_model.set_batch_size(1)
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if not obs:
+ return act
+ colormap, heightmap = self.get_heightmap(obs, self.camera_config)
+ input_image = colormap[None, ...]
+ p0_yxtheta = self.pick_regression_model.forward(input_image)[0]
+ p0_pixel = [int(p0_yxtheta[0]), int(p0_yxtheta[1])]
+ p0_theta = p0_yxtheta[2]
+ p1_yxtheta = self.place_regression_model.forward(input_image)[0]
+ p1_pixel = [int(p1_yxtheta[0]), int(p1_yxtheta[1])]
+ p1_theta = p1_yxtheta[2]
+ if p1_pixel[0] < 0:
+ p1_pixel[0] = 0
+ if p1_pixel[0] > 319:
+ p1_pixel[0] = 319
+ if p1_pixel[1] < 0:
+ p1_pixel[1] = 0
+ if p1_pixel[1] > 159:
+ p1_pixel[1] = 159
+ p0_position = utils.pixel_to_position(p0_pixel, heightmap, self.bounds, self.pixel_size)
+ p1_position = utils.pixel_to_position(p1_pixel, heightmap, self.bounds, self.pixel_size)
+ p0_rotation = p.getQuaternionFromEuler((0, 0, -p0_theta))
+ p1_rotation = p.getQuaternionFromEuler((0, 0, -p1_theta))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, p0_rotation),
+ 'pose1': (p1_position, p1_rotation)}
+ act['params'] = params
+ self.pick_regression_model.set_batch_size(self.batch_size)
+ self.place_regression_model.set_batch_size(self.batch_size)
+ return act
+[INST] Forward pass of our goal-conditioned Transporter.
+Relevant shapes and info.
+
+
+
+Run input through all three networks, to get output of the same
+shape, except that the last channel is 3 (output_dim). Then, the
+output for one stream has the convolutional kernels for another.
+
+I actually think cropping after the query network is easier, because
+otherwise we have to do a forward pass, then call tf.multiply, then
+do another forward pass, which splits up the computation.. [/INST] def forward(self, in_img, goal_img, p, apply_softmax=True):
+ assert in_img.shape == goal_img.shape, f'{in_img.shape}, {goal_img.shape}'
+ input_unproc = np.pad(in_img, self.padding, mode='constant')
+ input_data = self.preprocess(input_unproc.copy())
+ input_shape = (1,) + input_data.shape
+ input_data = input_data.reshape(input_shape)
+ in_tensor = tf.convert_to_tensor(input_data, dtype=tf.float32)
+ goal_unproc = np.pad(goal_img, self.padding, mode='constant')
+ goal_data = self.preprocess(goal_unproc.copy())
+ goal_shape = (1,) + goal_data.shape
+ goal_data = goal_data.reshape(goal_shape)
+ goal_tensor = tf.convert_to_tensor(goal_data, dtype=tf.float32)
+ pivot = np.array([p[1], p[0]]) + self.pad_size
+ rvecs = self.get_se2(self.num_rotations, pivot)
+ in_logits, kernel_nocrop_logits, goal_logits = \
+ self.model([in_tensor, in_tensor, goal_tensor])
+ goal_x_in_logits = tf.multiply(goal_logits, in_logits)
+ goal_x_kernel_logits = tf.multiply(goal_logits, kernel_nocrop_logits)
+ crop = tf.identity(goal_x_kernel_logits)
+ crop = tf.repeat(crop, repeats=self.num_rotations, axis=0)
+ crop = tfa.image.transform(crop, rvecs, interpolation='NEAREST')
+ kernel = crop[:,
+ p[0]:(p[0] + self.crop_size),
+ p[1]:(p[1] + self.crop_size),
+ :]
+ assert kernel.shape == (self.num_rotations, self.crop_size, self.crop_size, self.odim)
+ kernel_paddings = tf.constant([[0, 0], [0, 1], [0, 1], [0, 0]])
+ kernel = tf.pad(kernel, kernel_paddings, mode='CONSTANT')
+ kernel = tf.transpose(kernel, [1, 2, 3, 0])
+ output = tf.nn.convolution(goal_x_in_logits, kernel, data_format="NHWC")
+ output = (1 / (self.crop_size**2)) * output
+ if apply_softmax:
+ output_shape = output.shape
+ output = tf.reshape(output, (1, np.prod(output.shape)))
+ output = tf.nn.softmax(output)
+ output = np.float32(output).reshape(output_shape[1:])
+ return output
+[INST] Given logits (BEFORE tf.nn.convolution), get a heatmap.
+Here we apply a softmax to make it more human-readable. However, the
+tf.nn.convolution with the learned kernels happens without a softmax
+on the logits. [Update: wait, then why should we have a softmax,
+then. I forgot why we did this ...]. [/INST] def visualize_logits(self, logits, name):
+ original_shape = logits.shape
+ logits = tf.reshape(logits, (1, np.prod(original_shape)))
+ vis_transport = np.float32(logits).reshape(original_shape)
+ vis_transport = vis_transport[0]
+ vis_transport = vis_transport - np.min(vis_transport)
+ vis_transport = 255 * vis_transport / np.max(vis_transport)
+ vis_transport = cv2.applyColorMap(np.uint8(vis_transport), cv2.COLORMAP_RAINBOW)
+ vis_transport = cv2.cvtColor(vis_transport, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(f'tmp/logits_{name}.png', vis_transport)
+ plt.subplot(1, 1, 1)
+ plt.title(f'Logits: {name}', fontsize=15)
+ plt.imshow(vis_transport)
+ plt.tight_layout()
+ plt.show()
+[INST] The identity block is the block that has no conv layer at shortcut.
+
+Returns
+Output tensor for the block.. [/INST] def identity_block(input_tensor, kernel_size, filters, stage, block, activation=True, include_batchnorm=False):
+ filters1, filters2, filters3 = filters
+ batchnorm_axis = 3
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+ x = tf.keras.layers.Conv2D(filters1, (1, 1),
+ dilation_rate=(1, 1),
+ kernel_initializer='glorot_uniform',
+ name=conv_name_base + '2a')(input_tensor)
+ if include_batchnorm:
+ x = tf.keras.layers.BatchNormalization(axis=batchnorm_axis, name=bn_name_base + '2a')(x)
+ x = tf.keras.layers.ReLU()(x)
+ x = tf.keras.layers.Conv2D(filters2, kernel_size,
+ dilation_rate=(1, 1),
+ padding='same',
+ kernel_initializer='glorot_uniform',
+ name=conv_name_base + '2b')(x)
+ if include_batchnorm:
+ x = tf.keras.layers.BatchNormalization(axis=batchnorm_axis, name=bn_name_base + '2b')(x)
+ x = tf.keras.layers.ReLU()(x)
+ x = tf.keras.layers.Conv2D(filters3, (1, 1),
+ dilation_rate=(1, 1),
+ kernel_initializer='glorot_uniform',
+ name=conv_name_base + '2c')(x)
+ if include_batchnorm:
+ x = tf.keras.layers.BatchNormalization(axis=batchnorm_axis, name=bn_name_base + '2c')(x)
+ x = tf.keras.layers.add([x, input_tensor])
+ if activation:
+ x = tf.keras.layers.ReLU()(x)
+ return x
+[INST] produces an hourglass FCN network, adapted to CoRL submission size.
+
+
+
+
+
+Transport, query module, assumes cropping beforehand.
+
+Here I ignore output after identity blocks, which produce tensors of the same size.
+
+Parameters. [/INST] def ResNet43_8s(input_shape, output_dim, include_batchnorm=False, batchnorm_axis=3, prefix='', cutoff_early=False):
+ input_data = tf.keras.layers.Input(shape=input_shape)
+ x = tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='glorot_uniform', name=prefix + 'conv1')(input_data)
+ if include_batchnorm:
+ x = tf.keras.layers.BatchNormalization(axis=batchnorm_axis, name=prefix + 'bn_conv1')(x)
+ x = tf.keras.layers.ReLU()(x)
+ if cutoff_early:
+ x = conv_block(x, 5, [64, 64, output_dim], stage=2, block=prefix + 'a', strides=(1, 1), include_batchnorm=include_batchnorm)
+ x = identity_block(x, 5, [64, 64, output_dim], stage=2, block=prefix + 'b', include_batchnorm=include_batchnorm)
+ return input_data, x
+ x = conv_block(x, 3, [64, 64, 64], stage=2, block=prefix + 'a', strides=(1, 1))
+ x = identity_block(x, 3, [64, 64, 64], stage=2, block=prefix + 'b')
+ x = conv_block(x, 3, [128, 128, 128], stage=3, block=prefix + 'a', strides=(2, 2))
+ x = identity_block(x, 3, [128, 128, 128], stage=3, block=prefix + 'b')
+ x = conv_block(x, 3, [256, 256, 256], stage=4, block=prefix + 'a', strides=(2, 2))
+ x = identity_block(x, 3, [256, 256, 256], stage=4, block=prefix + 'b')
+ x = conv_block(x, 3, [512, 512, 512], stage=5, block=prefix + 'a', strides=(2, 2))
+ x = identity_block(x, 3, [512, 512, 512], stage=5, block=prefix + 'b')
+ x = conv_block(x, 3, [256, 256, 256], stage=6, block=prefix + 'a', strides=(1, 1))
+ x = identity_block(x, 3, [256, 256, 256], stage=6, block=prefix + 'b')
+ x = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name=prefix + 'upsample_1')(x)
+ x = conv_block(x, 3, [128, 128, 128], stage=7, block=prefix + 'a', strides=(1, 1))
+ x = identity_block(x, 3, [128, 128, 128], stage=7, block=prefix + 'b')
+ x = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name=prefix + 'upsample_2')(x)
+ x = conv_block(x, 3, [64, 64, 64], stage=8, block=prefix + 'a', strides=(1, 1))
+ x = identity_block(x, 3, [64, 64, 64], stage=8, block=prefix + 'b')
+ x = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name=prefix + 'upsample_3')(x)
+ x = conv_block(x, 3, [16, 16, output_dim], stage=9, block=prefix + 'a', strides=(1, 1), activation=False)
+ output = identity_block(x, 3, [16, 16, output_dim], stage=9, block=prefix + 'b', activation=False)
+ return input_data, output
+[INST] Depending on the env, make changes to image suffix `suff`.. [/INST] def _change_name(self, suff, info_extra):
+ if 'cable-ring' in self.path:
+ i1 = info_extra['convex_hull_area']
+ i2 = info_extra['best_possible_area']
+ f = i1 / i2
+ suff = suff.replace('.png',
+ f'-area-{i1:0.3f}-best-{i2:0.3f}-FRAC-{f:0.3f}.png')
+ elif 'cloth-flat' in self.path:
+ i1 = info_extra['cloth_coverage']
+ suff = suff.replace('.png', f'-coverage-{i1:0.3f}.png')
+ elif 'bag-alone' in self.path:
+ i1 = info_extra['convex_hull_area']
+ i2 = info_extra['best_possible_area']
+ suff = suff.replace('.png', f'-area-{i1:0.3f}-best-{i2:0.3f}.png')
+ else:
+ pass
+ return suff
+[INST] For each item (timestep) in this episode, save relevant images.. [/INST] def _save_images(self, episode_len, color_l, depth_l, info_l, outdir, i_ep):
+ for t in range(episode_len):
+ assert color_l[t].shape == (3, 480, 640, 3), color_l[t].shape
+ assert depth_l[t].shape == (3, 480, 640), depth_l[t].shape
+ info = info_l[t]
+ info_r = info['extras']
+ for k in range(3):
+ c_img = color_l[t][k]
+ d_img = depth_l[t][k]
+ assert c_img.dtype == 'uint8', c_img.dtype
+ assert d_img.dtype == 'float32', d_img.dtype
+ d_img = process_depth(img=d_img)
+ obs_input = {'color': color_l[t], 'depth': depth_l[t]}
+ colormap, heightmap = get_heightmap(obs_input)
+ heightmap_proc = process_depth(img=heightmap)
+ c_img_front = color_l[t][0]
+ c_img_front = cv2.resize(c_img_front, (426,320))
+ barrier = np.zeros((320,4,3))
+ combo = np.concatenate((
+ cv2.cvtColor(c_img_front, cv2.COLOR_BGR2RGB),
+ barrier,
+ cv2.cvtColor(colormap, cv2.COLOR_RGB2BGR),
+ barrier,
+ heightmap_proc),
+ axis=1)
+ suffix_all = f'{i_ep:06d}-{t:02d}-OVERALL.png'
+ suffix_all = self._change_name(suffix_all, info_r)
+ cv2.imwrite(os.path.join(outdir,suffix_all), combo)
+[INST] Computes IoU and coverage based on pixels.
+
+
+
+NOTE I: assumes cloth can be segmented by detecting the workspace and
+zone lines, and that any pixel OTHER than those belongs to cloth.
+
+NOTE II: IoU and coverage are computed in the same way, except that
+the former divides by the union, the latter divides by just the goal.. [/INST] def compute_pixel_IoU_coverage(self):
+ _, _, object_mask = self.get_object_masks(self.env)
+ IDs = [1, self.zone_ID]
+ mask = np.isin(object_mask, test_elements=IDs)
+ idx_0s = (mask == 0)
+ idx_1s = (mask == 1)
+ mask[idx_0s] = 1
+ mask[idx_1s] = 0
+ cloth_mask_bool = np.array(mask, dtype=bool)
+ overlap = self.target_hull_bool & cloth_mask_bool
+ union = self.target_hull_bool | cloth_mask_bool
+ overlap_count = np.count_nonzero(overlap)
+ union_count = np.count_nonzero(union)
+ goal_count = np.count_nonzero(self.target_hull_bool)
+ pixel_IoU = overlap_count / float(union_count)
+ pixel_coverage = overlap_count / float(goal_count)
+ return (pixel_IoU, pixel_coverage)
+[INST] For cloth-cover, if it's covered, it should NOT be in the mask.. [/INST] def is_item_covered(self):
+ _, _, object_mask = self.get_object_masks(self.env)
+ assert len(self.block_IDs) == 1, self.block_IDs
+ block = self.block_IDs[0]
+ return 1 - float(block in object_mask)
+[INST] Adds a square target (green) zone.
+To handle goal-conditioned cloth flattening, we save `zone_pose` and
+provide it as input, to avoid re-sampling. This means starting cloth
+states are sampled at a correct distance, and that the same IoU
+metric can be used as the reward.. [/INST] def add_zone(self, env, zone_pose=None):
+ zone_template = 'assets/zone/zone-template.urdf'
+ replace = {'LENGTH': (self._zone_scale, self._zone_scale)}
+ zone_urdf = self.fill_template(zone_template, replace)
+ if zone_pose is not None:
+ self.zone_pose = zone_pose
+ else:
+ self.zone_pose = self.random_pose(env, self.zone_size)
+ zone_id = env.add_object(zone_urdf, self.zone_pose, fixed=True)
+ self._IDs[zone_id] = 'zone'
+ self.get_masks_target(env, zone_id)
+ os.remove(zone_urdf)
+ self.zone_ID = zone_id
+[INST] Adding a cloth from an .obj file.. [/INST] def add_cloth(self, env, base_pos, base_orn):
+ cloth_id = p.loadSoftBody(
+ fileName=self._f_cloth,
+ basePosition=base_pos,
+ baseOrientation=base_orn,
+ collisionMargin=self._collisionMargin,
+ scale=self._cloth_scale,
+ mass=self._mass,
+ useNeoHookean=0,
+ useBendingSprings=1,
+ useMassSpring=1,
+ springElasticStiffness=40,
+ springDampingStiffness=0.1,
+ springDampingAllDirections=0,
+ useSelfCollision=1,
+ frictionCoeff=1.0,
+ useFaceContact=1,)
+ p_version = pkg_resources.get_distribution('pybullet').version
+ if p_version == '3.0.4':
+ color = U.COLORS['yellow'] + [1]
+ p.changeVisualShape(cloth_id, -1, flags=p.VISUAL_SHAPE_DOUBLE_SIDED,
+ rgbaColor=color)
+ self._IDs[cloth_id] = 'cloth'
+ self.object_points[cloth_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ env.objects.append(cloth_id)
+ self.def_IDs.append(cloth_id)
+ nb_vertices, _ = p.getMeshData(cloth_id, -1, flags=p.MESH_DATA_SIMULATION_MESH)
+ assert nb_vertices == self.n_cuts * self.n_cuts
+ return cloth_id
+[INST] Sample the bag (and let it drop) to get interesting starting states.. [/INST] def _sample_cloth_orientation(self):
+ orn = [self._base_orn[0] + np.random.normal(loc=0.0, scale=self._scalex),
+ self._base_orn[1] + np.random.normal(loc=0.0, scale=self._scaley),
+ self._base_orn[2] + np.random.normal(loc=0.0, scale=self._scalez),]
+ return p.getQuaternionFromEuler(orn)
+[INST] Reset to start an episode.
+Call the superclass to generate as usual, and then remove the zone
+here. Requires care in `environment.py` to avoid iterating over
+invalid IDs, and we need the zone in the superclass for many reasons.
+
+If loading goal images, we cannot just override self.target_hull_bool
+because that means the cloth isn't sampled the same way as if the
+target was visible. We need to first decide on the pose, THEN sample
+the cloth. Easiset solution: load the sampled `zone_pose`, then pass
+that into the reset() call so that we don't re-sample `zone_pose`.
+Everything is reconsructed from there.. [/INST] def reset(self, env, last_info=None):
+ zone_pose = None
+ if last_info is not None:
+ zone_pose = last_info['sampled_zone_pose']
+ super().reset(env, zone_pose=zone_pose)
+ p.removeBody(self.zone_ID)
+[INST] Standard gym environment rollout, following as in main.py.. [/INST] def rollout(agent, env, task):
+ episode = []
+ total_reward = 0
+ obs = env.reset(task)
+ info = env.info
+ for t in range(task.max_steps):
+ act = agent.act(obs, info)
+ if len(obs) > 0 and act['primitive']:
+ episode.append((obs, act, info))
+ (obs, reward, done, info) = env.step(act)
+ total_reward += reward
+ last_stuff = (obs, info)
+ if done:
+ break
+ return total_reward, t, episode, last_stuff
+[INST] Be careful with checking this condition.. [/INST] def is_goal_conditioned(args):
+ goal_tasks = ['insertion-goal', 'cable-shape-notarget', 'cable-line-notarget',
+ 'cloth-flat-notarget', 'bag-color-goal']
+ return (args.task in goal_tasks)
+[INST] In some cases, we should filter out demonstrations.
+Filter for if t == 0, which means the initial state was a success.
+Also, for the bag envs, if we end up in a catastrophic state, I exit
+gracefully and we should avoid those demos (they won't have images we
+need for the dataset anyway).. [/INST] def ignore_this_demo(args, demo_reward, t, last_extras):
+ ignore = (t == 0)
+ if 'exit_gracefully' in last_extras:
+ assert last_extras['exit_gracefully']
+ return True
+ if (args.task in ['bag-color-goal']) and demo_reward <= 0.5:
+ return True
+ return False
+[INST] Transport pixel p to pixel q.
+
+the `in_img` will include the color and depth. Much is
+similar to the attention model if we're not using the per-pixel loss:
+(a) forward pass, (b) get angle discretizations [though we set only 1
+rotation for the picking model], (c) make the label consider
+rotations in the last axis, but only provide the label to one single
+(pixel,rotation) combination, (d) follow same exact steps for the
+non-per pixel loss otherwise. The output reshaping to (1, ...) is
+done in the attention model forward pass, but not in the transport
+forward pass. Note the `1` meaning a batch size of 1.. [/INST] def train(self, in_img, p, q, theta):
+ self.metric.reset_states()
+ with tf.GradientTape() as tape:
+ output = self.forward(in_img, p, apply_softmax=False)
+ itheta = theta / (2 * np.pi / self.num_rotations)
+ itheta = np.int32(np.round(itheta)) % self.num_rotations
+ label_size = in_img.shape[:2] + (self.num_rotations,)
+ label = np.zeros(label_size)
+ label[q[0], q[1], itheta] = 1
+ if self.per_pixel_loss:
+ sampling = True
+ if sampling:
+ num_samples = 100
+ inegative = utils.sample_distribution(1 - label, num_samples)
+ inegative = [np.ravel_multi_index(i, label.shape) for i in inegative]
+ ipositive = np.ravel_multi_index([q[0], q[1], itheta], label.shape)
+ output = tf.reshape(output, (-1, 2))
+ output_samples = ()
+ for i in inegative:
+ output_samples += (tf.reshape(output[i, :], (1, 2)),)
+ output_samples += (tf.reshape(output[ipositive, :], (1, 2)),)
+ output = tf.concat(output_samples, axis=0)
+ label = np.int32([0] * num_samples + [1])[..., None]
+ label = np.hstack((1 - label, label))
+ weights = np.ones(label.shape[0])
+ weights[:num_samples] = 1./num_samples
+ weights = weights / np.sum(weights)
+ else:
+ ipositive = np.ravel_multi_index([q[0], q[1], itheta], label.shape)
+ output = tf.reshape(output, (-1, 2))
+ label = np.int32(np.reshape(label, (int(np.prod(label.shape)), 1)))
+ label = np.hstack((1 - label, label))
+ weights = np.ones(label.shape[0]) * 0.0025
+ weights[ipositive] = 1
+ label = tf.convert_to_tensor(label, dtype=tf.int32)
+ weights = tf.convert_to_tensor(weights, dtype=tf.float32)
+ loss = tf.nn.softmax_cross_entropy_with_logits(label, output)
+ loss = tf.reduce_mean(loss * weights)
+ else:
+ label = label.reshape(1, np.prod(label.shape))
+ label = tf.convert_to_tensor(label, dtype=tf.float32)
+ output = tf.reshape(output, (1, np.prod(output.shape)))
+ loss = tf.nn.softmax_cross_entropy_with_logits(label, output)
+ loss = tf.reduce_mean(loss)
+ grad = tape.gradient(loss, self.model.trainable_variables)
+ self.optim.apply_gradients(zip(grad, self.model.trainable_variables))
+ self.metric(loss)
+ return np.float32(loss)
+[INST] Given logits (BEFORE tf.nn.convolution), get heatmap.
+Here we apply a softmax to make it more human-readable. However, the
+tf.nn.convolution with the learned kernels happens without a softmax
+on the logits.. [/INST] def visualize_logits(self, logits):
+ original_shape = logits.shape
+ logits = tf.reshape(logits, (1, np.prod(original_shape)))
+ logits = tf.nn.softmax(logits)
+ vis_transport = np.float32(logits).reshape(original_shape)
+ vis_transport = vis_transport[0]
+ vis_transport = vis_transport - np.min(vis_transport)
+ vis_transport = 255 * vis_transport / np.max(vis_transport)
+ vis_transport = cv2.applyColorMap(np.uint8(vis_transport), cv2.COLORMAP_RAINBOW)
+ plt.subplot(1, 1, 1)
+ plt.title(f'Logits', fontsize=15)
+ plt.imshow(vis_transport)
+ plt.tight_layout()
+ plt.show()
+[INST] Regress pixel p to pixel q.. [/INST] def train_place_conditioned_on_pick(self, in_img, p, q, theta):
+ self.metric.reset_states()
+ with tf.GradientTape() as tape:
+ output = self.forward(in_img)
+ delta_pixel = np.array(q) - np.array(p)
+ yxtheta = np.array([delta_pixel[0], delta_pixel[1], theta])
+ loss = self.loss_criterion(yxtheta, output)
+ grad = tape.gradient(loss, self.model.trainable_variables)
+ self.optim.apply_gradients(zip(grad, self.model.trainable_variables))
+ self.metric(loss)
+ return np.float32(loss)
+[INST] Train on dataset for a specific number of iterations.. [/INST] def train(self, dataset, num_iter, writer, validation_dataset):
+ VALIDATION_RATE = 100
+ @tf.function
+ def pick_train_step(model, optim, in_tensor, yxtheta, loss_criterion):
+ with tf.GradientTape() as tape:
+ output = model(in_tensor)
+ loss = loss_criterion(yxtheta, output)
+ grad = tape.gradient(loss, model.trainable_variables)
+ optim.apply_gradients(
+ zip(grad, model.trainable_variables))
+ return loss
+ @tf.function
+ def pick_valid_step(model, optim, in_tensor, yxtheta, loss_criterion):
+ with tf.GradientTape() as tape:
+ output = model(in_tensor)
+ loss = loss_criterion(yxtheta, output)
+ return loss
+ for i in range(num_iter):
+ start = time.time()
+ batch_obs, batch_act = self.get_data_batch(dataset)
+ loss0 = self.regression_model.train_pick(batch_obs, batch_act, pick_train_step)
+ with writer.as_default():
+ tf.summary.scalar('pick_loss', self.regression_model.metric.result(),
+ step=self.total_iter+i)
+ print(f'Train Iter: {self.total_iter + i} Loss: {loss0:.4f} Iter time:', time.time() - start)
+ if (self.total_iter + i) % VALIDATION_RATE == 0:
+ print("Validating!")
+ tf.keras.backend.set_learning_phase(0)
+ batch_obs, batch_act = self.get_data_batch(validation_dataset, augment=False)
+ loss0 = self.regression_model.train_pick(batch_obs, batch_act, pick_valid_step, validate=True)
+ with writer.as_default():
+ tf.summary.scalar('validation_pick_loss', self.regression_model.val_metric.result(),
+ step=self.total_iter+i)
+ tf.keras.backend.set_learning_phase(1)
+ self.total_iter += num_iter
+ self.save()
+[INST] Run inference and return best action given visual observations.. [/INST] def act(self, obs, gt_act, info):
+ self.regression_model.set_batch_size(1)
+ act = {'camera_config': self.camera_config, 'primitive': None}
+ if not obs:
+ return act
+ colormap, heightmap = self.get_heightmap(obs, self.camera_config)
+ input_image = np.concatenate((colormap,
+ heightmap[..., None],
+ heightmap[..., None],
+ heightmap[..., None]), axis=2)[None, ...]
+ prediction = self.regression_model.forward(input_image)
+ if self.USE_MDN:
+ mdn_prediction = prediction
+ pi, mu, var = mdn_prediction
+ prediction = mdn_utils.sample_from_pdf(pi, mu, var)
+ prediction = prediction[:, 0, :]
+ prediction = prediction[0]
+ p0_position = np.hstack((prediction[0:2], 0.02))
+ p1_position = np.hstack((prediction[3:5], 0.02))
+ p0_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, -prediction[2]*self.THETA_SCALE))
+ p1_rotation = utils.get_pybullet_quaternion_from_rot((0, 0, -prediction[5]*self.THETA_SCALE))
+ act['primitive'] = 'pick_place'
+ if self.task == 'sweeping':
+ act['primitive'] = 'sweep'
+ elif self.task == 'pushing':
+ act['primitive'] = 'push'
+ params = {'pose0': (p0_position, p0_rotation),
+ 'pose1': (p1_position, p1_rotation)}
+ act['params'] = params
+ self.regression_model.set_batch_size(self.batch_size)
+ return act
+[INST] Make the cable beads coincide with the vertices of the top ring.
+This should lead to better physics and will make it easy for an
+algorithm to see the bag's top ring. Please see the cable-ring env
+for details, or `scratch/cable_ring_MWE.py`. Notable differences
+(or similarities) between this and `dan_cables.py`.
+
+(1) We don't need to discretize rotations and manually compute bead
+positions, because the previously created bag 'creates' it for us.
+
+(2) Beads have anchors with vertices, in addition to constraints with
+adjacent beads.
+
+(3) Still use `self.cable_bead_IDs` as we want that for the reward.. [/INST] def add_cable_ring(self, env, bag_id=None):
+ num_parts = len(self._top_ring_idxs)
+ radius = 0.005
+ color = U.COLORS['blue'] + [1]
+ beads = []
+ bead_positions_l = []
+ part_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=[radius]*3)
+ part_visual = p.createVisualShape(p.GEOM_SPHERE, radius=radius*1.5)
+ if bag_id is None:
+ bag_id = self.bag_id
+ _, verts_l = p.getMeshData(bag_id, -1, flags=p.MESH_DATA_SIMULATION_MESH)
+ for i in range(num_parts):
+ bag_vidx = self._top_ring_idxs[i]
+ bead_position = np.float32(verts_l[bag_vidx])
+ part_id = p.createMultiBody(0.01, part_shape, part_visual,
+ basePosition=bead_position)
+ p.changeVisualShape(part_id, -1, rgbaColor=color)
+ if i > 0:
+ parent_frame = bead_position - bead_positions_l[-1]
+ constraint_id = p.createConstraint(
+ parentBodyUniqueId=beads[-1],
+ parentLinkIndex=-1,
+ childBodyUniqueId=part_id,
+ childLinkIndex=-1,
+ jointType=p.JOINT_POINT2POINT,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=parent_frame,
+ childFramePosition=(0, 0, 0))
+ p.changeConstraint(constraint_id, maxForce=100)
+ if i == num_parts - 1:
+ parent_frame = bead_positions_l[0] - bead_position
+ constraint_id = p.createConstraint(
+ parentBodyUniqueId=part_id,
+ parentLinkIndex=-1,
+ childBodyUniqueId=beads[0],
+ childLinkIndex=-1,
+ jointType=p.JOINT_POINT2POINT,
+ jointAxis=(0, 0, 0),
+ parentFramePosition=parent_frame,
+ childFramePosition=(0, 0, 0))
+ p.changeConstraint(constraint_id, maxForce=100)
+ _ = p.createSoftBodyAnchor(
+ softBodyBodyUniqueId=bag_id,
+ nodeIndex=bag_vidx,
+ bodyUniqueId=part_id,
+ linkIndex=-1,)
+ beads.append(part_id)
+ bead_positions_l.append(bead_position)
+ self.cable_bead_IDs.append(part_id)
+ self._IDs[part_id] = f'cable_part_{str(part_id).zfill(2)}'
+ env.objects.append(part_id)
+ self.object_points[part_id] = np.float32((0, 0, 0)).reshape(3, 1)
+[INST] Adding a bag from an .obj file.. [/INST] def add_bag(self, env, base_pos, base_orn, bag_color='yellow'):
+ bag_id = p.loadSoftBody(
+ fileName=self._f_bag,
+ basePosition=base_pos,
+ baseOrientation=base_orn,
+ collisionMargin=self._collisionMargin,
+ scale=self._bag_scale,
+ mass=self._mass,
+ useNeoHookean=0,
+ useBendingSprings=1,
+ useMassSpring=1,
+ springElasticStiffness=40,
+ springDampingStiffness=0.1,
+ springDampingAllDirections=0,
+ useSelfCollision=1,
+ frictionCoeff=1.0,
+ useFaceContact=1,)
+ p_version = pkg_resources.get_distribution('pybullet').version
+ if p_version == '3.0.4':
+ color = U.COLORS[bag_color] + [1]
+ p.changeVisualShape(bag_id, -1, flags=p.VISUAL_SHAPE_DOUBLE_SIDED,
+ rgbaColor=color)
+ else:
+ raise ValueError(p_version)
+ self._IDs[bag_id] = 'bag'
+ self.object_points[bag_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ env.objects.append(bag_id)
+ self.def_IDs.append(bag_id)
+ return bag_id
+[INST] Sample the bag (and let it drop) to get interesting starting states.. [/INST] def _sample_bag_orientation(self):
+ orn = [self._base_orn[0] + np.random.normal(loc=0.0, scale=self._scale),
+ self._base_orn[1] + np.random.normal(loc=0.0, scale=self._scale),
+ self._base_orn[2] + np.random.normal(loc=0.0, scale=self._scale),]
+ return p.getQuaternionFromEuler(orn)
+[INST] By our circular bag design, there exists a top ring file.
+Reading it gives us several important pieces of information. We assign to.
+
+indices of the vertices (out of entire bag).
+_top_ring_posi: their starting xyz positions (BEFORE simulation
+or applying pose transformations). This way we can get the
+area of the circle. We can't take the rotated bag and map
+vertices to the xy plane, because any rotation will make the
+area artificially smaller.
+
+The .txt file saves in (x,y,z) order but the .obj files put z second.
+Make sure vertex indices are MONOTONICALLY INCREASING since I use
+that assumption to 'assign' vertex indices in order to targets.
+
+base_pos, the center of the bag's sphere.. [/INST] def understand_bag_top_ring(self, env, base_pos):
+ self._top_ring_f = (self._f_bag).replace('.obj', '_top_ring.txt')
+ self._top_ring_f = os.path.join('ravens', self._top_ring_f)
+ self._top_ring_idxs = []
+ self._top_ring_posi = []
+ with open(self._top_ring_f, 'r') as fh:
+ for line in fh:
+ ls = (line.rstrip()).split()
+ vidx = int(ls[0])
+ vx, vy, vz = float(ls[1]), float(ls[2]), float(ls[3])
+ if len(self._top_ring_idxs) >= 1:
+ assert vidx > self._top_ring_idxs[-1], \
+ f'Wrong: {vidx} vs {self._top_ring_idxs}'
+ self._top_ring_idxs.append(vidx)
+ self._top_ring_posi.append((vx,vy,vz))
+ self._target_positions = []
+ for item in self._top_ring_posi:
+ sx, sy, _ = item
+ sx = sx * self._bag_scale + base_pos[0]
+ sy = sy * self._bag_scale + base_pos[1]
+ self._target_positions.append( (sx,sy,0) )
+ if self._targets_visible:
+ square_pose = ((sx,sy,0.001), (0,0,0,1))
+ square_template = 'assets/square/square-template-allsides-green.urdf'
+ replace = {'DIM': (0.004,), 'HALF': (0.004 / 2,)}
+ urdf = self.fill_template(square_template, replace)
+ env.add_object(urdf, square_pose, fixed=True)
+ os.remove(urdf)
+ if self._name in ['bag-color-goal']:
+ return
+ xc, yc, rad, _ = U.fit_circle(self._top_ring_posi, self._bag_scale, debug=False)
+ self._circle_area = np.pi * (rad ** 2)
+ self._circle_center = (xc * self._bag_scale + base_pos[0],
+ yc * self._bag_scale + base_pos[1])
+[INST] A small force to perturb the starting bag.. [/INST] def _apply_small_force(self, num_iters, fx=10, fy=10, fz=8):
+ bead_idx = np.random.randint(len(self.cable_bead_IDs))
+ bead_id = self.cable_bead_IDs[bead_idx]
+ fx_1 = np.random.randint(low=-fx, high=fx + 1)
+ fy_1 = np.random.randint(low=-fy, high=fy + 1)
+ if len(self.cable_bead_IDs) > 32:
+ assert len(self.cable_bead_IDs) == 64, len(self.cable_bead_IDs)
+ if bead_idx < 32:
+ bead_idx_2 = np.random.randint(32, len(self.cable_bead_IDs))
+ else:
+ bead_idx_2 = np.random.randint(0, 32)
+ bead_id_2 = self.cable_bead_IDs[bead_idx_2]
+ fx_2 = np.random.randint(low=-fx, high=fx + 1)
+ fy_2 = np.random.randint(low=-fy, high=fy + 1)
+ for _ in range(num_iters):
+ p.applyExternalForce(bead_id, linkIndex=-1, forceObj=[fx_1, fy_1, fz],
+ posObj=[0,0,0], flags=p.LINK_FRAME)
+ if len(self.cable_bead_IDs) > 32:
+ p.applyExternalForce(bead_id_2, linkIndex=-1, forceObj=[fx_2, fy_2, fz],
+ posObj=[0,0,0], flags=p.LINK_FRAME)
+ if self._debug:
+ print(f'Perturbing {bead_id}: [{fx:0.2f}, {fy:0.2f}, {fz:0.2f}]')
+[INST] Get the task stage in a consistent manner among different policies.
+When training an oracle policy, we can determine the training stage,
+which is critical because of this task's particular quirks in
+requiring different action parameters (particularly height of the
+pull) for each stage. One option is to use this method to determine
+the hard-coded task stage for each task.
+
+For policies, I use simpler methods in their classes. They don't call this.
+
+False if the task is almost certainly going to fail.. [/INST] def determine_task_stage(self, colormap=None, heightmap=None,
+ object_mask=None, visible_beads=None):
+ if self.task_stage == 2 and (len(self.items_in_bag_IDs) == len(self.item_IDs)):
+ self.task_stage = 3
+ return (True, None)
+ elif self.task_stage == 3:
+ return (True, None)
+ BUF = 0.025
+ cable_IDs = np.array(self.cable_bead_IDs)
+ bead_mask = np.isin(object_mask, test_elements=cable_IDs)
+ bead_mask = np.uint8(bead_mask * 255)
+ contours, _ = cv2.findContours(bead_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ frac_visible = len(visible_beads) / len(self.cable_bead_IDs)
+ if len(contours) == 0 or frac_visible <= BEAD_THRESH:
+ return (False, None)
+ allc = np.concatenate([c for c in contours])
+ contours_list = [allc]
+ hull_list = [cv2.convexHull(c) for c in contours_list]
+ hull = np.zeros((bead_mask.shape[0], bead_mask.shape[1], 3), dtype=np.uint8)
+ cv2.drawContours(hull, hull_list, -1, (255,255,255), thickness=-1)
+ hull = cv2.cvtColor(hull, cv2.COLOR_BGR2GRAY)
+ object_size = self.item_sizes[0]
+ object_size = (object_size[0] + BUF, object_size[1] + BUF, object_size[2])
+ max_size = np.sqrt(object_size[0]**2 + object_size[1]**2)
+ erode_size = int(np.round(max_size / self.pixel_size))
+ place_pixels = np.uint8(hull == 255)
+ kernel = np.ones((erode_size, erode_size), np.uint8)
+ place_pixels_eroded = cv2.erode(place_pixels, kernel)
+ assert self.task_stage == 1
+ if np.sum(place_pixels_eroded) > 0:
+ self.task_stage = 2
+ return (True, place_pixels_eroded)
+[INST] Generate randomly shaped box, from aligning env.
+
+Also returning the object size, so we can use it for later.
+
+Use to control how long we can make boxes. I would
+keep this value at a level making these boxes comparable to the
+cubes, if not smaller, used in bag-items-easy.. [/INST] def add_random_box(self, env, max_total_dims):
+ min_val = 0.015
+ assert min_val*2 <= max_total_dims, min_val
+ rand_x = np.random.uniform(min_val, max_total_dims - min_val)
+ rand_y = max_total_dims - rand_x
+ box_size = (rand_x, rand_y, 0.03)
+ box_pose = self.random_pose(env, box_size)
+ box_template = 'assets/box/box-template.urdf'
+ box_urdf = self.fill_template(box_template, {'DIM': box_size})
+ box_id = env.add_object(box_urdf, box_pose)
+ os.remove(box_urdf)
+ self.color_random_brown(box_id)
+ self.object_points[box_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ self._IDs[box_id] = 'random_box'
+ return (box_id, box_size)
+[INST] Get the task stage in a consistent manner among different policies.
+When training an oracle policy, we can determine the training stage,
+which is critical because of this task's particular quirks in
+requiring different action parameters (particularly height of the
+pull) for each stage. One option is to use this method to determine
+the hard-coded task stage for each task.
+
+For policies, I use simpler methods in their classes. They don't call this.
+
+False if the task is almost certainly going to fail.. [/INST] def determine_task_stage(self, colormap=None, heightmap=None,
+ object_mask=None, visible_beads=None):
+ if self.task_stage == 2 and (len(self.items_in_bag_IDs) == len(self.item_IDs)):
+ self.task_stage = 3
+ return (True, None)
+ elif self.task_stage == 3:
+ return (True, None)
+ BUF = 0.025
+ if self.task_stage == 2:
+ BUF = 0.015
+ cable_IDs = np.array(self.cable_bead_IDs)
+ bead_mask = np.isin(object_mask, test_elements=cable_IDs)
+ bead_mask = np.uint8(bead_mask * 255)
+ contours, _ = cv2.findContours(bead_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ frac_visible = len(visible_beads) / len(self.cable_bead_IDs)
+ if len(contours) == 0 or frac_visible <= BEAD_THRESH:
+ return (False, None)
+ allc = np.concatenate([c for c in contours])
+ contours_list = [allc]
+ hull_list = [cv2.convexHull(c) for c in contours_list]
+ hull = np.zeros((bead_mask.shape[0], bead_mask.shape[1], 3), dtype=np.uint8)
+ cv2.drawContours(hull, hull_list, -1, (255,255,255), thickness=-1)
+ hull = cv2.cvtColor(hull, cv2.COLOR_BGR2GRAY)
+ object_size = self.item_sizes[0]
+ object_size = (object_size[0] + BUF, object_size[1] + BUF, object_size[2])
+ max_size = np.sqrt(object_size[0]**2 + object_size[1]**2)
+ erode_size = int(np.round(max_size / self.pixel_size))
+ if self.task_stage == 2 and len(self.items_in_bag_IDs) > 0:
+ pixels_bag_items = np.ones((hull.shape[0], hull.shape[1]), dtype=np.uint8)
+ for item_id in self.items_in_bag_IDs:
+ item_pix = np.uint8(item_id == object_mask)
+ pixels_bag_items = pixels_bag_items & item_pix
+ pixels_no_bag_items = np.uint8(1 - pixels_bag_items)
+ else:
+ pixels_no_bag_items = np.ones((hull.shape[0], hull.shape[1]), dtype=np.uint8)
+ place_pixels_hull = np.uint8(hull == 255)
+ place_pixels = place_pixels_hull & pixels_no_bag_items
+ kernel = np.ones((erode_size, erode_size), np.uint8)
+ place_pixels_eroded = cv2.erode(place_pixels, kernel)
+ if self.task_stage == 2 and np.sum(place_pixels_eroded) == 0:
+ place_pixels_eroded = cv2.erode(place_pixels_hull, kernel)
+ if self._debug:
+ heightmap = heightmap / np.max(heightmap) * 255
+ place_rgb = cv2.cvtColor(hull.copy(), cv2.COLOR_GRAY2BGR)
+ place_rgb[place_pixels_eroded > 0] = 127
+ print(f'max_size: {max_size:0.3f}, erode_size: {erode_size}')
+ print(f'number of pixels for placing: {np.sum(place_pixels)}')
+ print(f'number of pixels for placing (after eroding): {np.sum(place_pixels_eroded)}')
+ nb = len([x for x in os.listdir('tmp/') if 'color' in x and '.png' in x])
+ cv2.imwrite(f'tmp/img_{nb}_colormap.png', cv2.cvtColor(colormap, cv2.COLOR_RGB2BGR).astype(np.uint8))
+ cv2.imwrite(f'tmp/img_{nb}_heightmap.png', heightmap.astype(np.uint8))
+ cv2.imwrite(f'tmp/img_{nb}_bead_mask.png', bead_mask)
+ cv2.imwrite(f'tmp/img_{nb}_place_rgb.png', cv2.cvtColor(place_rgb, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(f'tmp/img_{nb}_place_pixels_eroded.png', np.uint8(place_pixels_eroded * 255))
+ if self.task_stage == 2 and len(self.items_in_bag_IDs) > 0:
+ pixels_no_bag_items *= 255
+ cv2.imwrite(f'tmp/img_{nb}_pixels_no_bag_items.png', np.uint8(pixels_no_bag_items))
+ if self.task_stage == 1:
+ if np.sum(place_pixels_eroded) > 0:
+ self.task_stage = 2
+ return (True, place_pixels_eroded)
+[INST] Generate randomly shaped block for the goal based task.
+Similar to the random block method used for bag-items-hard.
+env uses a box with mass 0.1, but we can make ours lighter.
+
+Also returning the object size, so we can use it for later.. [/INST] def add_block(self, env, block_name, max_total_dims):
+ box_size = (0.045, 0.045, 0.030)
+ box_pose = self.random_pose(env, box_size)
+ box_template = 'assets/box/box-template.urdf'
+ box_urdf = self.fill_template(box_template, {'DIM': box_size})
+ box_id = env.add_object(box_urdf, box_pose)
+ os.remove(box_urdf)
+ self.color_random_brown(box_id)
+ self.object_points[box_id] = np.float32((0, 0, 0)).reshape(3, 1)
+ self._IDs[box_id] = block_name
+ return (box_id, box_size)
+[INST] Used to determine reward. We want item to be inside the bag hull.
+Actually, first detect bead visibility. If we can't see any then we should
+just return 0 because that means we've messed up somewhere.
+
+Returns an info dict, which has 'reward' keyword.. [/INST] def is_item_in_bag(self):
+ result = {'exit_early': False,
+ 'frac_in_target_bag': 0.0,
+ 'frac_in_distract_bag': 0.0}
+ colormap, heightmap, object_mask = self.get_object_masks(self.env)
+ visible_beads = []
+ for bead in self.cable_bead_target_bag_IDs:
+ if bead in object_mask:
+ visible_beads.append(bead)
+ cable_IDs = np.array(self.cable_bead_target_bag_IDs)
+ bead_mask = np.isin(object_mask, test_elements=cable_IDs)
+ bead_mask = np.uint8(bead_mask * 255)
+ contours, _ = cv2.findContours(bead_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ frac_visible = len(visible_beads) / len(self.cable_bead_target_bag_IDs)
+ result['beads_visible'] = len(visible_beads)
+ if len(contours) == 0 or frac_visible <= BEAD_THRESH:
+ result['exit_early'] = True
+ return result
+ allc = np.concatenate([c for c in contours])
+ contours_list = [allc]
+ hull_list = [cv2.convexHull(c) for c in contours_list]
+ hull = np.zeros((bead_mask.shape[0], bead_mask.shape[1], 3), dtype=np.uint8)
+ cv2.drawContours(hull, hull_list, -1, (255,255,255), thickness=-1)
+ hull = cv2.cvtColor(hull, cv2.COLOR_BGR2GRAY)
+ pixels_target_item = np.uint8(object_mask == self.single_block_ID)
+ target_and_hull = pixels_target_item & hull
+ target_and_hull_count = np.count_nonzero(target_and_hull)
+ target_pixel_count = np.count_nonzero(pixels_target_item)
+ frac_in_target_bag = target_and_hull_count / (float(target_pixel_count) + 0.01)
+ visible_beads = []
+ for bead in self.cable_bead_IDs[self.beads_per_bag:]:
+ if bead in object_mask:
+ visible_beads.append(bead)
+ cable_IDs = np.array( self.cable_bead_IDs[self.beads_per_bag:] )
+ bead_mask = np.isin(object_mask, test_elements=cable_IDs)
+ bead_mask = np.uint8(bead_mask * 255)
+ contours, _ = cv2.findContours(bead_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ frac_visible = len(visible_beads) / self.beads_per_bag
+ if len(contours) == 0 or frac_visible <= BEAD_THRESH:
+ frac_in_distract_bag = 0.0
+ else:
+ allc = np.concatenate([c for c in contours])
+ contours_list = [allc]
+ hull_list = [cv2.convexHull(c) for c in contours_list]
+ hull = np.zeros((bead_mask.shape[0], bead_mask.shape[1], 3), dtype=np.uint8)
+ cv2.drawContours(hull, hull_list, -1, (255,255,255), thickness=-1)
+ hull = cv2.cvtColor(hull, cv2.COLOR_BGR2GRAY)
+ target_and_hull = pixels_target_item & hull
+ target_and_hull_count = np.count_nonzero(target_and_hull)
+ target_pixel_count = np.count_nonzero(pixels_target_item)
+ frac_in_distract_bag = target_and_hull_count / (float(target_pixel_count) + 0.01)
+ result['frac_in_target_bag'] = frac_in_target_bag
+ result['frac_in_distract_bag'] = frac_in_distract_bag
+ return result
+[INST] Get the task stage in a consistent manner among different policies.
+This should be easier as compared to the bag-items since we only handle
+two task stages (and we don't have drastically different action params).. [/INST] def determine_task_stage(self, colormap=None, heightmap=None,
+ object_mask=None, visible_beads=None):
+ if self.task_stage == 2:
+ print('ON TASK STAGE 2, still here, should not normally happen...')
+ return (True, None)
+ BUF = 0.025
+ cable_IDs = np.array(self.cable_bead_target_bag_IDs)
+ bead_mask = np.isin(object_mask, test_elements=cable_IDs)
+ bead_mask = np.uint8(bead_mask * 255)
+ contours, _ = cv2.findContours(bead_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ frac_visible = len(visible_beads) / len(self.cable_bead_target_bag_IDs)
+ if len(contours) == 0 or frac_visible <= BEAD_THRESH:
+ return (False, None)
+ allc = np.concatenate([c for c in contours])
+ contours_list = [allc]
+ hull_list = [cv2.convexHull(c) for c in contours_list]
+ hull = np.zeros((bead_mask.shape[0], bead_mask.shape[1], 3), dtype=np.uint8)
+ cv2.drawContours(hull, hull_list, -1, (255,255,255), thickness=-1)
+ hull = cv2.cvtColor(hull, cv2.COLOR_BGR2GRAY)
+ object_size = self.single_block_size
+ object_size = (object_size[0] + BUF, object_size[1] + BUF, object_size[2])
+ max_size = np.sqrt(object_size[0]**2 + object_size[1]**2)
+ erode_size = int(np.round(max_size / self.pixel_size))
+ place_pixels = np.uint8(hull == 255)
+ kernel = np.ones((erode_size, erode_size), np.uint8)
+ place_pixels_eroded = cv2.erode(place_pixels, kernel)
+ assert self.task_stage == 1
+ if np.sum(place_pixels_eroded) > 0:
+ self.task_stage = 2
+ return (True, place_pixels_eroded)
+[INST] Record account status into file.. [/INST] def record_into_file(self, btc_price):
+ btc_value = btc_price * self.__current_amount
+ total_value = btc_value + self.__current_cash
+ record = [datetime.now().strftime("%Y-%m-%d,%H:%M:%S"),\
+ str(total_value), str(self.__current_cash), str(btc_value),\
+ str(self.__current_amount)]
+ with open(TRADE_HISTORY_FILE, "a") as trade_histoty:
+ trade_histoty.write(" ".join(record) + "\n")
+ print("\nSuccessfully record Total Assert Value ${:,.2f}, "
+ "Available Cash ${:,.2f}, BitCoin Value ${:,.2f} "
+ "and Bitcoin Amount {:,.2f}."
+ .format(total_value, self.__current_cash, \
+ btc_value, self.__current_amount))
+[INST] Prompt user to select actions in the submenu.. [/INST] def simulated_trading_main(self):
+ simulated_trading = SimulatedTrading()
+ file_exist = simulated_trading.check_file_exist()
+ if not file_exist:
+ simulated_trading.setup_account()
+ while True:
+ select_str = input("Here is the Simulated Trading Menu: \n"
+ "1 - Start Simulated Trading \n"
+ "2 - View Trading History\n"
+ "3 - Exit Simulate Trading \n"
+ "Please enter number 1-3 to select actions: ")
+ try:
+ select_int = int(select_str)
+ except ValueError:
+ print("Error: You must enter an 1-digit integer.")
+ continue
+ if not 1 <= select_int <= 3:
+ print("Error: You must enter the number between 0 and 3(inclusive).")
+ continue
+ if select_int == 1:
+ print("\nYou select: 1 - Start Simulated Trading \n")
+ simulated_trading.simulated_trading()
+ elif select_int == 2:
+ print("\nYou select: 2 - View Trading History \n")
+ simulated_trading.view_historical_trading()
+ else:
+ print("\nYou select: 3 - Exit Simulated Trading")
+ break
+[INST] Return text description of the historical price.. [/INST] def view_historical_price(self, interval):
+ self.interval = interval
+ if self.interval == "one":
+ self.df_one = self.get_historical_price("one")
+ df_info = self.df_one.describe().round(2)
+ return ("\nPrice History on {}: \n"
+ "Open Price: ${} \n"
+ "High Price: ${} \n"
+ "Low Price: ${} \n"
+ "Close Price: ${} \n"
+ "Volumn: {:,.2f}"\
+ .format(self.view_date, df_info.iat[1,0], df_info.iat[1,1], \
+ df_info.iat[1,2], df_info.iat[1,3], df_info.iat[1,5]))
+ elif self.interval == "period":
+ self.df_period = self.get_historical_price("period")
+ df_info = self.df_period.describe().round(2)
+ return ("\nPrice History Summary from {} to {}: \n"
+ " Max Min Mean\n"
+ "Open Price: ${:<16.2f} ${:<16.2f} ${:<16.2f} \n"
+ "High Price: ${:<16.2f} ${:<16.2f} ${:<16.2f} \n"
+ "Low Price: ${:<16.2f} ${:<16.2f} ${:<16.2f} \n"
+ "Close Price: ${:<16.2f} ${:<16.2f} ${:<16.2f} \n"
+ "Volumn: {:<16,.2f} {:<16,.2f} {:<16,.2f}"\
+ .format(self.view_period[0], self.view_period[1], \
+ df_info.iat[7,0], df_info.iat[3,0], df_info.iat[1,0],\
+ df_info.iat[7,1], df_info.iat[3,1], df_info.iat[1,1],\
+ df_info.iat[7,2], df_info.iat[3,2], df_info.iat[1,2],\
+ df_info.iat[7,3], df_info.iat[3,3], df_info.iat[1,3],\
+ df_info.iat[7,5], df_info.iat[3,5], df_info.iat[1,5]))
+[INST] Get candlestick chart of historical price.. [/INST] def price_chart(self, interval):
+ self.interval = interval
+ if self.interval == "one":
+ view_history = self.view_historical_price("one")
+ mpf.plot(self.df_one, type="candle", style="yahoo",\
+ title=" Bitcoin Candlestick Chart", ylabel="Price ($)",\
+ ylabel_lower="Volume", volume=True)
+ print(view_history)
+ print("The Bitcoin candlestick chat shows in the Plots window.")
+ elif self.interval == "period":
+ view_history = self.view_historical_price("period")
+ mpf.plot(self.df_period, type="candle", style="yahoo",mav=(3,6,9),\
+ title=" Bitcoin Candlestick Chart", ylabel="Price ($)",\
+ ylabel_lower="Volume", volume=True, show_nontrading=True)
+ print(view_history)
+ print("The candlesticks chat shows in the Plots window.")
+[INST] Prompt user to select actions in the submenu.. [/INST] def price_history_main(self):
+ while True:
+ select_str = input("Here is the Historical Price Menu: \n"
+ "1 - View Historical Price For A Day \n"
+ "2 - View Historical Price For A Period \n"
+ "3 - Exit View Historical Price \n"
+ "Please enter number 1-3 to select actions: ")
+ try:
+ select_int = int(select_str)
+ except ValueError:
+ print("Error: You must enter an 1-digit integer.")
+ continue
+ if not 1 <= select_int <= 3:
+ print("Error: You must enter the number between 0 and "
+ "3(inclusive).")
+ continue
+ if select_int == 1:
+ print("\nYou select: 1 - View Historical Price For A Day")
+ interval = "one"
+ one_date_history = PriceHistory()
+ one_date_history.price_chart(interval)
+ elif select_int == 2:
+ print("\nYou select:2 - View Historical Price For A Period")
+ interval = "period"
+ time_period_history = PriceHistory()
+ time_period_history.price_chart(interval)
+ else:
+ print("\nYou select: 3 - Exit View Historical Price")
+ break
+[INST] Returns a heximal formated SHA1 hash of the input string.. [/INST] def hashhex(s):
+ s = s.encode('utf-8')
+ h = hashlib.sha1()
+ h.update(s)
+ return h.hexdigest()
+[INST] Reader to parse the csv file. [/INST] def read_csv_data(csv_file):
+ data = []
+ with open(args.input_file_path, encoding='utf-8', errors='ignore') as f:
+ csv_reader = csv.reader(f, delimiter=',')
+ for ind, row in enumerate(csv_reader):
+ if ind == 0:
+ header = row
+ else:
+ data.append(row)
+ return header, data
+[INST] reads any generic text file into
+list containing one line as element. [/INST] def read_generic_file(filepath):
+ text = []
+ with open(filepath, 'r') as f:
+ for line in f.read().splitlines():
+ text.append(line.strip())
+ return text
+[INST] calculate Rouge-1 precision, Bert precision
+and Entailment Scores. [/INST] def calculate_metric_scores(cands, refs):
+ rouge = Rouge()
+ rouge1_p = []
+ for r, c in tqdm(zip(refs, cands)):
+ r = " ".join(list(nlp_parser.tokenize(r))).lower()
+ c = " ".join(list(nlp_parser.tokenize(c))).lower()
+ scores = rouge.get_scores(c, r)[0]
+ rouge1_p.append(round(scores['rouge-1']['p'], 4))
+ P, _, _ = score(cands, refs, lang='en', verbose=True)
+ P = [round(x, 4) for x in P.tolist()]
+ url = 'http://localhost:5003/roberta_mnli_classifier'
+ mnli_data = []
+ for p, h in zip(refs, cands):
+ mnli_data.append({'premise': p, 'hypo': h})
+ r = requests.post(url, json=mnli_data)
+ results = r.json()
+ ent_scores = []
+ for ind, d in enumerate(results):
+ ent_scores.append(float(d['entailment']))
+ return rouge1_p, P, ent_scores
+[INST] Given a scu sentence retrieve SCUs. [/INST] def generate_scu(sentence, max_scus=5):
+ srl = predictor.predict(sentence=sentence['scuSentence'])
+ scus = srl['verbs']
+ scu_list = []
+ tokens = srl['words']
+ for scu in scus:
+ tags = scu['tags']
+ words = []
+ if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags):
+ continue
+ scu_start_offset = None
+ for ind, tag in enumerate(tags):
+ if "O" not in tag:
+ if scu_start_offset is None:
+ if ind == 0:
+ scu_start_offset = sentence['scuSentCharIdx'] + ind
+ else:
+ scu_start_offset = sentence['scuSentCharIdx'] + len(" ".join(tokens[:ind]))
+ else:
+ scu_end_offset = sentence['scuSentCharIdx'] + len(" ".join(tokens[:ind + 1]))
+ words.append(tokens[ind])
+ if len(words) <= 4:
+ continue
+ tmp = copy.deepcopy(sentence)
+ tmp['scuText'] = " ".join(words)
+ tmp['scuOffsets'] = (scu_start_offset, scu_end_offset)
+ scu_list.append(tmp)
+ scu_list = sorted(scu_list, key=lambda x: len(x['scuText'].split()), reverse=True)
+ return scu_list[:max_scus]
+[INST] Given a scu sentence retrieve SCUs. [/INST] def generate_scu_oie(sentence, max_scus=5, doc_summ='summ'):
+ if doc_summ=='summ':
+ KEY_sent = 'scuSentence'
+ KEY_sent_char_idx = 'scuSentCharIdx'
+ KEY_scu_text = 'scuText'
+ KEY_scu_offset = 'scuOffsets'
+ else:
+ KEY_sent = 'docSentText'
+ KEY_sent_char_idx = 'docSentCharIdx'
+ KEY_scu_text = 'docScuText'
+ KEY_scu_offset = 'docScuOffsets'
+ _, oie = run_oie([sentence[KEY_sent]])
+ if not oie:
+ return oie
+ else:
+ oie = oie[0]
+ scus = oie['verbs']
+ scu_list = []
+ tokens = oie['words']
+ for scu in scus:
+ tags = scu['tags']
+ words = []
+ if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags):
+ continue
+ scu_start_offset = None
+ for ind, tag in enumerate(tags):
+ if "O" not in tag:
+ if scu_start_offset is None:
+ if ind == 0:
+ scu_start_offset = sentence[KEY_sent_char_idx] + ind
+ else:
+ scu_start_offset = sentence[KEY_sent_char_idx] + len(" ".join(tokens[:ind]))
+ else:
+ scu_end_offset = sentence[KEY_sent_char_idx] + len(" ".join(tokens[:ind + 1]))
+ words.append(tokens[ind])
+ tmp = copy.deepcopy(sentence)
+ tmp[KEY_scu_text] = " ".join(words)
+ tmp[KEY_scu_offset] = (scu_start_offset, scu_end_offset)
+ scu_list.append(tmp)
+ scu_list = sorted(scu_list, key=lambda x: len(x[KEY_scu_text].split()), reverse=True)
+ return scu_list[:max_scus]
+[INST] Given a scu sentence retrieve SCUs. [/INST] def generate_scu_oie_multiSent(sentences, doc_summ='summ'):
+ if doc_summ=='summ':
+ KEY_sent = 'scuSentence'
+ KEY_sent_char_idx = 'scuSentCharIdx'
+ KEY_scu_text = 'scuText'
+ KEY_scu_offset = 'scuOffsets'
+ else:
+ KEY_sent = 'docSentText'
+ KEY_sent_char_idx = 'docSentCharIdx'
+ KEY_scu_text = 'docScuText'
+ KEY_scu_offset = 'docScuOffsets'
+ _, oies = run_oie([sentence[KEY_sent] for sentence in sentences], cuda_device = 0)
+ scu_list = []
+ assert(len(sentences) == len(oies))
+ for sentence ,oie in zip(sentences,oies):
+ sentence[KEY_sent] = sentence[KEY_sent].replace(u'\u00a0', ' ')
+ if not oie:
+ continue
+ scus = oie['verbs']
+ in_sentence_scu_dict = {}
+ tokens = oie['words']
+ for scu in scus:
+ tags = scu['tags']
+ words = []
+ if not ("B-ARG1" in tags or "B-ARG2" in tags or "B-ARG0" in tags):
+ continue
+ sub_scu_offsets = []
+ scu_start_offset = None
+ offset = 0
+ initialSpace = 0
+ while sentence[KEY_sent][offset + initialSpace] == ' ':
+ initialSpace += 1
+ offset += initialSpace
+ for ind, tag in enumerate(tags):
+ assert (sentence[KEY_sent][offset] == tokens[ind][0])
+ if "O" not in tag:
+ if scu_start_offset is None:
+ scu_start_offset = sentence[KEY_sent_char_idx] + offset
+ assert(sentence[KEY_sent][offset] == tokens[ind][0])
+ words.append(tokens[ind])
+ else:
+ if scu_start_offset is not None:
+ spaceBeforeToken = 0
+ while sentence[KEY_sent][offset-1-spaceBeforeToken] == ' ':
+ spaceBeforeToken += 1
+ if sentence[KEY_sent][offset] == '.' or sentence[KEY_sent][offset] == '?':
+ dotAfter = 1 + spaceAfterToken
+ dotTest = 1
+ else:
+ dotAfter = 0
+ dotTest = 0
+ scu_end_offset = sentence[KEY_sent_char_idx] + offset - spaceBeforeToken + dotAfter
+ if dotTest:
+ assert (sentence[KEY_sent][offset - spaceBeforeToken + dotAfter -1] == tokens[ind-1+ dotTest][0])
+ else:
+ assert (sentence[KEY_sent][offset - spaceBeforeToken + dotAfter - 1] == tokens[ind - 1 + dotTest][-1])
+ sub_scu_offsets.append([scu_start_offset, scu_end_offset])
+ scu_start_offset = None
+ offset += len(tokens[ind])
+ if ind < len(tags) - 1:
+ spaceAfterToken = 0
+ while sentence[KEY_sent][offset + spaceAfterToken] == ' ':
+ spaceAfterToken += 1
+ offset += spaceAfterToken
+ if scu_start_offset is not None:
+ scu_end_offset = sentence[KEY_sent_char_idx] + offset
+ sub_scu_offsets.append([scu_start_offset, scu_end_offset])
+ scu_start_offset = None
+ scuText = "...".join([sentence[KEY_sent][strt_end_indx[0] - sentence[KEY_sent_char_idx]:strt_end_indx[1] - sentence[KEY_sent_char_idx]] for strt_end_indx in sub_scu_offsets])
+ in_sentence_scu_dict[scuText] = sub_scu_offsets
+ notContainedDict = checkContained(in_sentence_scu_dict, sentence[KEY_sent], sentence[KEY_sent_char_idx])
+ for scuText, binaryNotContained in notContainedDict.items():
+ scu_offsets = in_sentence_scu_dict[scuText]
+ if binaryNotContained:
+ tmp = copy.deepcopy(sentence)
+ tmp[KEY_scu_text] = scuText
+ tmp[KEY_scu_offset] = scu_offsets
+ scu_list.append(tmp)
+ return scu_list
+[INST] wrapper which calls the monolingual
+word aligner and gives the alignment scores between
+sent1 and sent2. [/INST] def word_aligner(sent1, sent2):
+ sent1_tok = " ".join(list(nlp_parser.tokenize(sent1)))
+ sent2_tok = " ".join(list(nlp_parser.tokenize(sent2)))
+ process = subprocess.Popen(['python2', 'predict_align.py', '--s1', sent1_tok,
+ '--s2', sent2_tok], stdout=subprocess.PIPE,
+ cwd='/ssd-playpen/home/ram/monolingual-word-aligner')
+ output, error = process.communicate()
+ output = output.decode('utf-8')
+ output = output.split('\n')
+ return ast.literal_eval(output[0]), ast.literal_eval(output[1]), sent1_tok, sent2_tok
+[INST] Statically parse the version number from a python file. [/INST] def parse_version(fpath='wbia_lca/__init__.py'):
+ import ast
+ if not exists(fpath):
+ raise ValueError('fpath={!r} does not exist'.format(fpath))
+ with open(fpath, 'r') as file_:
+ sourcecode = file_.read()
+ pt = ast.parse(sourcecode)
+ class VersionVisitor(ast.NodeVisitor):
+ def visit_Assign(self, node):
+ for target in node.targets:
+ if getattr(target, 'id', None) == '__version__':
+ self.version = node.value.s
+ visitor = VersionVisitor()
+ visitor.visit(pt)
+ return visitor.version
+[INST] Function to send out email based on output of mem_daemon. [/INST] def sendMail(receiver, sender, password, text, smtp, port, threshold, col_log,
+ selfie):
+ subject_1 = "[Process Termination Notification] "
+ subject_0 = "[RAM Threshold Notification] "
+ pre_text_1 = ('Dear User,\n\n`{}` was terminated on {} as the total used '
+ 'RAM threshold of {}% was exceeded.\n\nBelow is a footprint '
+ 'of the process before it was terminated:\n\n')
+ pre_text_0 = ('Dear User,\n\nThe total used RAM threshold of {}% '
+ 'was exceeded by a process from another user `{}`'
+ ' on {}.\n\nBelow is a footprint of the process,'
+ ' for your information:\n\n')
+ end = "\n\nYour Memory Daemon"
+ command = text.split()[(col_log * 2) - 1]
+ user = text.split()[(col_log + 2) - 1]
+ if selfie == 1:
+ subject = subject_1 + command
+ pre_text = pre_text_1.format(command, getCurrentTime(), threshold)
+ elif selfie == 0:
+ subject = subject_0 + command
+ pre_text = pre_text_0.format(threshold, user, getCurrentTime())
+ text = (pre_text + "```\n" + re.sub(
+ "", "\n", re.sub(r"\s+", " ", re.sub(r"\n", "", text))) +
+ "\n```" + end)
+ msg = MIMEText(text)
+ msg['Subject'] = subject
+ msg['From'] = sender
+ msg['To'] = receiver
+ msg["Date"] = formatdate(localtime=True)
+ s = smtplib.SMTP(smtp, port)
+ s.starttls()
+ s.login(sender, password)
+ receiver = re.split(r"\s*,\s*", receiver)
+ for rec in receiver:
+ s.sendmail(sender, rec, msg.as_string())
+ s.quit()
+[INST] Asynchronously start a new Twitter stream. [/INST] def start_stream(queries):
+ stream = Stream(auth, RawTweetsListener())
+ stream.filter(track=queries, async=True)
+ return stream
+[INST] Return the next batch of data from the underlying stream.
+Accepted return values are:
+1. (x,y): tuple of list/numpy arrays representing the x and y axis
+2. pandas dataframe
+3. y: list/numpy array representing the y axis. In this case, the x axis is automatically created
+4. pandas serie: similar to #3
+5. json
+6. geojson
+7. url with supported payload (json/geojson). [/INST] def doGetNextData(self):
+ Pass
+[INST] Start a Spark Streaming DataFrame from a file source. [/INST] def start_streaming_dataframe(output_dir):
+ schema = StructType(
+ [StructField(f["name"], f["type"], True) for f in field_metadata]
+ )
+ return spark.readStream \
+ .csv(
+ output_dir,
+ schema=schema,
+ multiLine = True,
+ timestampFormat = 'EEE MMM dd kk:mm:ss Z yyyy',
+ ignoreTrailingWhiteSpace = True,
+ ignoreLeadingWhiteSpace = True
+ )
+[INST] Start a Spark Streaming DataFrame from a Kafka Input source. [/INST] def start_streaming_dataframe():
+ schema = StructType(
+ [StructField(f["name"], f["type"], True) for f in field_metadata]
+ )
+ kafka_options = {
+ "kafka.ssl.protocol":"TLSv1.2",
+ "kafka.ssl.enabled.protocols":"TLSv1.2",
+ "kafka.ssl.endpoint.identification.algorithm":"HTTPS",
+ 'kafka.sasl.mechanism': 'PLAIN',
+ 'kafka.security.protocol': 'SASL_SSL'
+ }
+ return spark.readStream \
+ .format("kafka") \
+ .option("kafka.bootstrap.servers", ",".join(message_hub_creds["kafka_brokers_sasl"])) \
+ .option("subscribe", "enriched_tweets") \
+ .load(**kafka_options)
+[INST] Create and run a streaming query from a Structured DataFrame
+outputing the results into a parquet database. [/INST] def start_parquet_streaming_query(csv_sdf):
+ streaming_query = csv_sdf \
+ .writeStream \
+ .format("parquet") \
+ .option("path", os.path.join(root_dir, "output_parquet")) \
+ .trigger(processingTime="2 seconds") \
+ .option("checkpointLocation", os.path.join(root_dir, "output_chkpt")) \
+ .start()
+ return streaming_query
+[INST] Check the buffer capacity and write to a new file if needed. [/INST] def flush_buffer_if_needed(self):
+ length = len(self.buffered_data)
+ if length > 0 and length % 10 == 0:
+ with open(os.path.join( output_dir, "tweets{}.csv".format(self.counter)), "w") as fs:
+ self.counter += 1
+ csv_writer = csv.DictWriter( fs, fieldnames = fieldnames)
+ for data in self.buffered_data:
+ csv_writer.writerow(data)
+ self.buffered_data = []
+[INST] Asynchronously start a new Twitter stream. [/INST] def start_stream(queries):
+ stream = Stream(auth, RawTweetsListener())
+ stream.filter(track=queries, languages=["en"], async=True)
+ return stream
+[INST] convert and outputs csv file from result query. [/INST] def as_csv(self, fileoutputlocation:str, sep:str=","):
+ f = open(fileoutputlocation, 'w', newline="")
+ writer = csv.DictWriter(f, self._data[0].keys(),delimiter=sep)
+ for row in self._data:
+ writer.writerow(row)
+ f.close()
+[INST] function that queries data with the given sparql. [/INST] def query(self, sparql:str) -> QueryResult:
+ pass
+[INST] Builds the named sparql query by applying the provided params. [/INST] def build_sparql_query(self, name: str, **variables):
+ pass
+[INST] Return the set of all the variable names applicable to the named query. [/INST] def variables_in_query(self, name:str):
+ pass
+[INST] The main entry point to this module.. [/INST] def main(sysargs = None):
+ print('sysargs=', sysargs)
+ args = get_arg_parser().parse_args(sysargs) if sysargs is not None and len(sysargs) > 0 else get_arg_parser().parse_args()
+ enable_logging(args)
+ log.info("The args passed to %s are: %s." % (sys.argv[0], args))
+ log.debug("Performing service")
+ performe_service(args)
+ params = {}
+ template_service = J2SparqlBuilder(os.path.join(os.getcwd(),args.template_folder))
+ vars_template = template_service.variables_in_query(args.template_name)
+ if args.variables is not None and len(vars_template) > 0:
+ params = args_values_to_params(args.variables)
+ variables_check(variables_template=vars_template, variables_given=params)
+ querry = template_service.build_sparql_query(name=args.template_name, variables= params)
+ print(f'Making KGSource')
+ source = makesource(args)
+ print(f'performing query')
+ log.debug("making exec service")
+ executive_service = KG2TblService(source)
+ log.debug("performing service query")
+ executive_service.exec(querry,os.path.join(os.getcwd(),args.output_location),getdelimiter(args))
+ log.info("done with query")
+ print(f'new file saved on location : {os.path.join(os.getcwd(),args.output_location)}')
+[INST] The set of variables to make this template work. [/INST] def variables_in_query(self, name:str) -> set:
+ template_name = name
+ templates_env = self._templates_env
+ log.debug(f"name template: {template_name}")
+ template_source = templates_env.loader.get_source(templates_env, template_name)
+ log.debug(f"template source = {template_source}")
+ ast = self._templates_env.parse(template_source)
+ return meta.find_undeclared_variables(ast)
+[INST] Returns metrics as a HTTP response in Prometheus text format.. [/INST] def expose_metrics(self, request: Request) -> Response:
+ if "name" not in request.args:
+ logger.debug(
+ "Registry name(s) not found in query string, using global registry"
+ )
+ registry = REGISTRY
+ else:
+ names = request.args.getlist("name")
+ registry = REGISTRY.restricted_registry(names)
+ encoder, content_type = choose_encoder(request.headers["Accept"])
+ try:
+ output = encoder(registry)
+ return Response(output, status=200, content_type=content_type)
+ except Exception:
+ message = "Failed to generate metrics"
+ logger.exception(message)
+ return Response(message, status=500)
+[INST] Called before service worker starts.. [/INST] def worker_setup(self, worker_ctx: WorkerContext) -> None:
+ self.worker_starts[worker_ctx] = time.perf_counter()
+[INST] Called after service worker completes.
+At this point the default metrics such as worker latency are observed,
+regardless of whether the worker finished successfully or raised an
+exception.. [/INST] def worker_result(
+ self, worker_ctx: WorkerContext, result=None, exc_info=None
+ ) -> None:
+ try:
+ start = self.worker_starts.pop(worker_ctx)
+ except KeyError:
+ logger.warning("No worker_ctx in request start dictionary")
+ return
+ worker_summary = WorkerSummary(
+ duration=time.perf_counter() - start,
+ result=result,
+ exc_info=exc_info,
+ )
+ self.observe_entrypoint(worker_ctx.entrypoint, worker_summary)
+ self.observe_state_metrics()
+[INST] Decode an audio file and load it into this `bl_song` object. Do not run
+any analysis on it.
+
+filename is the path to the file to load.. [/INST] def decode(self, filename):
+ filename_char = ffi.new("char[]", filename.encode("utf-8"))
+ lib.bl_audio_decode(filename_char, self._c_struct)
+[INST] Load and analyze an audio file, putting it in a `bl_song` object.
+Params:
+filename is the path to the file to load and analyze.. [/INST] def analyze(self, filename):
+ filename_char = ffi.new("char[]", filename.encode("utf-8"))
+ lib.bl_analyze(filename_char, self._c_struct)
+[INST] Run an envelope analysis on a previously loaded file.
+Returns a {tempo, attack} dict, which is a direct mapping of
+`struct envelope_result_s`. Also updates the object data members.. [/INST] def envelope_analysis(self):
+ result = ffi.new("struct envelope_result_s *")
+ lib.bl_envelope_sort(self._c_struct, result)
+ return {
+ "tempo": (result.tempo),
+ "attack": result.attack
+ }
+[INST] Run an amplitude analysis on a previously loaded file.
+Returns a the score obtained. Also updates the object data members.. [/INST] def amplitude_analysis(self):
+ lib.bl_amplitude_sort(self._c_struct)
+[INST] Run a frequency analysis on a previously loaded file.
+Returns a the score obtained. Also updates the object data members.. [/INST] def frequency_analysis(self):
+ lib.bl_frequency_sort(self._c_struct)
+[INST] Free dynamically allocated data in the underlying C struct (artist,
+genre, etc). Must be called at deletion to prevent memory leaks.. [/INST] def free(self):
+ for k in list(self._keepalive):
+ del(self._keepalive[k])
+ self.set(k, ffi.NULL)
+ lib.bl_free_song(self._c_struct)
+[INST] Wrapper around `bl_version` function which returns the current version.. [/INST] def version():
+ return lib.bl_version()
+[INST] Compare two version strings and return -1, 0 or 1 depending on the equality
+of the subset of matching version numbers.. [/INST] def _compare_versions(v1, v2):
+ def normalize(v):
+ return [int(x) for x in re.sub(r'(\.0+)*$', '', v).split(".")]
+ n1 = normalize(v1)
+ n2 = normalize(v2)
+ return (n1 > n2) - (n1 < n2)
+[INST] Check if the package meets the required version.. [/INST] def installed(package, version):
+ if not exists(package):
+ return False
+ number, comparator = _split_version_specifier(version)
+ modversion = _query(package, '--modversion')
+ try:
+ result = _compare_versions(modversion, number)
+ except ValueError:
+ msg = "{0} is not a correct version specifier".format(version)
+ raise ValueError(msg)
+ if comparator in ('', '=', '=='):
+ return result == 0
+ if comparator == '>':
+ return result > 0
+ if comparator == '>=':
+ return result >= 0
+ if comparator == '<':
+ return result < 0
+ if comparator == '<=':
+ return result <= 0
+[INST] Parse the output from pkg-config about the passed package or packages.. [/INST] def parse(packages):
+ def parse_package(package):
+ result = collections.defaultdict(set)
+ out = _query(package, '--cflags --libs')
+ out = out.replace('\\"', '')
+ for token in out.split():
+ key = _PARSE_MAP.get(token[:2])
+ if key:
+ result[key].add(token[2:].strip())
+ macros = set()
+ for declaration in result['define_macros']:
+ macro = tuple(declaration.split('='))
+ if len(macro) == 1:
+ macro += '',
+ macros.add(macro)
+ result['define_macros'] = macros
+ return result
+ result = collections.defaultdict(set)
+ for package in packages.split():
+ for k, v in parse_package(package).items():
+ result[k].update(v)
+ return result
+[INST] Splits the dataframe into the train set and 3 hold-out sets.
+Drops columns to drop and the target variable from the DataFrame df. Then the rows are reshuffled and split into 4 groups: train set and 3 hold-out sets.. [/INST] def data_split(df, y_col, to_drop=[], random_state=None, hold1_size=.1, hold2_size=.1, hold3_size=.1):
+ df_filtered = df.drop(columns=to_drop)
+ rows = list(df_filtered.index)
+ if random_state is not None:
+ random.seed(random_state)
+ random.shuffle(rows)
+ length = len(rows)
+ train_rows = rows[:int(length * (1 - (hold1_size + hold2_size + hold3_size)))]
+ hold1_rows = rows[int(length * (1 - (hold1_size + hold2_size + hold3_size))):int(
+ length * (1 - (hold2_size + hold3_size)))]
+ hold2_rows = rows[int(length * (1 - (hold1_size + hold3_size))):int(length * (1 - hold3_size))]
+ hold3_rows = rows[int(length * (1 - hold3_size)):]
+ X_train = df_filtered.drop(columns=[y_col]).iloc[train_rows].values
+ y_train = df_filtered.loc[train_rows, y_col].values
+ X_hold1 = df_filtered.drop(columns=[y_col]).iloc[hold1_rows].values
+ y_hold1 = df_filtered.loc[hold1_rows, y_col].values
+ X_hold2 = df_filtered.drop(columns=[y_col]).iloc[hold2_rows].values
+ y_hold2 = df_filtered.loc[hold2_rows, y_col].values
+ X_hold3 = df_filtered.drop(columns=[y_col]).iloc[hold3_rows].values
+ y_hold3 = df_filtered.loc[hold3_rows, y_col].values
+ cols = df_filtered.drop(columns=[y_col]).columns
+ return X_train, y_train, X_hold1, y_hold1, X_hold2, y_hold2, X_hold3, y_hold3, cols
+[INST] Preprocesses training and validation sets ready for the neural network training.
+For each categorical column, the function remaps the values to the integer values and adds a one-dimensional
+numpy array of with mapped values to the output list. Other columns are treated as numeric and added as numpy array
+as the last element of the output list.. [/INST] def preproc(X_train, X_val, cat_cols, cols_to_keep):
+ other_cols = [not c in cat_cols for c in cols_to_keep]
+ input_list_train = []
+ input_list_val = []
+ for c in cat_cols:
+ el_index = cols_to_keep.index(c)
+ raw_vals = np.unique(np.concatenate((X_train, X_val), axis=0)[:, el_index])
+ raw_vals.sort()
+ val_map = {}
+ for i in range(len(raw_vals)):
+ val_map[raw_vals[i]] = i
+ input_list_train.append(np.vectorize(val_map.get)(X_train[:, el_index]))
+ input_list_val.append(np.vectorize(val_map.get)(X_val[:, el_index]))
+ input_list_train.append(X_train[:, other_cols])
+ input_list_val.append(X_val[:, other_cols])
+ return input_list_train, input_list_val
+[INST] Builds a neural network model with entity embedding for categorical variables.
+The function builds the neural network, for which it creates entity embedding for each categorical feature specified
+in cat_cols argument. Numerical features are projected ot a dense layer. Entity embedding with the numerical
+features projected to the dense layer project to 2 more layers. The final layer is the output layer.. [/INST] def build_embedding_network(X_train, X_val, cat_cols, cols_to_keep, n_num=100, n=60, d=False, verbose=False):
+ inputs = []
+ embeddings = []
+ for categorical_var in cat_cols:
+ if verbose:
+ print("------------------------------------------------------------------")
+ print("for categorical column ", categorical_var)
+ input_cat = Input(shape=(1,))
+ el_index = cols_to_keep.index(categorical_var)
+ no_of_unique_cat = np.unique(np.concatenate((X_train, X_val), axis=0)[:, el_index]).size
+ if verbose:
+ print("number of unique cat", no_of_unique_cat)
+ embedding_size = min(np.ceil((no_of_unique_cat) / 2), 50)
+ embedding_size = int(embedding_size)
+ if verbose:
+ print("embedding_size set as ", embedding_size)
+ embedding = Embedding(no_of_unique_cat + 1, embedding_size, input_length=1)(input_cat)
+ embedding = Reshape(target_shape=(embedding_size,))(embedding)
+ inputs.append(input_cat)
+ embeddings.append(embedding)
+ other_cols = [not c in cat_cols for c in cols_to_keep]
+ len_other_cols = len([c for c in other_cols if c])
+ if verbose:
+ print("------------------------------------------------------------------")
+ print('Numeric columns')
+ print('Number of columns', len_other_cols)
+ input_numeric = Input(shape=(len_other_cols,))
+ embedding_numeric = Dense(n_num)(input_numeric)
+ inputs.append(input_numeric)
+ embeddings.append(embedding_numeric)
+ x = Concatenate()(embeddings)
+ x = Dense(n, activation='relu')(x)
+ if d:
+ x = Dropout(d)(x)
+ x = Dense(n, activation='relu')(x)
+ if d:
+ x = Dropout(d)(x)
+ output = Dense(1, activation='sigmoid')(x)
+ model = Model(inputs, output)
+ model.compile(loss='binary_crossentropy', optimizer='adam')
+ return model
+[INST] Builds a neural network model with entity embedding for categorical variables.
+The function builds the neural network, for which it creates entity embedding for each categorical feature specified
+in cat_cols argument. Numerical features are projected ot a dense layer. Entity embeddings are concatenated with
+layer numerical features are projected and project to a dense layer with n1 neurons. There are additional two layers
+with n2 neurons each before projecting to the output layer.. [/INST] def build_embedding_network_3(X_train, X_val, cat_cols, cols_to_keep, n_num=120, n1=150, n2=50, d=False, lr=0.001, verbose=False):
+ inputs = []
+ embeddings = []
+ for categorical_var in cat_cols:
+ if verbose:
+ print("------------------------------------------------------------------")
+ print("for categorical column ", categorical_var)
+ input_cat = Input(shape=(1,))
+ el_index = cols_to_keep.index(categorical_var)
+ no_of_unique_cat = np.unique(np.concatenate((X_train, X_val), axis=0)[:, el_index]).size
+ if verbose:
+ print("number of unique cat", no_of_unique_cat)
+ embedding_size = min(np.ceil((no_of_unique_cat) / 2), 50)
+ embedding_size = int(embedding_size)
+ if verbose:
+ print("embedding_size set as ", embedding_size)
+ embedding = Embedding(no_of_unique_cat + 1, embedding_size, input_length=1)(input_cat)
+ embedding = Reshape(target_shape=(embedding_size,))(embedding)
+ inputs.append(input_cat)
+ embeddings.append(embedding)
+ other_cols = [not c in cat_cols for c in cols_to_keep]
+ len_other_cols = len([c for c in other_cols if c])
+ if verbose:
+ print("------------------------------------------------------------------")
+ print('Numeric columns')
+ print('Number of columns', len_other_cols)
+ input_numeric = Input(shape=(len_other_cols,))
+ embedding_numeric = Dense(n_num)(input_numeric)
+ inputs.append(input_numeric)
+ embeddings.append(embedding_numeric)
+ x = Concatenate()(embeddings)
+ x = Dense(n1, activation='relu')(x)
+ if d:
+ x = Dropout(d)(x)
+ x = Dense(n1, activation='relu')(x)
+ if d:
+ x = Dropout(d)(x)
+ x = Dense(n2, activation='relu')(x)
+ if d:
+ x = Dropout(d)(x)
+ output = Dense(1, activation='sigmoid')(x)
+ model = Model(inputs, output)
+ optimizer = keras.optimizers.Adam(lr=lr)
+ model.compile(loss='binary_crossentropy', optimizer=optimizer)
+ return model
+[INST] Performs grid search over parameter grid.
+Function iterates over the combinations of parameters in the parameter grid. Trains the estimator on the X_train.
+Evaluation is done on both the training set and validation set. The parameters reported are ROC AUC score calculated
+on the training and test sets.. [/INST] def grid_search(estimator, param_grid, X_train, y_train, X_test, y_test, batch_size=10000, nn=False):
+ out = pd.DataFrame()
+ for g in ParameterGrid(param_grid):
+ estimator.set_params(**g)
+ if nn:
+ print('Fitting with params:', g)
+ early_stopping_monitor = EarlyStopping(patience=2)
+ estimator.fit(
+ X_train,
+ y_train,
+ verbose=True,
+ validation_split=0.3,
+ epochs=1000,
+ batch_size=batch_size,
+ callbacks=[early_stopping_monitor])
+ y_train_pred_prob_true = estimator.predict_proba(X_train)[:, 1]
+ y_test_pred_prob_true = estimator.predict_proba(X_test)[:, 1]
+ scores = {
+ 'train_roc_auc': roc_auc_score(y_train, y_train_pred_prob_true),
+ 'test_roc_auc': roc_auc_score(y_test, y_test_pred_prob_true)
+ }
+ print('Scores after fitting:', scores)
+ else:
+ estimator.fit(X_train, y_train)
+ y_train_pred = estimator.predict(X_train)
+ y_train_pred_prob = estimator.predict_proba(X_train)[:,1]
+ y_test_pred = estimator.predict(X_test)
+ y_test_pred_prob = estimator.predict_proba(X_test)[:,1]
+ scores = {'train_accuracy': accuracy_score(y_train, y_train_pred),
+ 'test_accuracy': accuracy_score(y_test, y_test_pred),
+ 'train_roc_auc': roc_auc_score(y_train, y_train_pred_prob),
+ 'test_roc_auc': roc_auc_score(y_test, y_test_pred_prob)}
+ out_dict = g.copy()
+ out_dict.update(scores)
+ print(out_dict)
+ out = out.append(out_dict, ignore_index=True)
+ return out
+[INST] Fits the Gradient Boosting on the training set, evaluates the ROC AUC score on the training and validation set.
+Plots the feature importance plot.. [/INST] def fit_GBC(alg, X, y, X_valid, y_valid, X_cols, printFeatureImportance=True):
+ alg.fit(X, y)
+ y_predictions = alg.predict(X)
+ y_predprob = alg.predict_proba(X)[:, 1]
+ y_valid_predprob = alg.predict_proba(X_valid)[:, 1]
+ print("\nModel Report")
+ print("Accuracy : %.4g" % accuracy_score(y, y_predictions))
+ print("AUC Score (Train): %f" % roc_auc_score(y, y_predprob))
+ print("AUC Validation Score : %f" % roc_auc_score(y_valid, y_valid_predprob))
+ if printFeatureImportance:
+ feat_imp = pd.Series(alg.feature_importances_, X_cols).sort_values(ascending=False)
+ feat_imp[:20].plot(kind='bar', title='Feature Importances', figsize=(12, 8))
+ plt.ylabel('Feature Importance Score')
+[INST] Evaluates log-probability for a single structure. [/INST] def _single_structure_log_prob(self, structure, ll, ul):
+ x = structure.reshape(-1, 3)
+ k_bb = self['k_bb'].value
+ d = np.sqrt(np.sum((x[1:] - x[:-1]) ** 2, 1))
+ u_viols = d > ul
+ l_viols = d < ll
+ delta = ul - ll
+ return -0.5 * k_bb * ( np.sum((d[u_viols] - ul[u_viols]) ** 2 ) \
+ + np.sum((ll[l_viols] - d[l_viols]) ** 2))
+[INST] Evaluates gradient of energy for a single structure. [/INST] def _single_structure_gradient(self, structure, ll, ul):
+ return backbone_prior_gradient(structure.ravel(), ll, ul, self['k_bb'].value)
+[INST] Evaluates log-probability of a structure ensemble. [/INST] def _evaluate_log_prob(self, structures):
+ log_prob = self._single_structure_log_prob
+ X = structures.reshape(self.n_structures, -1, 3)
+ mr = self._mol_ranges
+ ll, ul = self.lower_limits, self.upper_limits
+ def ss_lp(x):
+ return np.sum([log_prob(x[mr[i]:mr[i+1]], ll[i], ul[i])
+ for i in range(len(mr) - 1)])
+ return np.sum(map(lambda x: ss_lp(x), X))
+[INST] Evaluates gradient of energy of a structure ensemble. [/INST] def _evaluate_gradient(self, structures):
+ grad = self._single_structure_gradient
+ X = structures.reshape(self.n_structures, -1, 3)
+ mr = self._mol_ranges
+ ll, ul = self.lower_limits, self.upper_limits
+ def ss_grad(x):
+ return np.concatenate([grad(x[mr[i]:mr[i+1]], ll[i], ul[i])
+ for i in range(len(mr) - 1)])
+ return np.concatenate(map(lambda x: ss_grad(x), X))
+[INST] Returns a copy of an instance of this class. [/INST] def clone(self):
+ copy = self.__class__(self.name,
+ self.lower_limits,
+ self.upper_limits,
+ self['k_bb'].value,
+ self.n_structures,
+ self._mol_ranges)
+ copy.fix_variables(**{p: self[p].value for p in self.parameters
+ if not p in copy.parameters})
+ return copy
+[INST] Evaluates log-probability for a single structure. [/INST] def _single_structure_log_prob(self, structure):
+ X = structure.reshape(-1,3)
+ rg = radius_of_gyration(X)
+ return -0.5 * self['k_rog'].value * (self['rog'].value - rg) ** 2
+[INST] Evaluates the negative log-probability gradient
+for a single structure. [/INST] def _single_structure_gradient(self, structure):
+ X = structure.reshape(-1,3)
+ r_gyr = radius_of_gyration(X)
+ k = self['k_rog'].value
+ target_rog = self['rog'].value
+ return -k * (target_rog - r_gyr) * (X - X.mean(0)).ravel() / r_gyr / len(X)
+[INST] Loads full results of a Replica Exchange
+simulation.. [/INST] def load_samples(samples_folder, n_replicas, n_samples,
+ dump_interval, burnin, interval=1):
+ samples = []
+ for i in xrange(1, n_replicas + 1):
+ samples.append(load_sr_samples(samples_folder, i, n_samples, dump_interval,
+ burnin, interval))
+ return np.array(samples)
+[INST] Loads results for a single replica resulting from a Replica Exchange
+simulation.. [/INST] def load_sr_samples(samples_folder, replica_id, n_samples,
+ dump_interval, burnin, interval=1):
+ samples = []
+ for j in xrange((burnin / dump_interval) * dump_interval,
+ n_samples - dump_interval, dump_interval):
+ path = samples_folder + 'samples_replica{}_{}-{}.pickle'
+ samples += np.load(path.format(replica_id, j, j+dump_interval), allow_pickle=True)
+ start = burnin * (burnin < dump_interval) + (burnin > dump_interval) * (burnin % dump_interval)
+ return np.array(samples[start::interval])
+[INST] Writes a structure ensemble to a PDB file.. [/INST] def write_ensemble(X, filename, mol_ranges=None, center=True):
+ from csb.bio.structure import Atom, ProteinResidue, Chain, Structure, Ensemble
+ from csb.bio.sequence import ProteinAlphabet
+ if center:
+ X -= X.mean(1)[:,None,:]
+ if mol_ranges is None:
+ mol_ranges = np.array([0, X.shape[1]])
+ ensemble = Ensemble()
+ for i, x in enumerate(X):
+ structure = Structure('')
+ structure.model_id = i + 1
+ mol_coordinates = np.array([x[start:end]
+ for start, end in zip(mol_ranges[:-1],
+ mol_ranges[1:])])
+ for j, mol in enumerate(mol_coordinates):
+ structure.chains.append(Chain(chr(65 + j)))
+ for k, y in enumerate(mol):
+ atom = Atom(k+1, 'CA', 'C', y)
+ residue = ProteinResidue(k, 'ALA')
+ residue.atoms.append(atom)
+ structure.chains[chr(65 + j)].residues.append(residue)
+ ensemble.models.append(structure)
+ ensemble.to_pdb(filename)
+[INST] Writes a VMD script to show structures
+This writes a VMD script loading a structure ensemble PDB file, setting
+bead radii to given values and showing the structures as a chain of beads.. [/INST] def write_VMD_script(ensemble_pdb_file, bead_radii, output_file):
+ lines = ['color Display Background white',
+ 'menu main on',
+ 'menu graphics on',
+ 'mol load pdb {}'.format(ensemble_pdb_file),
+ 'mol color Index',
+ 'mol delrep 0 0',
+ 'mol representation VDW',
+ 'mol addrep 0'
+ ]
+ radii_set = set(bead_radii)
+ for br in radii_set:
+ p1 = 'set sel [atomselect top "index '
+ p2 = ''
+ for i, r in enumerate(bead_radii):
+ if r == br:
+ p2 += '{} '.format(i)
+ p3 = '"]'
+ lines.append(p1 + p2[:-1] + p3)
+ lines.append('$sel set radius {}'.format(br))
+ with open(output_file,'w') as opf:
+ [opf.write(line + '\n') for line in lines]
+[INST] Writes a PyMol script to show structures
+This writes a PyMol script loading a structure ensemble PDB file, setting
+bead radii to given values and showing the structures as a chain of beads.
+warning:: I'm not sure whether this works (I mostly use VMD). [/INST] def write_pymol_script(ensemble_pdb_file, bead_radii, output_file,
+ repr='spheres', grid=False):
+ epf = ensemble_pdb_file
+ fname = epf[-epf[::-1].find('/'):epf.find('.pdb')]
+ lines = ['load {}'.format(ensemble_pdb_file),
+ 'hide all',
+ 'util.chainbow',
+ 'set cartoon_trace_atoms=1',
+ 'set ribbon_trace_atoms=1',
+ ]
+ if repr == 'spheres':
+ for i, r in enumerate(bead_radii):
+ lines.append('set sphere_scale={}, resi {}'.format(r / 4.0, i))
+ elif repr == 'cartoon':
+ lines.append('as cartoon')
+ elif repr == 'ribbon':
+ lines.append('as ribbon')
+ if grid:
+ lines.append('set grid_mode=1')
+ lines.append('hide {}'.format(fname))
+ with open(output_file,'w') as opf:
+ [opf.write(line + '\n') for line in lines]
+[INST] Loads results of a simulation using a config file
+This returns posterior samples from a simulation given a config file.. [/INST] def load_samples_from_cfg(config_file, burnin=35000):
+ from .setup_functions import parse_config_file
+ cfg = parse_config_file(config_file)
+ output_folder = cfg['general']['output_folder']
+ n_beads = int(cfg['general']['n_beads'])
+ n_structures = int(cfg['general']['n_structures'])
+ n_samples = int(cfg['replica']['n_samples'])
+ samples = load_sr_samples(output_folder + 'samples/',
+ int(cfg['replica']['n_replicas']),
+ n_samples,
+ int(cfg['replica']['samples_dump_interval']),
+ burnin)
+ return samples
+[INST] Loads results of a simulation using a config file
+This returns posterior samples from a simulation given a config file
+and automatically determines the number of actually drawn samples,
+i.e., it ignores to the n_samples setting in the config file.. [/INST] def load_samples_from_cfg_auto(config_file, burnin=35000):
+ import os
+ from .setup_functions import parse_config_file
+ cfg = parse_config_file(config_file)
+ output_folder = cfg['general']['output_folder']
+ n_structures = int(cfg['general']['n_structures'])
+ n_replicas = int(cfg['replica']['n_replicas'])
+ dump_interval = int(cfg['replica']['samples_dump_interval'])
+ n_drawn_samples = 0
+ fname = output_folder + 'samples/samples_replica' \
+ + str(n_replicas) + '_{}-{}.pickle'
+ while True:
+ if os.path.exists(fname.format(n_drawn_samples,
+ n_drawn_samples + dump_interval)):
+ n_drawn_samples += dump_interval
+ else:
+ break
+ samples = load_sr_samples(output_folder + 'samples/',
+ n_replicas,
+ n_drawn_samples + 1,
+ dump_interval,
+ burnin)
+ return samples
+[INST] Loads a structure ensemble from a PDB file. [/INST] def load_ensemble_from_pdb(filename):
+ if False:
+ from csb.bio.io.wwpdb import StructureParser
+ ensemble = StructureParser(filename).parse_models()
+ return np.array([m.get_coordinates(['CA']) for m in ensemble])
+ else:
+ ip = open(filename)
+ lines = ip.readlines()
+ ip.close()
+ nres = 0
+ for l in lines:
+ w = l[:4]
+ if nres > 0 and w != 'ATOM':
+ break
+ if l[:4] == 'ATOM':
+ nres += 1
+ import re
+ atoms = []
+ for l in lines:
+ if 'ATOM' in l: atoms.append(l)
+ atoms = [x.split() for x in atoms]
+ atoms = [x[6:9] for x in atoms]
+ atoms = np.array([np.array(x).astype(np.float) for x in atoms])
+ atoms = np.array(np.split(atoms, nres))
+ atoms = atoms.reshape(len(filter(lambda x: 'MODEL' in x, lines)), -1, 3)
+ return atoms
+[INST] Calculates the density of states (DOS) using non-parametric
+histogram reweighting (WHAM).. [/INST] def calculate_DOS(config_file, n_samples, subsamples_fraction, burnin,
+ n_iter=100000, tol=1e-10, save_output=True, output_suffix=''):
+ from ensemble_hic.wham import PyWHAM as WHAM, DOS
+ from ensemble_hic.setup_functions import parse_config_file, make_posterior
+ from ensemble_hic.analysis_functions import load_sr_samples
+ settings = parse_config_file(config_file)
+ n_replicas = int(settings['replica']['n_replicas'])
+ target_replica = n_replicas
+ params = {'n_samples': n_samples,
+ 'burnin': burnin,
+ 'subsamples_fraction': subsamples_fraction,
+ 'niter': n_iter,
+ 'tol': tol
+ }
+ n_samples = min(params['n_samples'], int(settings['replica']['n_samples']))
+ dump_interval = int(settings['replica']['samples_dump_interval'])
+ output_folder = settings['general']['output_folder']
+ if output_folder[-1] != '/':
+ output_folder += '/'
+ n_beads = int(settings['general']['n_beads'])
+ n_structures = int(settings['general']['n_structures'])
+ schedule = np.load(output_folder + 'schedule.pickle')
+ posterior = make_posterior(settings)
+ p = posterior
+ variables = p.variables
+ energies = []
+ L = p.likelihoods['ensemble_contacts']
+ data = L.forward_model.data_points
+ P = p.priors['nonbonded_prior']
+ sels = []
+ for i in range(n_replicas):
+ samples = load_sr_samples(output_folder + 'samples/', i+1, n_samples+1,
+ dump_interval, burnin=params['burnin'])
+ sel = np.random.choice(len(samples),
+ int(len(samples) / float(subsamples_fraction)),
+ replace=False)
+ samples = samples[sel]
+ sels.append(sel)
+ energies.append([[-L.log_prob(**x.variables) if 'lammda' in schedule else 0,
+ -P.log_prob(structures=x.variables['structures'])
+ if 'beta' in schedule else 0]
+ for x in samples])
+ print "Calculated energies for {}/{} replicas...".format(i, n_replicas)
+ energies = np.array(energies)
+ energies_flat = energies.reshape(np.prod(energies.shape[:2]), 2)
+ sched = np.array([schedule['lammda'], schedule['beta']])
+ q = np.array([[(energy * replica_params).sum() for energy in energies_flat]
+ for replica_params in sched.T])
+ wham = WHAM(len(energies_flat), n_replicas)
+ wham.N[:] = len(energies_flat)/n_replicas
+ wham.run(q, niter=params['niter'], tol=params['tol'], verbose=100)
+ dos = DOS(energies_flat, wham.s, sort_energies=False)
+ if save_output:
+ import os
+ import sys
+ from cPickle import dump
+ ana_path = output_folder + 'analysis/'
+ if not os.path.exists(ana_path):
+ os.makedirs(ana_path)
+ with open(ana_path + 'dos{}.pickle'.format(output_suffix), 'w') as opf:
+ dump(dos, opf)
+ with open(ana_path + 'wham_params{}.pickle'.format(output_suffix), 'w') as opf:
+ dump(params, opf)
+ with open(ana_path + 'wham_sels{}.pickle'.format(output_suffix), 'w') as opf:
+ dump(np.array(sels), opf)
+ return dos
+[INST] Calculates the evidence from a DOS object. [/INST] def calculate_evidence(dos):
+ from csb.numeric import log_sum_exp
+ return log_sum_exp(-dos.E.sum(1) + dos.s) - \
+ log_sum_exp(-dos.E[:,1] + dos.s)
+[INST] Evaluates the potentital energy of a structure. [/INST] def energy(self, structure):
+ pass
+[INST] Evaluates the energy gradient for a structure. [/INST] def gradient(self, structure):
+ pass
+[INST] Cython implementation of the potential energy. [/INST] def energy(self, structure):
+ from ensemble_hic.forcefield_c import forcefield_energy
+ E = forcefield_energy(structure, self.bead_radii,
+ self.bead_radii2,
+ self.force_constant)
+ return E
+[INST] Cython implementation of the energy gradient. [/INST] def gradient(self, structure):
+ from ensemble_hic.forcefield_c import forcefield_gradient
+ grad = forcefield_gradient(structure, self.bead_radii,
+ self.bead_radii2,
+ self.force_constant)
+ return grad
+[INST] Edge length of the cubic cells.. [/INST] def cellsize():
+ pass
+[INST] Number of cubic cells in each spatial direction such that the
+total number of cells is 'n_cells^3'.. [/INST] def n_cells():
+ pass
+[INST] Maximum number of particles that fits into a cell.. [/INST] def n_per_cell():
+ pass
+[INST] Update the neighbor list.
+Parameters
+
+universe :
+Universe containing all particles whose pairwise interactions
+will be evaluated.
+
+update_box : boolean
+By toggling the flag, we can switch off the adaption of the
+cell grid. [/INST] def update(self, universe, update_box=True):
+ self.ctype.update(universe.coords, int(update_box))
+[INST] Reads the raw 5C data file and returns reverse restriction fragments,
+forward restriction fragments, and a matrix of shape
+(# forward fragments + 2, # reverse fragments + 2).
+First two rows are start / end genomic coordinates of reverse restriction
+fragments, first two columns are start / end genomic coordinates of forward
+restriction fragments.
+Genomic coordinates are hg11. [/INST] def parse_5C_file(filename):
+ data = open(filename).readlines()
+ data = data[7:]
+ data = [y.split('\t') for y in data]
+ data = np.array(data)
+ rev_fragments = [x[x.find('chrX:')+5:] for x in data[0]]
+ rev_fragments = [x.split('-') for x in rev_fragments]
+ rev_fragments = [(int(x[0]), int(x[1])) for x in rev_fragments[1:]]
+ rev_fragments = np.array(rev_fragments).swapaxes(1,0)
+ for_fragments = [x[x.find('chrX:')+5:] for x in data[1:,0]]
+ for_fragments = [x.split('-') for x in for_fragments]
+ for_fragments = [(int(x[0]), int(x[1])) for x in for_fragments]
+ for_fragments = np.array(for_fragments)
+ matrix = np.zeros((len(for_fragments) + 2, len(rev_fragments.T) + 2))
+ matrix[2:,:2] = for_fragments
+ matrix[:2,2:] = rev_fragments
+ matrix[2:,2:] = data[1:,1:]
+ return rev_fragments, for_fragments, matrix
+[INST] Extracts a region from a matrix produced by parse_5C_file.
+Returns the reverse and forward restriction fragments in the region
+and the part of the matrix covered by the region. [/INST] def extract_region(matrix, region_start, region_end):
+ land = np.logical_and
+ region_row_mask = land(matrix[:,0] >= region_start, matrix[:,1] <= region_end)
+ region_col_mask = land(matrix[0,:] >= region_start, matrix[1,:] <= region_end)
+ region = matrix[region_row_mask]
+ region = region[:,region_col_mask]
+ region_fors = matrix[region_row_mask, :2]
+ region_revs = matrix[:2, region_col_mask]
+ fragment_lengths = np.concatenate((region_fors[:,1] - region_fors[:,0],
+ region_revs[1,:] - region_revs[0,:])).astype(int)
+ return region_revs, region_fors, region
+[INST] Divides a region on a chromosome (or rather, the part of it covered by complete
+restriction fragments) into segments of equal, given length and one last
+segment which is smaller than the others such that the segments completely
+cover the region. These segments will be represented by spherical beads later.
+Returns the limits of the segments. [/INST] def calculate_bead_lims(bead_size, region_revs, region_fors):
+ region_length = np.max((region_fors[-1,1], region_revs[1,-1])) \
+ - np.min((region_fors[0,0], region_revs[0,0]))
+ n_beads = int(round(region_length / bead_size)) + 1
+ bead_lims = [np.min((region_fors[0,0], region_revs[0,0])) + i * bead_size
+ for i in range(n_beads)]
+ bead_lims[-1] = np.max((region_fors[-1,1], region_revs[1,-1]))
+ return np.array(bead_lims)
+[INST] Calculates a mapping assigning a bead to each restriction fragment.
+If one restriction fragment spans several beads, it will have the center
+bead (or center - 1 for even number of beads) assigned.
+Returns the mappings for reverse and forward restriction fragments. [/INST] def calculate_mappings(region_revs, region_fors, bead_lims):
+ region_revs = region_revs.T
+ mappings = []
+ for rfs in (region_revs, region_fors):
+ mapping = []
+ for b, e in rfs:
+ mapping.append((np.where(bead_lims <= b)[0][-1],
+ np.where(bead_lims <= e)[0][-1]))
+ mapping = np.array(mapping)
+ mapping = mapping.mean(1).astype(int)
+ mappings.append(mapping)
+ return mappings[0], mappings[1]
+[INST] Builds a square contact frequency matrix of shape (n_beads, n_beads).
+Contacts from restriction fragments mapping to the same bead are summed.
+A zero in this matrix means either that there was no data collected or that
+the number of counts is in fact zero. Later, we ignore zero-valued
+entries in the matrix.
+Return square contact frequency matrix.. [/INST] def build_cmatrix(rev_mapping, for_mapping, region, n_beads):
+ contmatrix = np.zeros((n_beads, n_beads))
+ cmatrix = np.zeros((n_beads, n_beads))
+ for i in range(n_beads):
+ contributing_fors = np.where(for_mapping == i)[0]
+ for j in range(n_beads):
+ contributing_revs = np.where(rev_mapping == j)[0]
+ for cfor in contributing_fors:
+ for crev in contributing_revs:
+ cmatrix[i,j] += region[cfor, crev]
+ cmatrix += cmatrix.T
+ return cmatrix
+[INST] Writes a square contact frequency matrix to a file, which will be the
+input for our structure calculation code.. [/INST] def write_cmatrix(cmatrix, filename):
+ with open(filename, 'w') as opf:
+ for i in range(len(cmatrix)):
+ for j in range(i+1, len(cmatrix)):
+ opf.write('{}\t{}\t{}\n'.format(i, j, int(cmatrix[i,j])))
+[INST] Calculates the shape of the Gamma distribution. [/INST] def _calculate_shape(self):
+ pass
+[INST] Calculates the rate of the Gamma distribution. [/INST] def _calculate_rate(self):
+ pass
+[INST] Draws a sample from the Gamma distribution specified
+by a rate and a scale parameter. [/INST] def sample(self, state=42):
+ rate = self._calculate_rate()
+ shape = self._calculate_shape()
+ sample = np.random.gamma(shape) / rate
+ if sample == 0.0:
+ sample += 1e-10
+ self.state = sample
+ return self.state
+[INST] Retrieves the prior distribution object associated with the
+scaling factor variable. [/INST] def _get_prior(self):
+ prior = filter(lambda p: 'norm' in p.variables, self.pdf.priors.values())[0]
+ return prior
+[INST] Checks whether retrieved prior distribution is in fact a Gamma
+distribution. [/INST] def _check_gamma_prior(self, prior):
+ from .gamma_prior import GammaPrior
+ return isinstance(prior, GammaPrior)
+[INST] Parses a config file consisting of several sections.
+I think I adapted this from the ConfigParser docs.. [/INST] def parse_config_file(config_file):
+ import ConfigParser
+ config = ConfigParser.ConfigParser()
+ config.read(config_file)
+ def config_section_map(section):
+ dict1 = {}
+ options = config.options(section)
+ for option in options:
+ try:
+ dict1[option] = config.get(section, option)
+ except:
+ dict1[option] = None
+ return dict1
+ return {section: config_section_map(section) for section in config.sections()}
+[INST] Makes the Gamma prior object for the scaling parameter
+Shape and rate of the Gamma distribution are set to rather broad
+values depending on the average number of counts in the data. [/INST] def make_norm_prior(norm_prior_settings, likelihood, n_structures):
+ from .gamma_prior import NormGammaPrior
+ shape = norm_prior_settings['shape']
+ rate = norm_prior_settings['rate']
+ if shape == rate == 'auto':
+ rate = 1.0 / n_structures
+ dp = likelihood.forward_model.data_points[:,2]
+ shape = np.mean(dp[dp > 0]) / float(n_structures)
+ else:
+ shape = float(shape)
+ rate = float(rate)
+ return NormGammaPrior(shape, rate)
+[INST] Helper function which creates an array of exponentially spaced values
+I use this to create temperature schedules for
+replica exchange simulations.. [/INST] def expspace(min, max, a, N):
+ g = lambda n: (max - min) / (np.exp(a*(N-1.0)) - 1.0) * (np.exp(a*(n-1.0)) - 1.0) + float(min)
+ return np.array(map(g, np.arange(1, N+1)))
+[INST] Makes a replica exchange schedule from settings specified in
+a config file.
+
+You can either have a linear or an exponential schedule
+and a separate prior annealing chain or not. You can also
+load a schedule from a Python pickle. It has to be a
+dict with the keys being the tempered parameters and the
+values the schedule for that parameter.
+But the latter option is currently handled in the run_simulation.py
+script.. [/INST] def make_replica_schedule(replica_params, n_replicas):
+ l_min = float(replica_params['lambda_min'])
+ l_max = float(replica_params['lambda_max'])
+ b_min = float(replica_params['beta_min'])
+ b_max = float(replica_params['beta_max'])
+ if replica_params['schedule'] == 'linear':
+ if replica_params['separate_prior_annealing'] == 'True':
+ separate_prior_annealing = True
+ else:
+ separate_prior_annealing = False
+ if separate_prior_annealing:
+ b_chain = np.arange(0, np.floor(n_replicas / 2))
+ l_chain = np.arange(np.floor(n_replicas / 2), n_replicas)
+ lambdas = np.concatenate((np.zeros(len(b_chain)) + l_min,
+ np.linspace(l_min, l_max,
+ len(l_chain))))
+ betas = np.concatenate((np.linspace(b_min, b_max,
+ len(b_chain)),
+ np.zeros(len(l_chain)) + b_max))
+ schedule = {'lammda': lambdas, 'beta': betas}
+ else:
+ schedule = {'lammda': np.linspace(l_min, l_max, n_replicas),
+ 'beta': np.linspace(b_min, b_max, n_replicas)}
+ elif replica_params['schedule'] == 'exponential':
+ l_rate = float(replica_params['lambda_rate'])
+ b_rate = float(replica_params['beta_rate'])
+ schedule = {'lammda': expspace(l_min, l_max, l_rate, n_replicas),
+ 'beta': expspace(b_min, b_max, b_rate, n_replicas)}
+ else:
+ msg = 'Schedule has to be either a file name, ' + \
+ '\'lambda_beta\', or \'exponential\''
+ raise ValueError(msg)
+ return schedule
+[INST] Makes a dictionary of (possibly MCMC) samplers for all variables. [/INST] def make_subsamplers(posterior, initial_state,
+ structures_hmc_params):
+ from binf.samplers.hmc import HMCSampler
+ p = posterior
+ variables = initial_state.keys()
+ structures_tl = int(structures_hmc_params['trajectory_length'])
+ structures_timestep = float(structures_hmc_params['timestep'])
+ s_adaption_limit = int(structures_hmc_params['adaption_limit'])
+ xpdf = p.conditional_factory(**{var: value for (var, value)
+ in initial_state.iteritems()
+ if not var == 'structures'})
+ structures_sampler = HMCSampler(xpdf,
+ initial_state['structures'],
+ structures_timestep, structures_tl,
+ variable_name='structures',
+ timestep_adaption_limit=s_adaption_limit)
+ subsamplers = dict(structures=structures_sampler)
+ if 'norm' in variables:
+ from .error_models import PoissonEM
+ if isinstance(p.likelihoods['ensemble_contacts'].error_model,
+ PoissonEM):
+ from .npsamplers import NormGammaSampler
+ norm_sampler = NormGammaSampler()
+ subsamplers.update(norm=norm_sampler)
+ else:
+ raise NotImplementedError('Norm sampling only implemented' +
+ 'for Poisson error model!')
+ return subsamplers
+[INST] Makes a set of fully elongated structures. [/INST] def make_elongated_structures(bead_radii, n_structures):
+ X = [bead_radii[0]]
+ for i in range(len(bead_radii) -1):
+ X.append(X[-1] + bead_radii[i+1] + bead_radii[i])
+ X = np.array(X) - np.mean(X)
+ X = np.array([X, np.zeros(len(bead_radii)),
+ np.zeros(len(bead_radii))]).T[None,:]
+ X = X.repeat(n_structures, 0).ravel().astype(float)
+ return X
+[INST] Makes a set of random structures with bead positions drawn
+from a normal distribution. [/INST] def make_random_structures(bead_radii, n_structures):
+ d = bead_radii.mean() * len(bead_radii) ** 0.333
+ X = np.random.normal(scale=d, size=(n_structures, len(bead_radii), 3))
+ return X.ravel()
+[INST] Conditions the posterior on the fixed variables. [/INST] def make_conditional_posterior(posterior, settings):
+ variables = settings['general']['variables'].split(',')
+ variables = [x.strip() for x in variables]
+ p = posterior
+ if not 'norm' in variables:
+ return p.conditional_factory(norm=settings['initial_state']['norm'])
+ else:
+ return p
+[INST] Makes the default backbone prior object.. [/INST] def make_backbone_prior(bead_radii, backbone_prior_params, n_beads,
+ n_structures):
+ from .backbone_prior import BackbonePrior
+ if 'mol_ranges' in backbone_prior_params:
+ mol_ranges = backbone_prior_params['mol_ranges']
+ else:
+ mol_ranges = None
+ if mol_ranges is None:
+ mol_ranges = np.array([0, n_beads])
+ else:
+ mol_ranges = np.loadtxt(mol_ranges).astype(int)
+ bb_ll = [np.zeros(mol_ranges[i+1] - mol_ranges[i] - 1)
+ for i in range(len(mol_ranges) - 1)]
+ bb_ul = [np.array([bead_radii[j] + bead_radii[j+1]
+ for j in range(mol_ranges[i], mol_ranges[i+1] - 1)])
+ for i in range(len(mol_ranges) - 1)]
+ BBP = BackbonePrior('backbone_prior',
+ lower_limits=bb_ll, upper_limits=bb_ul,
+ k_bb=float(backbone_prior_params['force_constant']),
+ n_structures=n_structures,
+ mol_ranges=mol_ranges
+ )
+ return BBP
+[INST] Sets up all structural prior distributions. [/INST] def make_priors(nonbonded_prior_params, backbone_prior_params,
+ sphere_prior_params, n_beads, n_structures):
+ nb_params = nonbonded_prior_params
+ try:
+ bead_radius = float(nb_params['bead_radii'])
+ bead_radii = np.ones(n_beads) * bead_radius
+ except:
+ bead_radii = np.loadtxt(nb_params['bead_radii'],
+ dtype=float)
+ NBP = make_nonbonded_prior(nb_params, bead_radii, n_structures)
+ BBP = make_backbone_prior(bead_radii, backbone_prior_params,
+ n_beads, n_structures)
+ priors = {NBP.name: NBP, BBP.name: BBP}
+ if sphere_prior_params['active'] == 'True':
+ SP = make_sphere_prior(sphere_prior_params, bead_radii, n_structures)
+ priors.update(**{SP.name: SP})
+ return priors
+[INST] Makes the default non-bonded structural prior object.
+This will either be a Boltzmann-like distribution or a
+Tsallis ensemble (currently not really supported).. [/INST] def make_nonbonded_prior(nb_params, bead_radii, n_structures):
+ from .forcefields import NBLForceField as ForceField
+ forcefield = ForceField(bead_radii, float(nb_params['force_constant']))
+ if not 'ensemble' in nb_params or nb_params['ensemble'] == 'boltzmann':
+ from .nonbonded_prior import BoltzmannNonbondedPrior2
+ NBP = BoltzmannNonbondedPrior2('nonbonded_prior', forcefield,
+ n_structures=n_structures, beta=1.0)
+ elif nb_params['ensemble'] == 'tsallis':
+ raise NotImplementedError
+ from .nonbonded_prior import TsallisNonbondedPrior2
+ NBP = TsallisNonbondedPrior('nonbonded_prior', bead_radii=bead_radii,
+ force_constant=force_constant,
+ n_structures=n_structures, q=1.0)
+ return NBP
+[INST] Makes a sphere structural prior object.
+This is a Boltzmann-like distribution with a potential energy
+harmonically restraining all beads to stay within a sphere
+of a given radius.. [/INST] def make_sphere_prior(sphere_prior_params, bead_radii, n_structures):
+ from .sphere_prior import SpherePrior
+ radius = sphere_prior_params['radius']
+ if radius == 'auto':
+ radius = 2 * bead_radii.mean() * len(bead_radii) ** (1 / 3.0)
+ else:
+ radius = float(radius)
+ SP = SpherePrior('sphere_prior',
+ sphere_radius=radius,
+ sphere_k=float(sphere_prior_params['force_constant']),
+ n_structures=n_structures, bead_radii=bead_radii)
+ return SP
+[INST] Sets up a likelihood object from settings parsed from a config
+file. [/INST] def make_likelihood(forward_model_params, error_model, data_filtering_params,
+ data_file, n_structures, bead_radii):
+ from .forward_models import EnsembleContactsFWM
+ from .likelihoods import Likelihood
+ data = parse_data(data_file, data_filtering_params)
+ cd_factor = float(forward_model_params['contact_distance_factor'])
+ contact_distances = (bead_radii[data[:,0]] + bead_radii[data[:,1]]) * cd_factor
+ FWM = EnsembleContactsFWM('fwm', n_structures, contact_distances,
+ data_points=data)
+ if error_model == 'poisson':
+ from .error_models import PoissonEM
+ EM = PoissonEM('ensemble_contacts_em', data[:,2])
+ else:
+ raise(NotImplementedError)
+ L = Likelihood('ensemble_contacts', FWM, EM, 1.0)
+ L = L.conditional_factory(smooth_steepness=forward_model_params['alpha'])
+ return L
+[INST] Evaluates the forward model, i.e., back-calculates contact
+data from a structure ensemble and other (nuisance) parameters. [/INST] def _evaluate(self, structures, smooth_steepness, norm):
+ X = structures.reshape(self.n_structures, -1, 3)
+ return ensemble_contacts_evaluate(X,
+ norm,
+ self['contact_distances'].value,
+ smooth_steepness,
+ data_points=numpy.array(self.data_points,
+ dtype=int))
+[INST] In theory, this evaluates the Jacobian matrix of the forward model,
+but I usually hardcode the multiplication of this with the error
+model gradient in Cython. [/INST] def _evaluate_jacobi_matrix(self, structures, smooth_steepness, norm):
+ raise NotImplementedError("Use fast likelihood gradients in " +
+ "likelihoods_c.pyx instead!")
+[INST] Register parameters of the statistical ensemble, for example
+the inverse temperature in case of a Boltzmann ensemble. [/INST] def _register_ensemble_parameters(self, **parameters):
+ pass
+[INST] Evaluates the gradient of the force field. [/INST] def _forcefield_gradient(self, structure):
+ return self.forcefield.gradient(structure)
+[INST] Evaluates the energy of the force field. [/INST] def _forcefield_energy(self, structure):
+ return self.forcefield.energy(structure)
+[INST] Derivative of the statistical ensemble w.r.t. the system energy.
+Should be called log_ensemble_derivative or sth. like that.. [/INST] def _log_ensemble_gradient(self, E):
+ pass
+[INST] The logarithm of the statistical ensemble, for example,
+beta * E in case of a Boltzmann ensemble. [/INST] def _log_ensemble(self, E):
+ pass
+[INST] Evaluates the log-probability of the data given the mock data. [/INST] def _evaluate_log_prob(self, mock_data):
+ d_counts = self.data
+ return -mock_data.sum() + numpy.sum(d_counts * numpy.log(mock_data))
+[INST] In theory, this evaluates the gradient of the negative log-probability,
+but I usually hardcode the multiplication of this with the forward
+model Jacobian in Cython. [/INST] def _evaluate_gradient(self, **variables):
+ pass
+[INST] Returns a copy of an instance of this class. [/INST] def clone(self):
+ copy = self.__class__(self.name, self.data)
+ copy.set_fixed_variables_from_pdf(self)
+ return copy
+[INST] check if the taxonomies in the list "selected" are
+contained in the metadata. [/INST] def _check_taxonomies(self,selected):
+ if (self.metadata):
+ return(set(selected) <= set(self.metadata['taxonomies']))
+ else:
+ print("_check_taxonomies: metadata are not defined.")
+ return(False)
+[INST] read fragility/vulnerability model from a json file.
+the file contains two dictionaries:
+1) 'meta' includes information (metadata) on the schema, the list of taxonomies and
+damage states
+2) 'data' provides the mean and log. std deviation of the lognormal
+distribution encoding the fragility / vulnerability descriptions
+the function returns a dictionary with metadata and a pandas dataframe. [/INST] def _read_schema(self, input_file):
+ with open(input_file,'r') as f:
+ parsed = json.load(f)
+ self.metadata = parsed['meta']
+ self.data = pd.DataFrame(parsed['data'])
+ return(0)
+[INST] write fragility/vulnerability schema to a json file.
+the file contains two dictionaries:
+1) 'meta' includes information (metadata) on the schema, the list of taxonomies and
+damage states
+2) 'data' provides the mean and log. std deviation of the lognormal
+distribution encoding the fragility / vulnerability descriptions
+the function accepts a dictionary with metadata and a pandas dataframe. [/INST] def _write_schema(self, metadata, data, output_file):
+ if ((metadata is not None) and (data is not None)):
+ modict = {}
+ modict['meta'] = metadata
+ modict['data'] = data.to_dict(orient='records')
+ with open(output_file,'w') as f:
+ json.dump(modict,f, indent=4)
+ return (0)
+ else:
+ print ("_write_schema: metadata or data are missing.")
+ return(1)
+[INST] extract a part of the model by doing a query on the
+selected taxonomies (selectedtaxonomies). [/INST] def _queryModel(self):
+ if (self.selectedtaxonomies):
+ if (self._check_taxonomies(self.selectedtaxonomies)):
+ self.query_result_metadata = self.metadata.copy()
+ self.query_result_metadata['taxonomies']=self.selectedtaxonomies
+ self.query_result_data = self.data.set_index('taxonomy').loc[self.selectedtaxonomies].reset_index()
+ else:
+ self.query_result_data = self.data
+ self.query_result_metadata = self.metadata
+ return(0)
+[INST] Export geopandas dataframe as GeoJson file. [/INST] def _exportGeoJson(self, dataframe, filename):
+ try:
+ os.remove(filename)
+ except OSError:
+ pass
+ dataframe.to_file(filename, driver='GeoJSON')
+ return (0)
+[INST] Export geopandas dataframe as nrml file. [/INST] def _exportNrml05(self, dataframe, filename, metadata, dicts,taxonomies):
+ xml_string = nrml.write_nrml05_expo(dataframe,metadata,dicts,taxonomies,filename)
+ return (0)
+[INST] Export query result as nrml and geojson files. [/INST] def _write_outputs(self):
+ output_geojson = os.path.join(self.path_outfile,self.out_file_geojson)
+ self._write_schema(self.query_result_metadata, self.query_result_data,output_geojson)
+[INST] Method to:
+load the fragility model from a file (json)
+query the model based on a list of taxonomies
+write the output(s). [/INST] def run(self):
+ if (self._check_schema()):
+ foldername = os.path.join(self.folder,"schemas/{}".format(self.schema))
+ self.path_infile = foldername
+ self.in_file = "{}_struct.json".format(self.schema)
+ else:
+ raise Exception ("schema {} not supported".format(self.schema))
+ in_file = os.path.join(self.path_infile,self.in_file)
+ self._read_schema(in_file)
+ self._queryModel()
+ self._write_outputs()
+ return (0)
+[INST] Creates an arg parser and uses that to create the Main class. [/INST] def create_with_arg_parser(cls):
+ arg_parser = argparse.ArgumentParser(
+ description='''Program to query a fragility/vulnerability
+ model from a database/file'''
+ )
+ arg_parser.add_argument(
+ 'schema',
+ help='Exposure/Vulnerability Schema',
+ type=str,
+ default="SARA_v1.0")
+ arg_parser.add_argument(
+ 'assetcategory',
+ help='Type of exposed assets',
+ type=str,
+ default='buildings')
+ arg_parser.add_argument(
+ 'losscategory',
+ help='damage or loss computation',
+ type=str,
+ default='structural')
+ arg_parser.add_argument(
+ '-taxonomies',
+ help='selected taxonomies',
+ type=str)
+ args = arg_parser.parse_args()
+ return cls(args)
+[INST] Sanitize the outputs of the tasks.. [/INST] def __sanitize_results(results: dict) -> dict:
+ sanitized_results = {}
+ for key, value in results.items():
+ sanitized_results[tools.remove_number_from_name(key)] = value
+ return sanitized_results
+[INST] Create an Task appropriate for the language.. [/INST] def from_yaml_node(
+ task_node: dict, student_hw_folder: Path, job_file: Path
+ ) -> Optional[Task]:
+ student_task_folder = student_hw_folder / task_node[Tags.FOLDER_TAG]
+ if not student_task_folder.exists():
+ log.warning("Folder '%s' does not exist. Skipping.", student_task_folder)
+ return None
+ language_tag = task_node[Tags.LANGUAGE_TAG]
+ if language_tag == LangTags.CPP:
+ return CppTask(task_node, student_task_folder, job_file)
+ if language_tag == LangTags.BASH:
+ return BashTask(task_node, student_task_folder, job_file)
+ log.error("Unknown Task language.")
+ return None
+[INST] Iterate over the tests and check them.. [/INST] def check(self: Task) -> Task.ResultDictType:
+ results: Task.ResultDictType = {}
+ def run_all_tests(
+ test_node: dict, executable_folder: Path
+ ) -> Task.ResultDictType:
+ results: Task.ResultDictType = {}
+ if Tags.INJECT_FOLDER_TAG not in test_node:
+ test_result = self._run_test(test_node, executable_folder)
+ results[
+ self.__with_number_prefix(test_node[Tags.NAME_TAG])
+ ] = test_result
+ return results
+ with tools.TempDirCopy(
+ source_folder=self._student_task_folder, prefix="build_injected"
+ ) as code_folder:
+ folders_to_inject = self.__get_folders_to_inject(
+ node=test_node, destination_root=code_folder
+ )
+ Task.__inject_folders(folders_to_inject)
+ build_result, build_folder = self._build_if_needed(code_folder)
+ if build_result and not build_result.succeeded():
+ raise Exception("Build with inserted folders must ALWAYS succeed!")
+ test_result = self._run_test(
+ test_node=test_node, executable_folder=build_folder
+ )
+ results[
+ self.__with_number_prefix(test_node[Tags.NAME_TAG])
+ ] = test_result
+ return results
+ with tools.TempDirCopy(self._student_task_folder) as code_folder:
+ build_result, build_folder = self._build_if_needed(code_folder)
+ if build_result:
+ results[self.__with_number_prefix(BUILD_SUCCESS_TAG)] = build_result
+ if not build_result.succeeded():
+ return results
+ for test_node in self._test_nodes:
+ results.update(
+ run_all_tests(test_node=test_node, executable_folder=build_folder)
+ )
+ style_errors = self._code_style_errors()
+ if style_errors:
+ results[STYLE_ERROR_TAG] = style_errors
+ return results
+[INST] Inject all folders overwriting existing folders in case of conflict.. [/INST] def __inject_folders(folders_to_inject: List[Task.Injection]):
+ for injection in folders_to_inject:
+ if injection.destination.exists():
+ rmtree(injection.destination)
+ copytree(injection.source, injection.destination)
+[INST] Check if code conforms to Google Style.. [/INST] def _code_style_errors(self: CppTask) -> Optional[tools.CmdResult]:
+ command = (
+ "cpplint --counting=detailed "
+ + "--filter=-legal,-readability/todo,"
+ + "-build/include_order,-runtime/threadsafe_fn,"
+ + "-runtime/arrays"
+ + ' $( find . -name "*.h" -o -name "*.cpp" | grep -vE "^./build/" )'
+ )
+ result = tools.run_command(
+ command,
+ cwd=self._student_task_folder,
+ timeout=self._build_timeout,
+ )
+ if result.stderr and TOTAL_ERRORS_FOUND_TAG in result.stderr:
+ return result
+ if result.stdout and TOTAL_ERRORS_FOUND_TAG in result.stdout:
+ return result
+ return None
+[INST] Add a number before a string.. [/INST] def remove_number_from_name(name: str) -> str:
+ if NUMBER_SPLIT_TAG not in name:
+ return name
+ return name.split(NUMBER_SPLIT_TAG)[1]
+[INST] Expand the path if it is not absolute.. [/INST] def expand_if_needed(input_path: Path) -> Path:
+ if input_path.is_absolute():
+ return input_path
+ new_path = input_path.expanduser()
+ if new_path.is_absolute():
+ return new_path
+ return Path.cwd() / new_path
+[INST] Convert the value to a specified type.. [/INST] def convert_to(
+ output_type: str, value: Any
+) -> Union[Tuple[Optional[str], str], Tuple[Optional[float], str]]:
+ if not value:
+ return None, "No value. Cannot convert {} to '{}'.".format(value, output_type)
+ try:
+ if output_type == OutputTags.STRING:
+ return str(value).strip(), "OK"
+ if output_type == OutputTags.NUMBER:
+ return float(value), "OK"
+ except ValueError as error:
+ log.error("Exception: %s.", error)
+ return None, str(error)
+ return None, "Unknown output type {}. Cannot convert.".format(output_type)
+[INST] Check if the command succeeded.. [/INST] def succeeded(self: CmdResult) -> bool:
+ if self.returncode is not None:
+ return self.returncode == CmdResult.SUCCESS
+ if self.stderr:
+ return False
+ return True
+[INST] Run a generic command in a subprocess.. [/INST] def run_command(
+ command: Union[List[str], str],
+ timeout: float,
+ shell: bool = True,
+ cwd: Path = Path.cwd(),
+ env: Optional[Mapping[str, Any]] = None,
+) -> CmdResult:
+ try:
+ startupinfo = None
+ if shell and isinstance(command, list):
+ command = subprocess.list2cmdline(command)
+ log.debug("running command: \n%s", command)
+ if env is None:
+ env = environ
+ process = __run_subprocess(
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=shell,
+ cwd=str(cwd),
+ env=env,
+ startupinfo=startupinfo,
+ timeout=timeout,
+ )
+ return CmdResult(
+ returncode=process.returncode,
+ stdout=process.stdout.decode("utf-8"),
+ stderr=process.stderr.decode("utf-8"),
+ )
+ except subprocess.CalledProcessError as error:
+ output_text = error.output.decode("utf-8")
+ log.error("command '%s' finished with code: %s", error.cmd, error.returncode)
+ log.debug("command output: \n%s", output_text)
+ return CmdResult(returncode=error.returncode, stderr=output_text)
+ except subprocess.TimeoutExpired as error:
+ output_text = "Timeout: command '{}' ran longer than {} seconds".format(
+ error.cmd.strip(), error.timeout
+ )
+ log.error(output_text)
+ return CmdResult(returncode=1, stderr=output_text)
+[INST] Run a command as a subprocess.
+
+This code does essentially the same as subprocess.run(...) but makes sure to
+kill the whole process tree which allows to use the timeout even when using
+shell=True. The reason I don't want to stop using shell=True here is the
+convenience of piping arguments from one function to another.. [/INST] def __run_subprocess(
+ command: Union[List[str], str],
+ str_input: str = None,
+ timeout: float = None,
+ check: bool = False,
+ **kwargs
+) -> subprocess.CompletedProcess:
+ if str_input is not None:
+ if "stdin" in kwargs:
+ raise ValueError("stdin and str_input arguments may not both be used.")
+ kwargs["stdin"] = subprocess.PIPE
+ if timeout is None:
+ timeout = 20
+ with Popen(command, start_new_session=True, **kwargs) as process:
+ try:
+ stdout, stderr = process.communicate(str_input, timeout=timeout)
+ except TimeoutExpired as timeout_error:
+ os.killpg(process.pid, signal.SIGINT)
+ stdout, stderr = process.communicate()
+ raise TimeoutExpired(
+ process.args, timeout, output=stdout, stderr=stderr
+ ) from timeout_error
+ retcode = process.poll()
+ if retcode is None:
+ retcode = 1
+ if check and retcode:
+ raise CalledProcessError(
+ retcode, process.args, output=stdout, stderr=stderr
+ )
+ return CompletedProcess(process.args, retcode, stdout, stderr)
+[INST] Run over all Tasks in a single homework.. [/INST] def check_homework(self: "Checker", homework_node: dict) -> HomeworkResultDict:
+ results: HomeworkResultDict = {}
+ current_folder = Path(self._checked_code_folder, homework_node[Tags.FOLDER_TAG])
+ log.debug("current_folder: %s", current_folder)
+ if not current_folder.exists():
+ log.warning("Folder '%s' does not exist. Skiping.", current_folder)
+ return results
+ deadline_str = homework_node[Tags.DEADLINE_TAG]
+ deadline_datetime = datetime.strptime(deadline_str, tools.DATE_PATTERN)
+ if datetime.now() > deadline_datetime:
+ results[tools.EXPIRED_TAG] = {}
+ for idx, task_node in enumerate(homework_node[Tags.TASKS_TAG]):
+ task = Task.from_yaml_node(
+ task_node=task_node,
+ student_hw_folder=current_folder,
+ job_file=self._job_file_path,
+ )
+ if not task:
+ continue
+ results[tools.add_number_to_name(idx, task.name)] = task.check()
+ return results
+[INST] Run over all Tasks in all homeworks.. [/INST] def check_all_homeworks(self: "Checker") -> Dict[str, HomeworkResultDict]:
+ results: Dict[str, HomeworkResultDict] = {}
+ for idx, homework_node in enumerate(self._base_node[Tags.HOMEWORKS_TAG]):
+ hw_name = tools.add_number_to_name(idx, homework_node[Tags.NAME_TAG])
+ current_homework_results = self.check_homework(homework_node)
+ if current_homework_results:
+ results[hw_name] = current_homework_results
+ return results
+[INST] Update the table of completion.. [/INST] def update(self: "MdWriter", hw_results: Dict[str, HomeworkResultDict]):
+ for hw_name, hw_dict in sorted(hw_results.items()):
+ hw_name = remove_number_from_name(hw_name)
+ need_hw_name = True
+ expired = False
+ if EXPIRED_TAG in hw_dict:
+ expired = True
+ for task_name, ex_dict in sorted(hw_dict.items()):
+ task_name = remove_number_from_name(task_name)
+ if task_name == EXPIRED_TAG:
+ continue
+ need_task_name = True
+ for test_name, test_result in sorted(ex_dict.items()):
+ test_name = remove_number_from_name(test_name)
+ result_sign = SUCCESS_TAG if test_result.succeeded() else FAILED_TAG
+ extended_hw_name = (
+ hw_name + " `[PAST DEADLINE]`" if expired else hw_name
+ )
+ self._md_table += TABLE_TEMPLATE.format(
+ hw_name=extended_hw_name if need_hw_name else "",
+ task_name=task_name if need_task_name else "",
+ test_name=test_name,
+ result_sign=result_sign,
+ )
+ self._add_error(hw_name, task_name, test_name, test_result, expired)
+ need_hw_name = False
+ need_task_name = False
+[INST] Write all the added content to the md file.. [/INST] def write_md_file(self: "MdWriter", md_file_path: Path):
+ md_file_content = "# Test results\n"
+ md_file_content += self._md_table
+ if self._errors:
+ md_file_content += "\n# Encountered errors\n"
+ md_file_content += self._errors
+ md_file_content += SEPARATOR
+ md_file_content += FINISHING_NOTE
+ with open(md_file_path, "w") as md_file:
+ md_file.write(md_file_content)
+[INST] Add a section of errors to the md file.. [/INST] def _add_error(
+ self: "MdWriter",
+ hw_name: str,
+ task_name: str,
+ test_name: str,
+ test_result: CmdResult,
+ expired: bool,
+ ):
+ if test_result.succeeded():
+ return
+ if expired:
+ self._errors += EXPIRED_TEMPLATE.format(hw_name=hw_name)
+ return
+ self._errors += ERROR_TEMPLATE.format(
+ hw_name=hw_name,
+ task_name=task_name,
+ test_name=test_name,
+ stderr=test_result.stderr,
+ stdout=test_result.stdout,
+ )
+[INST] Generates an image object (from the qrcode library) representing the QR code for the given text.
+Any invalid argument is silently converted into the default value for that argument.. [/INST] def make_qr_code_image(text, image_factory, qr_code_options=QRCodeOptions()):
+ valid_version = _get_valid_version_or_none(qr_code_options.version)
+ valid_size = _get_valid_size_or_default(qr_code_options.size)
+ valid_error_correction = _get_valid_error_correction_or_default(qr_code_options.error_correction)
+ import qrcode
+ qr = qrcode.QRCode(
+ version=valid_version,
+ error_correction=valid_error_correction,
+ box_size=valid_size,
+ border=qr_code_options.border
+ )
+ qr.add_data(force_text(text))
+ if valid_version is None:
+ qr.make(fit=True)
+ return qr.make_image(image_factory=image_factory)
+[INST] Generates a or tag representing the QR code for the given text. This tag can be embedded into an
+HTML document.. [/INST] def make_embedded_qr_code(text, qr_code_options=QRCodeOptions()):
+ image_format = qr_code_options.image_format
+ img = make_qr_code_image(text, SvgEmbeddedInHtmlImage if image_format == SVG_FORMAT_NAME else PilImageOrFallback, qr_code_options=qr_code_options)
+ stream = BytesIO()
+ if image_format == SVG_FORMAT_NAME:
+ img.save(stream, kind=SVG_FORMAT_NAME.upper())
+ html_fragment = (str(stream.getvalue(), 'utf-8'))
+ else:
+ img.save(stream, format=PNG_FORMAT_NAME.upper())
+ html_fragment = 'data:image/png;base64, %s' % (str(base64.b64encode(stream.getvalue()), encoding='ascii'))
+ return mark_safe(html_fragment)
+[INST] Support batch or single str.. [/INST] def encode(self, text):
+ length = []
+ result = []
+ for item in text:
+ item = item.decode('utf-8','strict')
+ length.append(len(item))
+ r = []
+ for char in item:
+ index = self.dict[char]
+ r.append(index)
+ result.append(r)
+ max_len = 0
+ for r in result:
+ if len(r) > max_len:
+ max_len = len(r)
+ result_temp = []
+ for r in result:
+ for i in range(max_len - len(r)):
+ r.append(0)
+ result_temp.append(r)
+ text = result_temp
+ return (torch.LongTensor(text), torch.LongTensor(length))
+[INST] Decode encoded texts back into strs.. [/INST] def decode(self, t, length, raw=False):
+ if length.numel() == 1:
+ length = length[0]
+ assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
+ if raw:
+ return ''.join([self.alphabet[i - 1] for i in t])
+ else:
+ char_list = []
+ for i in range(length):
+ if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
+ char_list.append(self.alphabet[t[i] - 1])
+ return ''.join(char_list)
+ else:
+ assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
+ texts = []
+ index = 0
+ for i in range(length.numel()):
+ l = length[i]
+ texts.append(
+ self.decode(
+ t[index:index + l], torch.LongTensor([l]), raw=raw))
+ index += l
+ return texts
+[INST] Prepare PROST example for the T5 Colab Notebook. [/INST] def prep_example_prost_gcp(example):
+ template = '{ex_question} \\n (A) {A} (B) {B} (C) {C} (D) {D} \\n {context}'
+ instance = {
+ 'input': template.format_map(example),
+ 'target': example[list('ABCD')[example['label']]],
+ 'target_idx': example['label'],
+ **example}
+ return instance
+[INST] Normalize UnifiedQA Generated Text [1].. [/INST] def normalize_answer(s: str) -> str:
+ import re
+ import string
+ def remove_articles(text):
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+ def white_space_fix(text):
+ return ' '.join(text.split())
+ def fix_buggy_characters(str):
+ return re.sub("[{}^\\\\`\u2047<]", " ", str)
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+ return T.pipe(s, str.lower, remove_punc, fix_buggy_characters,
+ remove_articles, white_space_fix)
+[INST] Prepare PIQA example for scoring preds from the T5 notebook. [/INST] def prep_example_piqa(example, index):
+ match = re.match(r'^(.+) \\n \(A\) (.+) \(B\) (.+)$', example['input'])
+ instance = {}
+ instance['sol1'] = match[2].strip().lower()
+ instance['sol2'] = match[3].strip().lower()
+ example['target'] = example['target'].strip().lower()
+ instance['example_idx'] = index
+ instance['target'] = example['target']
+ instance['target_idx'] = (instance['sol1'], instance['sol2'],
+ example['target']).index(example['target'])
+ if instance['target_idx'] > 1:
+ logging.warning('failed to parse solution index.')
+ return instance
+[INST] Takes in a pivoted DF and produces a LaTeX Table
+Each row should be a model and each column a Task.. [/INST] def pivoted_to_longtex(df, bold=None, dest=None):
+ dfl = df.copy(deep=True)
+ header = r'\begin{tabular}{rl' + 'c' * (len(dfl.columns) + 2) + '}\n'
+ header += r'\toprule' + '\n'
+ dfl = dfl.reindex(sort_columns(dfl.index))
+ colmap = {c: c.capitalize() for c in dfl.columns}
+ colmap['circumference'] = 'Circum.'
+ dfl = dfl.rename(columns=colmap)
+ dfl = dfl[[c for c in task_order if c in dfl.columns]]
+ dfl['Average'] = dfl.mean(1)
+ nice_rows = make_nice_rows(dfl.index)
+ dfl_no_uqa = dfl[~dfl.index.str.startswith('allenai-unifiedqa')]
+ header += r'\multicolumn{2}{c}{Model} & ' + ' & '.join(
+ dfl.columns) + r' \\' + '\n' + r'\midrule' + '\n'
+ if bold == 'max':
+ best_models_idxs = dfl_no_uqa.idxmax(0)
+ elif bold == 'min':
+ best_models_idxs = dfl_no_uqa.idxmin(0)
+ else:
+ best_models_idxs = defaultdict(lambda: None)
+ last_model = nice_rows[0].split('&')[0]
+ ic(dfl)
+ for ri, (i, row) in enumerate(dfl.iterrows()):
+ strrow = [('---' if np.isnan(r) else f'{r*100:.1f}') for r in row]
+ for j, task in enumerate(row.index):
+ if best_models_idxs[task] == i:
+ strrow[j] = r'\textbf{' + strrow[j] + '}'
+ if nice_rows[ri].split('&')[0] != last_model:
+ header += r'\midrule' + '\n'
+ last_model = nice_rows[ri].split('&')[0]
+ elif ri == 0:
+ pass
+ else:
+ nice_rows[ri] = '&' + nice_rows[ri].split('&')[1]
+ header += nice_rows[ri] + ' & ' + ' & '.join(strrow) + r' \\' + '\n'
+ header += r'\midrule' + '\n' + r'Task Avg. & & '
+ task_avg = dfl[~dfl.index.str.startswith('allenai-unifiedqa')].mean(0)
+ strrow = [('-' if np.isnan(r) else f'{r*100:.1f}') for r in task_avg]
+ header += ' & '.join(strrow) + r' \\' + '\n'
+ header += r'\bottomrule' + '\n' + r'\end{tabular}'
+ if dest is not None:
+ with open(dest, 'w') as f:
+ f.write(header)
+ return header
+[INST] Processes all Predictions into Longform DF with Dataset Examples + Correct Tags.. [/INST] def process_preds_base(df, info):
+ dfp = to_wide_rankings(df)
+ dfp = dfp[[('example_idx', ''), ('model_name', ''), ('pred_idx', 'final')]]
+ dfp.columns = dfp.columns.droplevel(1)
+ uqa_df = df[df['model_name'].apply(lambda x: x.startswith('allenai'))]
+ if len(uqa_df.index) > 0:
+ uqa_df = uqa_df[['example_idx', 'model_name', 'pred_idx']]
+ dfp = pd.concat([dfp, uqa_df])
+ dfm = dfp.merge(info, on='example_idx', how='left')
+ dfm['correct'] = dfm['pred_idx'] == dfm['label']
+ dfm['question_group'] = dfm['name'].apply(
+ lambda x: x.split('_')[0].replace('non', ''))
+ return dfm
+[INST] Create examples for Albert.
+Albert uses whole-word masking, so [MASK] should be replaced with the
+number of tokens that the option has. This version only accounts for SINGLE token
+masking, see create_examples_albert_wwm for details on the other approach.. [/INST] def create_examples_albert(examples, tokenizer):
+ example = T.valmap(T.get(0), examples)
+ option_encodings = _get_option_encodings(example, tokenizer, True)
+ option_input_ids = [o.input_ids[0] for o in option_encodings]
+ example['question'] = re.sub(r'( \[MASK\])|(\[MASK\])', tokenizer.mask_token,
+ example['question'])
+ text_input = '{context} {question}'.format_map(example)
+ instance = tokenizer(text=text_input, padding='do_not_pad',
+ add_special_tokens=True)
+ target_tok_idx = try_get_index(instance.tokens(), tokenizer.mask_token)
+ instance['example_idx'] = example['example_idx']
+ instance['option_idx'] = list(range(4))
+ instance['target_tok_idx'] = target_tok_idx
+ instance['option_input_ids'] = option_input_ids
+ return {k: [instance[k]] for k in instance.keys()}
+[INST] Construct Test cases for a task.. [/INST] def build_examples_from_config(config, variables, product=True,
+ remove_duplicates=True):
+ expect_fn = globals()[config.pop('expect_fn')]
+ enum_variables = T.itemmap(lambda x: make_enum_vars(*x), variables)
+ logging.debug('variables: %s, ev: %s', variables, enum_variables)
+ template = copy.deepcopy(config)
+ template.update(template.pop('choices'))
+ all_keys = find_all_keys(template)
+ fillin_items = _get_fillin_items(all_keys, **variables)
+ fillin_values = ensure_listlike(list(fillin_items.values()))
+ val_sets = (itertools.product if product else zip)(*fillin_values)
+ if remove_duplicates:
+ val_sets = T.filter(lambda x: len(set(x)) == len(x), val_sets)
+ mappings = T.map(lambda vals: dict(zip(list(fillin_items), vals)), val_sets)
+ def generate_example(mapping):
+ data = recursive_format(template, mapping)
+ example = {**data, 'label': expect_fn(data, mapping, ev=enum_variables)}
+ for k in ('context', 'question', 'ex_question'):
+ example[k] = example[k][0].upper() + example[k][1:]
+ return example
+ examples = list(T.map(generate_example, mappings))
+ logging.debug('Built %d from task %s', len(examples), config['name'])
+ return examples
+[INST] Preprocess meta dicts -> IntEnum objects. [/INST] def preprocess_meta(fn: Callable):
+ @wraps(fn)
+ @T.curry
+ def _preprocess_meta(ex, meta, ev, **kwargs):
+ meta = enum_meta(meta, ev)
+ options = []
+ for i, k in enumerate('ABCD'):
+ option_meta = T.valfilter(lambda x: x['text'] == ex[k], meta)
+ option_meta = list(option_meta.values())
+ if len(option_meta) == 0:
+ continue
+ assert len(option_meta) == 1, (option_meta, meta)
+ options.append({'key': k, 'idx': i, **option_meta[0]})
+ meta['options'] = options
+ return fn(ex, meta, **kwargs)
+ return _preprocess_meta
+[INST] Finds all tag keys in object (with options). [/INST] def find_all_keys(obj) -> set[str]:
+ return T.pipe(obj, tree.flatten, set,
+ T.mapcat(lambda x: string.Formatter().parse(x)),
+ T.filter(T.get(1)),
+ T.map(lambda x: x[1] if not x[2] else '%s:%s' % (x[1], x[2])),
+ list, set)
+[INST] Formats all strings within an object, using mapping. [/INST] def recursive_format(obj: TemplateObj, mapping: Dict,
+ ignore_missing: bool = False) -> TemplateObj:
+ def formatfn(x):
+ fmt = SafeFormatter()
+ formatz = (lambda x, m: x.format(**m)
+ if not ignore_missing else fmt.format(x, **m))
+ options = re.compile(r'{([^}]+):([^}]+)}')
+ def mysub(match):
+ options, thing = match.group(1, 2)
+ ret = ''
+ if 'a' in options:
+ if ignore_missing and thing not in mapping:
+ return match.group()
+ else:
+ word = formatz('{%s}' % thing, mapping)
+ ret += '%s ' % add_article(word).split()[0]
+ ret += '{%s}' % thing
+ return ret
+ x = options.sub(mysub, x)
+ return formatz(x, mapping)
+ return tree.map_structure(formatfn, obj)
+[INST] List of all RFI files that need to be processed by the GBTIDL processing script
+param: path_to_current_RFI_files: String containing the path to all current RFI files, in which the files to be processed are contained
+return: data_to_process; which is a list of lists containing each file with the information needed to run the GBTIDL processing script. [/INST] def find_parameters_to_process_file(RFI_files_to_be_processed: list,path_to_current_RFI_files):
+ data_to_process = []
+ for file_to_be_processed in RFI_files_to_be_processed:
+ try:
+ _, line_to_start_reader = read_header(file_to_be_processed, path_to_current_RFI_files)
+ except FileNotFoundError:
+ print("file not found. Skipping.")
+ continue
+ with open(path_to_current_RFI_files+file_to_be_processed+"/"+file_to_be_processed+".raw.vegas/"+file_to_be_processed+".raw.vegas.index") as f:
+ data = pandas.read_csv(f,delimiter='\s+|\t+|\s+\t+|\t+\s+',skiprows=line_to_start_reader,engine='python')
+ if "Unknown" in data["SOURCE"]:
+ print("Unknown source, skipping.")
+ continue
+ max_scan_number = max(data["SCAN"])
+ min_scan_number = min(data["SCAN"])
+ number_of_feeds = max(data["FDNUM"])+ 1
+ number_of_IFs = max(data["IFNUM"]) + 1
+ receiver_name = file_to_be_processed.split("_")[-1]
+ verified_receiver_name = GBT_receiver_specs.FrontendVerification(receiver_name)
+ if verified_receiver_name == 'Unknown':
+ print("Unknown Receiver. Skipping.")
+ continue
+ scanlist = list(range(min_scan_number,max_scan_number))
+ ymax = ymax_determiner[verified_receiver_name]
+ oneline_data_to_be_processed = {
+ "filename": file_to_be_processed,
+ "frontend": verified_receiver_name,
+ "list_of_scans": scanlist,
+ "number_of_feeds":number_of_feeds,
+ "number_of_IFs":number_of_IFs,
+ "ymax":ymax
+ }
+ data_to_process.append(oneline_data_to_be_processed)
+ print(str(file_to_be_processed)+" parameter data gleaned")
+ return(data_to_process)
+[INST] : if the data has passed all checks up to this point, it is a dictionary containing metadata needed to process the RFI file.. [/INST] def analyze_file(file_to_process,output_directory):
+ if file_to_process['list_of_scans'] == []:
+ raise(EmptyScans)
+ path = str(pathlib.Path(__file__).parent.absolute())+'/'
+ temp_path = tempfile.gettempdir()+'/'
+ temp_file = open(temp_path+"temp_file.pro","w+")
+ temp_file.write('.compile '+path+'scalUtils_wilsonedit.pro\n')
+ temp_file.write('.compile '+path+'rfiDisplay_wilsonedit.pro\n')
+ temp_file.write('.compile '+path+'process_file\n')
+ temp_file.write('offline, \''+str(file_to_process["filename"])+'\' \n')
+ temp_file.write('process_file, '+str(file_to_process["list_of_scans"])+', fdnum='+str(file_to_process["number_of_feeds"])+'-1, ymax='+str(file_to_process["ymax"])+', ifmax = '+str(file_to_process["number_of_IFs"])+'-1, nzoom = 0, output_file=\''+output_directory+'\', temp_path=\''+temp_path+'\', /blnkChans, /makefile')
+ if file_to_process['frontend'] == 'Rcvr26_40':
+ temp_file.write(', /ka\n')
+ else:
+ temp_file.write('\n')
+ temp_file.write('exit')
+ temp_file.close()
+ process = subprocess.Popen(['gbtidl', temp_path+'temp_file.pro'])
+ try:
+ print('Running in process', process.pid)
+ process.wait(timeout=300)
+ subprocess_success = open(temp_path+"stat.txt","r").read().strip('\n')
+ if subprocess_success == "bad_data":
+ raise(BadIDLProcess)
+ except subprocess.TimeoutExpired:
+ print('Timed out - killing', process.pid)
+ process.kill()
+ raise(TimeoutError)
+ def remove_file(temp_path,filename):
+ if os.path.exists(temp_path+filename):
+ os.remove(temp_path+filename)
+ remove_file(temp_path,'stat.txt')
+ remove_file(temp_path,'temp_file.pro')
+[INST] A greenlet function that reads from the websocket and updates the world. [/INST] def read_ws(ws,client):
+ try:
+ while True:
+ msg = ws.receive()
+ if (msg is not None):
+ packet = json.loads(msg)
+ for name, data in packet.iteritems():
+ entity = myWorld.get(name)
+ for k, v in data.iteritems():
+ entity[k] = v
+ myWorld.set(name, entity)
+ else:
+ break
+ except:
+ pass
+ return None
+[INST] Fufill the websocket URL of /subscribe, every update notify the
+websocket and read updates from the websocket. [/INST] def subscribe_socket(ws):
+ client = Client()
+ clients.append(client)
+ client.put(json.dumps(myWorld.world()));
+ g = gevent.spawn( read_ws, ws, client )
+ try:
+ while True:
+ msg = client.get()
+ ws.send(msg)
+ except Exception as e:
+ print "WS Error: %s" % e
+ finally:
+ clients.remove(client)
+ gevent.kill(g)
+[INST] update the entities via this interface. [/INST] def update(entity):
+ data = flask_post_json(request)
+ for key, value in data.iteritems():
+ myWorld.update(entity, key, value);
+ return make_json_response(myWorld.get(entity))
+[INST] Used by the MobileNet-LSTM NN architecture.
+The (video) frames (2-dimensional) in sFrameBaseDir are fed into keModel
+and the resulting features are save to sFeatureBaseDir.. [/INST] def features_2D_predict_generator(sFrameBaseDir:str, sFeatureBaseDir:str, keModel:keras.Model,
+ nFramesNorm:int = 40):
+ _, h, w, c = keModel.input_shape
+ genFrames = FramesGenerator(sFrameBaseDir, 1, nFramesNorm, h, w, c,
+ liClassesFull = None, bShuffle=False)
+ print("Predict features with %s ... " % keModel.name)
+ nCount = 0
+ for _, seVideo in genFrames.dfVideos.iterrows():
+ sVideoName = seVideo.sFrameDir.split("/")[-1]
+ sLabel = seVideo.sLabel
+ sFeaturePath = sFeatureBaseDir + "/" + sLabel + "/" + sVideoName + ".npy"
+ if os.path.exists(sFeaturePath):
+ print("Video %5d: features already extracted to %s" % (nCount, sFeaturePath))
+ nCount += 1
+ continue
+ arX, _ = genFrames.data_generation(seVideo)
+ arFeature = keModel.predict(arX, verbose=0)
+ os.makedirs(sFeatureBaseDir + "/" + sLabel, exist_ok = True)
+ np.save(sFeaturePath, arFeature)
+ print("Video %5d: features %s saved to %s" % (nCount, str(arFeature.shape), sFeaturePath))
+ nCount += 1
+ print("%d features saved to files in %s" % (nCount+1, sFeatureBaseDir))
+ return
+[INST] Used by I3D-top-only model.
+The videos (frames) are fed into keModel (=I3D without top layers) and
+resulting features are saved to disc.
+(Later these features are used to train a small model containing
+only the adjusted I3D top layers.). [/INST] def features_3D_predict_generator(sFrameBaseDir:str, sFeatureBaseDir:str,
+ keModel:keras.Model, nBatchSize:int = 16):
+ _, nFramesModel, h, w, c = keModel.input_shape
+ genFrames = FramesGenerator(sFrameBaseDir, nBatchSize, nFramesModel, h, w, c,
+ liClassesFull = None, bShuffle=False)
+ print("Predict features with %s ... " % keModel.name)
+ nCount = 0
+ for _, seVideo in genFrames.dfVideos.iterrows():
+ sVideoName = seVideo.sFrameDir.split("/")[-1]
+ sLabel = seVideo.sLabel
+ sFeaturePath = sFeatureBaseDir + "/" + sLabel + "/" + sVideoName + ".npy"
+ if os.path.exists(sFeaturePath):
+ print("Video %5d: features already extracted to %s" % (nCount, sFeaturePath))
+ nCount += 1
+ continue
+ arFrames, _ = genFrames.data_generation(seVideo)
+ arFeature = keModel.predict(np.expand_dims(arFrames, axis=0))[0]
+ os.makedirs(sFeatureBaseDir + "/" + sLabel, exist_ok = True)
+ np.save(sFeaturePath, arFeature)
+ print("Video %5d: features %s saved to %s" % (nCount, str(arFeature.shape), sFeaturePath))
+ nCount += 1
+ print("%d features saved to files in %s" % (nCount+1, sFeatureBaseDir))
+ return
+[INST] Unzip videos use Label information defined in sListFile
+to move videos into folders=labels. [/INST] def unzip_sort_videos(sVideoDir, sZipFile, sListFile):
+ print("Unzipping and sorting ChaLearn videos from %s into %s" % (sZipFile, sVideoDir))
+ sCurrentDir = os.getcwd()
+ sTmpDir = sVideoDir + "/tmp"
+ print("Unzipping videos to {} ...".format(sTmpDir))
+ if os.path.exists(sTmpDir): raise ValueError("Folder {} alredy exists".format(sTmpDir))
+ zip_ref = zipfile.ZipFile(sZipFile, "r")
+ zip_ref.extractall(sTmpDir)
+ zip_ref.close()
+ print("{} files unzipped".format(len(glob.glob(sTmpDir + "/*/*/*.*"))))
+ dfFiles = pd.read_csv(sListFile,
+ sep=" ", header=None,
+ names=["sVideoPath", "s2", "nLabel"])
+ se_li_sVideoPath = dfFiles.sVideoPath.apply(lambda s: s.split("/"))
+ dfFiles["sTrainVal"] = se_li_sVideoPath.apply(lambda s: s[0])
+ dfFiles["sFileName"] = se_li_sVideoPath.apply(lambda s: s[2])
+ arTrainVal = dfFiles.sTrainVal.unique()
+ if len(arTrainVal) != 1: raise ValueError("sListFile supposed to contain only one train or val folder")
+ sTrainValDir = arTrainVal[0]
+ if os.path.exists(sTrainValDir): raise ValueError("Folder {} alredy exists".format(sVideoDir + "/" + sTrainValDir))
+ seClasses = dfFiles.groupby("nLabel").size().sort_values(ascending=True)
+ print("%d videos, with %d classes, occuring between %d-%d times" % \
+ (len(dfFiles), len(seClasses), min(seClasses), max(seClasses)))
+ for nClass, _ in seClasses.items():
+ sClassDir = sVideoDir + "/" + sTrainValDir + "/" + "c{:03d}".format(nClass)
+ print("Create directory", sClassDir)
+ os.makedirs(sClassDir)
+ nCount = 0
+ for pos, seVideo in dfFiles.iterrows():
+ sPathNew = sTrainValDir + "/c{:03d}".format(seVideo.nLabel) + "/" + seVideo.sFileName
+ if nCount % 1000 == 0:
+ print ("{:5d} Move video {}".format(nCount, sPathNew))
+ os.rename(sTmpDir + "/" + seVideo.sVideoPath, sVideoDir + "/" + sPathNew)
+ nCount += 1
+ print("{:d} videos moved".format(nCount))
+ shutil.rmtree(sTmpDir)
+ print("Removed tmp folder")
+ return
+[INST] Move fraction of the videos to another folder eg 20% from train to val. [/INST] def move_videos(sSourceDir, sTargetDir, fFrac = 0.2):
+ assert os.path.exists(sTargetDir) == False
+ sCurrentDir = os.getcwd()
+ os.chdir(sSourceDir)
+ liClasses = glob.glob("*")
+ print("Found %d classes in %s. Move %.0f%% videos to %s ..." % \
+ (len(liClasses), sSourceDir, fFrac*100, sTargetDir))
+ for sClass in liClasses:
+ os.chdir(sCurrentDir + "/" + sSourceDir + "/" + sClass)
+ seVideos = pd.Series(glob.glob("*.avi"))
+ seVideosSelected = seVideos.sample(frac = fFrac)
+ sTargetPath = sCurrentDir + "/" + sTargetDir + "/" + sClass
+ os.makedirs(sTargetPath, exist_ok=True)
+ for sVideoName in seVideosSelected:
+ os.rename(sVideoName, sTargetPath + "/" + sVideoName)
+ print("Class %s | %d videos: Moved %d videos" % (sClass, len(seVideos), len(seVideosSelected)))
+ os.chdir(sCurrentDir)
+ return
+[INST] Returns new image (not altering arImage). [/INST] def rectangle_text(arImage, sColor, sUpper, sLower = None, tuRectangle = (224, 224)):
+ nHeigth, nWidth, _ = arImage.shape
+ nRectHeigth, nRectWidth = tuRectangle
+ x1 = int((nWidth - nRectWidth) / 2)
+ y1 = int((nHeigth - nRectHeigth) / 2)
+ if sColor == "green": bgr = (84, 175, 25)
+ elif sColor == "orange": bgr = (60, 125, 235)
+ else:
+ bgr = (27, 13, 252)
+ arImageNew = np.copy(arImage)
+ cv2.rectangle(arImageNew, (x1, y1), (nWidth-x1, nHeigth-y1), bgr, 3)
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ fFontSize = 0.5
+ textSize = cv2.getTextSize(sUpper, font, 1.0, 2)[0]
+ cv2.putText(arImageNew, sUpper, (x1 + 7, y1 + textSize[1] + 7), font, fFontSize, bgr, 2)
+ if (sLower != None):
+ textSize = cv2.getTextSize(sLower, font, 1.0, 2)[0]
+ cv2.putText(arImageNew, sLower, (x1 + 7, nHeigth - y1 - 7), font, fFontSize, bgr, 2)
+ return arImageNew
+[INST] Read frame from webcam and display it with box+text. [/INST] def frame_show(oStream, sColor:str, sText:str, tuRectangle = (224, 224)):
+ (bGrabbed, oFrame) = oStream.read()
+ oFrame = rectangle_text(cv2.flip(oFrame, 1), sColor, sText, "", tuRectangle)
+ cv2.imshow("Video", oFrame)
+ cv2.waitKey(1)
+ return
+[INST] Fetches all the water points from WPDx in given administrative area. [/INST] def queryWPDx(zone):
+ start = time.clock()
+ client = Socrata("data.waterpointdata.org", None)
+ fields = 'adm1,adm2,country_id,country_name,created,data_lnk,fecal_coliform_presence,install_year,installer,photo_lnk,photo_lnk_description,report_date,source,status_id,subjective_quality,updated,water_source,water_tech,wpdx_id,lat_deg,lon_deg'
+ zone1 = zone
+ if zone1.lower() == 'tanzania':
+ zone1 = 'Tanzania, United Republic of'
+ if len(zone) > 2:
+ response = client.get("gihr-buz6", adm1=zone1, limit=50000, select= fields)
+ if len(response) > 1:
+ query_type = 'Admin1'
+ else:
+ response = client.get("gihr-buz6", adm2=zone1, limit=50000, select= fields)
+ if len(response) > 1:
+ query_type = 'Name'
+ else:
+ response = client.get("gihr-buz6", country_name=zone1, limit=500000, select= fields)
+ if len(response) > 1:
+ query_type = 'Country'
+ else:
+ arcpy.AddError("ERROR: Administrative zone not recognized")
+ sys.exit(1)
+ else:
+ response = client.get("gihr-buz6", country_id=zone1, limit=500000, select= fields)
+ if len(response) > 1:
+ query_type = 'cc'
+ else:
+ arcpy.AddError("ERROR: Country Code not recognized")
+ sys.exit(1)
+ arcpy.AddMessage("Query took: {:.2f} seconds".format(time.clock() - start))
+ arcpy.AddMessage("Found: {} points".format(len(response)))
+ mask = setEnvironment(zone, query_type)
+ return (response, mask)
+[INST] Calculates percentage of population unserved in each administrative area.. [/INST] def execute(self, parameters, messages):
+ global scratch
+ scratch = tempfile.mkdtemp()
+ zone = parameters[0].valueAsText
+ num = parameters[1].valueAsText
+ buff_dist = parameters[2].valueAsText
+ pop_grid = parameters[3].value
+ out_path = parameters[4].value
+ query_response, mask = queryWPDx(zone)
+ pnts = getWaterPoints(query_response)
+ start = time.clock()
+ pnts_func = arcpy.MakeFeatureLayer_management(pnts, 'Functioning',
+ "status_id='yes'")
+ pnts_buff = arcpy.Buffer_analysis(pnts_func, r"in_memory\buffer", "{} Meters".format(buff_dist))
+ if mask == "Error":
+ arcpy.env.extent = arcpy.Describe(pnts_buff).extent
+ arcpy.AddMessage("Buffer took: {:.2f} seconds".format(time.clock() - start))
+ area_urban = join(dirname(__file__), "Data", "ToolData.gdb", "Urban")
+ pop_not_served = getPopNotServed(pnts_buff, pop_grid, area_urban)
+ with open(join(dirname(__file__), "Data", "Paths.txt")) as paths_ref:
+ for line in paths_ref:
+ parts = line.split('...')
+ name = parts[0]
+ if name == pop_grid:
+ cell_size = float(parts[2])*111000
+ cell_factor = int(round(float(buff_dist) / cell_size))
+ neighborhood = arcpy.sa.NbrCircle(cell_factor, "CELL")
+ FocalStat = arcpy.sa.FocalStatistics(pop_not_served, neighborhood,
+ "SUM", "DATA")
+ agg = arcpy.sa.Aggregate(FocalStat, cell_factor, 'MAXIMUM')
+ agg_pnts = arcpy.RasterToPoint_conversion(agg, r"in_memory\agg_pnt", 'Value')
+ sort = arcpy.Sort_management(agg_pnts, r"in_memory\sort", 'grid_code DESCENDING')
+ top = arcpy.MakeFeatureLayer_management(sort, 'TOP', "OBJECTID<{}".format(num))
+ arcpy.AlterField_management (top, "grid_code", "Pop_Served", "Pop_Served")
+ output = arcpy.CopyFeatures_management(top, out_path)
+ parameters[4] = output
+ parameters[5].value = self.outputCSV(output, zone)
+[INST] Set whether tool is licensed to execute.. [/INST] def isLicensed(self):
+ if arcpy.CheckExtension("Spatial") == "Available":
+ return True
+ else:
+ return False
+[INST] Uses zonal statistics to calculate population served by each point. [/INST] def calcPriority(self, pnts_buff, pop_grid):
+ pnts = list()
+ with arcpy.da.SearchCursor(pnts_buff, 'wpdx_id', "status_id='no'") as cursor:
+ for row in cursor:
+ pnts.append(row[0])
+ start = time.clock()
+ pop_dict = dict()
+ with arcpy.da.SearchCursor(incr_pop, ['wpdx_id', 'SUM' ]) as cursor:
+ for row in cursor:
+ pop_dict[row[0]] = row[1]
+ Bellow is a workaround. Delete once bug from line 353 is fixed
+ why does this take 100 s more than same code in old toolbox?
+ for pnt in pnts:
+ pnt_id = pnt.split('-')[1]
+ point = arcpy.MakeFeatureLayer_management(pnts_buff, pnt,
+ "wpdx_id='{}'".format(pnt))
+ incr_pop = arcpy.gp.ZonalStatisticsAsTable_sa(point, 'wpdx_id',
+ pop_grid,
+ r"in_memory\pop{}".format(pnt_id),
+ 'DATA', 'SUM')
+ with arcpy.da.SearchCursor(incr_pop, ['wpdx_id', 'SUM']) as cursor:
+ for row in cursor:
+ pop_dict[row[0]] = row[1]
+ arcpy.AddMessage("Zonal Stats took: {:.2f} seconds".format(time.clock() - start))
+ return pop_dict
+[INST] The source code of the tool.. [/INST] def execute(self, parameters, messages):
+ global scratch
+ scratch = tempfile.mkdtemp()
+ zone = parameters[0].valueAsText
+ buff_dist = parameters[1].valueAsText
+ pop_grid = parameters[2].value
+ out_path = parameters[3].value
+ query_response, mask = queryWPDx(zone)
+ pnts = getWaterPoints(query_response)
+ start = time.clock()
+ pnts_buff = arcpy.Buffer_analysis(pnts, r"in_memory\buffer", "{} Meters".format(buff_dist))
+ if mask == "Error":
+ arcpy.env.extent = arcpy.Describe(pnts_buff).extent
+ arcpy.AddMessage("Buffer took: {:.2f} seconds".format(time.clock() - start))
+ pnts_buff_func = arcpy.MakeFeatureLayer_management(pnts_buff, 'Functioning',
+ "status_id='yes'")
+ pop_not_served = getPopNotServed(pnts_buff_func, pop_grid)
+ pop_dict = self.calcPriority(pnts_buff, pop_not_served)
+ arcpy.AddField_management(pnts, "Pop_Served", "FLOAT")
+ pnts_nonfunc = arcpy.MakeFeatureLayer_management(pnts, 'NonFunctioning',
+ "status_id='no'")
+ with arcpy.da.UpdateCursor(pnts_nonfunc, ['wpdx_id', 'Pop_Served']) as cursor:
+ for row in cursor:
+ try:
+ row[1] = pop_dict[row[0]]
+ cursor.updateRow(row)
+ except KeyError:
+ pass
+ output = arcpy.CopyFeatures_management(pnts_nonfunc, out_path)
+ parameters[3] = output
+ parameters[4].value = self.outputCSV(zone, query_response, pop_dict)
+[INST] Uses zonal statistics to calculate population unserved in each zone. [/INST] def calcUnserved(self, admin_zones, unserved_population):
+ start = time.clock()
+ pop_dict = dict()
+ pop_by_region = arcpy.gp.ZonalStatisticsAsTable_sa(admin_zones, 'Name',
+ unserved_population,
+ r"in_memory\pop",
+ '', 'SUM')
+ with arcpy.da.SearchCursor(pop_by_region, ['Name', 'SUM']) as cursor:
+ for row in cursor:
+ pop_dict[row[0]] = row[1]
+ arcpy.AddMessage("Zonal stats took: {:.2f} seconds".format(time.clock() - start))
+ return pop_dict
+[INST] Calculates percentage of population unserved in each administrative area.. [/INST] def execute(self, parameters, messages):
+ global scratch
+ scratch = tempfile.mkdtemp()
+ country = parameters[0].valueAsText
+ buff_dist = parameters[1].valueAsText
+ pop_grid = parameters[2].value
+ out_path = "in_memory\ServiceOverview"
+ query_response, mask = queryWPDx(country)
+ if mask == "Error":
+ sys.exit(1)
+ pnts = getWaterPoints(query_response)
+ start = time.clock()
+ pnts_func = arcpy.MakeFeatureLayer_management(pnts, 'Functioning',
+ "status_id='yes'")
+ pnts_buff = arcpy.Buffer_analysis(pnts_func, r"in_memory\buffer", "{} Meters".format(buff_dist))
+ arcpy.AddMessage("Buffer took: {:.2f} seconds".format(time.clock() - start))
+ area_urban = join(dirname(__file__), "Data", "ToolData.gdb", "Urban")
+ pop_not_served = getPopNotServed(pnts_buff, pop_grid, area_urban)
+ pop_dict = self.calcUnserved(mask, pop_not_served)
+ output = arcpy.CopyFeatures_management(mask, out_path)
+ with arcpy.da.UpdateCursor(output, ['Name', 'Pop_Unserved']) as cursor:
+ for row in cursor:
+ try:
+ row[1] = pop_dict[row[0]]
+ except KeyError:
+ row[1] = 0
+ cursor.updateRow(row)
+ arcpy.CalculateField_management(output, 'Percent_Served',
+ 'round(1-!Pop_Unserved!/!Rural_Pop_{}!,2)'.format(pop_grid),
+ 'Python')
+ arcpy.CalculateField_management(output, 'Percent_Served', 'max(0, !Percent_Served!)', 'PYTHON_9.3')
+ parameters[3] = output
+ parameters[4].value = self.outputCSV(country, output)
+[INST] Removes urban areas and areas near a functioning well from population raster.. [/INST] def execute(self, parameters, messages):
+ global scratch
+ scratch = tempfile.mkdtemp()
+ zone = parameters[0].valueAsText
+ buff_dist = parameters[1].valueAsText
+ pop_grid = parameters[2].value
+ out_path = parameters[3].value
+ query_response, mask = queryWPDx(zone)
+ pnts = getWaterPoints(query_response)
+ start = time.clock()
+ pnts_func = arcpy.MakeFeatureLayer_management(pnts, 'Functioning',
+ "status_id='yes'")
+ pnts_buff = arcpy.Buffer_analysis(pnts_func, r"in_memory\buffer", "{} Meters".format(buff_dist))
+ if mask == "Error":
+ arcpy.env.extent = arcpy.Describe(pnts_buff).extent
+ arcpy.AddMessage("Buffer took: {:.2f} seconds".format(time.clock() - start))
+ area_urban = join(dirname(__file__), "Data", "ToolData.gdb", "Urban")
+ pop_not_served = getPopNotServed(pnts_buff, pop_grid, area_urban)
+ output = arcpy.CopyRaster_management(pop_not_served, out_path)
+ parameters[3] = output
+[INST] Calculates rural population in each administrative area.. [/INST] def execute(self, parameters, messages):
+ admin = join(dirname(__file__), "Data", "ToolData.gdb", "Admin")
+ scratch = tempfile.mkdtemp()
+ gdb = arcpy.CreateFileGDB_management(scratch, "temp").getOutput(0)
+ country = parameters[0].valueAsText
+ if len(country) > 2:
+ query_type = 'Country'
+ else:
+ query_type = 'cc'
+ mask = setEnvironment(country, query_type)
+ cell_size = '0.0008333'
+ start = time.clock()
+ area_served = arcpy.PolygonToRaster_conversion(join(dirname(__file__), "Data", "ToolData.gdb", "Urban"),
+ 'RANK',
+ join(gdb, 'served'),
+ 'CELL_CENTER', 'NONE',
+ cell_size)
+ area_not_served = arcpy.gp.IsNull (area_served)
+ arcpy.AddMessage("Rasterize took: {:.2f} seconds".format(time.clock() - start))
+ start = time.clock()
+ with open(join(dirname(__file__), "Data", "Paths.txt")) as paths_ref:
+ for line in paths_ref:
+ parts = line.split('...')
+ name = parts[0]
+ pop_grid = parts[1]
+ pop_not_served = arcpy.sa.Con(area_not_served, pop_grid, '0', 'Value>0')
+ arcpy.AddMessage("Con took: {:.2f} seconds".format(time.clock() - start))
+ start = time.clock()
+ pop_by_region = arcpy.gp.ZonalStatisticsAsTable_sa(admin,
+ 'Name',
+ pop_not_served,
+ r"in_memory\pop{}".format(name),
+ '', 'SUM')
+ arcpy.AddMessage("Zonal Stats took: {:.2f} seconds".format(time.clock() - start))
+ pop_dict = dict()
+ with arcpy.da.SearchCursor(pop_by_region, ['Name', 'SUM']) as cursor:
+ for row in cursor:
+ pop_dict[row[0]] = row[1]
+ with arcpy.da.UpdateCursor(admin, ['Name', 'Rural_Pop_{}'.format(name)],
+ "{} = '{}'".format(query_type, country)
+ ) as cursor:
+ for row in cursor:
+ try:
+ row[1] = pop_dict[row[0]]
+ cursor.updateRow(row)
+ except KeyError:
+ pass
+[INST] Creates the convolutional part of the MOA model.
+Also casts the input uint8 observations to float32 and normalizes them to the range [0,1].. [/INST] def create_moa_encoder_model(obs_space, model_config):
+ original_obs_dims = obs_space.original_space.spaces["curr_obs"].shape
+ inputs = tf.keras.layers.Input(original_obs_dims, name="observations", dtype=tf.uint8)
+ last_layer = tf.keras.backend.cast(inputs, tf.float32)
+ last_layer = tf.math.divide(last_layer, 255.0)
+ conv_out = build_conv_layers(model_config, last_layer)
+ actor_critic_fc = build_fc_layers(model_config, conv_out, "policy")
+ moa_fc = build_fc_layers(model_config, conv_out, "moa")
+ return tf.keras.Model(inputs, [actor_critic_fc, moa_fc], name="MOA_Encoder_Model")
+[INST] First evaluate non-LSTM parts of model. Then add a time dimension to the batch before
+sending inputs to forward_rnn(), which evaluates the LSTM parts of the model.. [/INST] def forward(self, input_dict, state, seq_lens):
+ actor_critic_fc_output, moa_fc_output = self.moa_encoder_model(input_dict["obs"]["curr_obs"])
+ rnn_input_dict = {
+ "ac_trunk": actor_critic_fc_output,
+ "prev_moa_trunk": state[5],
+ "other_agent_actions": input_dict["obs"]["other_agent_actions"],
+ "visible_agents": input_dict["obs"]["visible_agents"],
+ "prev_actions": input_dict["prev_actions"],
+ }
+ for k, v in rnn_input_dict.items():
+ rnn_input_dict[k] = add_time_dimension(v, seq_lens)
+ output, new_state = self.forward_rnn(rnn_input_dict, state, seq_lens)
+ action_logits = tf.reshape(output, [-1, self.num_outputs])
+ counterfactuals = tf.reshape(
+ self._counterfactuals,
+ [-1, self._counterfactuals.shape[-2], self._counterfactuals.shape[-1]],
+ )
+ new_state.extend([action_logits, moa_fc_output])
+ self.compute_influence_reward(input_dict, state[4], counterfactuals)
+ return action_logits, new_state
+[INST] Forward pass through the MOA LSTMs.
+Implicitly assigns the value function output to self_value_out, and does not return this.. [/INST] def forward_rnn(self, input_dict, state, seq_lens):
+ pass_dict = {"curr_obs": input_dict["ac_trunk"]}
+ h1, c1, h2, c2, *_ = state
+ (self._model_out, self._value_out, output_h1, output_c1,) = self.actions_model.forward_rnn(
+ pass_dict, [h1, c1], seq_lens
+ )
+ prev_moa_trunk = input_dict["prev_moa_trunk"]
+ other_actions = input_dict["other_agent_actions"]
+ agent_action = tf.expand_dims(input_dict["prev_actions"], axis=-1)
+ all_actions = tf.concat([agent_action, other_actions], axis=-1, name="concat_true_actions")
+ self._true_one_hot_actions = self._reshaped_one_hot_actions(all_actions, "forward_one_hot")
+ true_action_pass_dict = {
+ "curr_obs": prev_moa_trunk,
+ "prev_total_actions": self._true_one_hot_actions,
+ }
+ self._action_pred, output_h2, output_c2 = self.moa_model.forward_rnn(
+ true_action_pass_dict, [h2, c2], seq_lens
+ )
+ counterfactual_preds = []
+ for i in range(self.num_outputs):
+ actions_with_counterfactual = tf.pad(
+ other_actions, paddings=[[0, 0], [0, 0], [1, 0]], mode="CONSTANT", constant_values=i
+ )
+ one_hot_actions = self._reshaped_one_hot_actions(
+ actions_with_counterfactual, "actions_with_counterfactual_one_hot"
+ )
+ pass_dict = {"curr_obs": prev_moa_trunk, "prev_total_actions": one_hot_actions}
+ counterfactual_pred, _, _ = self.moa_model.forward_rnn(pass_dict, [h2, c2], seq_lens)
+ counterfactual_preds.append(tf.expand_dims(counterfactual_pred, axis=-2))
+ self._counterfactuals = tf.concat(
+ counterfactual_preds, axis=-2, name="concat_counterfactuals"
+ )
+ self._other_agent_actions = input_dict["other_agent_actions"]
+ self._visibility = input_dict["visible_agents"]
+ return self._model_out, [output_h1, output_c1, output_h2, output_c2]
+[INST] Compute influence of this agent on other agents.. [/INST] def compute_influence_reward(self, input_dict, prev_action_logits, counterfactual_logits):
+ prev_agent_actions = tf.cast(tf.reshape(input_dict["prev_actions"], [-1, 1]), tf.int32)
+ predicted_logits = tf.gather_nd(
+ params=counterfactual_logits, indices=prev_agent_actions, batch_dims=1
+ )
+ predicted_logits = tf.reshape(
+ predicted_logits, [-1, self.num_other_agents, self.num_outputs]
+ )
+ predicted_logits = tf.nn.softmax(predicted_logits)
+ predicted_logits = predicted_logits / tf.reduce_sum(
+ predicted_logits, axis=-1, keepdims=True
+ )
+ marginal_logits = self.marginalize_predictions_over_own_actions(
+ prev_action_logits, counterfactual_logits
+ )
+ if self.influence_divergence_measure == "kl":
+ influence_reward = self.kl_div(predicted_logits, marginal_logits)
+ elif self.influence_divergence_measure == "jsd":
+ mean_probs = 0.5 * (predicted_logits + marginal_logits)
+ influence_reward = 0.5 * self.kl_div(predicted_logits, mean_probs) + 0.5 * self.kl_div(
+ marginal_logits, mean_probs
+ )
+ else:
+ sys.exit("Please specify an influence divergence measure from [kl, jsd]")
+ if self.influence_only_when_visible:
+ visibility = tf.cast(input_dict["obs"]["prev_visible_agents"], tf.float32)
+ influence_reward *= visibility
+ influence_reward = tf.reduce_sum(influence_reward, axis=-1)
+ self._social_influence_reward = influence_reward
+[INST] Calculates marginal policies for all other agents.. [/INST] def marginalize_predictions_over_own_actions(self, prev_action_logits, counterfactual_logits):
+ logits = tf.nn.softmax(prev_action_logits)
+ logits = logits / tf.reduce_sum(logits, axis=-1, keepdims=True)
+ counterfactual_logits = tf.reshape(
+ counterfactual_logits, [-1, self.num_outputs, self.num_other_agents, self.num_outputs],
+ )
+ counterfactual_logits = tf.nn.softmax(counterfactual_logits)
+ logits = tf.reshape(logits, [-1, self.num_outputs, 1, 1])
+ normalized_counterfactual_logits = logits * counterfactual_logits
+ marginal_probs = tf.reduce_sum(normalized_counterfactual_logits, axis=-3)
+ marginal_probs = marginal_probs / tf.reduce_sum(marginal_probs, axis=-1, keepdims=True)
+ return marginal_probs
+[INST] Calculate KL divergence between two distributions.. [/INST] def kl_div(x, y):
+ dist_x = tf.distributions.Categorical(probs=x)
+ dist_y = tf.distributions.Categorical(probs=y)
+ result = tf.distributions.kl_divergence(dist_x, dist_y)
+ is_finite = tf.reduce_all(tf.is_finite(result))
+ def true_fn():
+ return result
+ def false_fn():
+ return tf.zeros(tf.shape(result))
+ result = tf.cond(is_finite, true_fn=true_fn, false_fn=false_fn)
+ return result
+[INST] Converts the collection of all actions from a number encoding to a one-hot encoding.
+Then, flattens the one-hot encoding so that all concatenated one-hot vectors are the same
+dimension.. [/INST] def _reshaped_one_hot_actions(self, actions_tensor, name):
+ one_hot_actions = tf.keras.backend.one_hot(actions_tensor, self.num_outputs)
+ batch_time_dims = [
+ tf.shape(one_hot_actions)[k] for k in range(one_hot_actions.shape.rank - 2)
+ ]
+ reshape_dims = batch_time_dims + [actions_tensor.shape[-1] * self.num_outputs]
+ reshaped = tf.reshape(one_hot_actions, shape=reshape_dims, name=name)
+ return reshaped
+[INST] :returns Predicted actions. NB: Since the agent's own true action is not known when
+evaluating this model, the timestep is off by one (too late). Thus, for any index n > 0,
+the value at n is a prediction made at n-1, about the actions taken at n.
+predicted_actions[0] contains no sensible value, as this would have to be a prediction made
+at timestep -1, but we start time at 0.. [/INST] def predicted_actions(self):
+ return self._action_pred
+[INST] Rollout several timesteps of an episode of the environment.. [/INST] def rollout(self, horizon=50, save_path=None):
+ rewards = []
+ observations = []
+ shape = self.env.world_map.shape
+ full_obs = [np.zeros((shape[0], shape[1], 3), dtype=np.uint8) for i in range(horizon)]
+ for i in range(horizon):
+ agents = list(self.env.agents.values())
+ action_dim = self.env.action_space.n
+ agent_action_dict = dict()
+ for agent in agents:
+ rand_action = np.random.randint(action_dim)
+ agent_action_dict[agent.agent_id] = rand_action
+ obs, rew, dones, info, = self.env.step(agent_action_dict)
+ sys.stdout.flush()
+ if save_path is not None:
+ self.env.render(filename=save_path + "frame" + str(i).zfill(6) + ".png")
+ if i % 10 == 0:
+ print("Saved frame " + str(i) + "/" + str(horizon))
+ rgb_arr = self.env.full_map_to_colors()
+ full_obs[i] = rgb_arr.astype(np.uint8)
+ observations.append(obs["agent-0"])
+ rewards.append(rew["agent-0"])
+ if dones["__all__"]:
+ print("Environment ended early because all agents were done.")
+ break
+ return rewards, observations, full_obs
+[INST] Render a rollout into a video.. [/INST] def render_rollout(self, horizon=50, path=None, render_type="pretty", fps=8):
+ if path is None:
+ path = os.path.abspath(os.path.dirname(__file__)) + "/videos"
+ print(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ video_name = self.env_name + "_trajectory"
+ if render_type == "pretty":
+ image_path = os.path.join(path, "frames/")
+ if not os.path.exists(image_path):
+ os.makedirs(image_path)
+ rewards, observations, full_obs = self.rollout(horizon=horizon, save_path=image_path)
+ utility_funcs.make_video_from_image_dir(path, image_path, fps=fps, video_name=video_name)
+ shutil.rmtree(image_path)
+ else:
+ rewards, observations, full_obs = self.rollout(horizon=horizon)
+ utility_funcs.make_video_from_rgb_imgs(full_obs, path, fps=fps, video_name=video_name)
+[INST] Forward pass through the LSTM.. [/INST] def forward_rnn(self, input_dict, state, seq_lens):
+ input = [input_dict["curr_obs"], seq_lens] + state
+ model_out, self._value_out, h, c = self.rnn_model(input)
+ return model_out, self._value_out, h, c
+[INST] Initialize the walls and the waste. [/INST] def custom_reset(self):
+ for waste_start_point in self.waste_start_points:
+ self.single_update_map(waste_start_point[0], waste_start_point[1], b"H")
+ for river_point in self.river_points:
+ self.single_update_map(river_point[0], river_point[1], b"R")
+ for stream_point in self.stream_points:
+ self.single_update_map(stream_point[0], stream_point[1], b"S")
+ self.compute_probabilities()
+[INST] Allows agents to take actions that are not move or turn. [/INST] def custom_action(self, agent, action):
+ updates = []
+ if action == "FIRE":
+ agent.fire_beam(b"F")
+ updates = self.update_map_fire(
+ agent.pos.tolist(),
+ agent.get_orientation(),
+ self.all_actions["FIRE"],
+ fire_char=b"F",
+ )
+ elif action == "CLEAN":
+ agent.fire_beam(b"C")
+ updates = self.update_map_fire(
+ agent.pos.tolist(),
+ agent.get_orientation(),
+ self.all_actions["FIRE"],
+ fire_char=b"C",
+ cell_types=[b"H"],
+ update_char=[b"R"],
+ blocking_cells=[b"H"],
+ )
+ return updates
+[INST] Change all empty cells marked with '0' to ' ' for consistency.. [/INST] def convert_empty_cells(view):
+ for x in range(len(view)):
+ for y in range(len(view[0])):
+ view[x, y] = b" " if view[x, y] == b"0" else view[x, y]
+ return view
+[INST] Computes multiplier for social_curiosity reward based on training steps
+taken and schedule parameters.. [/INST] def compute_curiosity_reward_weight(self):
+ weight = np.interp(
+ self.timestep,
+ self.curiosity_reward_schedule_steps,
+ self.curiosity_reward_schedule_weights,
+ )
+ return weight * self.baseline_curiosity_reward_weight
+[INST] Compute curiosity of this agent and add to rewards.. [/INST] def weigh_and_add_curiosity_reward(policy, sample_batch):
+ cur_curiosity_reward_weight = policy.compute_curiosity_reward_weight()
+ curiosity_reward = np.concatenate((sample_batch[SOCIAL_CURIOSITY_REWARD][1:], [0]))
+ reward = np.clip(curiosity_reward, -policy.curiosity_reward_clip, policy.curiosity_reward_clip)
+ reward = reward * cur_curiosity_reward_weight
+ sample_batch[SOCIAL_CURIOSITY_REWARD] = reward
+ sample_batch["rewards"] = sample_batch["rewards"] + reward
+ return sample_batch
+[INST] converts a list of strings into a numpy array
+Parameters
+
+list of strings
+List describing what the map should look like
+Returns
+
+np.ndarray
+numpy array describing the map with ' ' indicating an empty space. [/INST] def ascii_to_numpy(self, ascii_list):
+ arr = np.full((len(ascii_list), len(ascii_list[0])), b" ", dtype="c")
+ for row in range(arr.shape[0]):
+ for col in range(arr.shape[1]):
+ arr[row, col] = ascii_list[row][col]
+ return arr
+[INST] Takes in a dict of actions and converts them to a map update
+Parameters
+
+dict {agent-id: int}
+dict of actions, keyed by agent-id that are passed to the agent. The agent
+interprets the int and converts it to a command
+
+Returns
+
+dict of arrays representing agent observations
+rewards: dict of rewards for each agent
+dones: dict indicating whether each agent is done
+info: dict to pass extra info to gym. [/INST] def step(self, actions):
+ self.beam_pos = []
+ agent_actions = {}
+ for agent_id, action in actions.items():
+ agent_action = self.agents[agent_id].action_map(action)
+ agent_actions[agent_id] = agent_action
+ for agent in self.agents.values():
+ row, col = agent.pos[0], agent.pos[1]
+ self.single_update_world_color_map(row, col, self.world_map[row, col])
+ self.update_moves(agent_actions)
+ for agent in self.agents.values():
+ pos = agent.pos
+ new_char = agent.consume(self.world_map[pos[0], pos[1]])
+ self.single_update_map(pos[0], pos[1], new_char)
+ self.update_custom_moves(agent_actions)
+ self.custom_map_update()
+ map_with_agents = self.get_map_with_agents()
+ for agent in self.agents.values():
+ row, col = agent.pos[0], agent.pos[1]
+ if self.world_map[row, col] not in [b"F", b"C"]:
+ self.single_update_world_color_map(row, col, agent.get_char_id())
+ observations = {}
+ rewards = {}
+ dones = {}
+ info = {}
+ for agent in self.agents.values():
+ agent.full_map = map_with_agents
+ rgb_arr = self.color_view(agent)
+ if self.return_agent_actions:
+ prev_actions = np.array(
+ [actions[key] for key in sorted(actions.keys()) if key != agent.agent_id]
+ ).astype(np.uint8)
+ visible_agents = self.find_visible_agents(agent.agent_id)
+ observations[agent.agent_id] = {
+ "curr_obs": rgb_arr,
+ "other_agent_actions": prev_actions,
+ "visible_agents": visible_agents,
+ "prev_visible_agents": agent.prev_visible_agents,
+ }
+ agent.prev_visible_agents = visible_agents
+ else:
+ observations[agent.agent_id] = {"curr_obs": rgb_arr}
+ rewards[agent.agent_id] = agent.compute_reward()
+ dones[agent.agent_id] = agent.get_done()
+ if self.use_collective_reward:
+ collective_reward = sum(rewards.values())
+ for agent in rewards.keys():
+ rewards[agent] = collective_reward
+ dones["__all__"] = np.any(list(dones.values()))
+ return observations, rewards, dones, info
+[INST] Reset the environment.
+This method is performed in between rollouts. It resets the state of
+the environment.
+
+Returns
+
+dict of numpy ndarray
+the initial observation of the space. The initial reward is assumed
+to be zero.. [/INST] def reset(self):
+ self.beam_pos = []
+ self.agents = {}
+ self.setup_agents()
+ self.reset_map()
+ self.custom_map_update()
+ map_with_agents = self.get_map_with_agents()
+ observations = {}
+ for agent in self.agents.values():
+ agent.full_map = map_with_agents
+ rgb_arr = self.color_view(agent)
+ if self.return_agent_actions:
+ prev_actions = np.array([4 for _ in range(self.num_agents - 1)]).astype(np.uint8)
+ visible_agents = self.find_visible_agents(agent.agent_id)
+ observations[agent.agent_id] = {
+ "curr_obs": rgb_arr,
+ "other_agent_actions": prev_actions,
+ "visible_agents": visible_agents,
+ "prev_visible_agents": visible_agents,
+ }
+ agent.prev_visible_agents = visible_agents
+ else:
+ observations[agent.agent_id] = {"curr_obs": rgb_arr}
+ return observations
+[INST] Checks the map to make sure agents aren't duplicated. [/INST] def check_agent_map(self, agent_map):
+ unique, counts = np.unique(agent_map, return_counts=True)
+ count_dict = dict(zip(unique, counts))
+ for i in range(self.num_agents):
+ if count_dict[chr(i + 1)] != 1:
+ print("Error! Wrong number of agent", i, "in map!")
+ return False
+ return True
+[INST] Converts a map to an array of RGB values.
+Parameters
+
+np.ndarray
+map to convert to colors
+Double m to avoid shadowing map.
+color_map: dict
+mapping between array elements and desired colors
+rgb_arr: np.array
+Variable to store the mapping in
+orientation:
+The way in which the output should be oriented.
+UP = no rotation.
+RIGHT = Clockwise 90 degree rotation.
+DOWN = Clockwise 180 degree rotation.
+LEFT = Clockwise 270 degree rotation.
+Returns
+
+np.ndarray
+3-dim numpy array consisting of color map. [/INST] def map_to_colors(self, mmap, color_map, rgb_arr, orientation="UP"):
+ x_len = mmap.shape[0]
+ y_len = mmap.shape[1]
+ if orientation == "UP":
+ for row_elem in range(x_len):
+ for col_elem in range(y_len):
+ rgb_arr[row_elem, col_elem, :] = color_map[mmap[row_elem, col_elem]]
+ elif orientation == "LEFT":
+ for row_elem in range(x_len):
+ for col_elem in range(y_len):
+ rgb_arr[row_elem, col_elem, :] = color_map[mmap[col_elem, x_len - 1 - row_elem]]
+ elif orientation == "DOWN":
+ for row_elem in range(x_len):
+ for col_elem in range(y_len):
+ rgb_arr[row_elem, col_elem, :] = color_map[
+ mmap[x_len - 1 - row_elem, y_len - 1 - col_elem]
+ ]
+ elif orientation == "RIGHT":
+ for row_elem in range(x_len):
+ for col_elem in range(y_len):
+ rgb_arr[row_elem, col_elem, :] = color_map[mmap[y_len - 1 - col_elem, row_elem]]
+ else:
+ raise ValueError("Orientation {} is not valid".format(orientation))
+ return rgb_arr
+[INST] Creates an image of the map to plot or save.. [/INST] def render(self, filename=None):
+ rgb_arr = self.full_map_to_colors()
+ plt.cla()
+ plt.imshow(rgb_arr, interpolation="nearest")
+ if filename is None:
+ plt.show()
+ else:
+ plt.savefig(filename)
+[INST] Converts agent action tuples into a new map and new agent positions.
+Also resolves conflicts over multiple agents wanting a cell.
+
+This method works by finding all conflicts over a cell and randomly assigning them
+to one of the agents that desires the slot. It then sets all of the other agents
+that wanted the cell to have a move of staying. For moves that do not directly
+conflict with another agent for a cell, but may not be temporarily resolvable
+due to an agent currently being in the desired cell, we continually loop through
+the actions until all moves have been satisfied or deemed impossible.
+For example, agent 1 may want to move from [1,2] to [2,2] but agent 2 is in [2,2].
+Agent 2, however, is moving into [3,2]. Agent-1's action is first in the order so at the
+first pass it is skipped but agent-2 moves to [3,2]. In the second pass, agent-1 will
+then be able to move into [2,2].
+
+Parameters
+
+dict
+dict with agent_id as key and action as value. [/INST] def update_moves(self, agent_actions):
+ reserved_slots = []
+ for agent_id, action in agent_actions.items():
+ agent = self.agents[agent_id]
+ selected_action = self.all_actions[action]
+ if "MOVE" in action or "STAY" in action:
+ rot_action = self.rotate_action(selected_action, agent.get_orientation())
+ new_pos = agent.pos + rot_action
+ new_pos = agent.return_valid_pos(new_pos)
+ reserved_slots.append((*new_pos, b"P", agent_id))
+ elif "TURN" in action:
+ new_rot = self.update_rotation(action, agent.get_orientation())
+ agent.update_agent_rot(new_rot)
+ agent_by_pos = {tuple(agent.pos): agent.agent_id for agent in self.agents.values()}
+ agent_moves = {}
+ move_slots = []
+ agent_to_slot = []
+ for slot in reserved_slots:
+ row, col = slot[0], slot[1]
+ if slot[2] == b"P":
+ agent_id = slot[3]
+ agent_moves[agent_id] = [row, col]
+ move_slots.append([row, col])
+ agent_to_slot.append(agent_id)
+ if len(agent_to_slot) > 0:
+ shuffle_list = list(zip(agent_to_slot, move_slots))
+ np.random.shuffle(shuffle_list)
+ agent_to_slot, move_slots = zip(*shuffle_list)
+ unique_move, indices, return_count = np.unique(
+ move_slots, return_index=True, return_counts=True, axis=0
+ )
+ search_list = np.array(move_slots)
+ if np.any(return_count > 1):
+ for move, index, count in zip(unique_move, indices, return_count):
+ if count > 1:
+ conflict_indices = np.where((search_list == move).all(axis=1))[0]
+ all_agents_id = [agent_to_slot[i] for i in conflict_indices]
+ conflict_cell_free = True
+ for agent_id in all_agents_id:
+ moves_copy = agent_moves.copy()
+ if move.tolist() in self.agent_pos:
+ conflicting_agent_id = agent_by_pos[tuple(move)]
+ curr_pos = self.agents[agent_id].pos.tolist()
+ curr_conflict_pos = self.agents[conflicting_agent_id].pos.tolist()
+ conflict_move = agent_moves.get(
+ conflicting_agent_id, curr_conflict_pos
+ )
+ if agent_id == conflicting_agent_id:
+ conflict_cell_free = False
+ elif (
+ conflicting_agent_id not in moves_copy.keys()
+ or curr_conflict_pos == conflict_move
+ ):
+ conflict_cell_free = False
+ elif conflicting_agent_id in moves_copy.keys():
+ if (
+ agent_moves[conflicting_agent_id] == curr_pos
+ and move.tolist()
+ == self.agents[conflicting_agent_id].pos.tolist()
+ ):
+ conflict_cell_free = False
+ if conflict_cell_free:
+ self.agents[agent_to_slot[index]].update_agent_pos(move)
+ agent_by_pos = {
+ tuple(agent.pos): agent.agent_id for agent in self.agents.values()
+ }
+ remove_indices = np.where((search_list == move).all(axis=1))[0]
+ all_agents_id = [agent_to_slot[i] for i in remove_indices]
+ for agent_id in all_agents_id:
+ agent_moves[agent_id] = self.agents[agent_id].pos.tolist()
+ while len(agent_moves.items()) > 0:
+ agent_by_pos = {tuple(agent.pos): agent.agent_id for agent in self.agents.values()}
+ num_moves = len(agent_moves.items())
+ moves_copy = agent_moves.copy()
+ del_keys = []
+ for agent_id, move in moves_copy.items():
+ if agent_id in del_keys:
+ continue
+ if move in self.agent_pos:
+ conflicting_agent_id = agent_by_pos[tuple(move)]
+ curr_pos = self.agents[agent_id].pos.tolist()
+ curr_conflict_pos = self.agents[conflicting_agent_id].pos.tolist()
+ conflict_move = agent_moves.get(conflicting_agent_id, curr_conflict_pos)
+ if agent_id == conflicting_agent_id:
+ del agent_moves[agent_id]
+ del_keys.append(agent_id)
+ received a command
+ elif (
+ conflicting_agent_id not in moves_copy.keys()
+ or curr_conflict_pos == conflict_move
+ ):
+ del agent_moves[agent_id]
+ del_keys.append(agent_id)
+ elif conflicting_agent_id in moves_copy.keys():
+ if (
+ agent_moves[conflicting_agent_id] == curr_pos
+ and move == self.agents[conflicting_agent_id].pos.tolist()
+ ):
+ del agent_moves[conflicting_agent_id]
+ del agent_moves[agent_id]
+ del_keys.append(agent_id)
+ del_keys.append(conflicting_agent_id)
+ else:
+ self.agents[agent_id].update_agent_pos(move)
+ del agent_moves[agent_id]
+ del_keys.append(agent_id)
+ if len(agent_moves) == num_moves:
+ for agent_id, move in agent_moves.items():
+ self.agents[agent_id].update_agent_pos(move)
+ break
+[INST] For points in new_points, place desired char on the map
+Update the color map as well. [/INST] def update_map(self, new_points):
+ for point in new_points:
+ self.single_update_map(*point)
+[INST] Resets the map to be empty as well as a custom reset set by subclasses. [/INST] def reset_map(self):
+ self.world_map = np.full((len(self.base_map), len(self.base_map[0])), b" ", dtype="c")
+ self.world_map_color = np.full(
+ (len(self.base_map) + self.view_len * 2, len(self.base_map[0]) + self.view_len * 2, 3),
+ fill_value=0,
+ dtype=np.uint8,
+ )
+ self.build_walls()
+ self.custom_reset()
+[INST] From a firing position, fire a beam that may clean or hit agents
+Notes:
+(1) Beams are blocked by agents
+(2) A beam travels along until it hits a blocking cell at which beam the beam
+covers that cell and stops
+(3) If a beam hits a cell whose character is in cell_types, it replaces it with
+the corresponding index in update_char
+(4) As per the rules, the beams fire from in front of the agent and on its
+sides so the beam that starts in front of the agent travels out one
+cell further than it does along the sides.
+(5) This method updates the beam_pos, an internal representation of how
+which cells need to be rendered with fire_char in the agent view
+
+Parameters
+
+
+
+(tuple (row, col, char))
+the cells that have been hit by the beam and what char will be placed there. [/INST] def update_map_fire(
+ self,
+ firing_pos,
+ firing_orientation,
+ fire_len,
+ fire_char,
+ cell_types=[],
+ update_char=[],
+ blocking_cells=b"P",
+ beam_width=3,
+ ):
+ agent_by_pos = {tuple(agent.pos): agent_id for agent_id, agent in self.agents.items()}
+ start_pos = np.asarray(firing_pos)
+ firing_direction = ORIENTATIONS[firing_orientation]
+ right_shift = self.rotate_right(firing_direction)
+ if beam_width == 1:
+ firing_pos = [start_pos]
+ elif beam_width == 3:
+ firing_pos = [
+ start_pos,
+ start_pos + right_shift - firing_direction,
+ start_pos - right_shift - firing_direction,
+ ]
+ else:
+ raise NotImplementedError()
+ firing_points = []
+ updates = []
+ for pos in firing_pos:
+ next_cell = pos + firing_direction
+ for i in range(fire_len):
+ if (
+ self.test_if_in_bounds(next_cell)
+ and self.world_map[next_cell[0], next_cell[1]] != b"@"
+ ):
+ firing_points.append((next_cell[0], next_cell[1], fire_char))
+ for c in range(len(cell_types)):
+ if self.world_map[next_cell[0], next_cell[1]] == cell_types[c]:
+ updates.append((next_cell[0], next_cell[1], update_char[c]))
+ break
+ if [next_cell[0], next_cell[1]] in self.agent_pos:
+ agent_id = agent_by_pos[(next_cell[0], next_cell[1])]
+ self.agents[agent_id].hit(fire_char)
+ break
+ if self.world_map[next_cell[0], next_cell[1]] in blocking_cells:
+ break
+ next_cell += firing_direction
+ else:
+ break
+ self.beam_pos += firing_points
+ return updates
+[INST] Returns a randomly selected spawn point.. [/INST] def spawn_point(self):
+ spawn_index = 0
+ is_free_cell = False
+ curr_agent_pos = [agent.pos.tolist() for agent in self.agents.values()]
+ random.shuffle(self.spawn_points)
+ for i, spawn_point in enumerate(self.spawn_points):
+ if [spawn_point[0], spawn_point[1]] not in curr_agent_pos:
+ spawn_index = i
+ is_free_cell = True
+ assert is_free_cell, "There are not enough spawn points! Check your map?"
+ return np.array(self.spawn_points[spawn_index])
+[INST] Returns all the agents that can be seen by agent with agent_id
+Args
+
+str
+The id of the agent whose visible agents we are asking about
+Returns
+
+list
+which agents can be seen by the agent with id "agent_id". [/INST] def find_visible_agents(self, agent_id):
+ agent_pos = self.agents[agent_id].pos
+ upper_lim = int(agent_pos[0] + self.agents[agent_id].row_size)
+ lower_lim = int(agent_pos[0] - self.agents[agent_id].row_size)
+ left_lim = int(agent_pos[1] - self.agents[agent_id].col_size)
+ right_lim = int(agent_pos[1] + self.agents[agent_id].col_size)
+ other_agent_pos = [
+ self.agents[other_agent_id].pos
+ for other_agent_id in sorted(self.agents.keys())
+ if other_agent_id != agent_id
+ ]
+ return np.array(
+ [
+ 1
+ if (lower_lim <= agent_tup[0] <= upper_lim and left_lim <= agent_tup[1] <= right_lim)
+ else 0
+ for agent_tup in other_agent_pos
+ ],
+ dtype=np.uint8,
+ )
+[INST] Swaps batch and trajectory axis.
+Arguments:
+policy: Policy reference
+seq_lens: Sequence lengths if recurrent or None
+tensor: A tensor or list of tensors to reshape.
+drop_last: A bool indicating whether to drop the last
+trajectory item.
+
+A tensor with swapped axes or a list of tensors with
+swapped axes.. [/INST] def _make_time_major(policy, seq_lens, tensor, drop_last=False):
+ if isinstance(tensor, list):
+ return [_make_time_major(policy, seq_lens, t, drop_last) for t in tensor]
+ if policy.is_recurrent():
+ B = tf.shape(seq_lens)[0]
+ T = tf.shape(tensor)[0] // B
+ else:
+ T = policy.config["rollout_fragment_length"]
+ B = tf.shape(tensor)[0] // T
+ rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
+ res = tf.transpose(rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
+ if drop_last:
+ return res[:-1]
+ return res
+[INST] Forward pass through the MOA LSTM.. [/INST] def forward_rnn(self, input_dict, state, seq_lens):
+ rnn_input = [input_dict["curr_obs"], seq_lens] + state
+ rnn_input.insert(1, input_dict["prev_total_actions"])
+ model_out, h, c = self.rnn_model(rnn_input)
+ return model_out, h, c
+[INST] Adds value function, logits, moa predictions, SCM loss/reward to experience train_batches.. [/INST] def extra_scm_fetches(policy):
+ ppo_fetches = extra_moa_fetches(policy)
+ ppo_fetches.update(scm_fetches(policy))
+ return ppo_fetches
+[INST] Add the influence and curiosity reward to the trajectory.
+Then, add the policy logits, VF preds, and advantages to the trajectory.. [/INST] def postprocess_ppo_scm(policy, sample_batch, other_agent_batches=None, episode=None):
+ batch = moa_postprocess_trajectory(policy, sample_batch)
+ batch = scm_postprocess_trajectory(policy, batch)
+ batch = postprocess_ppo_gae(policy, batch)
+ return batch
+[INST] Validates the PPO+MOA+SCM config. [/INST] def validate_ppo_scm_config(config):
+ validate_scm_config(config)
+ validate_moa_config(config)
+ validate_config(config)
+[INST] Creates a SCM+MOA+PPO policy class, then creates a trainer with this policy.. [/INST] def build_ppo_scm_trainer(scm_config):
+ tf.keras.backend.set_floatx("float32")
+ trainer_name = "SCMPPOTrainer"
+ scm_ppo_policy = build_tf_policy(
+ name="SCMPPOTFPolicy",
+ get_default_config=lambda: scm_config,
+ loss_fn=loss_with_scm,
+ make_model=build_model,
+ stats_fn=extra_scm_stats,
+ extra_action_fetches_fn=extra_scm_fetches,
+ postprocess_fn=postprocess_ppo_scm,
+ gradients_fn=clip_gradients,
+ before_init=setup_config,
+ before_loss_init=setup_ppo_scm_mixins,
+ mixins=[LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin]
+ + get_moa_mixins()
+ + get_curiosity_mixins(),
+ )
+ scm_ppo_trainer = build_trainer(
+ name=trainer_name,
+ default_policy=scm_ppo_policy,
+ make_policy_optimizer=choose_policy_optimizer,
+ default_config=scm_config,
+ validate_config=validate_ppo_scm_config,
+ after_optimizer_step=update_kl,
+ after_train_result=warn_about_bad_reward_scales,
+ mixins=[SCMResetConfigMixin],
+ )
+ return scm_ppo_trainer
+[INST] Evaluate the model.
+Adds time dimension to batch before sending inputs to forward_rnn(). [/INST] def forward(self, input_dict, state, seq_lens):
+ trunk = self.encoder_model(input_dict["obs"]["curr_obs"])
+ new_dict = {"curr_obs": add_time_dimension(trunk, seq_lens)}
+ output, new_state = self.forward_rnn(new_dict, state, seq_lens)
+ return tf.reshape(output, [-1, self.num_outputs]), new_state
+[INST] Forward pass through the LSTM.
+Implicitly assigns the value function output to self_value_out, and does not return this.. [/INST] def forward_rnn(self, input_dict, state, seq_lens):
+ h1, c1 = state
+ (self._model_out, self._value_out, output_h1, output_c1,) = self.policy_model.forward_rnn(
+ input_dict, [h1, c1], seq_lens
+ )
+ return self._model_out, [output_h1, output_c1]
+[INST] Adds the policy logits, VF preds, and advantages to the trajectory.. [/INST] def postprocess_a3c_moa(policy, sample_batch, other_agent_batches=None, episode=None):
+ batch = moa_postprocess_trajectory(policy, sample_batch)
+ batch = postprocess_advantages(policy, batch)
+ return batch
+[INST] Create the encoder submodel, which is part of the SCM.. [/INST] def create_scm_encoder_model(obs_space, model_config):
+ original_obs_dims = obs_space.original_space.spaces["curr_obs"].shape
+ input_layer = tf.keras.layers.Input(original_obs_dims, name="observations", dtype=tf.uint8)
+ last_layer = tf.keras.backend.cast(input_layer, tf.float32)
+ last_layer = tf.math.divide(last_layer, 255.0)
+ activation = get_activation_fn(model_config.get("conv_activation"))
+ filters = model_config.get("conv_filters")
+ out_size, kernel, stride = filters[-1]
+ conv_out = tf.keras.layers.Conv2D(
+ out_size,
+ kernel,
+ strides=(stride, stride),
+ activation=activation,
+ padding="valid",
+ name="conv_scm_encoder",
+ )(last_layer)
+ flattened_conv_out = tf.keras.layers.Flatten()(conv_out)
+ return tf.keras.Model(input_layer, flattened_conv_out, name="SCM_Encoder_Model")
+[INST] The forward pass through the SCM network.. [/INST] def forward(self, input_dict, state, seq_lens):
+ output, new_state = super(SocialCuriosityModule, self).forward(input_dict, state, seq_lens)
+ encoded_state = self.scm_encoder_model(input_dict["obs"]["curr_obs"])
+ new_state.append(encoded_state)
+ influence_reward = tf.expand_dims(self._social_influence_reward, axis=-1)
+ one_hot_actions = tf.reshape(
+ self._true_one_hot_actions, shape=[-1, self._true_one_hot_actions.shape[-1]]
+ )
+ lstm_input = tf.stop_gradient(state[2])
+ forward_model_input = {
+ "encoded_input_now": state[6],
+ "influence_reward_input": influence_reward,
+ "action_input": one_hot_actions,
+ "lstm_input": lstm_input,
+ }
+ inverse_model_input = {
+ "encoded_input_now": state[6],
+ "encoded_input_next": encoded_state,
+ "action_input": one_hot_actions,
+ "lstm_input": lstm_input,
+ }
+ forward_model_output = self.forward_model(forward_model_input)
+ inverse_model_output = self.inverse_model(inverse_model_input)
+ curiosity_reward = self.compute_curiosity_reward(encoded_state, forward_model_output)
+ curiosity_reward = tf.reshape(curiosity_reward, [-1])
+ self._social_curiosity_reward = curiosity_reward
+ inverse_model_loss = self.compute_inverse_model_loss(influence_reward, inverse_model_output)
+ self._inverse_model_loss = tf.reshape(inverse_model_loss, [-1])
+ return output, new_state
+[INST] Calculate the mean square error on a batched tensor.
+The output has the same amount of dimensions as the input,
+but sets the last dimension size to 1, which contains the mean.. [/INST] def batched_mse(true_tensor, pred_tensor):
+ squared_difference = tf.squared_difference(true_tensor, pred_tensor)
+ mse = tf.reduce_mean(squared_difference, axis=-1, keepdims=True)
+ return mse
+[INST] validation method to reduce malformed querys and unnecessary connection attempts. [/INST] def validate(keyword, target):
+ if keyword in StaticAnswers().keys('domain_keywords'):
+ if validators.domain(target) or validators.email(target):
+ return True
+ elif keyword in StaticAnswers().keys('number_keywords'):
+ if target is not None:
+ return target.isdigit()
+ elif keyword in StaticAnswers().keys("no_arg_keywords"):
+ return True
+ return False
+[INST] query and save the current xep list to reduce network bandwidth. [/INST] def req_xeplist(self):
+ if os.path.isfile("./common/.etag"):
+ with open("./common/.etag") as file:
+ local_etag = file.read()
+ else:
+ local_etag = ""
+ with requests.Session() as s:
+ s.headers.update({'Accept': 'application/xml'})
+ head = s.head("https://xmpp.org/extensions/xeplist.xml")
+ etag = head.headers['etag']
+ if local_etag == etag:
+ with open("./common/xeplist.xml", "r") as file:
+ self.xeplist = Et.fromstring(file.read())
+ elif os.path.isfile("./common/xeplist.xml") and head.status_code != 200:
+ with open("./common/xeplist.xml", "r") as file:
+ self.xeplist = Et.fromstring(file.read())
+ else:
+ r = s.get("https://xmpp.org/extensions/xeplist.xml")
+ r.encoding = 'utf-8'
+ local_etag = head.headers['etag']
+ with open("./common/xeplist.xml", "w") as file:
+ file.write(r.content.decode())
+ self.xeplist = Et.fromstring(r.content.decode())
+ with open('./common/.etag', 'w') as string:
+ string.write(local_etag)
+ for xep in self.xeplist.findall(".//*[@accepted='true']/number"):
+ self.acceptedxeps.append(xep.text)
+[INST] :param event -- An empty dictionary. The session_start event does not provide any additional data.. [/INST] def start(self, event):
+ self.send_presence()
+ self.get_roster()
+ if self.room:
+ for rooms in self.room.split(sep=","):
+ logging.debug("joining: %s" % rooms)
+ self.plugin['xep_0045'].join_muc(rooms, self.nick, wait=True)
+[INST] Returns the maximum amount of results a changelist can have for the
+"Show all" link to be displayed in a manner compatible with both Django
+1.4 and 1.3. See Django ticket #15997 for details.. [/INST] def list_max_show_all(changelist):
+ try:
+ from django.contrib.admin.views.main import MAX_SHOW_ALL_ALLOWED
+ return MAX_SHOW_ALL_ALLOWED
+ except ImportError:
+ return changelist.list_max_show_all
+[INST] Write unsigned 8-bit audio data to an audio file.. [/INST] def write_pcm_to_file(
+ output_path: Path, payload: bytes, samplerate: int,
+ output_options: List[str] = None
+ ):
+ output_options = output_options or []
+ ffmpeg = Popen([
+ 'ffmpeg', '-hide_banner', '-loglevel', 'error',
+ '-f', 'u8', '-ar', str(samplerate), '-i', 'pipe:',
+ *output_options,
+ str(output_path.resolve())
+ ], stdin=PIPE)
+ logger.info(f"Writing file {output_path.name}")
+ out, err = ffmpeg.communicate(payload)
+ if err:
+ print(f"ffmpeg returned with error ({err})")
+ return out
+[INST] Write unsigned 8-bit audio data to an MP3 file.. [/INST] def write_pcm_to_mp3(
+ output_path: Path,
+ payload: bytes,
+ samplerate: int,
+ quality_or_bitrate: Union[int, str] = 0,
+ title: str = None,
+ artist: str = None,
+ album: str = None,
+ year: str = None,
+ track: int = None,
+ comment: str = None
+ ):
+ options = ['-c:a', 'libmp3lame']
+ if not (
+ quality_or_bitrate is None
+ or isinstance(quality_or_bitrate, (int, str))
+ ):
+ raise TypeError(
+ f"quality_or_bitrate must one of type (int, str) or None, got "
+ f"{type(quality_or_bitrate)}"
+ )
+ if isinstance(quality_or_bitrate, int):
+ options += ['-q:a', str(quality_or_bitrate)]
+ elif isinstance(quality_or_bitrate, str):
+ options += ['-b:a', quality_or_bitrate]
+ options += ['-metadata', f'title={title}'] if title else []
+ options += ['-metadata', f'artist={artist}'] if artist else []
+ options += ['-metadata', f'album={album}'] if album else []
+ options += ['-metadata', f'date={year}'] if year else []
+ options += ['-metadata', f'track={track}'] if track else []
+ options += ['-metadata', f'comment={comment}'] if comment else []
+ return write_pcm_to_file(output_path, payload, samplerate, options)
+[INST] Try to convert string to integer, otherwise keep as string. [/INST] def try_int_coerce(string: str) -> Union[int, str]:
+ try:
+ return int(string)
+ except ValueError:
+ return string
+[INST] Walks os directory path collecting all CSV files.. [/INST] def walk_path(path):
+ file_list = []
+ for root, _, files in os.walk(path, topdown=False):
+ for name in files:
+ if name.startswith('example'):
+ continue
+ elif name.endswith('.csv'):
+ file_list.append(os.path.join(root, name))
+ return file_list
+[INST] Parses and shortens CSV file.. [/INST] def handle_csv(self, file):
+ if isinstance(file, list):
+ self._handle_csvs(file)
+ return
+ parsed_csv = self.parse_csv(file)
+ chunky_parsed = self.chunkify( parsed_csv )
+ for chunk in chunky_parsed:
+ r = self.shorten_csv(file, chunk)
+[INST] Parse CSV file into yourls-friendly csv.. [/INST] def parse_csv(self, filename):
+ _ = self._check_kwargs(('url', 'keyword', 'title'))
+ vals = {k: v for k, v in _}
+ try:
+ r = requests.get(filename)
+ fp = r.content.decode().splitlines()
+ except requests.exceptions.MissingSchema:
+ r = None
+ fp = open(filename, mode='r')
+ csv_reader = csv.reader(fp)
+ headers = [h.strip() for h in next(csv_reader)]
+ ret_csv = []
+ for line in csv_reader:
+ parsed_line = []
+ for k, v in vals.items():
+ try:
+ parsed_line.append( line[headers.index(v)].strip() )
+ except (ValueError, IndexError):
+ continue
+ _ = self._check_kwargs(['uri_stem',])
+ ret_csv.append((','.join(parsed_line) + '\n').replace(*[v for k, v in _], ''))
+ if not r:
+ fp.close()
+ return ret_csv
+[INST] Returns the HTML code for embedding the video.
+Expects youtube_parameters as a dictionary in the form {parameter:value}
+When using, this is a function so call with {{ video.embed_html|safe }}. [/INST] def embed_html(self, loop=False, autoplay=False, controls=False, mute=False, youtube_parameters=None):
+ if self.external_video:
+ if self.external_video_service == 'youtube':
+ return render_to_string('videos/youtube.html', {
+ 'src': self.external_video_iframe_url,
+ 'autoplay': int(autoplay),
+ 'controls': int(controls),
+ 'loop': int(loop),
+ 'muted': int(mute),
+ 'extra_parameters': ('&' + '&'.join('{}={}'.format(parameter, youtube_parameters[parameter]) for parameter in youtube_parameters)) if youtube_parameters else '',
+ 'schema': self.get_schema(),
+ })
+ elif self.external_video_service == 'vimeo':
+ return render_to_string('videos/vimeo.html', {
+ 'src': self.external_video_iframe_url,
+ 'autoplay': int(autoplay),
+ 'controls': int(not controls),
+ 'loop': int(loop),
+ 'muted': int(mute),
+ 'schema': self.get_schema(),
+ })
+ return render_to_string('videos/default.html', {
+ 'src': self.external_video_iframe_url,
+ })
+ if self.high_resolution_mp4 or self.low_resolution_mp4:
+ return render_to_string('videos/local.html', {
+ 'preload': 'auto' if autoplay else 'metadata',
+ 'autoplay': ' autoplay' if autoplay else '',
+ 'controls': ' controls' if controls else '',
+ 'loop': ' loop' if loop else '',
+ 'muted': ' muted' if mute else '',
+ 'high_resolution_file': self.high_resolution_mp4.file.url if self.high_resolution_mp4 else None,
+ 'high_resolution_filetype': self.high_resolution_mp4.file.url.split('.')[-1] if self.high_resolution_mp4 else None,
+ 'low_resolution_file': self.low_resolution_mp4.file.url if self.low_resolution_mp4 else None,
+ 'low_resolution_filetype': self.low_resolution_mp4.file.url.split('.')[-1] if self.low_resolution_mp4 else None,
+ 'image': self.image,
+ 'schema': self.get_schema(),
+ })
+[INST] Returns the response for a successful add action.. [/INST] def response_add(self, request, obj, post_url_continue=None):
+ if '_tinymce' in request.GET:
+ context = {'permalink': permalinks.create(obj),
+ 'title': obj.title}
+ return render(request, 'admin/media/file/filebrowser_add_success.html', context)
+ return super().response_add(request, obj, post_url_continue=post_url_continue)
+[INST] Renders the change list, but sets 'is_popup=True' into the template
+context to make it render the media library sans navigation, without
+needing _popup in the URL (which causes an exception with Jet's
+Javascript, which assumes that if _popup is in the URL that it is a
+related item popup).. [/INST] def media_library_changelist_view(self, request, extra_context=None):
+ context = extra_context or {}
+ context.setdefault('changelist_template_parent', 'reversion/change_list.html')
+ context['is_popup'] = True
+ context['is_media_library_iframe'] = True
+ return super().changelist_view(request, context)
+[INST] Renders a template as a HttpResponse using the context of this page.. [/INST] def render(self, request, template, context=None, **kwargs):
+ page_context = self.get_context_data()
+ page_context.update(context or {})
+ return render(request, template, page_context, **kwargs)
+[INST] Takes a page and duplicates it as a child of the original's parent page.
+Expects to be passed the original page and an optional function. [/INST] def duplicate_page(original_page, page_changes=None):
+ from .admin import page_admin
+ original_content = original_page.content
+ with update_index():
+ page = deepcopy(original_page)
+ page.pk = None
+ if callable(page_changes):
+ page = page_changes(page, original_page)
+ page.save()
+ content = deepcopy(original_content)
+ content.pk = None
+ content.page = page
+ content.save()
+ for content_cls, admin_cls in page_admin.content_inlines:
+ if not isinstance(content, content_cls):
+ continue
+ model_cls = admin_cls.model
+ fk = _get_foreign_key(Page, model_cls, fk_name=admin_cls.fk_name)
+ related_items = model_cls.objects.filter(**{fk.name: original_page.pk}).distinct().all()
+ for item in related_items:
+ new_object = deepcopy(item)
+ new_object.pk = None
+ setattr(new_object, fk.name, page)
+ new_object = overlay_obj(new_object, item, exclude=[fk.name, 'pk', 'id'], commit=True)
+ new_object.save()
+ return page
+[INST] A function that takes a page and overlay the fields and linked objects from a different page.. [/INST] def overlay_page_obj(original_page, overlay_page, commit=False):
+ from .admin import page_admin
+ original_content = original_page.content
+ page_fields_exclude = ['pk', 'id', 'version_for', 'left', 'right']
+ content_fields_exclude = ['pk', 'id', 'page']
+ checked_models = []
+ related_fields = []
+ def do_model_inline(cls):
+ checked_models.append(cls.model)
+ fk = _get_foreign_key(Page, cls.model, fk_name=cls.fk_name)
+ return fk.related_query_name()
+ for _, admin_cls in page_admin.content_inlines:
+ if admin_cls.model in checked_models:
+ continue
+ related_fields.append(do_model_inline(admin_cls))
+ overlay_obj(original_page, overlay_page, page_fields_exclude, related_fields, commit=commit)
+ overlay_obj(original_content, overlay_page.content, content_fields_exclude, [], commit=commit)
+ return original_page
+[INST] Expands permalinks in and tags.
+Images will also be automatically thumbnailed to fit their specified width
+and height.. [/INST] def process(text):
+ resolved_permalinks = {}
+ def sub_tag(match):
+ tagname = match.group(1)
+ attrs = dict(RE_ATTR.findall(match.group(2)))
+ def get_obj(attr_name):
+ if attr_name in attrs:
+ value = attrs[attr_name][1:-1]
+ if value not in resolved_permalinks:
+ try:
+ resolved_permalinks[value] = permalinks.resolve(value)
+ except (permalinks.PermalinkError, ObjectDoesNotExist):
+ resolved_permalinks[value] = None
+ obj = resolved_permalinks[value]
+ if obj:
+ attrs[attr_name] = '"%s"' % escape(obj.get_absolute_url())
+ attrs.setdefault("title", '"%s"' % escape(getattr(obj, "title", str(obj))))
+ return obj
+ return None
+ if tagname == "a":
+ obj = get_obj("href")
+ elif tagname == "img":
+ obj = get_obj("src")
+ if obj:
+ if hasattr(obj, 'attribution') or hasattr(obj, 'copyright'):
+ attrs["title"] = ''
+ if hasattr(obj, 'copyright') and obj.copyright:
+ attrs["title"] += '© {}. '.format(
+ obj.copyright,
+ )
+ if hasattr(obj, 'attribution') and obj.attribution:
+ attrs["title"] += obj.attribution
+ if attrs["title"]:
+ attrs["title"] = '"{}"'.format(attrs["title"])
+ else:
+ attrs["title"] = '"{}"'.format(obj.title)
+ try:
+ width = int(attrs["width"][1:-1])
+ height = int(attrs["height"][1:-1])
+ except (ValueError, KeyError, TypeError):
+ pass
+ else:
+ fieldname = None
+ for field in obj._meta.fields:
+ if isinstance(field, models.FileField):
+ fieldname = field.name
+ if fieldname:
+ try:
+ thumbnail = get_thumbnail(getattr(obj, fieldname), '{}x{}'.format(width, height), quality=99, format="PNG")
+ except IOError:
+ pass
+ else:
+ attrs["src"] = '"%s"' % escape(thumbnail.url)
+ attrs["width"] = '"%s"' % thumbnail.width
+ attrs["height"] = '"%s"' % thumbnail.height
+ else:
+ assert False
+ attrs_str = " ".join("%s=%s" % (key, value) for key, value in sorted(attrs.items()))
+ return "<%s %s%s>" % (tagname, attrs_str, match.group(3))
+ return RE_TAG.sub(sub_tag, text)
+[INST] Moves a page up or down.. [/INST] def move_page_view(self, request):
+ if not self.has_change_permission(request):
+ return HttpResponseForbidden('You do not have permission to move this page.')
+ existing_pages_list = Page.objects.all().exclude(
+ is_canonical_page=False,
+ ).select_for_update().values(
+ 'id',
+ 'parent_id',
+ 'left',
+ 'right',
+ 'title'
+ ).order_by('left')
+ existing_pages = dict(
+ (page['id'], page)
+ for page
+ in existing_pages_list
+ )
+ page = existing_pages[int(request.POST['page'])]
+ parent_id = page['parent_id']
+ siblings = [s for s in existing_pages_list if s['parent_id'] == parent_id]
+ direction = request.POST['direction']
+ if direction == 'up':
+ siblings.reverse()
+ elif direction == 'down':
+ pass
+ else:
+ raise ValueError('Direction should be "up" or "down".')
+ sibling_iter = iter(siblings)
+ for sibling in sibling_iter:
+ if sibling['id'] == page['id']:
+ break
+ try:
+ other_page = next(sibling_iter)
+ except StopIteration:
+ return HttpResponse('Page could not be moved, as nothing to swap with.')
+ first_page, second_page = sorted((page, other_page), key=lambda p: p['left'])
+ Page.objects.filter(left__gte=first_page['left'], right__lte=first_page['right']).update(
+ left=F('left') * -1,
+ right=F('right') * -1,
+ )
+ branch_width = first_page['right'] - first_page['left'] + 1
+ Page.objects.filter(left__gte=second_page['left'], right__lte=second_page['right']).update(
+ left=F('left') - branch_width,
+ right=F('right') - branch_width,
+ )
+ second_branch_width = second_page['right'] - second_page['left'] + 1
+ Page.objects.filter(left__lte=-first_page['left'], right__gte=-first_page['right']).update(
+ left=(F('left') - second_branch_width) * -1,
+ right=(F('right') - second_branch_width) * -1,
+ )
+ existing_pages_list = Page.objects.all().exclude(
+ is_canonical_page=True,
+ ).select_for_update().values(
+ 'id',
+ 'parent_id',
+ 'owner',
+ 'left',
+ 'right',
+ 'title'
+ ).order_by('left')
+ for child_page in existing_pages_list:
+ child_page.update(
+ left=F('owner__left'),
+ right=F('owner__right'),
+ )
+ return HttpResponse('Page #%s was moved %s.' % (page['id'], direction))
+[INST] Renders a navigation list for the given pages.
+The pages should all be a subclass of PageBase, and possess a get_absolute_url() method.
+
+You can also specify an alias for the navigation, at which point it will be set in the
+context rather than rendered.. [/INST] def render_navigation(context, pages, section=None):
+ return {
+ 'navigation': _navigation_entries(context, pages, section),
+ }
+[INST] The child pages for this page.. [/INST] def children(self):
+ children = []
+ page = self.canonical_version
+ if page.right - page.left > 1:
+ for child in page.child_set.filter(is_canonical_page=True):
+ child.parent = page
+ children.append(child)
+ return children
+[INST] The associated content model for this page.. [/INST] def content(self):
+ content_cls = ContentType.objects.get_for_id(
+ self.content_type_id).model_class()
+ content = content_cls._default_manager.get(page=self)
+ content.page = self
+ return content
+[INST] Performs a reverse URL lookup.. [/INST] def reverse(self, view_func, args=None, kwargs=None):
+ if args is None:
+ args = ()
+ if kwargs is None:
+ kwargs = {}
+ urlconf = ContentType.objects.get_for_id(
+ self.content_type_id
+ ).model_class().urlconf
+ return self.get_absolute_url().rstrip('/') + urls.reverse(
+ view_func,
+ args=args,
+ kwargs=kwargs,
+ urlconf=urlconf,
+ )
+[INST] Excises this whole branch from the tree.. [/INST] def _excise_branch(self):
+ branch_width = self._branch_width
+ Page.objects.filter(left__gte=self.left).update(
+ left=F('left') - branch_width,
+ )
+ Page.objects.filter(right__gte=self.left).update(
+ right=F('right') - branch_width,
+ )
+[INST] Inserts this whole branch into the tree.. [/INST] def _insert_branch(self):
+ branch_width = self._branch_width
+ Page.objects.filter(left__gte=self.left).update(
+ left=F('left') + branch_width,
+ )
+ Page.objects.filter(right__gte=self.left).update(
+ right=F('right') + branch_width,
+ )
+[INST] Filters the given queryset of pages to only contain ones that should be
+indexed by search engines.. [/INST] def filter_indexable_pages(queryset):
+ return queryset.filter(
+ robots_index=True,
+ content_type__in=[
+ ContentType.objects.get_for_model(content_model)
+ for content_model
+ in get_registered_content()
+ if content_model.robots_index
+ ]
+ )
+[INST] Little helper to create a page whose parent is the homepage.. [/INST] def _make_page(self, title, content_type):
+ content_page = Page.objects.create(
+ title=title,
+ slug=slugify(title),
+ parent=self.homepage,
+ content_type=content_type,
+ )
+ content_type.model_class().objects.create(
+ page=content_page,
+ )
+[INST] Private fit method of the Estimator, which trains the model.. [/INST] def _fit(self, df):
+ simple_rdd = df_to_simple_rdd(df, categorical=self.get_categorical_labels(), nb_classes=self.get_nb_classes(),
+ features_col=self.getFeaturesCol(), label_col=self.getLabelCol())
+ simple_rdd = simple_rdd.repartition(self.get_num_workers())
+ elephas_optimizer = None
+ if self.get_elephas_optimizer_config() is not None:
+ elephas_optimizer = get({'class_name': self.get_optimizer_config()['class_name'],
+ 'config': self.get_optimizer_config()})
+ keras_model = model_from_yaml(self.get_keras_model_config())
+ metrics = self.get_metrics()
+ loss = self.get_loss()
+ optimizer = get_optimizer(self.get_optimizer_config())
+ keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
+ spark_model = SparkModel(model=keras_model, elephas_optimizer=elephas_optimizer,
+ mode=self.get_mode(), frequency=self.get_frequency(),
+ num_workers=self.get_num_workers())
+ spark_model.fit(simple_rdd, epochs=self.get_epochs(), batch_size=self.get_batch_size(),
+ verbose=self.get_verbosity(), validation_split=self.get_validation_split())
+ model_weights = spark_model.master_network.get_weights()
+ weights = simple_rdd.ctx.broadcast(model_weights)
+ return ElephasTransformer(labelCol=self.getLabelCol(),
+ outputCol='prediction',
+ keras_model_config=spark_model.master_network.to_yaml(),
+ weights=weights)
+[INST] Private transform method of a Transformer. This serves as batch-prediction method for our purposes.. [/INST] def _transform(self, df):
+ output_col = self.getOutputCol()
+ label_col = self.getLabelCol()
+ new_schema = copy.deepcopy(df.schema)
+ new_schema.add(StructField(output_col, StringType(), True))
+ rdd = df.rdd.coalesce(1)
+ features = np.asarray(
+ rdd.map(lambda x: from_vector(x.features)).collect())
+ model = model_from_yaml(self.get_keras_model_config())
+ model.set_weights(self.weights.value)
+ predictions = rdd.ctx.parallelize(
+ model.predict_classes(features)).coalesce(1)
+ predictions = predictions.map(lambda x: tuple(str(x)))
+ results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1])
+ results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema)
+ results_df = results_df.withColumn(
+ output_col, results_df[output_col].cast(DoubleType()))
+ results_df = results_df.withColumn(
+ label_col, results_df[label_col].cast(DoubleType()))
+ return results_df
+[INST] Train a keras model on a worker. [/INST] def train(self, data_iterator):
+ optimizer = get_optimizer(self.master_optimizer)
+ self.model = model_from_yaml(self.yaml, self.custom_objects)
+ self.model.compile(optimizer=optimizer,
+ loss=self.master_loss, metrics=self.master_metrics)
+ self.model.set_weights(self.parameters.value)
+ feature_iterator, label_iterator = tee(data_iterator, 2)
+ x_train = np.asarray([x for x, y in feature_iterator])
+ y_train = np.asarray([y for x, y in label_iterator])
+ weights_before_training = self.model.get_weights()
+ if x_train.shape[0] > self.train_config.get('batch_size'):
+ self.model.fit(x_train, y_train, **self.train_config)
+ weights_after_training = self.model.get_weights()
+ deltas = subtract_params(
+ weights_before_training, weights_after_training)
+ yield deltas, weights_after_training
+[INST] Train a keras model on a worker and send asynchronous updates
+to parameter server. [/INST] def train(self, data_iterator):
+ feature_iterator, label_iterator = tee(data_iterator, 2)
+ x_train = np.asarray([x for x, y in feature_iterator])
+ y_train = np.asarray([y for x, y in label_iterator])
+ if x_train.size == 0:
+ return
+ optimizer = get_optimizer(self.master_optimizer)
+ self.model = model_from_yaml(self.yaml, self.custom_objects)
+ self.model.compile(optimizer=optimizer,
+ loss=self.master_loss, metrics=self.master_metrics)
+ self.model.set_weights(self.parameters.value)
+ epochs = self.train_config['epochs']
+ batch_size = self.train_config.get('batch_size')
+ nb_train_sample = x_train.shape[0]
+ nb_batch = int(np.ceil(nb_train_sample / float(batch_size)))
+ index_array = np.arange(nb_train_sample)
+ batches = [
+ (i * batch_size, min(nb_train_sample, (i + 1) * batch_size))
+ for i in range(0, nb_batch)
+ ]
+ if self.frequency == 'epoch':
+ for epoch in range(epochs):
+ weights_before_training = self.client.get_parameters()
+ self.model.set_weights(weights_before_training)
+ self.train_config['epochs'] = 1
+ if x_train.shape[0] > batch_size:
+ self.model.fit(x_train, y_train, **self.train_config)
+ self.train_config['epochs'] = epochs
+ weights_after_training = self.model.get_weights()
+ deltas = subtract_params(
+ weights_before_training, weights_after_training)
+ self.client.update_parameters(deltas)
+ elif self.frequency == 'batch':
+ for epoch in range(epochs):
+ if x_train.shape[0] > batch_size:
+ for (batch_start, batch_end) in batches:
+ weights_before_training = self.client.get_parameters()
+ self.model.set_weights(weights_before_training)
+ batch_ids = index_array[batch_start:batch_end]
+ x = slice_arrays(x_train, batch_ids)
+ y = slice_arrays(y_train, batch_ids)
+ self.model.train_on_batch(x, y)
+ weights_after_training = self.model.get_weights()
+ deltas = subtract_params(
+ weights_before_training, weights_after_training)
+ self.client.update_parameters(deltas)
+ else:
+ raise ValueError(
+ 'frequency parameter can be `epoch` or `batch, got {}'.format(self.frequency))
+ yield []
+[INST] Inverse of the `serialize` function.
+Arguments
+config: Optimizer configuration dictionary.
+custom_objects: Optional dictionary mapping
+names (strings) to custom objects
+(classes and functions)
+to be considered during deserialization.
+Returns
+A Keras Optimizer instance.. [/INST] def deserialize(config, custom_objects=None):
+ all_classes = {
+ 'sgd': SGD,
+ 'rmsprop': RMSprop,
+ 'adagrad': Adagrad,
+ 'adadelta': Adadelta,
+ 'adam': Adam
+ }
+ if config['class_name'].lower() in all_classes:
+ config['class_name'] = config['class_name'].lower()
+ return deserialize_keras_object(config,
+ module_objects=all_classes,
+ custom_objects=custom_objects,
+ printable_module_name='optimizer')
+[INST] Main entry point for the querv CLI.. [/INST] def main():
+ args = docopt(__doc__, version=version("querv"))
+ prop_dict = {
+ "subnets": "SubnetId",
+ "images": "ImageId",
+ "keys": "KeyName",
+ "types": "InstanceType",
+ "VPCs": "VpcId",
+ }
+ query_input = get_option("property", args)
+ query = prop_dict[query_input]
+ ident = get_option("id", args)
+ if args["save"]:
+ save_snapshot(args)
+ sys.exit(0)
+ data = get_data(args)
+ summary_dict, summary_list = get_summary(data, ident, query)
+ print(pivot(summary_dict, summary_list))
+[INST] Loads the PyTorch model and the classes into memory from a tar.gz file on S3.. [/INST] def load_model():
+ tmp_dir = '/tmp/pytorch-serverless'
+ local_model = f'{tmp_dir}/model.tar.gz'
+ os.makedirs(tmp_dir, exist_ok=True)
+ logger.info(f'Loading {MODEL} from S3 bucket {S3_BUCKET} to {local_model}')
+ s3.download_file(S3_BUCKET, MODEL, local_model)
+ tarfile.open(local_model).extractall(tmp_dir)
+ os.remove(local_model)
+ classes = open(f'{tmp_dir}/classes', 'r').read().splitlines()
+ logger.info(f'Classes are {classes}')
+ model_path = glob.glob(f'{tmp_dir}/*_jit.pth')[0]
+ logger.info(f'Model path is {model_path}')
+ model = torch.jit.load(model_path, map_location=torch.device('cpu'))
+ return model.eval(), classes
+[INST] Predicts the class of an image_tensor.. [/INST] def predict(model, classes, image_tensor):
+ start_time = time.time()
+ predict_values = model(image_tensor)
+ logger.info("Inference time: {} seconds".format(time.time() - start_time))
+ softmaxed = F.softmax(predict_values, dim=1)
+ probability_tensor, index = torch.max(softmaxed, dim=1)
+ prediction = classes[index]
+ probability = "{:1.2f}".format(probability_tensor.item())
+ logger.info(f'Predicted class is {prediction} with a probability of {probability}')
+ return {'class': prediction, 'probability': probability}
+[INST] Transforms the posted image to a PyTorch Tensor.. [/INST] def image_to_tensor(preprocess_pipeline, body):
+ data = json.loads(body)
+ name = data['name']
+ image = data['file']
+ dec = base64.b64decode(image)
+ img = PIL.Image.open(io.BytesIO(dec))
+ img_tensor = preprocess_pipeline(img)
+ img_tensor = img_tensor.unsqueeze(0)
+ return img_tensor
+[INST] The main function which is called in the lambda function as defined in our template.yml. [/INST] def lambda_handler(event, context):
+ image_tensor = image_to_tensor(preprocess_pipeline, event['body'])
+ response = predict(model, classes, image_tensor)
+ return {
+ "statusCode": 200,
+ "body": json.dumps(response)
+ }
+[INST] r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
+The graph is differentiated using the chain rule.
+
+This function accumulates gradients in the leaves - you might need to zero
+them before calling it.. [/INST] def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
+ if grad_variables is not None:
+ warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
+ if grad_tensors is None:
+ grad_tensors = grad_variables
+ else:
+ raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
+ "arguments both passed to backward(). Please only "
+ "use 'grad_tensors'.")
+ tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
+ if grad_tensors is None:
+ grad_tensors = [None] * len(tensors)
+ elif isinstance(grad_tensors, torch.Tensor):
+ grad_tensors = [grad_tensors]
+ else:
+ grad_tensors = list(grad_tensors)
+ grad_tensors = _make_grads(tensors, grad_tensors)
+ if retain_graph is None:
+ retain_graph = create_graph
+ Variable._execution_engine.run_backward(
+ tensors, grad_tensors, retain_graph, create_graph,
+ allow_unreachable=True)
+[INST] Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops. [/INST] def onnxifi_caffe2_net(
+ pred_net,
+ input_shapes,
+ infer_shapes=False,
+ debug=False):
+ shape_hints = {}
+ external_inputs = []
+ if not infer_shapes:
+ for k, v in input_shapes.items():
+ need_input_tensor = True
+ if workspace.HasBlob(k):
+ itensor = workspace.FetchBlob(k)
+ if itensor.shape == v:
+ need_input_tensor = False
+ if need_input_tensor:
+ workspace.FeedBlob(k, np.random.randn(*v).astype(np.float32))
+ external_inputs.append(k)
+ for k, v in input_shapes.items():
+ shape_hints[k] = v
+ pred_net_str = C.onnxifi(pred_net.SerializeToString(),
+ external_inputs,
+ shape_hints,
+ infer_shapes,
+ debug)
+ pred_net_cut = caffe2_pb2.NetDef()
+ pred_net_cut.ParseFromString(pred_net_str)
+ return pred_net_cut
+[INST] r"""einsum(equation, *operands) -> Tensor
+This function provides a way of computing multilinear expressions using the
+Einstein summation convention.. [/INST] def einsum(equation, *operands):
+ if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
+ operands = operands[0]
+ return torch._C._VariableFunctions.einsum(equation, operands)
+[INST] r"""Returns a new tensor with boolean elements representing if each element is `Finite` or not.. [/INST] def isfinite(tensor):
+ if not isinstance(tensor, torch.Tensor):
+ raise ValueError("The argument is not a tensor", str(tensor))
+ if not tensor.is_floating_point():
+ return torch.ones_like(tensor, dtype=torch.uint8)
+ return (tensor == tensor) & (tensor.abs() != inf)
+[INST] r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.. [/INST] def unique(input, sorted=False, return_inverse=False, dim=None):
+ if dim is not None:
+ output, inverse_indices = torch._unique_dim(
+ input,
+ dim,
+ sorted=sorted,
+ return_inverse=return_inverse
+ )
+ else:
+ output, inverse_indices = torch._unique(
+ input,
+ sorted=sorted,
+ return_inverse=return_inverse,
+ )
+ if return_inverse:
+ return output, inverse_indices
+ else:
+ return output
+[INST] r"""Returns a contraction of a and b over multiple dimensions.
+:attr:`tensordot` implements a generalizes the matrix product.. [/INST] def tensordot(a, b, dims=2):
+ if isinstance(dims, (list, tuple)) or \
+ (isinstance(dims, torch.Tensor) and dims.numel() > 1):
+ dims_a, dims_b = dims
+ else:
+ if isinstance(dims, torch.Tensor):
+ dims = dims.item()
+ dims_a = list(range(-dims, 0))
+ dims_b = list(range(dims))
+ return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
+[INST] r"""Returns the indices that sort a tensor along a given dimension in ascending
+order by value.
+
+This is the second value returned by :meth:`torch.sort`. See its documentation
+for the exact semantics of this method.. [/INST] def argsort(input, dim=None, descending=False):
+ if dim is None:
+ return torch.sort(input, -1, descending)[1]
+ return torch.sort(input, dim, descending)[1]
+[INST] r"""Returns the matrix norm or vector norm of a given tensor.. [/INST] def norm(input, p="fro", dim=None, keepdim=False, out=None):
+ ndim = input.dim()
+ if dim is None and out is None:
+ if p == "fro":
+ return torch._C._VariableFunctions.frobenius_norm(input)
+ elif p != "nuc":
+ return torch._C._VariableFunctions.norm(input, p)
+ if p == "fro":
+ if dim is None:
+ dim = tuple(range(ndim))
+ if out is None:
+ return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
+ return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
+ elif p == "nuc":
+ if out is None:
+ torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
+ return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+ else:
+ if out is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+[INST] r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
+using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
+of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
+needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
+If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.. [/INST] def chain_matmul(*matrices):
+ return torch._C._VariableFunctions.chain_matmul(matrices)
+[INST] r"""Sets the seed for generating random numbers for the current GPU.
+It's safe to call this function if CUDA is not available; in that
+case, it is silently ignored.. [/INST] def manual_seed(seed):
+ seed = int(seed)
+ _lazy_call(lambda: _C._cuda_manualSeed(seed))
+[INST] r"""Sets the seed for generating random numbers on all GPUs.
+It's safe to call this function if CUDA is not available; in that
+case, it is silently ignored.. [/INST] def manual_seed_all(seed):
+ seed = int(seed)
+ _lazy_call(lambda: _C._cuda_manualSeedAll(seed))
+[INST] r"""Sets the seed for generating random numbers to a random number for the current GPU.
+It's safe to call this function if CUDA is not available; in that
+case, it is silently ignored.
+
+:
+If you are working with a multi-GPU model, this function will only initialize
+the seed on one GPU.. [/INST] def seed():
+ _lazy_call(lambda: _C._cuda_seed())
+[INST] r"""Sets the seed for generating random numbers to a random number on all GPUs.
+It's safe to call this function if CUDA is not available; in that
+case, it is silently ignored.. [/INST] def seed_all():
+ _lazy_call(lambda: _C._cuda_seedAll())
+[INST] r"""Returns the current random seed of the current GPU.
+warning::
+This function eagerly initializes CUDA.. [/INST] def initial_seed():
+ _lazy_init()
+ return _C._cuda_initialSeed()
+[INST] Parses a type annotation specified as a comment.. [/INST] def parse_type_line(type_line):
+ arg_ann_str, ret_ann_str = split_type_line(type_line)
+ try:
+ arg_ann = eval(arg_ann_str, _eval_env)
+ except SyntaxError:
+ raise RuntimeError("Failed to parse the argument list of a type annotation")
+ if not isinstance(arg_ann, tuple):
+ arg_ann = (arg_ann,)
+ try:
+ ret_ann = eval(ret_ann_str, _eval_env)
+ except SyntaxError:
+ raise RuntimeError("Failed to parse the return type of a type annotation")
+ arg_types = [ann_to_type(ann) for ann in arg_ann]
+ return arg_types, ann_to_type(ret_ann)
+[INST] r"""A helper function for checkpointing sequential models.
+Sequential models execute a list of modules/functions in order
+(sequentially). Therefore, we can divide such a model in various segments
+and checkpoint each segment. All segments except the last will run in
+:func:`torch.no_grad` manner, i.e., not storing the intermediate
+activations. The inputs of each checkpointed segment will be saved for
+re-running the segment in the backward pass.
+
+See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
+
+
+
+At least one of the inputs needs to have :code:`requires_grad=True` if
+grads are needed for model inputs, otherwise the checkpointed part of the
+model won't have gradients.. [/INST] def checkpoint_sequential(functions, segments, *inputs):
+ def run_function(start, end, functions):
+ def forward(*inputs):
+ for j in range(start, end + 1):
+ if isinstance(inputs, tuple):
+ inputs = functions[j](*inputs)
+ else:
+ inputs = functions[j](inputs)
+ return inputs
+ return forward
+ if isinstance(functions, torch.nn.Sequential):
+ functions = list(functions.children())
+ segment_size = len(functions) // segments
+ end = -1
+ for start in range(0, segment_size * (segments - 1), segment_size):
+ end = start + segment_size - 1
+ inputs = checkpoint(run_function(start, end, functions), *inputs)
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ return run_function(end + 1, len(functions) - 1, functions)(*inputs)
+[INST] Marks a function as a weak script function. When used in a script function
+or ScriptModule, the weak script function will be lazily compiled and
+inlined in the graph. When not used in a script function, the weak script
+annotation has no effect.. [/INST] def weak_script(fn, _frames_up=0):
+ _compiled_weak_fns[fn] = {
+ "status": COMPILATION_PENDING,
+ "compiled_fn": None,
+ "rcb": createResolutionCallback(_frames_up + 1)
+ }
+ return fn
+[INST] Dispatches to either of 2 weak script functions based on a boolean argument.
+In TorchScript, the boolean argument must be constant so that the correct
+function to use can be determined at compile time.. [/INST] def boolean_dispatch(arg_name, arg_index, default, if_true, if_false):
+ if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None:
+ raise RuntimeError("both functions must be weak script")
+ def fn(*args, **kwargs):
+ dispatch_flag = False
+ if arg_name in kwargs:
+ dispatch_flag = kwargs[arg_name]
+ elif arg_index < len(args):
+ dispatch_flag = args[arg_index]
+ if dispatch_flag:
+ return if_true(*args, **kwargs)
+ else:
+ return if_false(*args, **kwargs)
+ if if_true.__doc__ is None and if_false.__doc__ is not None:
+ doc = if_false.__doc__
+ if_true.__doc__ = doc
+ elif if_false.__doc__ is None and if_true.__doc__ is not None:
+ doc = if_true.__doc__
+ if_false.__doc__ = doc
+ elif if_false.__doc__ is None and if_true.__doc__ is None:
+ doc = None
+ else:
+ raise RuntimeError("only one function can have a docstring")
+ fn.__doc__ = doc
+ _boolean_dispatched[fn] = {
+ "if_true": if_true,
+ "if_false": if_false,
+ "index": arg_index,
+ "default": default,
+ "arg_name": arg_name
+ }
+ return fn
+[INST] All previously saved modules, no matter their device, are first loaded onto CPU,
+and then are moved to the devices they were saved from. If this fails , an exception is raised.
+However, storages can be dynamically remapped to an alternative set of devices
+using the `map_location` argument. Comparing to :func:`torch.load`, `map_location`
+in this function is simplified, which only accepts a string ,
+or torch.device ). [/INST] def load(f, map_location=None):
+ m = ScriptModule()
+ def module_lookup(names):
+ curr = m
+ for name in names:
+ if not hasattr(curr, name):
+ setattr(curr, name, ScriptModule())
+ curr = getattr(curr, name)
+ return curr
+ if isinstance(map_location, string_classes):
+ map_location = torch.device(map_location)
+ elif not (map_location is None or
+ isinstance(map_location, torch.device)):
+ raise ValueError("map_location should be either None, string or torch.device, "
+ "but got type: " + str(type(map_location)))
+ if (str(map_location).startswith('cuda')):
+ validate_cuda_device(map_location)
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
+ (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ torch._C.import_ir_module(module_lookup, f, map_location)
+ else:
+ torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location)
+ return m
+[INST] Saves a ScriptModule to a file.. [/INST] def save(m, f):
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
+ (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ m.save(f)
+ else:
+ ret = m.save_to_buffer()
+ f.write(ret)
+[INST] Trace a function and return an executable trace that will be optimized
+using just-in-time compilation.
+
+
+
+Tracing only correctly records functions and modules which are not data
+dependent and do not have
+any untracked external dependencies . If you trace such models, you may silently get
+incorrect results on subsequent invocations of the model. The tracer
+will try to emit warnings when doing something that may cause an
+incorrect trace to be produced.. [/INST] def trace(func,
+ example_inputs,
+ optimize=True,
+ check_trace=True,
+ check_inputs=None,
+ check_tolerance=1e-5,
+ _force_outplace=False):
+ if not _enabled:
+ return func
+ executor_options = {'optimize': bool(optimize)}
+ if isinstance(example_inputs, torch.Tensor):
+ example_inputs = (example_inputs,)
+ elif not isinstance(example_inputs, tuple):
+ example_inputs = tuple(example_inputs)
+ module = TopLevelTracedModule(func, **executor_options)
+ var_lookup_fn = _create_interpreter_name_lookup_fn(0)
+ module._create_method_from_trace('forward', func, example_inputs,
+ var_lookup_fn, _force_outplace)
+ if check_trace:
+ if check_inputs is not None:
+ _check_trace(check_inputs, func, executor_options, module, check_tolerance, _force_outplace)
+ else:
+ _check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)
+ return module
+[INST] Get the WeakScriptModuleProxy corresponding to mod if it exists. [/INST] def _try_get_weak_module(mod):
+ if not isinstance(mod, Module):
+ return None
+ return _weak_modules.get(mod)
+[INST] Check if a type has been annotated with `weak_module`. [/INST] def _is_weak_type(cls):
+ return cls in _weak_types
+[INST] Calls script_method for each method on the type of the object passed in and
+returns the generated ScriptMethodStubs. [/INST] def _get_weak_stubs(cls):
+ stubs = []
+ for name in dir(cls):
+ func = get_function_from_type(cls, name)
+ if func in _weak_script_methods:
+ entry = _weak_script_methods[func]
+ stub = script_method(entry["original_method"], entry["rcb"])
+ stubs.append(stub)
+ return stubs
+[INST] Converts a weak module into a subclass of ScriptModule. [/INST] def _make_strong(mod):
+ if mod in _weak_modules:
+ return _weak_modules[mod]
+ stubs = _weak_types.get(type(mod))["method_stubs"]
+ if stubs is None:
+ stubs = _get_weak_stubs(type(mod))
+ _weak_types[type(mod)]["method_stubs"] = stubs
+ proxy = WeakScriptModuleProxy(mod, stubs)
+ _weak_modules[mod] = proxy
+ return proxy
+[INST] Gets a parameter from the request by name and casts by the desired type. [/INST] def param(name, cast=None):
+ value = None
+ if flask.request.json:
+ return flask.request.json.get(name, None)
+ if value is None:
+ value = flask.request.args.get(name, None)
+ if value is None and flask.request.form:
+ value = flask.request.form.get(name, None)
+ if cast and value is not None:
+ if cast is bool:
+ return value.lower() in ['true', 'yes', 'y', '1', '']
+ if cast is list:
+ return value.split(',') if len(value) > 0 else []
+ if cast is ndb.Key:
+ return ndb.Key(urlsafe=value)
+ return cast(value)
+ return value
+[INST] Helper function to make sure lambda functions are cached and not lost.. [/INST] def make_callback(func, *param):
+ return lambda: func(*param)
+[INST] Configure and compile the action value network. For now, just an
+MLP, but CNNs will be supported soon.. [/INST] def configure(self, config=None):
+ batch_shape = tuple([None] + self._state_dim)
+ if self._state_type == "vector":
+ config_ = {
+ "dense_dims": list(range(self._action_dim+1, self._state_dim[0]+1))[::-1],
+ "dense_activation": "relu",
+ "output_dim": self._action_dim,
+ "optimizer": {"Adam": {"learning_rate": 0.001}},
+ "loss": {"MeanSquaredError": {}},
+ "output_activation": "linear"
+ }
+ if config is not None:
+ config_.update(config)
+ network = dense.MLP(config_)
+ elif "tensor" in self._state_type:
+ raise NotImplementedError("Only supporting MLPs right now.")
+ else:
+ raise ValueError("Unrecognized state type.")
+ return train.model_init(network, config_["loss"], config_["optimizer"], batch_shape)
+[INST] Evaluate the q network (or frozen q network) at the current state array. [/INST] def q_eval(self, state, freeze_flag=False, reshape=None):
+ if freeze_flag:
+ assert self._frozen_q_network is not None, \
+ "Asked for the frozen q network evaluation, but no frozen q network available."
+ preds = self._frozen_q_network.predict(state)
+ else:
+ preds = self._q_network.predict(state)
+ if reshape is not None:
+ preds = preds.reshape(reshape)
+ return preds
+[INST] Greedy action for the provided state. [/INST] def greedy_policy(self, state):
+ return np.argmax(self.q_eval(state, reshape=self._action_dim))
+[INST] Return the total discounted reward for a given sequential
+list of rewards and a discount rate. [/INST] def _return(self, rewards):
+ return sum([rewards[i]*pow(self._discount, i) for i in range(len(rewards))])
+[INST] Parse the raw state, action, reward episode data into a batch
+for updating the action value network.. [/INST] def _batch(self, data):
+ states, actions, rewards = zip(*data)
+ state_array = np.array(states).reshape([len(data)] + self._state_dim)
+ target_array = np.zeros((len(data), self._action_dim))
+ for i, s in enumerate(states):
+ target_vector = self.q_eval(s, reshape=self._action_dim)
+ target_vector[actions[i]] = self._return(rewards[i:])
+ target_array[i,:] = target_vector
+ return state_array, target_array
+[INST] For each episode, play with the epsilon-greedy policy and record
+the states, actions, and rewards. Once the episode is up, use the
+true reward to prep a batch and update the action value network.. [/INST] def train(self, n_episodes, max_steps=1000, epsilon=0.01, epsilon_schedule=False, network=None):
+ if self._q_network is None:
+ self._q_network = self.configure(network)
+ max_reward = 0.0
+ for i in range(n_episodes):
+ if epsilon_schedule is not None:
+ if i % epsilon_schedule == 0:
+ schedule_step = int(i/epsilon_schedule)
+ epsilon = epsilon*pow(0.9, schedule_step)
+ obs = self.reset()
+ ep_reward = 0.0
+ tuple_batch = []
+ for j in range(max_steps):
+ self.render()
+ obs = obs.reshape([1] + self._state_dim)
+ action = self.policy(obs, epsilon=epsilon)
+ new_obs, reward, done, info = self.observe(action)
+ ep_reward += reward
+ tuple_batch.append((obs, action, reward))
+ if done:
+ max_reward = max(max_reward, ep_reward)
+ print("Current max reward: {}".format(max_reward))
+ break
+ else:
+ obs = new_obs
+ batch_x, batch_y = self._batch(tuple_batch)
+ loss, grads = train.grad(self._q_network, batch_x, batch_y)
+ updates = zip(grads, self._q_network.trainable_variables)
+ self._q_network.optimizer.apply_gradients(updates)
+ print("Current loss: {}".format(loss))
+[INST] Construct a batch for learning using the provided tuples and
+the action value function.. [/INST] def _batch(self, data, q_freeze):
+ states, actions, rewards, states_prime = zip(*data)
+ state_shape = [len(data)] + self._state_dim
+ action_shape = [len(data), self._action_dim]
+ state_array = np.array(states).reshape(state_shape)
+ state_prime_array = np.array(states_prime).reshape(state_shape)
+ target_array = self.q_eval(state_array, freeze_flag=q_freeze is not None,
+ reshape=action_shape)
+ target_prime_array = self.q_eval(state_prime_array, freeze_flag=q_freeze is not None,
+ reshape=action_shape)
+ for i in range(target_prime_array.shape[0]):
+ target_array[i, actions[i]] = rewards[i] + self._discount*np.max(target_prime_array[i,:])
+ return state_array, target_array
+[INST] For each episode, play with the epsilon-greedy policy and record
+the states, actions, and rewards. At each step, use the action value
+function and a set of random 4-tuples from the replay buffer to bootstrap
+the q-learning targets and update the network.. [/INST] def train(self, n_episodes, max_steps=1000, epsilon=0.01, epsilon_schedule=10, buffer_size=128,
+ batch_size=16, weight_freeze=None, network=None):
+ if self._q_network is None:
+ self._q_network = self.configure(network)
+ self._buffer_size = buffer_size
+ if weight_freeze is not None:
+ self._frozen_q_network = self.configure(network)
+ self._set_frozen_weights()
+ max_reward = 0.0
+ total_steps = 0
+ freeze_step = 0
+ loss = 0.0
+ for i in range(n_episodes):
+ if epsilon_schedule is not None:
+ if i % epsilon_schedule == 0:
+ epsilon = epsilon*0.999
+ obs = self.reset()
+ obs = obs.reshape([1] + self._state_dim)
+ ep_reward = 0.0
+ for j in range(max_steps):
+ total_steps += 1
+ self.render()
+ action = self.policy(obs, epsilon=epsilon)
+ new_obs, reward, done, info = self.observe(action)
+ new_obs = new_obs.reshape([1] + self._state_dim)
+ ep_reward += reward
+ self._buffer_add((obs, action, reward, new_obs))
+ tuple_batch = self._buffer_batch(batch_size)
+ batch_x, batch_y = self._batch(tuple_batch, weight_freeze)
+ loss, grads = train.grad(self._q_network, batch_x, batch_y)
+ updates = zip(grads, self._q_network.trainable_variables)
+ self._q_network.optimizer.apply_gradients(updates)
+ if done:
+ max_reward = max(max_reward, ep_reward)
+ print("Current max reward: {}".format(max_reward))
+ break
+ else:
+ obs = new_obs
+ print("Finished game {}...".format(i+1))
+ print("Current loss: {}".format(loss))
+ if weight_freeze is not None:
+ current_step = int(total_steps / weight_freeze)
+ if current_step > freeze_step:
+ self._set_frozen_weights()
+ freeze_step = current_step
+[INST] Image name format: 0001001.png, where first four digits represent identity
+and last four digits represent cameras. Camera 1&2 are considered the same
+view and camera 3&4 are considered the same view.. [/INST] def prepare_split(self):
+ if not osp.exists(self.split_path):
+ print('Creating 10 random splits of train ids and test ids')
+ img_paths = sorted(glob.glob(osp.join(self.dataset_dir, '*.jpg')))
+ img_list = []
+ pid_container = set()
+ camid_container = set()
+ for img_path in img_paths:
+ img_name = osp.basename(img_path)
+ pid, camid, datetime, sn = img_name.split('-')
+ camid = int(camid)
+ if '_' in pid:
+ pid, vid = pid.split('_')
+ pid = int(pid) - 1
+ img_list.append((img_path, pid, vid, camid))
+ else:
+ pid = int(pid) - 1
+ img_list.append((img_path, pid, camid))
+ pid_container.add(pid)
+ camid_container.add(camid)
+ num_pids = len(pid_container)
+ num_train_pids, num_train_camids = num_pids // 2, len(camid_container) // 2
+ splits = []
+ for _ in range(10):
+ order = np.arange(num_pids)
+ np.random.shuffle(order)
+ train_idxs = np.sort(order[:num_train_pids])
+ idx2label = {idx: label for label, idx in enumerate(train_idxs)}
+ order = np.array(list(camid_container))
+ np.random.shuffle(order)
+ cam_a_ids = np.sort(order[:num_train_camids])
+ train, test_a, test_b = [], [], []
+ for attr in img_list:
+ if attr[1] in train_idxs:
+ attr = list(attr)
+ attr[1] = idx2label[attr[1]]
+ train.append(tuple(attr))
+ else:
+ if attr[-1] in cam_a_ids:
+ test_a.append(attr)
+ else:
+ test_b.append(attr)
+ split = {
+ 'train': train,
+ 'query': test_a,
+ 'gallery': test_b,
+ 'num_train_pids': num_train_pids,
+ 'num_query_pids': num_pids - num_train_pids,
+ 'num_gallery_pids': num_pids - num_train_pids
+ }
+ splits.append(split)
+ split = {
+ 'train': train,
+ 'query': test_b,
+ 'gallery': test_a,
+ 'num_train_pids': num_train_pids,
+ 'num_query_pids': num_pids - num_train_pids,
+ 'num_gallery_pids': num_pids - num_train_pids
+ }
+ splits.append(split)
+ print('Totally {} splits are created'.format(len(splits)))
+ write_json(splits, self.split_path)
+ print('Split file saved to {}'.format(self.split_path))
+[INST] A function wrapper for building a model.. [/INST] def build_model(name, num_classes, loss='softmax', pretrained=True, use_gpu=True, backbone='resnet50'):
+ avai_models = list(__model_factory.keys())
+ if name not in avai_models:
+ raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models))
+ return __model_factory[name](
+ num_classes=num_classes,
+ loss=loss,
+ pretrained=pretrained,
+ use_gpu=use_gpu,
+ backbone=backbone,
+ )
+[INST] Initializes model with pretrained weights.
+Layers that don't match with pretrained layers in name or size are kept unchanged.. [/INST] def init_pretrained_weights(model, key=''):
+ import os
+ import errno
+ import gdown
+ from collections import OrderedDict
+ def _get_torch_home():
+ ENV_TORCH_HOME = 'TORCH_HOME'
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+ DEFAULT_CACHE_DIR = '~/.cache'
+ torch_home = os.path.expanduser(
+ os.getenv(
+ ENV_TORCH_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+ )
+ )
+ )
+ return torch_home
+ torch_home = _get_torch_home()
+ model_dir = os.path.join(torch_home, 'checkpoints')
+ try:
+ os.makedirs(model_dir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+ filename = key + '_imagenet.pth'
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
+ state_dict = torch.load(cached_file)
+ model_dict = model.state_dict()
+ new_state_dict = OrderedDict()
+ matched_layers, discarded_layers = [], []
+ for k, v in state_dict.items():
+ if k.startswith('module.'):
+ k = k[7:]
+ if k in model_dict and model_dict[k].size() == v.size():
+ new_state_dict[k] = v
+ matched_layers.append(k)
+ else:
+ discarded_layers.append(k)
+ model_dict.update(new_state_dict)
+ model.load_state_dict(model_dict)
+ if len(matched_layers) == 0:
+ warnings.warn(
+ 'The pretrained weights from "{}" cannot be loaded, '
+ 'please check the key names manually '
+ '(** ignored and continue **)'.format(cached_file)
+ )
+ else:
+ print(
+ 'Successfully loaded imagenet pretrained weights from "{}"'.
+ format(cached_file)
+ )
+ if len(discarded_layers) > 0:
+ print(
+ '** The following layers are discarded '
+ 'due to unmatched keys or layer size: {}'.
+ format(discarded_layers)
+ )
+[INST] A unified pipeline for training and evaluating a model.. [/INST] def run(self, aim_sess, save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None,
+ start_eval=0, eval_freq=-1, test_only=False, print_freq=10,
+ dist_metric='euclidean', normalize_feature=False, visrank=False, visrankactiv=False, visrankactivthr=False,
+ maskthr=0.7, visrank_topk=10, use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False, visactmap=False,
+ vispartmap=False, visdrop=False, visdroptype='random'):
+ if visrank and not test_only:
+ raise ValueError('visrank=True is valid only if test_only=True')
+ if visrankactiv and not test_only:
+ raise ValueError('visrankactiv=True is valid only if test_only=True')
+ if visrankactivthr and not test_only:
+ raise ValueError('visrankactivthr=True is valid only if test_only=True')
+ if visdrop and not test_only:
+ raise ValueError('visdrop=True is valid only if test_only=True')
+ if test_only:
+ self.test(
+ 0,
+ dist_metric=dist_metric,
+ normalize_feature=normalize_feature,
+ visrank=visrank,
+ visrankactiv=visrankactiv,
+ visrank_topk=visrank_topk,
+ save_dir=save_dir,
+ use_metric_cuhk03=use_metric_cuhk03,
+ ranks=ranks,
+ rerank=rerank,
+ maskthr=maskthr,
+ visrankactivthr=visrankactivthr,
+ visdrop=visdrop,
+ visdroptype=visdroptype
+ )
+ return
+ if self.writer is None:
+ self.writer = SummaryWriter(log_dir=save_dir)
+ if visactmap:
+ self.visactmap(self.test_loader, save_dir, self.datamanager.width, self.datamanager.height, print_freq)
+ return
+ if vispartmap:
+ self.vispartmap(self.test_loader, save_dir, self.datamanager.width, self.datamanager.height, print_freq)
+ return
+ time_start = time.time()
+ self.max_epoch = max_epoch
+ print('=> Start training')
+ for epoch in range(start_epoch, max_epoch):
+ losses = self.train(print_freq=print_freq, fixbase_epoch=fixbase_epoch,
+ open_layers=open_layers, epoch=epoch)
+ for key in losses.meters.keys():
+ aim_sess.track(losses.meters[key].avg, name=key, epoch=epoch, subset='train')
+ if (epoch+1) >= start_eval and eval_freq > 0 and (epoch+1) % eval_freq == 0 and (epoch+1) != max_epoch:
+ rank1 = self.test(
+ epoch,
+ aim_sess=aim_sess,
+ dist_metric=dist_metric,
+ normalize_feature=normalize_feature,
+ visrank=visrank,
+ visrankactiv=visrankactiv,
+ visrank_topk=visrank_topk,
+ save_dir=save_dir,
+ use_metric_cuhk03=use_metric_cuhk03,
+ ranks=ranks,
+ rerank=rerank,
+ maskthr=maskthr,
+ visrankactivthr=visrankactivthr
+ )
+ self._save_checkpoint(epoch, rank1, save_dir)
+ if max_epoch > 0:
+ print('=> Final test')
+ rank1 = self.test(
+ epoch,
+ dist_metric=dist_metric,
+ normalize_feature=normalize_feature,
+ visrank=visrank,
+ visrankactiv=visrankactiv,
+ visrank_topk=visrank_topk,
+ save_dir=save_dir,
+ use_metric_cuhk03=use_metric_cuhk03,
+ ranks=ranks,
+ rerank=rerank,
+ maskthr=maskthr,
+ visrankactivthr=visrankactivthr
+ )
+ self._save_checkpoint(epoch, rank1, save_dir)
+ elapsed = round(time.time() - time_start)
+ elapsed = str(datetime.timedelta(seconds=elapsed))
+ print('Elapsed {}'.format(elapsed))
+ if self.writer is None:
+ self.writer.close()
+[INST] Visualizes CNN activation maps to see where the CNN focuses on to extract features.
+This function takes as input the query images of target datasets
+
+Zagoruyko and Komodakis. Paying more attention to attention: Improving the
+performance of convolutional neural networks via attention transfer. ICLR, 2017
+Zhou et al. Omni-Scale Feature Learning for Person Re-Identification.. [/INST] def visactmap(self, testloader, save_dir, width, height, print_freq):
+ self.model.eval()
+ imagenet_mean = [0.485, 0.456, 0.406]
+ imagenet_std = [0.229, 0.224, 0.225]
+ for target in list(testloader.keys()):
+ queryloader = testloader[target]['query']
+ actmap_dir = osp.join(save_dir, 'actmap_'+target)
+ mkdir_if_missing(actmap_dir)
+ print('Visualizing activation maps for {} ...'.format(target))
+ for batch_idx, data in enumerate(queryloader):
+ imgs, paths = data[0], data[3]
+ if self.use_gpu:
+ imgs = imgs.cuda()
+ try:
+ outputs = self.model(imgs, return_featuremaps=True)
+ except TypeError:
+ raise TypeError('forward() got unexpected keyword argument "return_featuremaps". '
+ 'Please add return_featuremaps as an input argument to forward(). When '
+ 'return_featuremaps=True, return feature maps only.')
+ if outputs.dim() != 4:
+ raise ValueError('The model output is supposed to have '
+ 'shape of (b, c, h, w), i.e. 4 dimensions, but got {} dimensions. '
+ 'Please make sure you set the model output at eval mode '
+ 'to be the last convolutional feature maps'.format(outputs.dim()))
+ outputs = (outputs**2).sum(1)
+ b, h, w = outputs.size()
+ outputs = outputs.view(b, h*w)
+ outputs = F.normalize(outputs, p=2, dim=1)
+ outputs = outputs.view(b, h, w)
+ if self.use_gpu:
+ imgs, outputs = imgs.cpu(), outputs.cpu()
+ for j in range(outputs.size(0)):
+ path = paths[j]
+ imname = osp.basename(osp.splitext(path)[0])
+ img = imgs[j, ...]
+ for t, m, s in zip(img, imagenet_mean, imagenet_std):
+ t.mul_(s).add_(m).clamp_(0, 1)
+ img_np = np.uint8(np.floor(img.numpy() * 255))
+ img_np = img_np.transpose((1, 2, 0))
+ am = outputs[j, ...].numpy()
+ am = cv2.resize(am, (width, height))
+ am = 255 * (am - np.max(am)) / (np.max(am) - np.min(am) + 1e-12)
+ am = np.uint8(np.floor(am))
+ am = cv2.applyColorMap(am, cv2.COLORMAP_JET)
+ overlapped = img_np * 0.4 + am * 0.6
+ overlapped[overlapped>255] = 255
+ overlapped = overlapped.astype(np.uint8)
+ grid_img = 255 * np.ones((height, 3*width+2*GRID_SPACING, 3), dtype=np.uint8)
+ grid_img[:, :width, :] = img_np[:, :, ::-1]
+ grid_img[:, width+GRID_SPACING: 2*width+GRID_SPACING, :] = am
+ grid_img[:, 2*width+2*GRID_SPACING:, :] = overlapped
+ cv2.imwrite(osp.join(actmap_dir, imname+'.jpg'), grid_img)
+ if (batch_idx+1) % print_freq == 0:
+ print('- done batch {}/{}'.format(batch_idx+1, len(queryloader)))
+[INST] Two-stepped transfer learning.
+The idea is to freeze base layers for a certain number of epochs
+and then open all layers for training.. [/INST] def two_stepped_transfer_learning(self, epoch, fixbase_epoch, open_layers, model=None):
+ model = self.model if model is None else model
+ if model is None:
+ return
+ if (epoch + 1) <= fixbase_epoch and open_layers is not None:
+ print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch + 1, fixbase_epoch))
+ open_specified_layers(model, open_layers)
+ else:
+ open_all_layers(model)
+[INST] Lemmatizes tokens using NLTK's lemmatizer tool.. [/INST] def lemmatize_lyrics(tokens):
+ lemmatized = [nltk.stem.WordNetLemmatizer().lemmatize(t) for t in tokens]
+ return lemmatized
+[INST] Create a new study from a name.
+If no name is specified, the storage class generates a name.
+The returned study ID is unique among all current and deleted studies.. [/INST] def create_new_study(self, study_name: Optional[str] = None) -> int:
+ return 1
+[INST] Create and add a new trial to a study.
+The returned trial ID is unique among all current and deleted trials.. [/INST] def create_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> int:
+ return 1
+[INST] process data before fitting into machine learning models.
+Parameters
+
+collection : list
+The out-of-sample set.
+
+order : int, optional
+Order of the signature.
+Default is 2.
+
+int
+the length of data considered for each patient.
+Default is 20.
+
+True or False
+whether or not the piece of data being standardised
+
+True or False
+whether or not the piece of data being counted for missing value
+
+True or False
+whether or not the missing values in the piece of data is filled by feedforward function
+
+True or False
+whether or not the missing values in the piece of data is to be cleaned
+
+True or False
+whether or not the missing values at the beginning of the piece of data is filled by the average
+
+True or False
+whether or not the navid method (by calculating the mean) is implemented
+
+True or False
+whether or not the time is included
+
+True or False
+whether or not the piece of data is being cumulated
+
+Returns
+
+x,y in appropriate form. [/INST] def data_model(collection, order=2,minlen=20, standardise=True, count=True, feedforward=True,\
+ missing_clean=False,start_average=False, naive=False,time=True,cumsum=True):
+ x=[]
+ y=[]
+ for participant in collection:
+ par_data=participant.data
+ if missing_clean:
+ participant_data=clean_missing1(par_data)
+ elif count:
+ if standardise:
+ participant_data=normalise(feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum),\
+ cumsum=cumsum,\
+ count=count,\
+ time=time)
+ else:
+ participant_data=feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum)
+ else:
+ if standardise:
+ if feedforward:
+ participant_data=normalise(feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum),\
+ cumsum=cumsum,\
+ count=count,\
+ time=time)
+ else:
+ participant_data=normalise(list_to_num(par_data),\
+ count=count,\
+ time=time,\
+ cumsum=cumsum)
+ else:
+ if feedforward:
+ participant_data=feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum)
+ else:
+ participant_data=list_to_num(par_data)
+ if participant_data is not False:
+ if naive:
+ if missing_clean:
+ x.append(np.sum(participant_data,axis=0)/minlen)
+ else:
+ x.append(np.sum(list_to_num(par_data),axis=0)/minlen)
+ else:
+ x.append(iisignature.sig(participant_data, order))
+ y.append(participant.diagnosis)
+ return x,y
+[INST] trying models with different parameters in len(set) or order in one-go.
+Parameters
+
+class of participants for the corresponding 2 tests
+
+num
+size of each participant data.
+training : scalar
+Training set proportional.
+
+Returns
+
+average accuracy for each case. [/INST] def model_onego(Participants, minlen, training=0.7, sample_size=50, \
+ start_average=False, cv=False,cumsum=True, class_=None):
+ random.seed(42)
+ standardise_set=[False,True, True]
+ count_set=[False, True,True]
+ feedforward_set=[False,True,True]
+ naive_set=[True, False,False]
+ time_set=[False, True,True]
+ missingclean_set=[False, False, False]
+ order_set=[None, 2,3]
+ col_num=len(standardise_set)
+ mean_accuracy=np.zeros(col_num)
+ mean_CM=[np.zeros((3,3)) for j in range(col_num)]
+ accuracy_collections=[[] for j in range(col_num)]
+ for i in range(sample_size):
+ train_set, test_set = buildData(Participants,\
+ training=training,\
+ minlen=minlen,\
+ class_=class_)
+ for j in range(col_num):
+ X_train,y_train=data_model(train_set,minlen=minlen,\
+ order=order_set[j],\
+ standardise=standardise_set[j],\
+ count=count_set[j],\
+ missing_clean=missingclean_set[j], \
+ start_average=start_average,\
+ feedforward=feedforward_set[j],\
+ naive=naive_set[j],\
+ time=time_set[j],\
+ cumsum=cumsum)
+ X_test,y_test=data_model(test_set,minlen=minlen, \
+ order=order_set[j],\
+ standardise=standardise_set[j], \
+ count=count_set[j],\
+ missing_clean=missingclean_set[j], \
+ start_average=start_average,\
+ feedforward=feedforward_set[j],\
+ naive=naive_set[j],\
+ time=time_set[j],\
+ cumsum=cumsum)
+ if cv:
+ accuracy= rf_cv(X_train,y_train,X_test,y_test,threshold)
+ else:
+ CM,accuracy= rf_model(X_train,y_train,X_test,y_test,threshold,regression=regression)
+ mean_CM[j]+=CM
+ mean_accuracy[j]+=accuracy/sample_size
+ accuracy_collections[j].append(accuracy)
+ return mean_CM, mean_accuracy,accuracy_collections
+[INST] Normalises the data of the patient with missing count.
+Parameters
+
+data : two dim data, consisting of ALTMAN and QIDS scores
+
+Returns
+
+data that are normalised and cumulated.. [/INST] def normalise(data,cumsum=True,count=True,time=True):
+ normalised_data=np.zeros((data.shape[0],data.shape[1]))
+ scoreMAX=[20,27]
+ scoreMIN=[0,0]
+ if count:
+ if time:
+ len_data=data.shape[1]-2
+ else:
+ len_data=data.shape[1]-1
+ else:
+ len_data=data.shape[1]
+ for i in range(len_data):
+ for j in range(data.shape[0]):
+ normalised_data[j][i]=standardise1(data[j][i],scoreMIN[i],scoreMAX[i])
+ if j>0:
+ normalised_data[j][i]+=normalised_data[j-1][i]
+ if count:
+ if time:
+ for kk in range(len(data[:,-1])):
+ normalised_data[kk][-1]=standardise(data[kk][-1],0,data[-1][-1])
+ if cumsum and data[-1][-2]!=0:
+ for jj in range(len(data[:,-2])):
+ normalised_data[jj][-2]=standardise(data[jj][-2],0,data[-1][-2])
+ else:
+ if cumsum and data[-1][-1]!=0:
+ for jjj in range(len(data[:,-1])):
+ normalised_data[jjj][-1]=standardise(data[jjj][-1],0,data[-1][-1])
+ return normalised_data
+[INST] Tests the model against a particular participant.
+Parameters
+
+path : str
+Path of the pickle file containing the streams
+of data from the participant.
+order : int, optional
+Order of the signature.
+Default is 2.
+
+int; the length of data considered for each patient.
+Default is 20.
+
+data whether or not standardised
+Default True
+count: missing data count or not
+Default True
+missing_clean: rulling out missing data or not
+Default False
+start_average: if the firt element is missing, replace it with average or 0
+Default False
+naive: using merely mean value of each dimension
+
+whether or not cumulate the data with time
+Default True
+
+Returns
+
+list
+3-dimensional vector indicating how often the participant
+has buckets that were classified in each clinical group.. [/INST] def test(self,\
+ path,\
+ order=2,\
+ standardise=True,\
+ count=True,\
+ feedforward=True,\
+ missing_clean=False,\
+ start_average=False,\
+ naive=False,\
+ time=False,\
+ cumsum=True):
+ file = open(path,'rb')
+ collection = pickle.load(file)
+ file.close()
+ x,y =data_model(collection,\
+ order=order,\
+ minlen=minlen,\
+ standardise=standardise,\
+ count=count,\
+ missing_clean=missing_clean,\
+ start_average=start_average, \
+ feedforward=feedforward,\
+ naive=naive,\
+ time=time,\
+ cumsum=cumsum)
+ predicted = self.rf.predict(x)
+ vector = np.zeros(3)
+ for i in range(len(x)):
+ vector[int(predicted[i])] += 1
+ vector /= float(len(x))
+ return vector
+[INST] Trains the model, as specified in the original paper.
+Parameters
+
+path : str
+Path of the pickle file containing the streams
+of data from the participant.
+order : int, optional
+Order of the signature.
+Default is 2.
+
+int; the length of data considered for each patient.
+Default is 20.
+
+data whether or not standardised
+Default True
+count: missing data count or not
+Default True
+missing_clean: rulling out missing data or not
+Default False
+start_average: if the firt element is missing, replace it with average or 0
+Default False
+naive: using merely mean value of each dimension
+
+whether or not cumulate the data with time
+Default True
+
+Returns
+
+Model
+Trained model.. [/INST] def train(path,\
+ order=2,\
+ minlen=20,\
+ standardise=True,\
+ count=True,\
+ feedforward=True,\
+ missing_clean=False,\
+ start_average=False, \
+ naive=False,\
+ time=False,\
+ cumsum=True):
+ file = open(path,'rb')
+ collection = pickle.load(file)
+ file.close()
+ random_state=42
+ x,y=data_model(collection,\
+ order=order,\
+ minlen=minlen,\
+ standardise=standardise,\
+ count=count,\
+ missing_clean=missing_clean,\
+ start_average=start_average, \
+ feedforward=feedforward,\
+ naive=naive,\
+ time=time,\
+ cumsum=cumsum)
+ reg = OneVsRestClassifier(RandomForestClassifier(n_estimators=1500,random_state=random_state))
+ reg.fit(x, y)
+ return Model(rf)
+[INST] Saves as a pickle file the training or testing sets.
+Parameters
+
+coll : list
+List of participants that should be exported. If the
+length of the list is 1, the set is the out-of-sample
+set. Otherwise, it is the training set.
+ID : int
+A random ID that will be used to export the file.
+
+Number of observations of each stream of data
+
+
+
+directory to save the dataset. [/INST] def export(coll,\
+ ID,\
+ sample_length=20,\
+ test_size=5,\
+ path_save="./dataset_spectrum/"):
+ try:
+ os.mkdir(path_save)
+ print("Directory " , path_save , " Created ")
+ except FileExistsError:
+ continue
+ if not os.path.exists(path_save+str(ID)):
+ os.makedirs(path_save+str(ID))
+ dataset=[]
+ l=copy.deepcopy(coll)
+ if len(coll)==1:
+ setType="test_set"
+ for participant in l:
+ min2=np.minimum(len(participant.data[0]),\
+ len(participant.data[1]))
+ random_start=np.random.randint(min2-sample_length,\
+ size=test_size)
+ for i in range(test_size):
+ p = Participant([participant.data[0][random_start[i]:random_start[i]+sample_length],\
+ participant.data[1][random_start[i]:random_start[i]+sample_length]],\
+ [participant.time[0][random_start[i]:random_start[i]+sample_length],\
+ participant.time[1][random_start[i]:random_start[i]+sample_length]],\
+ participant.idNumber,\
+ participant.diagnosis,\
+ None)
+ dataset.append(p)
+ else:
+ setType="train_set"
+ for participant in l:
+ min2=np.minimum(len(participant.data[0]),\
+ len(participant.data[1]))
+ random_start=random.randint(0,min2-sample_length)
+ p = Participant([participant.data[0][random_start:random_start+sample_length],\
+ participant.data[1][random_start:random_start+sample_length]],\
+ [participant.time[0][random_start:random_start+sample_length],\
+ participant.time[1][random_start:random_start+sample_length]],\
+ participant.idNumber,\
+ participant.diagnosis,\
+ None)
+ dataset.append(p)
+ filehandler = open(path_save+str(ID)+"/"+setType+".obj","wb")
+ pickle.dump(dataset,filehandler)
+ filehandler.close()
+[INST] trim healthy data such that plot can be seen.
+Parameters
+
+col : a collection of healthy data
+
+Returns
+
+list of str
+List of data can has been trim by threshold 0.03.. [/INST] def trim_triangle(col,index=1):
+ try1=copy.deepcopy(col)
+ for md in try1:
+ if md[0]==0.0:
+ if md[int(index)]<0.95:
+ md[0]=0.03
+ md[2]-=0.03
+ return try1
+[INST] Plots, given a set of scores, the density map on a triangle.
+Parameters
+
+scores : list
+List of scores, where each score is a 3-dimensional list.. [/INST] def plotDensityMap(scores):
+ TRIANGLE = np.array([[math.cos(math.pi*0.5), math.sin(math.pi*0.5)],
+ [math.cos(math.pi*1.166), math.sin(math.pi*1.166)],
+ [math.cos(math.pi*1.833), math.sin(math.pi*1.833)]])
+ pointsX = [score.dot(TRIANGLE)[0] for score in scores]
+ pointsY = [score.dot(TRIANGLE)[1] for score in scores]
+ vertices = []
+ vertices.append(np.array([1,0,0]).dot(TRIANGLE))
+ vertices.append(np.array([0,1,0]).dot(TRIANGLE))
+ vertices.append(np.array([0,0,1]).dot(TRIANGLE))
+ for i in range(3):
+ p1 = vertices[i]
+ if i == 2:
+ p2 = vertices[0]
+ else:
+ p2 = vertices[i+1]
+ c = 0.5 * (p1 + p2)
+ plt.plot([p1[0], p2[0]], [p1[1], p2[1]], color='k', linestyle='-', linewidth=2)
+ plt.plot([0, c[0]], [0, c[1]], color='k', linestyle='-', linewidth=1)
+ ax = plt.gca()
+ ax.set_xlim([-1.2, 1.32])
+ ax.set_ylim([-0.7,1.3])
+ ax.text(0.8, -0.6, 'Bipolar')
+ ax.text(-1.1, -0.6, 'Healthy')
+ ax.text(-0.15, 1.05, 'Borderline')
+ data = [[pointsX[i], pointsY[i]] for i in range(len(pointsX))]
+ H, _, _=np.histogram2d(pointsX,pointsY,bins=40,normed=True)
+ norm=H.sum()
+ contour1=0.75
+ target1=norm*contour1
+ def objective(limit, target):
+ w = np.where(H>limit)
+ count = H[w]
+ return count.sum() - target
+ level1 = scipy.optimize.bisect(objective, H.min(), H.max(), args=(target1,))
+ levels = [level1]
+ data = np.array(data)
+ sns.kdeplot(np.array(pointsX), np.array(pointsY), shade=True, ax=ax)
+ sns.kdeplot(np.array(pointsX), np.array(pointsY), n_levels=3, ax=ax, cmap="Reds")
+ plt.show()
+[INST] process data before fitting into machine learning models.
+Parameters
+
+collection : list
+The out-of-sample set.
+order : int, optional
+Order of the signature.
+Default is 2.
+
+int
+the length of data considered for each patient.
+Default is 20.
+
+True or False
+whether or not the piece of data being standardised
+count: True or False
+whether or not the piece of data being counted for missing value
+
+True or False
+whether or not the missing values in the piece of data is filled by feedforward function
+
+True or False
+whether or not the missing values in the piece of data is to be cleaned
+
+True or False
+whether or not the missing values at the beginning of the piece of data is filled by the average
+naive: True or False
+whether or not the navid method (by calculating the mean) is implemented
+time: True or False
+whether or not the time is included
+
+True or False
+whether or not the piece of data is being cumulated
+
+
+
+Returns
+
+x,y in appropriate form. [/INST] def data_model(collection,\
+ minlen=20,\
+ order=2,\
+ standardise=True,\
+ count=True,\
+ feedforward=True,\
+ missing_clean=False,\
+ start_average=False, \
+ naive=False,\
+ time=True,\
+ cumsum=True,\
+ feature=int(0)):
+ x=[]
+ y=[]
+ for participant in collection:
+ if dual:
+ par_data=participant.data
+ else:
+ par_data=[participant.data[feature]]
+ if missing_clean:
+ participant_data=clean_missing1(par_data)
+ elif count:
+ if standardise:
+ participant_data=normalise(feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum),\
+ count=count,\
+ time=time)
+ else:
+ participant_data=feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum)
+ else:
+ if standardise:
+ if feedforward:
+ participant_data=normalise(feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum),\
+ count=count,\
+ time=time)
+ else:
+ participant_data=normalise(list_to_num(par_data),\
+ count=count,\
+ time=time)
+ else:
+ if feedforward:
+ participant_data=feed_forward(par_data,\
+ count=count,\
+ time=time,\
+ start_average=start_average,\
+ cumsum=cumsum)
+ else:
+ participant_data=list_to_num(par_data)
+ if participant_data is not False:
+ if naive:
+ if len(participant_data)==0:
+ x.append([-1])
+ else:
+ x.append([np.sum(participant_data)/minlen])
+ else:
+ x.append(iisignature.sig(participant_data, order))
+ y.append(participant.nextdata[feature])
+ return x,y
+[INST] Computing the mean absolute error for two lists of lists c and d. [/INST] def MAE(c,d,feature=int(0),scaling=False):
+ a = [item for sublist in c for item in sublist]
+ b= [item for sublist in d for item in sublist]
+ if not scaling:
+ a=scaling_list(a1,feature=feature)
+ b=scaling_list(b1,feature=feature)
+ if len(a)!=len(b):
+ print("something is wrong.")
+ else:
+ sd=np.array([np.abs(a[i]-b[i]) for i in range(len(a))])
+ return len(np.where(sd<1)[0])/len(sd), np.mean(sd)
+[INST] trying models (stateMRSPM, level 2, naive model) with different parameters in len(set) or order in one-go.
+Parameters
+
+class of participants for the corresponding 2 tests
+
+which class we are working on (0/1/2)
+
+minlen_set : list
+size of each participant data.
+
+training : scalar
+Training set proportional.
+
+number for loop
+Default is 50
+
+a parameter in data-transforming step, can be set to True for all models
+
+Returns
+
+average accuracy for each case. [/INST] def comprehensive_model(Participants,\
+ class_,\
+ minlen=10,\
+ training=0.7,\
+ sample_size=10,\
+ cumsum=True):
+ random.seed(42)
+ random_state=42
+ standardise_set=[False, True]
+ count_set=[False, True]
+ feedforward_set=[False, True]
+ naive_set=[True, False]
+ time_set=[False, True]
+ missing_clean_set=[False,False]
+ order_set=[None,int(3)]
+ feature_set=[int(0),int(1)]
+ y_collections=[[] for i in range(int(len(feature_set)*len(standardise_set)))]
+ y_pred_collections=[[] for i in range(int(len(feature_set)*len(standardise_set)))]
+ accuracy=np.zeros((len(feature_set),len(standardise_set)))
+ for i in range(sample_size):
+ train_set, test_set = buildData_prediction(Participants,\
+ training=training,\
+ minlen=minlen,\
+ class_=class_,\
+ regression=regression)
+ for j in range(len(standardise_set)):
+ for ii in range(len(feature_set)):
+ X_train,y_train=data_model(train_set, \
+ order=order_set[j],\
+ standardise=standardise_set[j],\
+ count=count_set[j],\
+ feedforward=feedforward_set[j],\
+ missing_clean=missing_clean_set[j],\
+ time=time_set[j],\
+ naive=naive_set[j],\
+ cumsum=cumsum,\
+ feature=feature_set[ii])
+ X_test,y_test=data_model(test_set,\
+ order=order_set[j],\
+ standardise=standardise_set[j],\
+ count=count_set[j],\
+ feedforward=feedforward_set[j],\
+ missing_clean=missing_clean_set[j],\
+ time=time_set[j],\
+ naive=naive_set[j],\
+ cumsum=cumsum,\
+ feature=feature_set[ii])
+ current_index=int(j*len(feature_set)+ii)
+ y_test_pred_=rf_nextdaymodel(X_train,y_train,X_test,\
+ regression=regression,\
+ feature=feature_set[ii] )
+ y_pred_collections[current_index].append(y_test_pred_)
+ y_collections[current_index].append(y_test)
+ accuracy[ii,j]+=accuracy_(y_test, y_test_pred_)/sample_size
+ return accuracy
+[INST] trying models (stateMRSPM, level 2, naive model) with different parameters in len(set) or order in one-go.
+Parameters
+
+class of participants for the corresponding 2 tests
+
+which class we are working on (0/1/2)
+
+minlen_set : list
+size of each participant data.
+
+training : scalar
+Training set proportional.
+
+number for loop
+Default is 50
+
+if severity of symptoms being considered
+Defalt: False (directly predicting raw scores),
+otherwise True (scaling the predicted score to severity of symptoms)
+
+a parameter in data-transforming step, can be set to True for all models
+
+Returns
+
+average accuracy for each case
+mae: average MAE for each case. [/INST] def comprehensive_nomissing_model(Participants,\
+ class_,\
+ minlen=10,\
+ training=0.7,\
+ sample_size=10,\
+ scaling=False,\
+ cumsum=True):
+ random.seed(42)
+ random_state=42
+ standardise_set=[False, True, True]
+ count_set=[False, True, True]
+ feedforward_set=[False, True, True]
+ naive_set=[True, False, False]
+ time_set=[False, True, True]
+ missing_clean_set=[False,False,False]
+ order_set=[None,int(2),int(3)]
+ feature_set=[int(0),int(1)]
+ y_collections=[[] for i in range(int(len(feature_set)*len(standardise_set)))]
+ y_pred_collections=[[] for i in range(int(len(feature_set)*len(standardise_set)))]
+ accuracy=np.zeros((len(feature_set),len(standardise_set)))
+ mse=np.zeros((len(feature_set),len(standardise_set)))
+ mae=np.zeros((len(feature_set),len(standardise_set)))
+ r2=np.zeros((len(feature_set),len(standardise_set)))
+ for i in range(sample_size):
+ train_set, test_set = buildData_prediction_nomissing(Participants,\
+ training=training,\
+ minlen=minlen,\
+ class_=class_)
+ for j in range(len(standardise_set)):
+ for ii in range(len(feature_set)):
+ X_train,y_train=data_model(train_set,\
+ order=order_set[j],\
+ standardise=standardise_set[j],\
+ count=count_set[j],\
+ feedforward=feedforward_set[j],\
+ missing_clean=missing_clean_set[j],\
+ time=time_set[j],\
+ naive=naive_set[j],\
+ feature=feature_set[ii],\
+ cumsum=cumsum)
+ X_test,y_test=data_model(test_set,\
+ order=order_set[j],\
+ standardise=standardise_set[j],\
+ count=count_set[j],\
+ feedforward=feedforward_set[j],\
+ missing_clean=missing_clean_set[j],\
+ time=time_set[j],\
+ naive=naive_set[j],\
+ feature=feature_set[ii],\
+ cumsum=cumsum)
+ current_index=int(j*len(feature_set)+ii)
+ y_test_pred_=rf_nextdaymodel(X_train,y_train,X_test,\
+ regression=True,\
+ feature=feature_set[ii])
+ if scaling:
+ y_test_pred_=scaling_list(y_test_pred_,\
+ feature=feature_set[ii])
+ y_test=scaling_list(y_test,\
+ feature=feature_set[ii])
+ y_pred_collections[current_index].append(y_test_pred_)
+ y_collections[current_index].append(y_test)
+ for j in range(len(standardise_set)):
+ for ii in range(len(feature_set)):
+ current_index=int(j*len(feature_set)+ii)
+ accuracy[ii,j],mae[ii,j]=MAE(y_pred_collections[current_index],\
+ y_collections[current_index],\
+ feature=feature_set[ii])
+ return accuracy, mae
+[INST] data process to make class Participant
+Parameters
+
+List of corresponding test1 & test2, test1 time & test 2 time, id list
+
+Returns
+
+class of participants for the corresponding 2 tests. [/INST] def make_classes(participants_data_list,participants_time_list,participants_list):
+ num=len(participants_list)
+ Participants=sorted(list(csv.reader(open("./source_data/patients.csv"))))
+ participants=[]
+ t=0
+ for i in range(num):
+ n=int(participants_list[i])
+ for l in Participants:
+ if int(l[0])==n:
+ if not l[1].isdigit():
+ print(n)
+ break
+ data=[participants_data_list[j][i].astype(int) for j in range(len(participants_data_list)) ]
+ time=[participants_time_list[j][i].astype(int) for j in range(len(participants_time_list)) ]
+ bp0 = int(l[1])
+ bp = {1: 2, 2: 0, 3: 1}[bp0]
+ participant=Participant(data, time, n, bp,None)
+ participants.append(participant)
+ t+=1
+ break
+ return participants
+[INST] cleaning redundant data: if two data are stored in the same day,
+and scores are the same, keep one; if one score is recorded as missing, then keep the
+other one.
+
+Parameters
+
+class participant data
+
+Returns
+
+shortened participant data. [/INST] def cleaning_same_data(Participants):
+ Pars=copy.deepcopy(Participants)
+ n=len(Pars[0].data)
+ for par in Pars:
+ for i in range(n):
+ t=0
+ total_len=len(par.time[i])
+ for j in range(total_len)[:-1]:
+ while int(par.time[i][j+1])==int(par.time[i][j]):
+ if par.data[i][j+1]==par.data[i][j]:
+ par.data[i]=np.delete(par.data[i],j+1)
+ par.time[i]=np.delete(par.time[i],j+1)
+ t+=1
+ elif par.data[i][j]<0:
+ par.data[i]=np.delete(par.data[i],j)
+ par.time[i]=np.delete(par.time[i],j)
+ t+=1
+ else:
+ if par.data[i][j+1]<0:
+ par.data[i]=np.delete(par.data[i],j+1)
+ par.time[i]=np.delete(par.time[i],j+1)
+ t+=1
+ else:
+ break
+ if int(j+1+t)==total_len:
+ break
+ if int(j+2+t)>=total_len:
+ break
+ return Pars
+[INST] for list a, apply function cutoff to each element. [/INST] def cutoff_list(a,feature=int(0)):
+ for i in range(len(a)):
+ a[i]=cutoff(a[i],feature=feature)
+ return a
+[INST] for list a, apply function scaling to each element. [/INST] def scaling_list(a,feature=int(0)):
+ b=np.zeros(len(a))
+ for i in range(len(a)):
+ b[i]=scaling(a[i],feature=feature)
+ return list(b)
+[INST] Displays the login form and handles the login action.. [/INST] def login(request, template_name='account/login.html',
+ redirect_field_name=REDIRECT_FIELD_NAME,
+ authentication_form=AuthenticationForm):
+ redirect_to = request.REQUEST.get(redirect_field_name, '')
+ if request.method == "POST":
+ form = authentication_form(data=request.POST)
+ if form.is_valid():
+ if not redirect_to or ' ' in redirect_to:
+ redirect_to = '/website'
+ elif '//' in redirect_to and re.match(r'[^\?]*//', redirect_to):
+ redirect_to = '/website/list'
+ auth_login(request, form.get_user())
+ if request.session.test_cookie_worked():
+ request.session.delete_test_cookie()
+ return HttpResponseRedirect(redirect_to)
+ else:
+ form = authentication_form(request)
+ request.session.set_test_cookie()
+ return render_to_response(template_name, {
+ 'form': form,
+ redirect_field_name: redirect_to
+ }, context_instance=RequestContext(request))
+[INST] The function's wrapper.
+The actual 'decorator' that is applied to the
+decorated function.. [/INST] def _wrap(self, f: Union[FunctionType, MethodType]):
+ return self._func(f, *self._args, **self._kwargs)
+[INST] An unwrapped version of the function.. [/INST] def unwrapped(self):
+ return self._func
+[INST] Used to raise an ErrorImitation with more flexibility.. [/INST] def raise_(error: Type[ErrorImitation], name, text):
+ error.__raise__(error(name, text))
+[INST] Activate ansi codes for the command prompt.
+Only works on Windows!. [/INST] def configure():
+ colors.Ansi.enable_ansi()
+[INST] Return whether f() has not declared a body,
+in other words, if it does nothing.. [/INST] def HasFunc_NoBody(f: Union[FunctionType, MethodType]) -> bool:
+ if f.__code__.co_code == b'd\x01S\x00':
+ return True
+ if f.__code__.co_code == b'd\x00S\x00':
+ return True
+ return False
+[INST] Function decorator that implements dll functions
+with custom type hints and documentation.
+
+This method is similar to the C# method System.Runtime.InteropServices.DllImport().
+
+Decorated functions must have not declared any body
+and have the name of the wanted dll's function.
+
+
+
+
+
+
+
+>>a
+
+Decorated functions can't take keyword arguments, since
+dll functions never do. To represent optional arguments,
+use system.optional as showed above.
+
+This syntax is only a user-made type hint for dll functions.
+
+If you don't need it, please use the 'Dll' class directly.. [/INST] def DllImport(func: Union[FunctionType, MethodType], file: Union[str, PathLike], type_: type[Dll.AnyDll] = Dll.WinDll):
+ dll = Dll(file, type_)
+ name = func.__name__
+ dll_name = file.split('/')[-1]
+ if not hasattr(dll, name):
+ raise AttributeError("Dll \"{0}\" has no function named \"{1}\".".format(dll_name, name))
+ def new_function(*args):
+ _func = getattr(dll, name)
+ _args = [common.Dll.basic_type_wrap(a) for a in args]
+ args = tuple(_args)
+ res = _func(*args)
+ return common.Dll.basic_type_unwrap(res)
+ if func.__defaults__ is not None:
+ raise SystemError("Dll functions don't have keyword arguments.")
+ if not Dll.HasFunc_NoBody(func):
+ raise SystemError("A dll hint function can't declare a body.")
+ new_function.__doc__ = func.__doc__
+ return new_function
+[INST] Get a string representing the name of the key.. [/INST] def name(self):
+ return self._names[self._value]
+[INST] Changes the file on which the IO stream points for new_target.
+As the above '@final' suggests, this method isn't
+overrideable.. [/INST] def configure(self, new_target: str or TextIOWrapper):
+ self.__redirect__(new_target)
+[INST] The dot product between two vectors, here it's self and other.. [/INST] def _dot_product(self, other) -> Union[int, float]:
+ return (self._coordinates[0] * other.x) + (self._coordinates[1] * other.y)
+[INST] The product between a vector and a real number, here it's self and item. [/INST] def _product(self, item) -> __class__:
+ x = item * self._coordinates[0]
+ y = item * self._coordinates[1]
+ return Vector2D(x, y)
+[INST] The dot product between two vectors, here it's self and other.. [/INST] def _dot_product(self, other):
+ return (self._coordinates[0] * other.x) + (self._coordinates[1] * other.y)
+[INST] The product between a vector and a real number, here it's self and item. [/INST] def _product(self, item):
+ x = item * self._coordinates[0]
+ y = item * self._coordinates[1]
+ return Vector2D(x, y)
+[INST] Create and return a new HWND window handle.
+Return None upon failure.. [/INST] def CreateWindowExA(dwExStyle: DWORD, lpClassName: str, lpWindowName: str, dwStyle: DWORD, X: int, Y: int,
+ nWidth: int, nHeight: int, hWndParent: HWND, hMenu: HMENU, hInstance: HINSTANCE,
+ lpParam: LPVOID) -> HWND:
+[INST] Get the string name of a type.. [/INST] def typename(obj):
+ return eval("type(obj).__name__", {'obj': obj, **globals()}, locals())
+[INST] Open file for reading and writing,
+without the system considering it as open by
+the actual process.. [/INST] def file(file_: str):
+ return text_io.TextIO(file_)
+[INST] Create a new file as per path .. [/INST] def nfile(path: str):
+ try:
+ x = open(path, mode="x")
+ x.close()
+ except:
+ raise
+[INST] Apply a CustomAnsiObject to the console.. [/INST] def apply_ansi(self, code: colors.CustomAnsiObject):
+ self.write(f"\033[{code.code}m")
+[INST] A C version of print(): prints bytes to the console.. [/INST] def printf(text: bytes):
+ pass
+[INST] A more optimized version of print().. [/INST] def printb(*args, sep: str = ' ', end: str = '\n'):
+ text = ""
+ counter = 0
+ for word in args:
+ word = str(word)
+ text += word
+ if counter < (len(args) - 1):
+ text += sep
+ text += end
+ _Msvcrt.printf(bytes(text))
+[INST] Search for module and return whether it exists.. [/INST] def module_installed(module: str):
+ if util.find_spec(module) is not None:
+ return True
+ return False
+[INST] Enable ansi codes for the command prompt (the console).
+Only works on Windows. [/INST] def enable_ansi(self):
+ if sys.platform == "win32":
+ self.enabled = True
+ import ctypes
+ kernel32 = ctypes.WinDLL('kernel32')
+ hStdOut = kernel32.GetStdHandle(-11)
+ mode = ctypes.c_ulong()
+ kernel32.GetConsoleMode(hStdOut, ctypes.byref(mode))
+ mode.value |= 4
+ kernel32.SetConsoleMode(hStdOut, mode)
+ else:
+ raise NotImplementedError("This function is only implemented on Windows!")
+[INST] Copy all useful data that is not the __code__ from source to
+destination and return destination.. [/INST] def copy_func_data(source_func, dest_func):
+ dest_func.__name__ = source_func.__name__
+ dest_func.__qualname__ = source_func.__qualname__
+ dest_func.__doc__ = source_func.__doc__
+ dest_func.__annotations__ = source_func.__annotations__
+ dest_func.__defaults__ = source_func.__defaults__
+ dest_func.__module__ = source_func.__module__
+ dest_func.__kwdefaults__ = source_func.__kwdefaults__
+ return dest_func
+[INST] Utility for searching .dll libraries.. [/INST] def search_dll(name: str) -> ctypes.CDLL:
+ for path in config.Cpp.search_paths:
+ if os.path.exists(path + name) and name.endswith('.dll'):
+ return config.Cpp.dll_type(path + name)
+ raise ReferenceError(f"Failed to solve external reference '{name}' in {tuple(config.Cpp.search_paths)}.")
+[INST] Utility for searching specific functions inside a .dll library.. [/INST] def search_method(dll: ctypes.CDLL, name: str):
+ if hasattr(dll, name):
+ return getattr(dll, name)
+ raise ReferenceError(f"Failed to solve external reference '{name}' in library {dll}.")
+[INST] If the owner class is decorated with @HeaderClass , decorated function will be replaced by
+the corresponding function of the dll that was specified to the class
+decorator. If the owner class is not correctly decorated, TypeError is raised.
+
+
+
+
+
+The right dll function is chosen depending on the header
+function's name. Please note that static methods are not compatible with this system.
+If you plan to use static methods coming from .dlls, consider
+using the @system.DllImport decorator instead.
+
+If the function is a classic function, the decorator will look in the dll
+for a function named after "Py_", where 'name' is the name of the
+header function; and replace it.
+
+
+
+For example, if you were to make a header function called "foo", a function
+called "Py_foo" would be searched inside the .dll library.
+
+If you were to make your "__init__" function a header, a function called
+"PySpecial_init" would be searched inside the .dll library.
+
+If the function happens to not be exported by the .dll, python wouldn't find it
+and a ReferenceError will be raised.
+
+
+
+When the decorator replaces the header function, it keeps it's name, dotted name,
+documentation, annotations (signature), and parent module reference, and copies
+them into the new function object that comes from the dll.
+
+For this reason, you should make sure to have the function's accurate signature
+given to the header function, otherwise, at runtime, it won't be stored anywhere.
+
+Other note.
+
+Multiple classes can inheritate from the same .dll library.. [/INST] def HeaderFunc(func):
+ _h = _Header(func)
+ return _h.func
+[INST] Save object_ to the registry.. [/INST] def save(self, name: str, object_):
+ self._dict[name] = object
+[INST] Get the numerical address self points to.. [/INST] def _getaddress(self) -> int:
+ pass
+[INST] The address of self, in form of an Address object.. [/INST] def address(self):
+ return Address(self._getaddress())
+[INST] Call an external function as
+func(*args, *kwargs)
+and return it's result.
+
+Keyword arguments are converted to positional arguments and are
+added to the end of the standard positional arguments.. [/INST] def call_with_wrap(func, *args, **kwargs):
+ kwds = []
+ for key in kwargs:
+ kwds.append(kwargs.get(key))
+ wrapped = TypeWrap(func, *args, *kwds)
+ return wrapped.ret
+[INST] Returns an array of metrics parsed from the markdown file. See the top of
+the file for documentation on the format of the markdown file.. [/INST] def GetMetricsFromMdFile(mdfile):
+ with open(mdfile) as f:
+ raw_md = f.read()
+ metrics = []
+ sections = re.split('\n---+\n', raw_md)
+ tag_pattern = re.compile('^\* ([^:]*): (.*)$')
+ for section in sections:
+ if len(section.strip()) == 0: break
+ lines = section.strip().split('\n')
+ header_match = re.match('^##+ ', lines[0])
+ if not header_match: continue
+ metric = {}
+ metric['name'] = lines[0][len(header_match.group(0)):]
+ for i in range(1, len(lines)):
+ if len(lines[i]) == 0:
+ i += 1
+ break
+ match = tag_pattern.match(lines[i])
+ assert match
+ assert IsTagKnown(match.group(1)), 'Unknown tag: "%s".' % (match.group(1))
+ assert IsTagValid(match.group(1), match.group(2)), 'Invalid value "%s" ' \
+ 'for tag "%s".' % (match.group(2), match.group(1))
+ metric[match.group(1)] = match.group(2)
+ assert i < len(lines), 'No summary found for "%s"' % metric['name']
+ metric['summary'] = '\n'.join(lines[i:])
+ assert 'owners' in metric, 'Must have owners for "%s"' % metric['name']
+ assert 'enum' in metric or 'units' in metric, 'Metric "%s" must have ' \
+ 'a unit listed in "enum" or "units".' % metric['name']
+ metrics.append(metric)
+ return metrics
+[INST] builds an adjacency graph as a dictionary: {character: [adjacent_characters]}.
+adjacent characters occur in a clockwise order.. [/INST] def build_graph(layout_str, slanted):
+ position_table = {}
+ tokens = layout_str.split()
+ token_size = len(tokens[0])
+ x_unit = token_size + 1
+ adjacency_func = get_slanted_adjacent_coords if slanted else get_aligned_adjacent_coords
+ assert all(len(token) == token_size for token in tokens), 'token len mismatch:\n ' + layout_str
+ for y, line in enumerate(layout_str.split('\n')):
+ slant = y - 1 if slanted else 0
+ for token in line.split():
+ x, remainder = divmod(line.index(token) - slant, x_unit)
+ assert remainder == 0, 'unexpected x offset for %s in:\n%s' % (token, layout_str)
+ position_table[(x,y)] = token
+ adjacency_graph = {}
+ for (x,y), chars in position_table.items():
+ for char in chars:
+ adjacency_graph[char] = []
+ for coord in adjacency_func(x, y):
+ adjacency_graph[char].append(position_table.get(coord, None))
+ return adjacency_graph
+[INST] Returns a list of all the grd files in the current git repository.. [/INST] def list_grds_in_repository(repo_path):
+ output = subprocess.check_output([GIT, 'ls-files', '--', '*.grd'],
+ cwd=repo_path)
+ return output.decode('utf8').strip().splitlines()
+[INST] Adds relative paths given in files to the current CL.. [/INST] def git_add(files, repo_root):
+ BATCH_SIZE = 50
+ added_count = 0
+ while added_count < len(files):
+ batch = files[added_count:added_count + BATCH_SIZE]
+ command = [GIT, 'add'] + batch
+ subprocess.check_call(command, cwd=repo_root)
+ added_count += len(batch)
+[INST] Creates a new instance of RefById.
+Args:
+identifier: An identifier to be resolved later.
+debug_info: Where the reference is created, which is useful
+especially when the reference is unresolvable.. [/INST] def create(self, identifier, debug_info=None):
+ assert not self._is_frozen
+ ref = RefById(
+ identifier,
+ debug_info=debug_info,
+ target_attrs=self._target_attrs,
+ target_attrs_with_priority=self._target_attrs_with_priority,
+ pass_key=_REF_BY_ID_PASS_KEY)
+ self._references.append(ref)
+ return ref
+[INST] Initializes an instance of a subclass of RefById.. [/INST] def init_subclass_instance(self, instance, identifier, debug_info=None):
+ assert type(instance) is not RefById
+ assert isinstance(instance, RefById)
+ assert not self._is_frozen
+ RefById.__init__(
+ instance,
+ identifier,
+ debug_info=debug_info,
+ target_attrs=self._target_attrs,
+ target_attrs_with_priority=self._target_attrs_with_priority,
+ pass_key=_REF_BY_ID_PASS_KEY)
+ self._references.append(instance)
+[INST] Applies |callback| to all the references created by this factory.
+You can no longer create a new reference.. [/INST] def for_each(self, callback):
+ assert callable(callback)
+ self._is_frozen = True
+ for ref in self._references:
+ callback(ref)
+[INST] Checks style in the given patch.. [/INST] def check(self, patch_string):
+ patch_files = DiffParser(patch_string.splitlines()).files
+ for path, diff_file in patch_files.items():
+ line_numbers = diff_file.added_or_modified_line_numbers()
+ _log.debug('Found %s new or modified lines in: %s',
+ len(line_numbers), path)
+ if not line_numbers:
+ self._text_file_reader.count_delete_only_file()
+ continue
+ self._text_file_reader.process_file(
+ file_path=path, line_numbers=line_numbers)
+[INST] Reads the coverity password in from a file where it was stashed. [/INST] def _ReadPassword(pwfilename):
+ pwfile = open(pwfilename, 'r')
+ password = pwfile.readline()
+ pwfile.close()
+ return password.rstrip()
+[INST] Runs the command if dry_run is false, otherwise just prints the command.. [/INST] def _RunCommand(cmd, dry_run, shell=False, echo_cmd=True):
+ if echo_cmd:
+ print(cmd)
+ if not dry_run:
+ return subprocess.call(cmd, shell=shell)
+ else:
+ return 0
+[INST] Runs all the selected tests for the given build type and target.. [/INST] def run_coverity(options, args):
+ lock_filename = os.path.join(options.source_dir, LOCK_FILE)
+ try:
+ lock_file = os.open(lock_filename,
+ os.O_CREAT | os.O_EXCL | os.O_TRUNC | os.O_RDWR)
+ except OSError, err:
+ print('Failed to open lock file:\n ' + str(err))
+ return 1
+ os.write(lock_file, str(os.getpid()))
+ options.target = options.target.title()
+ start_time = time.time()
+ print('Change directory to ' + options.source_dir)
+ os.chdir(options.source_dir)
+ coverity_password = _ReadPassword(options.coverity_password_file)
+ cmd = 'gclient sync'
+ gclient_exit = _RunCommand(cmd, options.dry_run, shell=True)
+ if gclient_exit != 0:
+ print('gclient aborted with status %s' % gclient_exit)
+ _ReleaseLock(lock_file, lock_filename)
+ return 1
+ print('Elapsed time: %ds' % (time.time() - start_time))
+ if sys.platform.startswith('linux'):
+ rm_path = os.path.join(options.source_dir,'src','out',options.target)
+ elif sys.platform == 'win32':
+ rm_path = os.path.join(options.source_dir,options.solution_dir,
+ options.target)
+ elif sys.platform == 'darwin':
+ rm_path = os.path.join(options.source_dir,'src','xcodebuild')
+ else:
+ print('Platform "%s" unrecognized, aborting' % sys.platform)
+ _ReleaseLock(lock_file, lock_filename)
+ return 1
+ if options.dry_run:
+ print('shutil.rmtree(%s)' % repr(rm_path))
+ else:
+ shutil.rmtree(rm_path,True)
+ if options.preserve_intermediate_dir:
+ print('Preserving intermediate directory.')
+ else:
+ if options.dry_run:
+ print('shutil.rmtree(%s)' % repr(options.coverity_intermediate_dir))
+ print('os.mkdir(%s)' % repr(options.coverity_intermediate_dir))
+ else:
+ shutil.rmtree(options.coverity_intermediate_dir,True)
+ os.mkdir(options.coverity_intermediate_dir)
+ print('Elapsed time: %ds' % (time.time() - start_time))
+ use_shell_during_make = False
+ if sys.platform.startswith('linux'):
+ use_shell_during_make = True
+ os.chdir('src')
+ _RunCommand('pwd', options.dry_run, shell=True)
+ cmd = '%s/cov-build --dir %s make BUILDTYPE=%s chrome' % (
+ options.coverity_bin_dir, options.coverity_intermediate_dir,
+ options.target)
+ elif sys.platform == 'win32':
+ cmd = ('%s\\cov-build.exe --dir %s devenv.com %s\\%s /build %s '
+ '/project chrome.vcproj') % (
+ options.coverity_bin_dir, options.coverity_intermediate_dir,
+ options.source_dir, options.solution_file, options.target)
+ elif sys.platform == 'darwin':
+ use_shell_during_make = True
+ os.chdir('src/chrome')
+ _RunCommand('pwd', options.dry_run, shell=True)
+ cmd = ('%s/cov-build --dir %s xcodebuild -project chrome.xcodeproj '
+ '-configuration %s -target chrome') % (
+ options.coverity_bin_dir, options.coverity_intermediate_dir,
+ options.target)
+ _RunCommand(cmd, options.dry_run, shell=use_shell_during_make)
+ print('Elapsed time: %ds' % (time.time() - start_time))
+ cov_analyze_exe = os.path.join(options.coverity_bin_dir,'cov-analyze')
+ cmd = '%s --dir %s %s' % (cov_analyze_exe,
+ options.coverity_intermediate_dir,
+ options.coverity_analyze_options)
+ _RunCommand(cmd, options.dry_run, shell=use_shell_during_make)
+ print('Elapsed time: %ds' % (time.time() - start_time))
+ cov_commit_exe = os.path.join(options.coverity_bin_dir,'cov-commit-defects')
+ coverity_target = options.coverity_target
+ if sys.platform != 'win32':
+ coverity_target = '"%s"' % coverity_target
+ cmd = ('%s --dir %s --remote %s --port %s '
+ '--product %s '
+ '--target %s '
+ '--user %s '
+ '--password %s') % (cov_commit_exe,
+ options.coverity_intermediate_dir,
+ options.coverity_dbhost,
+ options.coverity_port,
+ options.coverity_product,
+ coverity_target,
+ options.coverity_user,
+ coverity_password)
+ _RunCommand(cmd, options.dry_run, shell=use_shell_during_make, echo_cmd=False)
+ print('Total time: %ds' % (time.time() - start_time))
+ _ReleaseLock(lock_file, lock_filename)
+ return 0
+[INST] Returns sys.maxint if the width cannot be determined.. [/INST] def terminal_width(self):
+ try:
+ if self.is_win():
+ from ctypes import windll, create_string_buffer
+ handle = windll.kernel32.GetStdHandle(-12)
+ console_screen_buffer_info = create_string_buffer(22)
+ if windll.kernel32.GetConsoleScreenBufferInfo(
+ handle, console_screen_buffer_info):
+ import struct
+ _, _, _, _, _, left, _, right, _, _, _ = struct.unpack(
+ 'hhhhHhhhhhh', console_screen_buffer_info.raw)
+ return right - left
+ return sys.maxsize
+ else:
+ import fcntl
+ import struct
+ import termios
+ packed = fcntl.ioctl(sys.stderr.fileno(), termios.TIOCGWINSZ,
+ '\0' * 8)
+ _, columns, _, _ = struct.unpack('HHHH', packed)
+ return columns
+ except Exception:
+ return sys.maxsize
+[INST] Finds an ancestor dir in a path.. [/INST] def FindAncestor(start_dir, ancestor):
+ start_dir = os.path.abspath(start_dir)
+ path = start_dir
+ while True:
+ (parent, tail) = os.path.split(path)
+ if tail == ancestor:
+ return path
+ if not tail:
+ break
+ path = parent
+ raise PathNotFound("Unable to find ancestor %s in %s" % (ancestor, start_dir))
+[INST] Finds the desired object's parent, searching upward from the start_dir.
+Searches start_dir and all its parents looking for the desired directory
+or file, which may be given in one or more path components. Returns the
+first directory in which the top desired path component was found, or raises
+PathNotFound if it wasn't.. [/INST] def FindUpwardParent(start_dir, *desired_list):
+ desired_path = os.path.join(*desired_list)
+ last_dir = ''
+ cur_dir = start_dir
+ found_path = os.path.join(cur_dir, desired_path)
+ while not os.path.exists(found_path):
+ last_dir = cur_dir
+ cur_dir = os.path.dirname(cur_dir)
+ if last_dir == cur_dir:
+ raise PathNotFound('Unable to find %s above %s' %
+ (desired_path, start_dir))
+ found_path = os.path.join(cur_dir, desired_path)
+ found_path = found_path[:len(found_path) - len(desired_path)]
+ if found_path.endswith(os.sep):
+ found_path = found_path[:len(found_path) - 1]
+ return found_path
+[INST] Returns a path to the desired directory or file, searching upward.
+Searches start_dir and all its parents looking for the desired directory
+or file, which may be given in one or more path components. Returns the full
+path to the desired object, or raises PathNotFound if it wasn't found.. [/INST] def FindUpward(start_dir, *desired_list):
+ parent = FindUpwardParent(start_dir, *desired_list)
+ return os.path.join(parent, *desired_list)
+[INST] Creates an entire path, if it doesn't already exist.. [/INST] def MaybeMakeDirectory(*path):
+ file_path = os.path.join(*path)
+ try:
+ os.makedirs(file_path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+[INST] Reads data from input_file and stores its MSZIP-compressed data
+in output_file. cab_stored_filename is the filename stored in the
+cab file, input_size is the size of the input file, and input_mtimestamp
+the mtime timestamp of the input file (must be at least midnight 1980-1-1).. [/INST] def WriteCab(output_file, input_file, cab_stored_filename, input_size,
+ input_mtimestamp):
+ cffile_offset = 36 + 8
+ cfdata_offset = cffile_offset + 16 + len(cab_stored_filename) + 1
+ chunk_size = 1 << 15
+ num_chunks = int((input_size + chunk_size - 1) / chunk_size)
+ CFHEADER = ('<'
+ '4s'
+ 'I'
+ 'I'
+ 'I'
+ 'I'
+ 'I'
+ 'B'
+ 'B'
+ 'H'
+ 'H'
+ 'H'
+ 'H'
+ 'H'
+ )
+ output_file.write(struct.pack(CFHEADER,
+ b'MSCF', 0, 0, 0,
+ cffile_offset, 0, 3, 1, 1, 1, 0,
+ 0, 0))
+ CFFOLDER = ('<'
+ 'I'
+ 'H'
+ 'H'
+ )
+ output_file.write(struct.pack(CFFOLDER, cfdata_offset, num_chunks, 1))
+ CFFILE = ('<'
+ 'I'
+ 'I'
+ 'H'
+ 'H'
+ 'H'
+ 'H'
+ )
+ assert output_file.tell() == cffile_offset
+ mtime = datetime.datetime.fromtimestamp(input_mtimestamp)
+ date = (mtime.year - 1980) << 9 | mtime.month << 5 | mtime.day
+ time = mtime.hour << 11 | mtime.minute << 5 | int(mtime.second / 2)
+ output_file.write(struct.pack(CFFILE, input_size, 0, 0, date, time, 0))
+ output_file.write(cab_stored_filename.encode() + b'\0')
+ assert output_file.tell() == cfdata_offset
+ CFDATA = ('<'
+ 'I'
+ 'H'
+ 'H'
+ )
+ for _ in range(num_chunks):
+ chunk = input_file.read(chunk_size)
+ zlib_obj = zlib.compressobj(
+ zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed = zlib_obj.compress(chunk) + zlib_obj.flush()
+ compressed_size = 2 + len(compressed)
+ assert compressed_size <= chunk_size + 12
+ output_file.write(struct.pack(CFDATA, 0, compressed_size, len(chunk)))
+ output_file.write(b'\x43\x4b')
+ output_file.write(compressed)
+ outfile_size = output_file.tell()
+ output_file.seek(8)
+ output_file.write(struct.pack('
+[INST] Transforms a string so that it will be suitable for use as
+a pfm_name in the plist manifest file.. [/INST] def GetPlistFriendlyName(name):
+ return name.replace(' ', '_')
+[INST] Forms an actions XML string and returns it after processing.
+It parses the original XML string, adds new user actions (if there is any),
+and pretty prints it.. [/INST] def _GetProcessedAction(self, owner, description, obsolete,
+ not_user_triggered=NO_VALUE, new_actions=[],
+ comment=NO_VALUE):
+ current_xml = ACTIONS_XML.format(owners=owner, description=description,
+ obsolete=obsolete, comment=comment,
+ not_user_triggered=not_user_triggered)
+ actions_dict, comments, suffixes = extract_actions.ParseActionFile(
+ current_xml)
+ for action_name in new_actions:
+ actions_dict[action_name] = action_utils.Action(action_name, None, [])
+ return extract_actions.PrettyPrint(actions_dict, comments, suffixes)
+[INST] Parses the given actions XML, expands suffixes and pretty prints it.. [/INST] def _ExpandSuffixesInActionsXML(self, actions_xml):
+ actions_dict, comments, suffixes = extract_actions.ParseActionFile(
+ actions_xml)
+ suffixes = []
+ for action in actions_dict.values():
+ action.from_suffix = False
+ return extract_actions.PrettyPrint(actions_dict, comments, suffixes)
+[INST] Parses the given actions XML and pretty prints it.. [/INST] def _PrettyPrintActionsXML(self, actions_xml):
+ actions_dict, comments, suffixes = extract_actions.ParseActionFile(
+ actions_xml)
+ return extract_actions.PrettyPrint(actions_dict, comments, suffixes)
+[INST] Tests that suffixes are preserved when pretty-printing.. [/INST] def testSuffixPrettyPrint(self):
+ original_xml = """
+
+
+ name1@chromium.org
+ Description.
+
+
+ name1@chromium.org
+ Description.
+
+
+
+
+
+
+
+
+ """
+ xml_result = self._PrettyPrintActionsXML(original_xml)
+ expected_pretty_xml = """
+
+ name1@chromium.org
+ Description.
+
+
+ name1@chromium.org
+ Description.
+
+
+
+
+
+
+
+
+"""
+ self.assertMultiLineEqual(expected_pretty_xml, xml_result)
+[INST] Given mangled symbol |sym|, return its demangled form.. [/INST] def Demangle(self, sym):
+ self.cppfilt.stdin.write(sym + '\n')
+ self.cppfilt.stdin.flush()
+ return self.cppfilt.stdout.readline().strip()
+[INST] Attempt to qualify a bare |filename| with a src-relative path, assuming it
+is a protoc-generated file. If a single match is found, it is returned.
+Otherwise the original filename is returned.. [/INST] def QualifyFilenameAsProto(filename):
+ if not IS_GIT_WORKSPACE:
+ return filename
+ match = protobuf_filename_re.match(filename)
+ if not match:
+ return filename
+ basename = match.groups(0)
+ cmd = ['git', 'ls-files', '--', '*/%s.proto' % basename]
+ gitlsfiles = subprocess.Popen(cmd,
+ stdout=subprocess.PIPE,
+ universal_newlines=True)
+ candidate = filename
+ for line in gitlsfiles.stdout:
+ if candidate != filename:
+ return filename
+ candidate = line.strip()
+ return candidate
+[INST] Given a bare filename and a symbol that occurs in it, attempt to qualify
+it with a src-relative path. If more than one file matches, return the
+original filename.. [/INST] def QualifyFilename(filename, symbol):
+ if not IS_GIT_WORKSPACE:
+ return filename
+ match = symbol_code_name_re.match(symbol)
+ if not match:
+ return filename
+ symbol = match.group(1)
+ cmd = ['git', 'grep', '-l', symbol, '--', '*/' + filename]
+ gitgrep = subprocess.Popen(cmd,
+ stdout=subprocess.PIPE,
+ universal_newlines=True)
+ candidate = filename
+ for line in gitgrep.stdout:
+ if candidate != filename:
+ return filename
+ candidate = line.strip()
+ return candidate
+[INST] Parse static initializers from a line of nm output.
+Given a line of nm output, parse static initializers as a
+(file, start, size, symbol) tuple.. [/INST] def ParseNmLine(line):
+ match = nm_re.match(line)
+ if match:
+ addr, size, prefix, filename = match.groups()
+ return (filename, int(addr, 16), int(size, 16), prefix+filename)
+[INST] Yield static initializers for the given binary.
+Given a binary, yield static initializers as (file, start, size, symbol)
+tuples.. [/INST] def ParseNm(toolchain, binary):
+ nm = subprocess.Popen([toolchain + 'nm', '-S', binary],
+ stdout=subprocess.PIPE,
+ universal_newlines=True)
+ for line in nm.stdout:
+ parse = ParseNmLine(line)
+ if parse:
+ yield parse
+[INST] Given a span of addresses, returns symbol references from disassembly.. [/INST] def ExtractSymbolReferences(toolchain, binary, start, end, symbol):
+ cmd = [toolchain + 'objdump', binary, '--disassemble',
+ '--start-address=0x%x' % start, '--stop-address=0x%x' % end]
+ objdump = subprocess.Popen(cmd,
+ stdout=subprocess.PIPE,
+ universal_newlines=True)
+ refs = set()
+ for line in objdump.stdout:
+ if '__static_initialization_and_destruction' in line:
+ raise RuntimeError('code mentions '
+ '__static_initialization_and_destruction; '
+ 'did you accidentally run this on a Debug binary?')
+ match = disassembly_re.search(line)
+ if match:
+ (ref,) = match.groups()
+ if ref.startswith('.LC') or ref.startswith('_DYNAMIC'):
+ continue
+ if re.match(symbol, ref):
+ continue
+ refs.add(ref)
+ return sorted(refs)
+[INST] Constructs a histogram node based on the |histogram_dict|.. [/INST] def ConstructHistogram(doc, name, histogram_dict):
+ histogram = doc.createElement('histogram')
+ histogram.setAttribute('name', name)
+ if 'enum' in histogram_dict:
+ histogram.setAttribute('enum', histogram_dict['enum']['name'])
+ else:
+ histogram.setAttribute('units', histogram_dict['units'])
+ if 'expires_after' in histogram_dict:
+ histogram.setAttribute('expires_after', histogram_dict['expires_after'])
+ if histogram_dict.get('base', False):
+ histogram.setAttribute('base', 'True')
+ if 'obsolete' in histogram_dict:
+ obsolete_node = doc.createElement('obsolete')
+ obsolete_node.appendChild(doc.createTextNode(histogram_dict['obsolete']))
+ histogram.appendChild(obsolete_node)
+ for owner in histogram_dict['owners']:
+ owner_node = doc.createElement('owner')
+ owner_node.appendChild(doc.createTextNode(owner))
+ histogram.appendChild(owner_node)
+ if 'summary' in histogram_dict:
+ summary_node = doc.createElement('summary')
+ summary_node.appendChild(doc.createTextNode(histogram_dict['summary']))
+ histogram.appendChild(summary_node)
+ return histogram
+[INST] Chromium cr tool main function.
+This is the main entry point of the cr tool, it finds and loads all the
+plugins, creates the context and then activates and runs the specified
+command.. [/INST] def Main():
+ user_path = os.path.expanduser(os.path.join('~', '.config', 'cr'))
+ cr.auto.user.__path__.append(user_path)
+ cr.loader.Scan()
+ with cr.base.context.Create(
+ description='The chrome dev build tool.',
+ epilog='Contact ' + _CONTACT + ' if you have issues with this tool.',
+ ) as context:
+ cr.base.client.DetectClient()
+ for command in cr.Command.Plugins():
+ cr.context.AddSubParser(command)
+ if cr.context.autocompleting:
+ cr.autocomplete.Complete()
+ return
+ cr.context.ParseArgs(True)
+ cr.plugin.Activate()
+ command = cr.Command.GetActivePlugin()
+ if command:
+ command.EarlyArgProcessing()
+ cr.plugin.Activate()
+ found_build_dir = cr.base.client.LoadConfig()
+ cr.plugin.Activate()
+ cr.context.ParseArgs()
+ if command is None:
+ command = cr.Command.GetActivePlugin()
+ if cr.context.verbose >= 3:
+ cr.context.DumpValues(cr.context.verbose > 3)
+ if command is None:
+ print(cr.context.Substitute('No command specified.'))
+ exit(1)
+ if command.requires_build_dir:
+ if not found_build_dir:
+ if not cr.context.Find('CR_OUT_FULL'):
+ print(cr.context.Substitute(
+ 'No build directory specified. Please use cr init to make one.'))
+ else:
+ print(cr.context.Substitute(
+ 'Build {CR_BUILD_DIR} not a valid build directory'))
+ exit(1)
+ if cr.context.Find('CR_VERSION') != cr.base.client.VERSION:
+ print(cr.context.Substitute(
+ 'Build {CR_BUILD_DIR} is for the wrong version of cr'))
+ print('Please run cr init to reset it')
+ exit(1)
+ cr.Platform.Prepare()
+ if cr.context.verbose >= 1:
+ print(cr.context.Substitute('Running cr ' + command.name +
+ ' for {CR_BUILD_DIR}'))
+ command.Run()
+[INST] Factory method for instanciating the ADMLWriter. Every Writer needs a
+GetWriter method because the TemplateFormatter uses this method to
+instantiate a Writer.. [/INST] def GetWriter(config):
+ return ADMLWriter(['win', 'win7'], config)
+[INST] Generates ADML elements for a Policy-Group. For each Policy-Group two
+ADML "string" elements are added to the string-table. One contains the
+caption of the Policy-Group and the other a description. A Policy-Group also
+requires an ADML "presentation" element that must be added to the
+presentation-table. The "presentation" element is the container for the
+elements that define the visual presentation of the Policy-Goup's Policies.
+The following ADML snippet shows an example.. [/INST] def BeginPolicyGroup(self, group):
+ self._AddString(group['name'] + '_group', group['caption'])
+[INST] Adds ADML "string" elements to the string-table that are referenced by
+the ADMX file but not related to any specific Policy-Group or Policy.. [/INST] def _AddBaseStrings(self):
+ self._AddString(self.config['win_supported_os'],
+ self.messages['win_supported_all']['text'])
+ self._AddString(self.config['win_supported_os_win7'],
+ self.messages['win_supported_win7']['text'])
+ categories = self.winconfig['mandatory_category_path'] + \
+ self.winconfig['recommended_category_path']
+ strings = self.winconfig['category_path_strings']
+ for category in categories:
+ if (category in strings):
+ string = re.sub(r"\{(\w+)\}", \
+ lambda m: self.messages[m.group(1)]['text'], \
+ strings[category])
+ self._AddString(category, string)
+[INST] Generates a string that describes the example value, if needed.
+Returns None if no string is needed. For instance, if the setting is a
+boolean, the user can only select true or false, so example text is not
+useful.. [/INST] def _GetExampleValueText(self, policy):
+ example_value = policy.get('example_value')
+ if not example_value:
+ return None
+ if isinstance(example_value, str):
+ return self.GetLocalizedMessage('example_value') + ' ' + example_value
+ if isinstance(example_value, dict):
+ value_as_text = json.dumps(example_value, indent=2)
+ return self.GetLocalizedMessage('example_value') + '\n\n' + value_as_text
+ if isinstance(example_value, list):
+ policy_type = policy.get('type')
+ if policy_type == 'dict':
+ value_as_text = json.dumps(example_value, indent=2)
+ elif policy_type is not None and 'list' in policy_type:
+ if isinstance(example_value[0], str):
+ value_as_text = '\n'.join([str(v) for v in example_value])
+ else:
+ value_as_text = '\n'.join([json.dumps(v) for v in example_value])
+ else:
+ raise Exception(
+ 'Unexpected policy type with list example value: %s' % policy_type)
+ return self.GetLocalizedMessage('example_value') + '\n\n' + value_as_text
+ return None
+[INST] Checks that no files call IsPasswordField() or FormControlType().. [/INST] def _CheckNoDirectPasswordCalls(input_api, output_api):
+ pattern = input_api.re.compile(
+ r'(IsPasswordField|FormControlType)\(\)',
+ input_api.re.MULTILINE)
+ files = []
+ for f in input_api.AffectedSourceFiles(input_api.FilterSourceFile):
+ if (f.LocalPath().startswith('components/autofill/content/renderer/') and
+ not f.LocalPath().endswith("PRESUBMIT.py")):
+ contents = input_api.ReadFile(f)
+ if pattern.search(contents):
+ files.append(f)
+ if len(files):
+ return [ output_api.PresubmitPromptWarning(
+ 'Consider to not call IsPasswordField() or FormControlType() directly ' +
+ 'but use IsPasswordFieldForAutofill() and FormControlTypeForAutofill() ' +
+ 'respectively. These declare text input fields as password fields ' +
+ 'if they have been password fields in the past. This is relevant ' +
+ 'for websites that allow to reveal passwords with a button that ' +
+ 'triggers a change of the type attribute of an element.',
+ files) ]
+ return []
+[INST] Checks common to both upload and commit.. [/INST] def _CommonChecks(input_api, output_api):
+ results = []
+ results.extend(_CheckNoDirectPasswordCalls(input_api, output_api))
+ return results
+[INST] Runs all DIR_METADATA files with dirmd and returns a dict of raw data.. [/INST] def read_raw_dir_metadata(chromium_root: str, dirmd_path: str) -> Dict:
+ raw_str: str = subprocess_utils.run_command(
+ [dirmd_path, 'read', chromium_root])
+ raw_dict: Dict = json.loads(raw_str)
+ return raw_dict['dirs']
+[INST] Creates a synthetic representation of an DIR_METADATA file.. [/INST] def build_dir_metadata(all_dir_metadata: Dict,
+ requested_path: owners_data.RequestedPath
+ ) -> owners_data.DirMetadata:
+ return _build_dir_metadata_recursive(all_dir_metadata,
+ pathlib.Path(requested_path.path))
+[INST] Runs a command and returns the output.
+Raises an exception and prints the command output if the command fails.. [/INST] def run_command(command: List[str]) -> str:
+ try:
+ run_result = subprocess.run(command,
+ capture_output=True,
+ text=True,
+ check=True)
+ except subprocess.CalledProcessError as e:
+ print(f'{command} failed with code {e.returncode}.', file=sys.stderr)
+ print(f'\nSTDERR:\n{e.stderr}', file=sys.stderr)
+ print(f'\nSTDOUT:\n{e.stdout}', file=sys.stderr)
+ raise
+ return run_result.stdout
+[INST] Formats a string like the built-in |format| allowing unbound keys.. [/INST] def format_template(format_string, *args, **kwargs):
+ return _TemplateFormatter().format(format_string, *args, **kwargs)
+[INST] Runs a command and returns the output.
+Raises an exception and prints the command output if the command fails.. [/INST] def run_command(command: List[str], cwd: str) -> str:
+ try:
+ run_result = subprocess.run(command,
+ capture_output=True,
+ text=True,
+ check=True,
+ cwd=cwd)
+ except subprocess.CalledProcessError as e:
+ print(f'{command} failed with code {e.returncode}.', file=sys.stderr)
+ print(f'\nSTDERR:\n{e.stderr}', file=sys.stderr)
+ print(f'\nSTDOUT:\n{e.stdout}', file=sys.stderr)
+ raise
+ return run_result.stdout.strip()
+[INST] Writes the processed data to output_dir. In the case of a chrome_html
+structure this will add references to other scale factors. If flattening
+this will also write file references to be base64 encoded data URLs. The
+name of the new file is returned.. [/INST] def Process(self, output_dir):
+ filename = self.ToRealPath(self.GetInputPath())
+ flat_filename = os.path.join(output_dir,
+ self.attrs['name'] + '_' + os.path.basename(filename))
+ if self._last_flat_filename == flat_filename:
+ return
+ with open(flat_filename, 'wb') as outfile:
+ if self.ExpandVariables():
+ text = self.gatherer.GetText()
+ file_contents = self._Substitute(text)
+ else:
+ file_contents = self.gatherer.GetData('', 'utf-8')
+ outfile.write(file_contents.encode('utf-8'))
+ self._last_flat_filename = flat_filename
+ return os.path.basename(flat_filename)
+[INST] Returns a bytes representation for a data_pack entry.. [/INST] def GetDataPackValue(self, lang, encoding):
+ if self.ExpandVariables():
+ text = self.gatherer.GetText()
+ data = util.Encode(self._Substitute(text), encoding)
+ else:
+ data = self.gatherer.GetData(lang, encoding)
+ if encoding != util.BINARY:
+ data = data.encode(encoding)
+ data = self.CompressDataIfNeeded(data)
+ if self.attrs['type'] == 'lottie':
+ data = 'LOTTIE'.encode(self.attrs['output_encoding']) + data
+ return data
+[INST] Returns the gatherer for the alternate skeleton that should be used,
+based on the expressions for selecting skeletons, or None if the skeleton
+from the English version of the structure should be used.. [/INST] def GetSkeletonGatherer(self):
+ for expr in self.skeletons:
+ if self.EvaluateCondition(expr):
+ return self.skeletons[expr]
+ return None
+[INST] Variable expansion on structures is controlled by an XML attribute.
+However, old files assume that expansion is always on for Rc files.. [/INST] def ExpandVariables(self):
+ attrs = self.GetRoot().attrs
+ if 'grit_version' in attrs and attrs['grit_version'] > 1:
+ return self.attrs['expand_variables'] == 'true'
+ else:
+ return (self.attrs['expand_variables'] == 'true' or
+ self.attrs['file'].lower().endswith('.rc'))
+[INST] Perform local and global variable substitution.. [/INST] def _Substitute(self, text):
+ if self.substituter:
+ text = self.substituter.Substitute(text)
+ return self.GetRoot().GetSubstituter().Substitute(text)
+[INST] Returns the filename of the file associated with this structure,
+for the specified language.. [/INST] def FileForLanguage(self, lang, output_dir, create_file=True,
+ return_if_not_generated=True):
+ assert self.HasFileForLanguage()
+ if ((not lang or lang == self.GetRoot().GetSourceLanguage()) and
+ self.attrs['expand_variables'] != 'true' and
+ (not self.attrs['run_command'] or
+ not self.RunCommandOnCurrentPlatform())):
+ if return_if_not_generated:
+ input_path = self.GetInputPath()
+ if input_path is None:
+ return None
+ return self.ToRealPath(input_path)
+ else:
+ return None
+ if self.attrs['output_filename'] != '':
+ filename = self.attrs['output_filename']
+ else:
+ filename = os.path.basename(self.attrs['file'])
+ assert len(filename)
+ filename = '%s_%s' % (lang, filename)
+ filename = os.path.join(output_dir, filename)
+ if create_file:
+ text = self.gatherer.Translate(
+ lang,
+ pseudo_if_not_available=self.PseudoIsAllowed(),
+ fallback_to_english=self.ShouldFallbackToEnglish(),
+ skeleton_gatherer=self.GetSkeletonGatherer())
+ file_contents = util.FixLineEnd(text, self.GetLineEnd())
+ if self.ExpandVariables():
+ file_contents = self._Substitute(file_contents)
+ with open(filename, 'wb') as file_object:
+ output_stream = util.WrapOutputStream(file_object,
+ self.attrs['output_encoding'])
+ output_stream.write(file_contents)
+ if self.attrs['run_command'] and self.RunCommandOnCurrentPlatform():
+ command = self.attrs['run_command'] % {'filename': filename}
+ result = os.system(command)
+ assert result == 0, '"%s" failed.' % command
+ return filename
+[INST] Creates a new node which is a child of 'parent', with attributes set
+by parameters of the same name.. [/INST] def Construct(parent, name, type, file, encoding='cp1252'):
+ node = StructureNode()
+ node.StartParsing('structure', parent)
+ node.HandleAttribute('name', name)
+ node.HandleAttribute('type', type)
+ node.HandleAttribute('file', file)
+ node.HandleAttribute('encoding', encoding)
+ node.EndParsing()
+ return node
+[INST] Tests that a match can be successfully found.. [/INST] def testMatch(self):
+ test_map = {
+ 'isolated_scripts': [
+ {
+ 'args': [
+ 'webgl_conformance',
+ ],
+ 'isolate_name': 'telemetry_gpu_integration_test',
+ },
+ ],
+ }
+ self.assertTrue(
+ self.instance._BuilderRunsTestOfInterest(test_map, 'webgl_conformance'))
+[INST] Tests that a match is not found if the isolate name is not valid.. [/INST] def testNoMatchIsolate(self):
+ test_map = {
+ 'isolated_scripts': [
+ {
+ 'args': [
+ 'webgl_conformance',
+ ],
+ 'isolate_name': 'not_telemetry',
+ },
+ ],
+ }
+ self.assertFalse(
+ self.instance._BuilderRunsTestOfInterest(test_map, 'webgl_conformance'))
+[INST] Tests that a match is not found if the suite name is not valid.. [/INST] def testNoMatchSuite(self):
+ test_map = {
+ 'isolated_scripts': [
+ {
+ 'args': [
+ 'not_a_suite',
+ ],
+ 'isolate_name': 'telemetry_gpu_integration_test',
+ },
+ ],
+ }
+ self.assertFalse(
+ self.instance._BuilderRunsTestOfInterest(test_map, 'webgl_conformance'))
+[INST] Tests that Android-specific isolates are added.. [/INST] def testAndroidSuffixes(self):
+ isolate_names = self.instance.GetIsolateNames()
+ for isolate in isolate_names:
+ if 'telemetry_gpu_integration_test' in isolate and 'android' in isolate:
+ return
+ self.fail('Did not find any Android-specific isolate names')
+[INST] Logs in into a Linkedin account.
+This function navigates the tab into Linkedin's login page and logs in a user
+using credentials in |credential| part of the |credentials_path| file.
+
+Action runner responsible for running actions on the page.
+credential: The credential to retrieve from the credentials file (string).
+credentials_path: The path to credential file (string).
+
+exceptions.Error: See ExecuteJavaScript()
+for a detailed list of possible exceptions.. [/INST] def LoginDesktopAccount(action_runner, credential,
+ credentials_path=login_utils.DEFAULT_CREDENTIAL_PATH):
+ account_name, password = login_utils.GetAccountNameAndPassword(
+ credential, credentials_path=credentials_path)
+ action_runner.Navigate('https://www.linkedin.com/uas/login')
+ action_runner.Wait(1)
+ login_utils.InputWithSelector(
+ action_runner, '%s@gmail.com' % account_name, 'input[type=text]')
+ login_utils.InputWithSelector(
+ action_runner, password, 'input[type=password]')
+ login_button_function = (
+ '''document.querySelectorAll("[aria-label^='Sign in']")[0]''')
+ action_runner.WaitForElement(element_function=login_button_function)
+ action_runner.ClickElement(element_function=login_button_function)
+ action_runner.Wait(2)
+ action_runner.ReloadPage()
+ search_bar_function = (
+ "document.getElementById('extended-nav-search')")
+ action_runner.WaitForElement(element_function=search_bar_function)
+[INST] Returns the ADMX element type that should be used for the given policy.
+This logic is shared between the ADMX writer and the ADML writer, to ensure
+that the ADMX and ADML generated from policy_tempates.json are compatible.. [/INST] def GetType(policy, allow_multi_strings=False):
+ policy_type = policy['type']
+ policy_example = policy.get('example_value')
+ if (policy_type == 'string' and allow_multi_strings and
+ policy_example is not None and policy_example.strip().startswith('{')):
+ return AdmxElementType.MULTI_STRING
+ admx_element_type = AdmxElementType._POLICY_TYPE_MAP.get(policy_type)
+ if admx_element_type is None:
+ raise Exception('Unknown policy type %s.' % policy_type)
+ if (admx_element_type == AdmxElementType.MULTI_STRING and
+ not allow_multi_strings):
+ return AdmxElementType.STRING
+ return admx_element_type
+[INST] Factory method for instanciating the ADMXWriter. Every Writer needs a
+GetWriter method because the TemplateFormatter uses this method to
+instantiate a Writer.. [/INST] def GetWriter(config):
+ return ADMXWriter(['win', 'win7'], config)
+[INST] Creates a reference to the named string in an ADML file.. [/INST] def _AdmlString(self, name):
+ name = name.replace('.', '_')
+ return '$(string.' + name + ')'
+[INST] Creates a reference to the named explanation string in an ADML file.. [/INST] def _AdmlStringExplain(self, name):
+ name = name.replace('.', '_')
+ return '$(string.' + name + '_Explain)'
+[INST] Creates a reference to the named presentation element in an ADML file.. [/INST] def _AdmlPresentation(self, name):
+ return '$(presentation.' + name + ')'
+[INST] Generates the ADMX "policyNamespace" element and adds the elements to the
+passed parent element. The namespace of the generated ADMX document is
+define via the ADMX "target" element. Used namespaces are declared with an
+ADMX "using" element.. [/INST] def _AddPolicyNamespaces(self, parent, prefix, namespace):
+ policy_namespaces_elem = self.AddElement(parent, 'policyNamespaces')
+ attributes = {
+ 'prefix': prefix,
+ 'namespace': namespace,
+ }
+ self.AddElement(policy_namespaces_elem, 'target', attributes)
+ if 'admx_using_namespaces' in self.config:
+ prefix_namespace_map = self.config['admx_using_namespaces']
+ for prefix in prefix_namespace_map:
+ attributes = {
+ 'prefix': prefix,
+ 'namespace': prefix_namespace_map[prefix],
+ }
+ self.AddElement(policy_namespaces_elem, 'using', attributes)
+ attributes = {
+ 'prefix': 'windows',
+ 'namespace': 'Microsoft.Policies.Windows',
+ }
+ self.AddElement(policy_namespaces_elem, 'using', attributes)
+[INST] Adds an ADMX category element to the passed parent node. The following
+snippet shows an example of a category element where "chromium" is the value
+of the parameter name.
+
+
+
+Each parent node can have only one category with a given name. Adding the
+same category again with the same attributes is ignored, but adding it
+again with different attributes is an error.. [/INST] def _AddCategory(self, parent, name, display_name, parent_category_name=None):
+ existing = list(
+ filter(lambda e: e.getAttribute('name') == name,
+ parent.getElementsByTagName('category')))
+ if existing:
+ assert len(existing) == 1
+ assert existing[0].getAttribute('name') == name
+ assert existing[0].getAttribute('displayName') == display_name
+ return
+ attributes = {
+ 'name': name,
+ 'displayName': display_name,
+ }
+ category_elem = self.AddElement(parent, 'category', attributes)
+ if parent_category_name:
+ attributes = {'ref': parent_category_name}
+ self.AddElement(category_elem, 'parentCategory', attributes)
+[INST] Generates the ADMX "categories" element and adds it to the categories
+main node. The "categories" element defines the category for the policies
+defined in this ADMX document. Here is an example of an ADMX "categories"
+element.. [/INST] def _AddCategories(self, categories):
+ category_name = None
+ for category in categories:
+ parent_category_name = category_name
+ category_name = category
+ if (":" not in category_name):
+ self._AddCategory(self._categories_elem, category_name,
+ self._AdmlString(category_name), parent_category_name)
+[INST] Generates the "supportedOn" ADMX element and adds it to the passed
+parent node. The "supportedOn" element contains information about supported
+Windows OS versions. The following code snippet contains an example of a
+"supportedOn" element.. [/INST] def _AddSupportedOn(self, parent, supported_os_list):
+ supported_on_elem = self.AddElement(parent, 'supportedOn')
+ definitions_elem = self.AddElement(supported_on_elem, 'definitions')
+ for supported_os in supported_os_list:
+ attributes = {
+ 'name': supported_os,
+ 'displayName': self._AdmlString(supported_os)
+ }
+ self.AddElement(definitions_elem, 'definition', attributes)
+[INST] Generates ADMX elements for a String-Policy and adds them to the
+passed parent node.. [/INST] def _AddStringPolicy(self, parent, name, id=None):
+ attributes = {
+ 'id': id or name,
+ 'valueName': name,
+ 'maxLength': '1000000',
+ }
+ self.AddElement(parent, 'text', attributes)
+[INST] Generates ADMX elements for a multi-line String-Policy and adds them to
+the passed parent node.. [/INST] def _AddMultiStringPolicy(self, parent, name):
+ self._AddStringPolicy(parent, name, id=name + '_Legacy')
+ attributes = {
+ 'id': name,
+ 'valueName': name,
+ 'maxLength': '1000000',
+ }
+ self.AddElement(parent, 'multiText', attributes)
+[INST] Generates ADMX elements for an Int-Policy and adds them to the passed
+parent node.. [/INST] def _AddIntPolicy(self, parent, policy):
+ max = 2000000000
+ min = 0
+ if self.PolicyHasRestrictions(policy):
+ schema = policy['schema']
+ if 'minimum' in schema and schema['minimum'] >= 0:
+ min = schema['minimum']
+ if 'maximum' in schema and schema['maximum'] >= 0:
+ max = schema['maximum']
+ assert type(min) == int
+ assert type(max) == int
+ attributes = {
+ 'id': policy['name'],
+ 'valueName': policy['name'],
+ 'maxValue': str(max),
+ 'minValue': str(min),
+ }
+ self.AddElement(parent, 'decimal', attributes)
+[INST] Generates ADMX elements for an Enum-Policy and adds them to the
+passed parent element.. [/INST] def _AddEnumPolicy(self, parent, policy):
+ name = policy['name']
+ items = policy['items']
+ attributes = {
+ 'id': name,
+ 'valueName': name,
+ }
+ enum_elem = self.AddElement(parent, 'enum', attributes)
+ for item in items:
+ attributes = {'displayName': self._AdmlString(name + "_" + item['name'])}
+ item_elem = self.AddElement(enum_elem, 'item', attributes)
+ value_elem = self.AddElement(item_elem, 'value')
+ value_string = str(item['value'])
+ if policy['type'] == 'int-enum':
+ self.AddElement(value_elem, 'decimal', {'value': value_string})
+ else:
+ self.AddElement(value_elem, 'string', {}, value_string)
+[INST] Generates ADMX XML elements for a List-Policy and adds them to the
+passed parent element.. [/INST] def _AddListPolicy(self, parent, key, name):
+ attributes = {
+ 'id': name + 'Desc',
+ 'valuePrefix': '',
+ 'key': key + '\\' + name,
+ }
+ self.AddElement(parent, 'list', attributes)
+[INST] Generates ADMX elements for a Main-Policy amd adds them to the
+passed parent element.. [/INST] def _AddMainPolicy(self, parent):
+ enabled_value_elem = self.AddElement(parent, 'enabledValue')
+ self.AddElement(enabled_value_elem, 'decimal', {'value': '1'})
+ disabled_value_elem = self.AddElement(parent, 'disabledValue')
+ self.AddElement(disabled_value_elem, 'decimal', {'value': '0'})
+[INST] Generates ADMX elements for a Policy.. [/INST] def _WritePolicy(self, policy, name, key, parent):
+ policies_elem = self._active_policies_elem
+ policy_name = policy['name']
+ attributes = {
+ 'name': name,
+ 'class': self.GetClass(policy),
+ 'displayName': self._AdmlString(policy_name),
+ 'explainText': self._AdmlStringExplain(policy_name),
+ 'presentation': self._AdmlPresentation(policy_name),
+ 'key': key,
+ }
+ is_win7_only = self.IsPolicyOnWin7Only(policy)
+ supported_key = ('win_supported_os_win7'
+ if is_win7_only else 'win_supported_os')
+ supported_on_text = self.config[supported_key]
+ policy_elem = self.AddElement(policies_elem, 'policy', attributes)
+ self.AddElement(policy_elem, 'parentCategory', {'ref': parent})
+ self.AddElement(policy_elem, 'supportedOn', {'ref': supported_on_text})
+ element_type = self._GetAdmxElementType(policy)
+ if element_type == AdmxElementType.MAIN:
+ self.AddAttribute(policy_elem, 'valueName', policy_name)
+ self._AddMainPolicy(policy_elem)
+ elif element_type == AdmxElementType.STRING:
+ parent = self._GetElements(policy_elem)
+ self._AddStringPolicy(parent, policy_name)
+ elif element_type == AdmxElementType.MULTI_STRING:
+ parent = self._GetElements(policy_elem)
+ self._AddMultiStringPolicy(parent, policy_name)
+ elif element_type == AdmxElementType.INT:
+ parent = self._GetElements(policy_elem)
+ self._AddIntPolicy(parent, policy)
+ elif element_type == AdmxElementType.ENUM:
+ parent = self._GetElements(policy_elem)
+ self._AddEnumPolicy(parent, policy)
+ elif element_type == AdmxElementType.LIST:
+ parent = self._GetElements(policy_elem)
+ self._AddListPolicy(parent, key, policy_name)
+ elif element_type == AdmxElementType.GROUP:
+ pass
+ else:
+ raise Exception('Unknown element type %s.' % element_type)
+[INST] Generates ADMX elements for a Policy-Group.. [/INST] def _BeginPolicyGroup(self, group, name, parent):
+ attributes = {
+ 'name': name,
+ 'displayName': self._AdmlString(group['name'] + '_group'),
+ }
+ category_elem = self.AddElement(self._categories_elem, 'category',
+ attributes)
+ attributes = {'ref': parent}
+ self.AddElement(category_elem, 'parentCategory', attributes)
+[INST] Generates the skeleton of the ADMX template.. [/INST] def BeginTemplate(self):
+ dom_impl = minidom.getDOMImplementation('')
+ self._doc = dom_impl.createDocument(None, 'policyDefinitions', None)
+ if self._GetChromiumVersionString() is not None:
+ self.AddComment(self._doc.documentElement, self.config['build'] + \
+ ' version: ' + self._GetChromiumVersionString())
+ policy_definitions_elem = self._doc.documentElement
+ policy_definitions_elem.attributes['revision'] = '1.0'
+ policy_definitions_elem.attributes['schemaVersion'] = '1.0'
+ self._AddPolicyNamespaces(policy_definitions_elem,
+ self.config['admx_prefix'],
+ self.winconfig['namespace'])
+ self.AddElement(policy_definitions_elem, 'resources',
+ {'minRequiredRevision': '1.0'})
+ self._AddSupportedOn(
+ policy_definitions_elem,
+ [self.config['win_supported_os'], self.config['win_supported_os_win7']])
+ self._categories_elem = self.AddElement(policy_definitions_elem,
+ 'categories')
+ self._AddCategories(self.winconfig['mandatory_category_path'])
+ self._AddCategories(self.winconfig['recommended_category_path'])
+ self._active_policies_elem = self.AddElement(policy_definitions_elem,
+ 'policies')
+ self._active_mandatory_policy_group_name = \
+ self.winconfig['mandatory_category_path'][-1]
+ self._active_recommended_policy_group_name = \
+ self.winconfig['recommended_category_path'][-1]
+[INST] This is the main function.. [/INST] def main(argv):
+ parser = optparse.OptionParser()
+ parser.add_option(
+ "--output-dir",
+ help="Output directory for generated files. Defaults to this script's "
+ "directory.")
+ parser.add_option(
+ "-c", "--check", action="store_true",
+ help="Check if output files match generated files in chromium root "
+ "directory. Use this in PRESUBMIT scripts with --output-dir.")
+ (options, _) = parser.parse_args(args=argv)
+ if options.output_dir:
+ output_dir = options.output_dir
+ else:
+ output_dir = _SELF_LOCATION
+ def ClangFormat(filename):
+ formatter = "clang-format"
+ if platform.system() == "Windows":
+ formatter += ".bat"
+ subprocess.call([formatter, "-i", "-style=chromium", filename])
+ vk_xml_file_path = os.path.join(
+ _SELF_LOCATION, "../../..", _VK_XML_FILE)
+ ParseVkXMLFile(vk_xml_file_path)
+ mojom_file_name = "vulkan_types.mojom"
+ mojom_file = open(
+ os.path.join(output_dir, mojom_file_name), 'wb')
+ GenerateMojom(mojom_file)
+ mojom_file.close()
+ ClangFormat(mojom_file.name)
+ traits_header_file_name = "vulkan_types_mojom_traits.h"
+ traits_header_file = \
+ open(os.path.join(output_dir, traits_header_file_name), 'wb')
+ traits_source_file_name = "vulkan_types_mojom_traits.cc"
+ traits_source_file = \
+ open(os.path.join(output_dir, traits_source_file_name), 'wb')
+ GenerateTraitsFile(traits_header_file, traits_source_file)
+ traits_header_file.close()
+ ClangFormat(traits_header_file.name)
+ traits_source_file.close()
+ ClangFormat(traits_source_file.name)
+ typemap_file_name = "generated_vulkan_type_mappings.gni"
+ typemap_file = open(
+ os.path.join(output_dir, typemap_file_name), 'wb')
+ GenerateTypemapFile(typemap_file)
+ typemap_file.close()
+ check_failed_filenames = []
+ if options.check:
+ for filename in [mojom_file_name, traits_header_file_name,
+ traits_source_file_name, typemap_file_name]:
+ if not filecmp.cmp(os.path.join(output_dir, filename),
+ os.path.join(_SELF_LOCATION, filename)):
+ check_failed_filenames.append(filename)
+ if len(check_failed_filenames) > 0:
+ print 'Please run gpu/ipc/common/generate_vulkan_types.py'
+ print 'Failed check on generated files:'
+ for filename in check_failed_filenames:
+ print filename
+ return 1
+ return 0
+[INST] Initialize the class. Must be called exactly once.. [/INST] def init(cls):
+ assert not hasattr(cls, "_was_initialized"), "Do not call twice."
+ cls._was_initialized = True
+ cls._context_attrs = {
+ "callback_function": None,
+ "callback_interface": None,
+ "dictionary": None,
+ "enumeration": None,
+ "interface": None,
+ "namespace": None,
+ "observable_array": None,
+ "typedef": None,
+ "union": None,
+ "attribute": None,
+ "attribute_get": False,
+ "attribute_set": False,
+ "constant": None,
+ "constructor": None,
+ "constructor_group": None,
+ "dict_member": None,
+ "exposed_construct": None,
+ "is_named_constructor": False,
+ "legacy_window_alias": None,
+ "operation": None,
+ "operation_group": None,
+ "indexed_property_getter": None,
+ "indexed_property_setter": None,
+ "named_property_getter": None,
+ "named_property_setter": None,
+ "named_property_deleter": None,
+ "stringifier": None,
+ "dictionary_own_members": (),
+ "union_members": (),
+ "base_class_name": None,
+ "class_name": None,
+ "for_world": cls.ALL_WORLDS,
+ "no_alloc_direct_call": False,
+ "no_alloc_direct_call_for_testing": False,
+ "v8_callback_type": cls.V8_FUNCTION_CALLBACK,
+ }
+ computational attribute names
+ cls._computational_attrs = (
+ "class_like",
+ "function_like",
+ "idl_definition",
+ "idl_location",
+ "idl_location_and_name",
+ "idl_name",
+ "may_throw_exception",
+ "member_like",
+ "property_",
+ "return_type",
+ )
+ for attr in cls._context_attrs.keys():
+ def make_get():
+ _attr = cls._internal_attr(attr)
+ def get(self):
+ return getattr(self, _attr)
+ return get
+ setattr(cls, attr, property(make_get()))
+[INST] Returns a copy of this context applying the updates given as the
+arguments.. [/INST] def make_copy(self, **kwargs):
+ for arg in kwargs.keys():
+ assert arg in self._context_attrs, "Unknown argument: {}".format(
+ arg)
+ new_object = copy.copy(self)
+ for attr, new_value in kwargs.items():
+ old_value = getattr(self, attr)
+ assert old_value is None or type(new_value) is type(old_value), (
+ "Type mismatch at argument: {}".format(attr))
+ setattr(new_object, self._internal_attr(attr), new_value)
+ return new_object
+[INST] Returns a bindings object to be passed into
+|CodeNode.add_template_vars|. Only properties with a non-None value are
+bound so that it's easy to detect invalid use cases (use of an unbound
+template variable raises a NameError).. [/INST] def template_bindings(self):
+ bindings = {}
+ for attr in self._context_attrs.keys():
+ value = getattr(self, attr)
+ if value is None:
+ value = NonRenderable(attr)
+ bindings[attr] = value
+ for attr in self._computational_attrs:
+ value = getattr(self, attr)
+ if value is None:
+ value = NonRenderable(attr)
+ bindings[attr.strip("_")] = value
+ return bindings
+[INST] Called to select which command is active.
+This picks a command based on the first non - argument on the command
+line.. [/INST] def Select(cls):
+ if cr.context.args:
+ return getattr(cr.context.args, '_command', None)
+ return None
+[INST] Add arguments to the command line parser.
+Called by the main function to add the command to the command line parser.
+Commands that override this function to add more arguments must invoke
+this method.. [/INST] def AddArguments(self, subparsers):
+ self.parser = subparsers.add_parser(
+ self.name,
+ add_help=False,
+ help=self.help,
+ description=self.description or self.help,
+ epilog=self.epilog,
+ )
+ self.parser.set_defaults(_command=self)
+ cr.context.AddCommonArguments(self.parser)
+ cr.base.client.AddArguments(self.parser)
+ return self.parser
+[INST] Adds a remaining argument consumer to the parser.
+A helper method that commands can use to consume all remaining arguments.
+Use for things like lists of targets.. [/INST] def ConsumeArgs(self, parser, reason):
+ parser.add_argument(
+ '_remains', metavar='arguments',
+ nargs=argparse.REMAINDER,
+ help='The additional arguments to {0}.'.format(reason)
+ )
+[INST] Called to make decisions based on speculative argument parsing.
+When this method is called, enough of the command line parsing has been
+done that the command is selected. This allows the command to make any
+modifications needed before the final argument parsing is done.. [/INST] def EarlyArgProcessing(self):
+ cr.base.client.ApplyOutArgument()
+[INST] The main method of the command.
+This is the only thing that a command has to implement, and it should not
+call this base version.. [/INST] def Run(self):
+ raise NotImplementedError('Must be overridden.')
+[INST] Tests that a match can be successfully found.. [/INST] def testMatch(self):
+ test_map = {
+ 'isolated_scripts': [
+ {
+ 'isolate_name': 'blink_web_tests',
+ },
+ ],
+ }
+ self.assertTrue(
+ self.instance._BuilderRunsTestOfInterest(test_map, None))
+[INST] Print usage and exit with error.. [/INST] def PrintUsage():
+ print()
+ print(" usage: " + sys.argv[0] + " [options] [FILE]")
+ print()
+ print(" --symbols-dir=path")
+ print(" the path to a symbols dir, such as")
+ print(" =/tmp/out/target/product/dream/symbols")
+ print()
+ print(" --chrome-symbols-dir=path")
+ print(" the path to a Chrome symbols dir (can be absolute or relative")
+ print(" to src), such as =out/Debug/lib.unstripped")
+ print()
+ print(" --output-directory=path")
+ print(" the path to the build output directory, such as out/Debug.")
+ print(" Ignored if --chrome-symbols-dir is passed.")
+ print()
+ print(" --apks-directory=path")
+ print(" Overrides the default apks directory. Useful if a bundle APKS")
+ print(" file has been unzipped into a temporary directory.")
+ print()
+ print(" --symbols-zip=path")
+ print(" the path to a symbols zip file, such as")
+ print(" =dream-symbols-12345.zip")
+ print()
+ print(" --more-info")
+ print(" --less-info")
+ print(" Change the level of detail in the output.")
+ print(" --more-info is slower and more verbose, but more functions")
+ print(" will be fully qualified with namespace/classname and have full")
+ print(" argument information. Also, the 'stack data' section will be")
+ print(" printed.")
+ print()
+ print(" --arch=arm|arm64|x64|x86|mips")
+ print(" the target architecture")
+ print()
+ print(" --fallback-monochrome")
+ print(" fallback to monochrome instead of chrome if fail to detect")
+ print(" shared lib which is loaded from APK, this doesn't work for")
+ print(" component build.")
+ print()
+ print(" --quiet")
+ print(" Show less logging")
+ print()
+ print(" --verbose")
+ print(" enable extra logging, particularly for debugging failed")
+ print(" symbolization")
+ print()
+ print(" FILE should contain a stack trace in it somewhere")
+ print(" the tool will find that and re-print it with")
+ print(" source files and line numbers. If you don't")
+ print(" pass FILE, or if file is -, it reads from")
+ print(" stdin.")
+ print()
+ sys.exit(1)
+[INST] Unzips a file to DEFAULT_SYMROOT and returns the unzipped location.. [/INST] def UnzipSymbols(symbolfile, symdir=None):
+ if not symdir:
+ symdir = "%s/%s" % (DEFAULT_SYMROOT, hash(symbolfile))
+ if not os.path.exists(symdir):
+ os.makedirs(symdir)
+ logging.info('extracting %s...', symbolfile)
+ saveddir = os.getcwd()
+ os.chdir(symdir)
+ try:
+ unzipcode = subprocess.call(["unzip", "-qq", "-o", symbolfile])
+ if unzipcode > 0:
+ os.remove(symbolfile)
+ raise SymbolDownloadException("failed to extract symbol files (%s)."
+ % symbolfile)
+ finally:
+ os.chdir(saveddir)
+ android_symbols = glob.glob("%s/out/target/product/*/symbols" % symdir)
+ if android_symbols:
+ return (symdir, android_symbols[0])
+ symbol.CHROME_SYMBOLS_DIR = symdir
+ return (symdir, symdir)
+[INST] Logs results_by_directory in a pretty format.. [/INST] def write_by_directory(self, results_by_directory, writer, indent):
+ for path in sorted(results_by_directory):
+ writer('%s%s: %s' % (indent, self._platform(path),
+ results_by_directory[path]))
+[INST] Reads the baselines with the given file name in all directories.. [/INST] def read_results_by_directory(self, test_name, baseline_name):
+ results_by_directory = {}
+ directories = set()
+ for port in self._ports.values():
+ directories.update(set(self._relative_baseline_search_path(port)))
+ for directory in directories:
+ path = self._join_directory(directory, baseline_name)
+ if self._filesystem.exists(path):
+ results_by_directory[directory] = ResultDigest(
+ self._filesystem, path, self._is_reftest(test_name))
+ return results_by_directory
+[INST] Guesses the platform from a path (absolute or relative).. [/INST] def _platform(self, filename):
+ platform_dir = self._web_tests_dir_name + self._filesystem.sep + 'platform' + self._filesystem.sep
+ if filename.startswith(platform_dir):
+ return filename.replace(platform_dir,
+ '').split(self._filesystem.sep)[0]
+ platform_dir = self._filesystem.join(self._parent_of_tests,
+ platform_dir)
+ if filename.startswith(platform_dir):
+ return filename.replace(platform_dir,
+ '').split(self._filesystem.sep)[0]
+ return '(generic)'
+[INST] Returns a Port object from the given baseline directory.. [/INST] def _port_from_baseline_dir(self, baseline_dir):
+ baseline_dir = self._filesystem.basename(baseline_dir)
+ for port in self._ports.values():
+ if self._filesystem.basename(
+ port.baseline_version_dir()) == baseline_dir:
+ return port
+ raise Exception(
+ 'Failed to find port for primary baseline %s.' % baseline_dir)
+[INST] Maps a function onto each immediate predecessor of the virtual root.
+For each immediate predecessor, we call
+worker_func(virtual_baseline, non_virtual_fallback)
+where the two arguments are the absolute paths to the virtual platform
+baseline and the non-virtual fallback respectively.. [/INST] def _walk_immediate_predecessors_of_virtual_root(
+ self, test_name, extension, baseline_name, worker_func):
+ actual_test_name = self._virtual_base(test_name)
+ assert actual_test_name, '%s is not a virtual test.' % test_name
+ for directory in self._directories_immediately_preceding_root():
+ port = self._port_from_baseline_dir(directory)
+ virtual_baseline = self._join_directory(directory, baseline_name)
+ non_virtual_fallback = port.expected_filename(
+ actual_test_name, extension, return_default=False)
+ if not non_virtual_fallback:
+ continue
+ worker_func(virtual_baseline, non_virtual_fallback)
+[INST] Returns a list of paths to check for baselines in order.
+The generic baseline path is appended to the list. All paths are
+relative to the parent of the test directory.. [/INST] def _relative_baseline_search_path(self, port):
+ baseline_search_path = self._baseline_search_path(port)
+ relative_paths = [
+ self._filesystem.relpath(path, self._parent_of_tests)
+ for path in baseline_search_path
+ ]
+ relative_baseline_root = self._baseline_root()
+ return relative_paths + [relative_baseline_root]
+[INST] Returns the base (non-virtual) version of baseline_name, or None if
+baseline_name is not virtual.. [/INST] def _virtual_base(self, baseline_name):
+ return self._default_port.lookup_virtual_test_base(baseline_name)
+[INST] Transforms a by-directory result dict to by-port-name.
+The method mimicks the baseline search behaviour, i.e. results[port] is
+the first baseline found on the baseline search path of the port. If no
+baseline is found on the search path, the test is assumed to be an all
+PASS testharness.js test.. [/INST] def _results_by_port_name(self, results_by_directory):
+ results_by_port_name = {}
+ for port_name, port in self._ports.items():
+ for directory in self._relative_baseline_search_path(port):
+ if directory in results_by_directory:
+ results_by_port_name[port_name] = results_by_directory[
+ directory]
+ break
+ if port_name not in results_by_port_name:
+ results_by_port_name[port_name] = ResultDigest(None, None)
+ return results_by_port_name
+[INST] Returns a list of directories immediately preceding the root on
+search paths.. [/INST] def _directories_immediately_preceding_root(self):
+ directories = set()
+ for port in self._ports.values():
+ directory = self._filesystem.relpath(
+ self._baseline_search_path(port)[-1], self._parent_of_tests)
+ directories.add(directory)
+ return frozenset(directories)