pub mod offchain;
pub mod onchain;
mod shared;
use codec::{Decode, Encode};
use sp_runtime::{
	traits::{Convert, OpaqueKeys},
	KeyTypeId,
};
use sp_session::{MembershipProof, ValidatorCount};
use sp_staking::SessionIndex;
use sp_std::prelude::*;
use sp_trie::{
	trie_types::{TrieDBBuilder, TrieDBMutBuilderV0},
	LayoutV0, MemoryDB, Recorder, Trie, TrieMut, EMPTY_PREFIX,
};
use frame_support::{
	print,
	traits::{KeyOwnerProofSystem, ValidatorSet, ValidatorSetWithIdentification},
	Parameter,
};
use crate::{self as pallet_session, Pallet as Session};
pub use pallet::*;
#[frame_support::pallet]
pub mod pallet {
	use super::*;
	use frame_support::pallet_prelude::*;
	const STORAGE_VERSION: StorageVersion = StorageVersion::new(1);
	#[pallet::pallet]
	#[pallet::storage_version(STORAGE_VERSION)]
	pub struct Pallet<T>(_);
	#[pallet::config]
	pub trait Config: pallet_session::Config + frame_system::Config {
		type FullIdentification: Parameter;
		type FullIdentificationOf: Convert<Self::ValidatorId, Option<Self::FullIdentification>>;
	}
	#[pallet::storage]
	#[pallet::getter(fn historical_root)]
	pub type HistoricalSessions<T: Config> =
		StorageMap<_, Twox64Concat, SessionIndex, (T::Hash, ValidatorCount), OptionQuery>;
	#[pallet::storage]
	pub type StoredRange<T> = StorageValue<_, (SessionIndex, SessionIndex), OptionQuery>;
}
impl<T: Config> Pallet<T> {
	pub fn prune_up_to(up_to: SessionIndex) {
		StoredRange::<T>::mutate(|range| {
			let (start, end) = match *range {
				Some(range) => range,
				None => return, };
			let up_to = sp_std::cmp::min(up_to, end);
			if up_to < start {
				return }
			(start..up_to).for_each(HistoricalSessions::<T>::remove);
			let new_start = up_to;
			*range = if new_start == end {
				None } else {
				Some((new_start, end))
			}
		})
	}
}
impl<T: Config> ValidatorSet<T::AccountId> for Pallet<T> {
	type ValidatorId = T::ValidatorId;
	type ValidatorIdOf = T::ValidatorIdOf;
	fn session_index() -> sp_staking::SessionIndex {
		super::Pallet::<T>::current_index()
	}
	fn validators() -> Vec<Self::ValidatorId> {
		super::Pallet::<T>::validators()
	}
}
impl<T: Config> ValidatorSetWithIdentification<T::AccountId> for Pallet<T> {
	type Identification = T::FullIdentification;
	type IdentificationOf = T::FullIdentificationOf;
}
pub trait SessionManager<ValidatorId, FullIdentification>:
	pallet_session::SessionManager<ValidatorId>
{
	fn new_session(new_index: SessionIndex) -> Option<Vec<(ValidatorId, FullIdentification)>>;
	fn new_session_genesis(
		new_index: SessionIndex,
	) -> Option<Vec<(ValidatorId, FullIdentification)>> {
		<Self as SessionManager<_, _>>::new_session(new_index)
	}
	fn start_session(start_index: SessionIndex);
	fn end_session(end_index: SessionIndex);
}
pub struct NoteHistoricalRoot<T, I>(sp_std::marker::PhantomData<(T, I)>);
impl<T: Config, I: SessionManager<T::ValidatorId, T::FullIdentification>> NoteHistoricalRoot<T, I> {
	fn do_new_session(new_index: SessionIndex, is_genesis: bool) -> Option<Vec<T::ValidatorId>> {
		<StoredRange<T>>::mutate(|range| {
			range.get_or_insert_with(|| (new_index, new_index)).1 = new_index + 1;
		});
		let new_validators_and_id = if is_genesis {
			<I as SessionManager<_, _>>::new_session_genesis(new_index)
		} else {
			<I as SessionManager<_, _>>::new_session(new_index)
		};
		let new_validators_opt = new_validators_and_id
			.as_ref()
			.map(|new_validators| new_validators.iter().map(|(v, _id)| v.clone()).collect());
		if let Some(new_validators) = new_validators_and_id {
			let count = new_validators.len() as ValidatorCount;
			match ProvingTrie::<T>::generate_for(new_validators) {
				Ok(trie) => <HistoricalSessions<T>>::insert(new_index, &(trie.root, count)),
				Err(reason) => {
					print("Failed to generate historical ancestry-inclusion proof.");
					print(reason);
				},
			};
		} else {
			let previous_index = new_index.saturating_sub(1);
			if let Some(previous_session) = <HistoricalSessions<T>>::get(previous_index) {
				<HistoricalSessions<T>>::insert(new_index, previous_session);
			}
		}
		new_validators_opt
	}
}
impl<T: Config, I> pallet_session::SessionManager<T::ValidatorId> for NoteHistoricalRoot<T, I>
where
	I: SessionManager<T::ValidatorId, T::FullIdentification>,
{
	fn new_session(new_index: SessionIndex) -> Option<Vec<T::ValidatorId>> {
		Self::do_new_session(new_index, false)
	}
	fn new_session_genesis(new_index: SessionIndex) -> Option<Vec<T::ValidatorId>> {
		Self::do_new_session(new_index, true)
	}
	fn start_session(start_index: SessionIndex) {
		<I as SessionManager<_, _>>::start_session(start_index)
	}
	fn end_session(end_index: SessionIndex) {
		onchain::store_session_validator_set_to_offchain::<T>(end_index);
		<I as SessionManager<_, _>>::end_session(end_index)
	}
}
pub type IdentificationTuple<T> =
	(<T as pallet_session::Config>::ValidatorId, <T as Config>::FullIdentification);
pub struct ProvingTrie<T: Config> {
	db: MemoryDB<T::Hashing>,
	root: T::Hash,
}
impl<T: Config> ProvingTrie<T> {
	fn generate_for<I>(validators: I) -> Result<Self, &'static str>
	where
		I: IntoIterator<Item = (T::ValidatorId, T::FullIdentification)>,
	{
		let mut db = MemoryDB::default();
		let mut root = Default::default();
		{
			let mut trie = TrieDBMutBuilderV0::new(&mut db, &mut root).build();
			for (i, (validator, full_id)) in validators.into_iter().enumerate() {
				let i = i as u32;
				let keys = match <Session<T>>::load_keys(&validator) {
					None => continue,
					Some(k) => k,
				};
				let full_id = (validator, full_id);
				for key_id in T::Keys::key_ids() {
					let key = keys.get_raw(*key_id);
					let res =
						(key_id, key).using_encoded(|k| i.using_encoded(|v| trie.insert(k, v)));
					let _ = res.map_err(|_| "failed to insert into trie")?;
				}
				let _ = i
					.using_encoded(|k| full_id.using_encoded(|v| trie.insert(k, v)))
					.map_err(|_| "failed to insert into trie")?;
			}
		}
		Ok(ProvingTrie { db, root })
	}
	fn from_nodes(root: T::Hash, nodes: &[Vec<u8>]) -> Self {
		use sp_trie::HashDBT;
		let mut memory_db = MemoryDB::default();
		for node in nodes {
			HashDBT::insert(&mut memory_db, EMPTY_PREFIX, &node[..]);
		}
		ProvingTrie { db: memory_db, root }
	}
	pub fn prove(&self, key_id: KeyTypeId, key_data: &[u8]) -> Option<Vec<Vec<u8>>> {
		let mut recorder = Recorder::<LayoutV0<T::Hashing>>::new();
		{
			let trie =
				TrieDBBuilder::new(&self.db, &self.root).with_recorder(&mut recorder).build();
			let val_idx = (key_id, key_data).using_encoded(|s| {
				trie.get(s).ok()?.and_then(|raw| u32::decode(&mut &*raw).ok())
			})?;
			val_idx.using_encoded(|s| {
				trie.get(s)
					.ok()?
					.and_then(|raw| <IdentificationTuple<T>>::decode(&mut &*raw).ok())
			})?;
		}
		Some(recorder.drain().into_iter().map(|r| r.data).collect())
	}
	pub fn root(&self) -> &T::Hash {
		&self.root
	}
	fn query(&self, key_id: KeyTypeId, key_data: &[u8]) -> Option<IdentificationTuple<T>> {
		let trie = TrieDBBuilder::new(&self.db, &self.root).build();
		let val_idx = (key_id, key_data)
			.using_encoded(|s| trie.get(s))
			.ok()?
			.and_then(|raw| u32::decode(&mut &*raw).ok())?;
		val_idx
			.using_encoded(|s| trie.get(s))
			.ok()?
			.and_then(|raw| <IdentificationTuple<T>>::decode(&mut &*raw).ok())
	}
}
impl<T: Config, D: AsRef<[u8]>> KeyOwnerProofSystem<(KeyTypeId, D)> for Pallet<T> {
	type Proof = MembershipProof;
	type IdentificationTuple = IdentificationTuple<T>;
	fn prove(key: (KeyTypeId, D)) -> Option<Self::Proof> {
		let session = <Session<T>>::current_index();
		let validators = <Session<T>>::validators()
			.into_iter()
			.filter_map(|validator| {
				T::FullIdentificationOf::convert(validator.clone())
					.map(|full_id| (validator, full_id))
			})
			.collect::<Vec<_>>();
		let count = validators.len() as ValidatorCount;
		let trie = ProvingTrie::<T>::generate_for(validators).ok()?;
		let (id, data) = key;
		trie.prove(id, data.as_ref()).map(|trie_nodes| MembershipProof {
			session,
			trie_nodes,
			validator_count: count,
		})
	}
	fn check_proof(key: (KeyTypeId, D), proof: Self::Proof) -> Option<IdentificationTuple<T>> {
		let (id, data) = key;
		if proof.session == <Session<T>>::current_index() {
			<Session<T>>::key_owner(id, data.as_ref()).and_then(|owner| {
				T::FullIdentificationOf::convert(owner.clone()).and_then(move |id| {
					let count = <Session<T>>::validators().len() as ValidatorCount;
					if count != proof.validator_count {
						return None
					}
					Some((owner, id))
				})
			})
		} else {
			let (root, count) = <HistoricalSessions<T>>::get(&proof.session)?;
			if count != proof.validator_count {
				return None
			}
			let trie = ProvingTrie::<T>::from_nodes(root, &proof.trie_nodes);
			trie.query(id, data.as_ref())
		}
	}
}
#[cfg(test)]
pub(crate) mod tests {
	use super::*;
	use crate::mock::{
		force_new_session, set_next_validators, NextValidators, Session, System, Test,
	};
	use sp_runtime::{key_types::DUMMY, testing::UintAuthorityId, BuildStorage};
	use sp_state_machine::BasicExternalities;
	use frame_support::traits::{KeyOwnerProofSystem, OnInitialize};
	type Historical = Pallet<Test>;
	pub(crate) fn new_test_ext() -> sp_io::TestExternalities {
		let mut t = frame_system::GenesisConfig::<Test>::default().build_storage().unwrap();
		let keys: Vec<_> = NextValidators::get()
			.iter()
			.cloned()
			.map(|i| (i, i, UintAuthorityId(i).into()))
			.collect();
		BasicExternalities::execute_with_storage(&mut t, || {
			for (ref k, ..) in &keys {
				frame_system::Pallet::<Test>::inc_providers(k);
			}
		});
		pallet_session::GenesisConfig::<Test> { keys }
			.assimilate_storage(&mut t)
			.unwrap();
		sp_io::TestExternalities::new(t)
	}
	#[test]
	fn generated_proof_is_good() {
		new_test_ext().execute_with(|| {
			set_next_validators(vec![1, 2]);
			force_new_session();
			System::set_block_number(1);
			Session::on_initialize(1);
			let encoded_key_1 = UintAuthorityId(1).encode();
			let proof = Historical::prove((DUMMY, &encoded_key_1[..])).unwrap();
			assert!(Historical::check_proof((DUMMY, &encoded_key_1[..]), proof.clone()).is_some());
			set_next_validators(vec![1, 2, 4]);
			force_new_session();
			System::set_block_number(2);
			Session::on_initialize(2);
			assert!(Historical::historical_root(proof.session).is_some());
			assert!(Session::current_index() > proof.session);
			assert!(Historical::check_proof((DUMMY, &encoded_key_1[..]), proof.clone()).is_some());
			set_next_validators(vec![1, 2, 5]);
			force_new_session();
			System::set_block_number(3);
			Session::on_initialize(3);
		});
	}
	#[test]
	fn prune_up_to_works() {
		new_test_ext().execute_with(|| {
			for i in 1..99u64 {
				set_next_validators(vec![i]);
				force_new_session();
				System::set_block_number(i);
				Session::on_initialize(i);
			}
			assert_eq!(<StoredRange<Test>>::get(), Some((0, 100)));
			for i in 0..100 {
				assert!(Historical::historical_root(i).is_some())
			}
			Historical::prune_up_to(10);
			assert_eq!(<StoredRange<Test>>::get(), Some((10, 100)));
			Historical::prune_up_to(9);
			assert_eq!(<StoredRange<Test>>::get(), Some((10, 100)));
			for i in 10..100 {
				assert!(Historical::historical_root(i).is_some())
			}
			Historical::prune_up_to(99);
			assert_eq!(<StoredRange<Test>>::get(), Some((99, 100)));
			Historical::prune_up_to(100);
			assert_eq!(<StoredRange<Test>>::get(), None);
			for i in 99..199u64 {
				set_next_validators(vec![i]);
				force_new_session();
				System::set_block_number(i);
				Session::on_initialize(i);
			}
			assert_eq!(<StoredRange<Test>>::get(), Some((100, 200)));
			for i in 100..200 {
				assert!(Historical::historical_root(i).is_some())
			}
			Historical::prune_up_to(9999);
			assert_eq!(<StoredRange<Test>>::get(), None);
			for i in 100..200 {
				assert!(Historical::historical_root(i).is_none())
			}
		});
	}
}